Skip to content

Commit

Permalink
feat(event_handler): add support for additional response models (#3591)
Browse files Browse the repository at this point in the history
* feat(event_handler): add support for additional response models

* fix: I hate sonarcube

* fix: pydantic 2

* fix: refactor

* fix: increase coverage

* chore: update docs

---------

Co-authored-by: Leandro Damascena <[email protected]>
  • Loading branch information
rubenfonseca and leandrodamascena authored Jan 16, 2024
1 parent 76dc016 commit 1409499
Show file tree
Hide file tree
Showing 6 changed files with 215 additions and 27 deletions.
79 changes: 64 additions & 15 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
from aws_lambda_powertools.event_handler.openapi.types import (
COMPONENT_REF_PREFIX,
METHODS_WITH_BODY,
OpenAPIResponse,
OpenAPIResponseContentModel,
OpenAPIResponseContentSchema,
validation_error_definition,
validation_error_response_definition,
)
Expand Down Expand Up @@ -273,7 +276,7 @@ def __init__(
cache_control: Optional[str],
summary: Optional[str],
description: Optional[str],
responses: Optional[Dict[int, Dict[str, Any]]],
responses: Optional[Dict[int, OpenAPIResponse]],
response_description: Optional[str],
tags: Optional[List[str]],
operation_id: Optional[str],
Expand Down Expand Up @@ -303,7 +306,7 @@ def __init__(
The OpenAPI summary for this route
description: Optional[str]
The OpenAPI description for this route
responses: Optional[Dict[int, Dict[str, Any]]]
responses: Optional[Dict[int, OpenAPIResponse]]
The OpenAPI responses for this route
response_description: Optional[str]
The OpenAPI response description for this route
Expand Down Expand Up @@ -442,7 +445,7 @@ def dependant(self) -> "Dependant":
if self._dependant is None:
from aws_lambda_powertools.event_handler.openapi.dependant import get_dependant

self._dependant = get_dependant(path=self.openapi_path, call=self.func)
self._dependant = get_dependant(path=self.openapi_path, call=self.func, responses=self.responses)

return self._dependant

Expand Down Expand Up @@ -501,11 +504,54 @@ def _get_openapi_path(

# Add the response to the OpenAPI operation
if self.responses:
# If the user supplied responses, we use them and don't set a default 200 response
for status_code in list(self.responses):
response = self.responses[status_code]

# Case 1: there is not 'content' key
if "content" not in response:
response["content"] = {
"application/json": self._openapi_operation_return(
param=dependant.return_param,
model_name_map=model_name_map,
field_mapping=field_mapping,
),
}

# Case 2: there is a 'content' key
else:
# Need to iterate to transform any 'model' into a 'schema'
for content_type, payload in response["content"].items():
new_payload: OpenAPIResponseContentSchema

# Case 2.1: the 'content' has a model
if "model" in payload:
# Find the model in the dependant's extra models
return_field = next(
filter(
lambda model: model.type_ is cast(OpenAPIResponseContentModel, payload)["model"],
self.dependant.response_extra_models,
),
)
if not return_field:
raise AssertionError("Model declared in custom responses was not found")

new_payload = self._openapi_operation_return(
param=return_field,
model_name_map=model_name_map,
field_mapping=field_mapping,
)

# Case 2.2: the 'content' has a schema
else:
# Do nothing! We already have what we need!
new_payload = payload

response["content"][content_type] = new_payload

operation["responses"] = self.responses
else:
# Set the default 200 response
responses = operation.setdefault("responses", self.responses or {})
responses = operation.setdefault("responses", {})
success_response = responses.setdefault(200, {})
success_response["description"] = self.response_description or _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION
success_response["content"] = {"application/json": {"schema": {}}}
Expand Down Expand Up @@ -682,7 +728,7 @@ def _openapi_operation_return(
Tuple["ModelField", Literal["validation", "serialization"]],
"JsonSchemaValue",
],
) -> Dict[str, Any]:
) -> OpenAPIResponseContentSchema:
"""
Returns the OpenAPI operation return.
"""
Expand Down Expand Up @@ -832,7 +878,7 @@ def route(
cache_control: Optional[str] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
responses: Optional[Dict[int, Dict[str, Any]]] = None,
responses: Optional[Dict[int, OpenAPIResponse]] = None,
response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
tags: Optional[List[str]] = None,
operation_id: Optional[str] = None,
Expand Down Expand Up @@ -890,7 +936,7 @@ def get(
cache_control: Optional[str] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
responses: Optional[Dict[int, Dict[str, Any]]] = None,
responses: Optional[Dict[int, OpenAPIResponse]] = None,
response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
tags: Optional[List[str]] = None,
operation_id: Optional[str] = None,
Expand Down Expand Up @@ -943,7 +989,7 @@ def post(
cache_control: Optional[str] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
responses: Optional[Dict[int, Dict[str, Any]]] = None,
responses: Optional[Dict[int, OpenAPIResponse]] = None,
response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
tags: Optional[List[str]] = None,
operation_id: Optional[str] = None,
Expand Down Expand Up @@ -997,7 +1043,7 @@ def put(
cache_control: Optional[str] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
responses: Optional[Dict[int, Dict[str, Any]]] = None,
responses: Optional[Dict[int, OpenAPIResponse]] = None,
response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
tags: Optional[List[str]] = None,
operation_id: Optional[str] = None,
Expand Down Expand Up @@ -1051,7 +1097,7 @@ def delete(
cache_control: Optional[str] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
responses: Optional[Dict[int, Dict[str, Any]]] = None,
responses: Optional[Dict[int, OpenAPIResponse]] = None,
response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
tags: Optional[List[str]] = None,
operation_id: Optional[str] = None,
Expand Down Expand Up @@ -1104,7 +1150,7 @@ def patch(
cache_control: Optional[str] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
responses: Optional[Dict[int, Dict[str, Any]]] = None,
responses: Optional[Dict[int, OpenAPIResponse]] = None,
response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
tags: Optional[List[str]] = None,
operation_id: Optional[str] = None,
Expand Down Expand Up @@ -1662,7 +1708,7 @@ def route(
cache_control: Optional[str] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
responses: Optional[Dict[int, Dict[str, Any]]] = None,
responses: Optional[Dict[int, OpenAPIResponse]] = None,
response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
tags: Optional[List[str]] = None,
operation_id: Optional[str] = None,
Expand Down Expand Up @@ -2110,6 +2156,9 @@ def _get_fields_from_routes(routes: Sequence[Route]) -> List["ModelField"]:
if route.dependant.return_param:
responses_from_routes.append(route.dependant.return_param)

if route.dependant.response_extra_models:
responses_from_routes.extend(route.dependant.response_extra_models)

flat_models = list(responses_from_routes + request_fields_from_routes + body_fields_from_routes)
return flat_models

Expand All @@ -2132,7 +2181,7 @@ def route(
cache_control: Optional[str] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
responses: Optional[Dict[int, Dict[str, Any]]] = None,
responses: Optional[Dict[int, OpenAPIResponse]] = None,
response_description: Optional[str] = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
tags: Optional[List[str]] = None,
operation_id: Optional[str] = None,
Expand Down Expand Up @@ -2221,7 +2270,7 @@ def route(
cache_control: Optional[str] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
responses: Optional[Dict[int, Dict[str, Any]]] = None,
responses: Optional[Dict[int, OpenAPIResponse]] = None,
response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
tags: Optional[List[str]] = None,
operation_id: Optional[str] = None,
Expand Down
34 changes: 32 additions & 2 deletions aws_lambda_powertools/event_handler/openapi/dependant.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
create_response_field,
get_flat_dependant,
)
from aws_lambda_powertools.event_handler.openapi.types import OpenAPIResponse, OpenAPIResponseContentModel

"""
This turns the opaque function signature into typed, validated models.
Expand Down Expand Up @@ -145,6 +146,7 @@ def get_dependant(
path: str,
call: Callable[..., Any],
name: Optional[str] = None,
responses: Optional[Dict[int, OpenAPIResponse]] = None,
) -> Dependant:
"""
Returns a dependant model for a handler function. A dependant model is a model that contains
Expand All @@ -158,6 +160,8 @@ def get_dependant(
The handler function
name: str, optional
The name of the handler function
responses: List[Dict[int, OpenAPIResponse]], optional
The list of extra responses for the handler function
Returns
-------
Expand Down Expand Up @@ -195,6 +199,34 @@ def get_dependant(
else:
add_param_to_fields(field=param_field, dependant=dependant)

_add_return_annotation(dependant, endpoint_signature)
_add_extra_responses(dependant, responses)

return dependant


def _add_extra_responses(dependant: Dependant, responses: Optional[Dict[int, OpenAPIResponse]]):
# Also add the optional extra responses to the dependant model.
if not responses:
return

for response in responses.values():
for schema in response.get("content", {}).values():
if "model" in schema:
response_field = analyze_param(
param_name="return",
annotation=cast(OpenAPIResponseContentModel, schema)["model"],
value=None,
is_path_param=False,
is_response_param=True,
)
if response_field is None:
raise AssertionError("Response field is None for response model")

dependant.response_extra_models.append(response_field)


def _add_return_annotation(dependant: Dependant, endpoint_signature: inspect.Signature):
# If the return annotation is not empty, add it to the dependant model.
return_annotation = endpoint_signature.return_annotation
if return_annotation is not inspect.Signature.empty:
Expand All @@ -210,8 +242,6 @@ def get_dependant(

dependant.return_param = param_field

return dependant


def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
"""
Expand Down
2 changes: 2 additions & 0 deletions aws_lambda_powertools/event_handler/openapi/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
cookie_params: Optional[List[ModelField]] = None,
body_params: Optional[List[ModelField]] = None,
return_param: Optional[ModelField] = None,
response_extra_models: Optional[List[ModelField]] = None,
name: Optional[str] = None,
call: Optional[Callable[..., Any]] = None,
request_param_name: Optional[str] = None,
Expand All @@ -64,6 +65,7 @@ def __init__(
self.cookie_params = cookie_params or []
self.body_params = body_params or []
self.return_param = return_param or None
self.response_extra_models = response_extra_models or []
self.request_param_name = request_param_name
self.websocket_param_name = websocket_param_name
self.http_connection_param_name = http_connection_param_name
Expand Down
15 changes: 15 additions & 0 deletions aws_lambda_powertools/event_handler/openapi/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Set, Type, Union

from aws_lambda_powertools.shared.types import NotRequired, TypedDict

if TYPE_CHECKING:
from pydantic import BaseModel # noqa: F401

Expand Down Expand Up @@ -43,3 +45,16 @@
},
},
}


class OpenAPIResponseContentSchema(TypedDict, total=False):
schema: Dict


class OpenAPIResponseContentModel(TypedDict):
model: Any


class OpenAPIResponse(TypedDict):
description: str
content: NotRequired[Dict[str, Union[OpenAPIResponseContentSchema, OpenAPIResponseContentModel]]]
18 changes: 9 additions & 9 deletions docs/core/event_handler/api_gateway.md
Original file line number Diff line number Diff line change
Expand Up @@ -955,15 +955,15 @@ Customize your API endpoints by adding metadata to endpoint definitions. This pr

Here's a breakdown of various customizable fields:

| Field Name | Type | Description |
| ---------------------- | --------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `summary` | `str` | A concise overview of the main functionality of the endpoint. This brief introduction is usually displayed in autogenerated API documentation and helps consumers quickly understand what the endpoint does. |
| `description` | `str` | A more detailed explanation of the endpoint, which can include information about the operation's behavior, including side effects, error states, and other operational guidelines. |
| `responses` | `Dict[int, Dict[str, Any]]` | A dictionary that maps each HTTP status code to a Response Object as defined by the [OpenAPI Specification](https://swagger.io/specification/#response-object). This allows you to describe expected responses, including default or error messages, and their corresponding schemas for different status codes. |
| `response_description` | `str` | Provides the default textual description of the response sent by the endpoint when the operation is successful. It is intended to give a human-readable understanding of the result. |
| `tags` | `List[str]` | Tags are a way to categorize and group endpoints within the API documentation. They can help organize the operations by resources or other heuristic. |
| `operation_id` | `str` | A unique identifier for the operation, which can be used for referencing this operation in documentation or code. This ID must be unique across all operations described in the API. |
| `include_in_schema` | `bool` | A boolean value that determines whether or not this operation should be included in the OpenAPI schema. Setting it to `False` can hide the endpoint from generated documentation and schema exports, which might be useful for private or experimental endpoints. |
| Field Name | Type | Description |
| ---------------------- |-----------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `summary` | `str` | A concise overview of the main functionality of the endpoint. This brief introduction is usually displayed in autogenerated API documentation and helps consumers quickly understand what the endpoint does. |
| `description` | `str` | A more detailed explanation of the endpoint, which can include information about the operation's behavior, including side effects, error states, and other operational guidelines. |
| `responses` | `Dict[int, Dict[str, OpenAPIResponse]]` | A dictionary that maps each HTTP status code to a Response Object as defined by the [OpenAPI Specification](https://swagger.io/specification/#response-object). This allows you to describe expected responses, including default or error messages, and their corresponding schemas or models for different status codes. |
| `response_description` | `str` | Provides the default textual description of the response sent by the endpoint when the operation is successful. It is intended to give a human-readable understanding of the result. |
| `tags` | `List[str]` | Tags are a way to categorize and group endpoints within the API documentation. They can help organize the operations by resources or other heuristic. |
| `operation_id` | `str` | A unique identifier for the operation, which can be used for referencing this operation in documentation or code. This ID must be unique across all operations described in the API. |
| `include_in_schema` | `bool` | A boolean value that determines whether or not this operation should be included in the OpenAPI schema. Setting it to `False` can hide the endpoint from generated documentation and schema exports, which might be useful for private or experimental endpoints. |

To implement these customizations, include extra parameters when defining your routes:

Expand Down
Loading

0 comments on commit 1409499

Please sign in to comment.