From 1409499b49ac7ac5e9cd84248c75163d9e6c19a0 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Tue, 16 Jan 2024 17:35:59 +0100 Subject: [PATCH] feat(event_handler): add support for additional response models (#3591) * feat(event_handler): add support for additional response models * fix: I hate sonarcube * fix: pydantic 2 * fix: refactor * fix: increase coverage * chore: update docs --------- Co-authored-by: Leandro Damascena --- .../event_handler/api_gateway.py | 79 +++++++++++++--- .../event_handler/openapi/dependant.py | 34 ++++++- .../event_handler/openapi/params.py | 2 + .../event_handler/openapi/types.py | 15 +++ docs/core/event_handler/api_gateway.md | 18 ++-- .../event_handler/test_openapi_responses.py | 94 ++++++++++++++++++- 6 files changed, 215 insertions(+), 27 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 7a41c99d053..9260ede43e9 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -37,6 +37,9 @@ from aws_lambda_powertools.event_handler.openapi.types import ( COMPONENT_REF_PREFIX, METHODS_WITH_BODY, + OpenAPIResponse, + OpenAPIResponseContentModel, + OpenAPIResponseContentSchema, validation_error_definition, validation_error_response_definition, ) @@ -273,7 +276,7 @@ def __init__( cache_control: Optional[str], summary: Optional[str], description: Optional[str], - responses: Optional[Dict[int, Dict[str, Any]]], + responses: Optional[Dict[int, OpenAPIResponse]], response_description: Optional[str], tags: Optional[List[str]], operation_id: Optional[str], @@ -303,7 +306,7 @@ def __init__( The OpenAPI summary for this route description: Optional[str] The OpenAPI description for this route - responses: Optional[Dict[int, Dict[str, Any]]] + responses: Optional[Dict[int, OpenAPIResponse]] The OpenAPI responses for this route response_description: Optional[str] The OpenAPI response description for this route @@ -442,7 +445,7 @@ def dependant(self) -> "Dependant": if self._dependant is None: from aws_lambda_powertools.event_handler.openapi.dependant import get_dependant - self._dependant = get_dependant(path=self.openapi_path, call=self.func) + self._dependant = get_dependant(path=self.openapi_path, call=self.func, responses=self.responses) return self._dependant @@ -501,11 +504,54 @@ def _get_openapi_path( # Add the response to the OpenAPI operation if self.responses: - # If the user supplied responses, we use them and don't set a default 200 response + for status_code in list(self.responses): + response = self.responses[status_code] + + # Case 1: there is not 'content' key + if "content" not in response: + response["content"] = { + "application/json": self._openapi_operation_return( + param=dependant.return_param, + model_name_map=model_name_map, + field_mapping=field_mapping, + ), + } + + # Case 2: there is a 'content' key + else: + # Need to iterate to transform any 'model' into a 'schema' + for content_type, payload in response["content"].items(): + new_payload: OpenAPIResponseContentSchema + + # Case 2.1: the 'content' has a model + if "model" in payload: + # Find the model in the dependant's extra models + return_field = next( + filter( + lambda model: model.type_ is cast(OpenAPIResponseContentModel, payload)["model"], + self.dependant.response_extra_models, + ), + ) + if not return_field: + raise AssertionError("Model declared in custom responses was not found") + + new_payload = self._openapi_operation_return( + param=return_field, + model_name_map=model_name_map, + field_mapping=field_mapping, + ) + + # Case 2.2: the 'content' has a schema + else: + # Do nothing! We already have what we need! + new_payload = payload + + response["content"][content_type] = new_payload + operation["responses"] = self.responses else: # Set the default 200 response - responses = operation.setdefault("responses", self.responses or {}) + responses = operation.setdefault("responses", {}) success_response = responses.setdefault(200, {}) success_response["description"] = self.response_description or _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION success_response["content"] = {"application/json": {"schema": {}}} @@ -682,7 +728,7 @@ def _openapi_operation_return( Tuple["ModelField", Literal["validation", "serialization"]], "JsonSchemaValue", ], - ) -> Dict[str, Any]: + ) -> OpenAPIResponseContentSchema: """ Returns the OpenAPI operation return. """ @@ -832,7 +878,7 @@ def route( cache_control: Optional[str] = None, summary: Optional[str] = None, description: Optional[str] = None, - responses: Optional[Dict[int, Dict[str, Any]]] = None, + responses: Optional[Dict[int, OpenAPIResponse]] = None, response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List[str]] = None, operation_id: Optional[str] = None, @@ -890,7 +936,7 @@ def get( cache_control: Optional[str] = None, summary: Optional[str] = None, description: Optional[str] = None, - responses: Optional[Dict[int, Dict[str, Any]]] = None, + responses: Optional[Dict[int, OpenAPIResponse]] = None, response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List[str]] = None, operation_id: Optional[str] = None, @@ -943,7 +989,7 @@ def post( cache_control: Optional[str] = None, summary: Optional[str] = None, description: Optional[str] = None, - responses: Optional[Dict[int, Dict[str, Any]]] = None, + responses: Optional[Dict[int, OpenAPIResponse]] = None, response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List[str]] = None, operation_id: Optional[str] = None, @@ -997,7 +1043,7 @@ def put( cache_control: Optional[str] = None, summary: Optional[str] = None, description: Optional[str] = None, - responses: Optional[Dict[int, Dict[str, Any]]] = None, + responses: Optional[Dict[int, OpenAPIResponse]] = None, response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List[str]] = None, operation_id: Optional[str] = None, @@ -1051,7 +1097,7 @@ def delete( cache_control: Optional[str] = None, summary: Optional[str] = None, description: Optional[str] = None, - responses: Optional[Dict[int, Dict[str, Any]]] = None, + responses: Optional[Dict[int, OpenAPIResponse]] = None, response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List[str]] = None, operation_id: Optional[str] = None, @@ -1104,7 +1150,7 @@ def patch( cache_control: Optional[str] = None, summary: Optional[str] = None, description: Optional[str] = None, - responses: Optional[Dict[int, Dict[str, Any]]] = None, + responses: Optional[Dict[int, OpenAPIResponse]] = None, response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List[str]] = None, operation_id: Optional[str] = None, @@ -1662,7 +1708,7 @@ def route( cache_control: Optional[str] = None, summary: Optional[str] = None, description: Optional[str] = None, - responses: Optional[Dict[int, Dict[str, Any]]] = None, + responses: Optional[Dict[int, OpenAPIResponse]] = None, response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List[str]] = None, operation_id: Optional[str] = None, @@ -2110,6 +2156,9 @@ def _get_fields_from_routes(routes: Sequence[Route]) -> List["ModelField"]: if route.dependant.return_param: responses_from_routes.append(route.dependant.return_param) + if route.dependant.response_extra_models: + responses_from_routes.extend(route.dependant.response_extra_models) + flat_models = list(responses_from_routes + request_fields_from_routes + body_fields_from_routes) return flat_models @@ -2132,7 +2181,7 @@ def route( cache_control: Optional[str] = None, summary: Optional[str] = None, description: Optional[str] = None, - responses: Optional[Dict[int, Dict[str, Any]]] = None, + responses: Optional[Dict[int, OpenAPIResponse]] = None, response_description: Optional[str] = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List[str]] = None, operation_id: Optional[str] = None, @@ -2221,7 +2270,7 @@ def route( cache_control: Optional[str] = None, summary: Optional[str] = None, description: Optional[str] = None, - responses: Optional[Dict[int, Dict[str, Any]]] = None, + responses: Optional[Dict[int, OpenAPIResponse]] = None, response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List[str]] = None, operation_id: Optional[str] = None, diff --git a/aws_lambda_powertools/event_handler/openapi/dependant.py b/aws_lambda_powertools/event_handler/openapi/dependant.py index e22eb535a7e..418a86e083c 100644 --- a/aws_lambda_powertools/event_handler/openapi/dependant.py +++ b/aws_lambda_powertools/event_handler/openapi/dependant.py @@ -24,6 +24,7 @@ create_response_field, get_flat_dependant, ) +from aws_lambda_powertools.event_handler.openapi.types import OpenAPIResponse, OpenAPIResponseContentModel """ This turns the opaque function signature into typed, validated models. @@ -145,6 +146,7 @@ def get_dependant( path: str, call: Callable[..., Any], name: Optional[str] = None, + responses: Optional[Dict[int, OpenAPIResponse]] = None, ) -> Dependant: """ Returns a dependant model for a handler function. A dependant model is a model that contains @@ -158,6 +160,8 @@ def get_dependant( The handler function name: str, optional The name of the handler function + responses: List[Dict[int, OpenAPIResponse]], optional + The list of extra responses for the handler function Returns ------- @@ -195,6 +199,34 @@ def get_dependant( else: add_param_to_fields(field=param_field, dependant=dependant) + _add_return_annotation(dependant, endpoint_signature) + _add_extra_responses(dependant, responses) + + return dependant + + +def _add_extra_responses(dependant: Dependant, responses: Optional[Dict[int, OpenAPIResponse]]): + # Also add the optional extra responses to the dependant model. + if not responses: + return + + for response in responses.values(): + for schema in response.get("content", {}).values(): + if "model" in schema: + response_field = analyze_param( + param_name="return", + annotation=cast(OpenAPIResponseContentModel, schema)["model"], + value=None, + is_path_param=False, + is_response_param=True, + ) + if response_field is None: + raise AssertionError("Response field is None for response model") + + dependant.response_extra_models.append(response_field) + + +def _add_return_annotation(dependant: Dependant, endpoint_signature: inspect.Signature): # If the return annotation is not empty, add it to the dependant model. return_annotation = endpoint_signature.return_annotation if return_annotation is not inspect.Signature.empty: @@ -210,8 +242,6 @@ def get_dependant( dependant.return_param = param_field - return dependant - def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool: """ diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index bd542ba7932..78426cbc7c9 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -49,6 +49,7 @@ def __init__( cookie_params: Optional[List[ModelField]] = None, body_params: Optional[List[ModelField]] = None, return_param: Optional[ModelField] = None, + response_extra_models: Optional[List[ModelField]] = None, name: Optional[str] = None, call: Optional[Callable[..., Any]] = None, request_param_name: Optional[str] = None, @@ -64,6 +65,7 @@ def __init__( self.cookie_params = cookie_params or [] self.body_params = body_params or [] self.return_param = return_param or None + self.response_extra_models = response_extra_models or [] self.request_param_name = request_param_name self.websocket_param_name = websocket_param_name self.http_connection_param_name = http_connection_param_name diff --git a/aws_lambda_powertools/event_handler/openapi/types.py b/aws_lambda_powertools/event_handler/openapi/types.py index 0d166de1131..beafa0e566c 100644 --- a/aws_lambda_powertools/event_handler/openapi/types.py +++ b/aws_lambda_powertools/event_handler/openapi/types.py @@ -2,6 +2,8 @@ from enum import Enum from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Set, Type, Union +from aws_lambda_powertools.shared.types import NotRequired, TypedDict + if TYPE_CHECKING: from pydantic import BaseModel # noqa: F401 @@ -43,3 +45,16 @@ }, }, } + + +class OpenAPIResponseContentSchema(TypedDict, total=False): + schema: Dict + + +class OpenAPIResponseContentModel(TypedDict): + model: Any + + +class OpenAPIResponse(TypedDict): + description: str + content: NotRequired[Dict[str, Union[OpenAPIResponseContentSchema, OpenAPIResponseContentModel]]] diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md index 4602231a63e..a34a94975bc 100644 --- a/docs/core/event_handler/api_gateway.md +++ b/docs/core/event_handler/api_gateway.md @@ -955,15 +955,15 @@ Customize your API endpoints by adding metadata to endpoint definitions. This pr Here's a breakdown of various customizable fields: -| Field Name | Type | Description | -| ---------------------- | --------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `summary` | `str` | A concise overview of the main functionality of the endpoint. This brief introduction is usually displayed in autogenerated API documentation and helps consumers quickly understand what the endpoint does. | -| `description` | `str` | A more detailed explanation of the endpoint, which can include information about the operation's behavior, including side effects, error states, and other operational guidelines. | -| `responses` | `Dict[int, Dict[str, Any]]` | A dictionary that maps each HTTP status code to a Response Object as defined by the [OpenAPI Specification](https://swagger.io/specification/#response-object). This allows you to describe expected responses, including default or error messages, and their corresponding schemas for different status codes. | -| `response_description` | `str` | Provides the default textual description of the response sent by the endpoint when the operation is successful. It is intended to give a human-readable understanding of the result. | -| `tags` | `List[str]` | Tags are a way to categorize and group endpoints within the API documentation. They can help organize the operations by resources or other heuristic. | -| `operation_id` | `str` | A unique identifier for the operation, which can be used for referencing this operation in documentation or code. This ID must be unique across all operations described in the API. | -| `include_in_schema` | `bool` | A boolean value that determines whether or not this operation should be included in the OpenAPI schema. Setting it to `False` can hide the endpoint from generated documentation and schema exports, which might be useful for private or experimental endpoints. | +| Field Name | Type | Description | +| ---------------------- |-----------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `summary` | `str` | A concise overview of the main functionality of the endpoint. This brief introduction is usually displayed in autogenerated API documentation and helps consumers quickly understand what the endpoint does. | +| `description` | `str` | A more detailed explanation of the endpoint, which can include information about the operation's behavior, including side effects, error states, and other operational guidelines. | +| `responses` | `Dict[int, Dict[str, OpenAPIResponse]]` | A dictionary that maps each HTTP status code to a Response Object as defined by the [OpenAPI Specification](https://swagger.io/specification/#response-object). This allows you to describe expected responses, including default or error messages, and their corresponding schemas or models for different status codes. | +| `response_description` | `str` | Provides the default textual description of the response sent by the endpoint when the operation is successful. It is intended to give a human-readable understanding of the result. | +| `tags` | `List[str]` | Tags are a way to categorize and group endpoints within the API documentation. They can help organize the operations by resources or other heuristic. | +| `operation_id` | `str` | A unique identifier for the operation, which can be used for referencing this operation in documentation or code. This ID must be unique across all operations described in the API. | +| `include_in_schema` | `bool` | A boolean value that determines whether or not this operation should be included in the OpenAPI schema. Setting it to `False` can hide the endpoint from generated documentation and schema exports, which might be useful for private or experimental endpoints. | To implement these customizations, include extra parameters when defining your routes: diff --git a/tests/functional/event_handler/test_openapi_responses.py b/tests/functional/event_handler/test_openapi_responses.py index bd470867428..be5d9bca288 100644 --- a/tests/functional/event_handler/test_openapi_responses.py +++ b/tests/functional/event_handler/test_openapi_responses.py @@ -1,4 +1,9 @@ -from aws_lambda_powertools.event_handler import APIGatewayRestResolver +from secrets import randbelow +from typing import Union + +from pydantic import BaseModel + +from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response def test_openapi_default_response(): @@ -47,3 +52,90 @@ def handler(): assert 200 not in responses.keys() assert 422 not in responses.keys() + + +def test_openapi_200_custom_schema(): + app = APIGatewayRestResolver(enable_validation=True) + + class User(BaseModel): + pass + + @app.get( + "/", + responses={200: {"description": "Custom response", "content": {"application/json": {"schema": User.schema()}}}}, + ) + def handler(): + return {"message": "hello world"} + + schema = app.get_openapi_schema() + responses = schema.paths["/"].get.responses + assert 200 in responses.keys() + + assert responses[200].description == "Custom response" + assert responses[200].content["application/json"].schema_.title == "User" + + +def test_openapi_union_response(): + app = APIGatewayRestResolver(enable_validation=True) + + class User(BaseModel): + pass + + class Order(BaseModel): + pass + + @app.get( + "/", + responses={ + 200: {"description": "200 Response", "content": {"application/json": {"model": User}}}, + 202: {"description": "202 Response", "content": {"application/json": {"model": Order}}}, + }, + ) + def handler() -> Response[Union[User, Order]]: + if randbelow(2) > 0: + return Response(status_code=200, body=User()) + else: + return Response(status_code=202, body=Order()) + + schema = app.get_openapi_schema() + responses = schema.paths["/"].get.responses + assert 200 in responses.keys() + assert responses[200].description == "200 Response" + assert responses[200].content["application/json"].schema_.ref == "#/components/schemas/User" + + assert 202 in responses.keys() + assert responses[202].description == "202 Response" + assert responses[202].content["application/json"].schema_.ref == "#/components/schemas/Order" + + +def test_openapi_union_partial_response(): + app = APIGatewayRestResolver(enable_validation=True) + + class User(BaseModel): + pass + + class Order(BaseModel): + pass + + @app.get( + "/", + responses={ + 200: {"description": "200 Response"}, + 202: {"description": "202 Response", "content": {"application/json": {"model": Order}}}, + }, + ) + def handler() -> Response[Union[User, Order]]: + if randbelow(2) > 0: + return Response(status_code=200, body=User()) + else: + return Response(status_code=202, body=Order()) + + schema = app.get_openapi_schema() + responses = schema.paths["/"].get.responses + assert 200 in responses.keys() + assert responses[200].description == "200 Response" + assert responses[200].content["application/json"].schema_.anyOf is not None + + assert 202 in responses.keys() + assert responses[202].description == "202 Response" + assert responses[202].content["application/json"].schema_.ref == "#/components/schemas/Order"