From 270060de13cdb5a96d2d2ccac3125df54d410f40 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 22 Nov 2023 15:32:00 +0100 Subject: [PATCH] fix(event_handler): allow fine grained Response with data validation (#3394) --- .../event_handler/api_gateway.py | 5 +- .../event_handler/openapi/params.py | 12 +++++ .../event_handler/test_openapi_params.py | 20 +++++++- .../test_openapi_validation_middleware.py | 50 ++++++++++++++++++- 4 files changed, 83 insertions(+), 4 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index a2b81974a21..5b7262e5d55 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -66,6 +66,7 @@ _ROUTE_REGEX = "^{}$" ResponseEventT = TypeVar("ResponseEventT", bound=BaseProxyEvent) +ResponseT = TypeVar("ResponseT") if TYPE_CHECKING: from aws_lambda_powertools.event_handler.openapi.compat import ( @@ -207,14 +208,14 @@ def to_dict(self, origin: Optional[str]) -> Dict[str, str]: return headers -class Response: +class Response(Generic[ResponseT]): """Response data class that provides greater control over what is returned from the proxy event""" def __init__( self, status_code: int, content_type: Optional[str] = None, - body: Any = None, + body: Optional[ResponseT] = None, headers: Optional[Dict[str, Union[str, List[str]]]] = None, cookies: Optional[List[Cookie]] = None, compress: Optional[bool] = None, diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index c8099d20404..28154466ff6 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -5,6 +5,7 @@ from pydantic import BaseConfig from pydantic.fields import FieldInfo +from aws_lambda_powertools.event_handler import Response from aws_lambda_powertools.event_handler.openapi.compat import ( ModelField, Required, @@ -724,6 +725,9 @@ def get_field_info_and_type_annotation(annotation, value, is_path_param: bool) - # If the annotation is an Annotated type, we need to extract the type annotation and the FieldInfo if get_origin(annotation) is Annotated: field_info, type_annotation = get_field_info_annotated_type(annotation, value, is_path_param) + # If the annotation is a Response type, we recursively call this function with the inner type + elif get_origin(annotation) is Response: + field_info, type_annotation = get_field_info_response_type(annotation, value) # If the annotation is not an Annotated type, we use it as the type annotation else: type_annotation = annotation @@ -731,6 +735,14 @@ def get_field_info_and_type_annotation(annotation, value, is_path_param: bool) - return field_info, type_annotation +def get_field_info_response_type(annotation, value) -> Tuple[Optional[FieldInfo], Any]: + # Example: get_args(Response[inner_type]) == (inner_type,) # noqa: ERA001 + (inner_type,) = get_args(annotation) + + # Recursively resolve the inner type + return get_field_info_and_type_annotation(inner_type, value, False) + + def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> Tuple[Optional[FieldInfo], Any]: """ Get the FieldInfo and type annotation from an Annotated type. diff --git a/tests/functional/event_handler/test_openapi_params.py b/tests/functional/event_handler/test_openapi_params.py index ec31bb14236..6e4f0395aff 100644 --- a/tests/functional/event_handler/test_openapi_params.py +++ b/tests/functional/event_handler/test_openapi_params.py @@ -4,7 +4,7 @@ from pydantic import BaseModel -from aws_lambda_powertools.event_handler.api_gateway import APIGatewayRestResolver +from aws_lambda_powertools.event_handler.api_gateway import APIGatewayRestResolver, Response from aws_lambda_powertools.event_handler.openapi.models import ( Example, Parameter, @@ -153,6 +153,24 @@ def handler() -> str: assert response.schema_.type == "string" +def test_openapi_with_response_returns(): + app = APIGatewayRestResolver() + + @app.get("/") + def handler() -> Response[Annotated[str, Body(title="Response title")]]: + return Response(body="Hello, world", status_code=200) + + schema = app.get_openapi_schema() + assert len(schema.paths.keys()) == 1 + + get = schema.paths["/"].get + assert get.parameters is None + + response = get.responses[200].content[JSON_CONTENT_TYPE] + assert response.schema_.title == "Response title" + assert response.schema_.type == "string" + + def test_openapi_with_omitted_param(): app = APIGatewayRestResolver() diff --git a/tests/functional/event_handler/test_openapi_validation_middleware.py b/tests/functional/event_handler/test_openapi_validation_middleware.py index 2e14979acce..9c7ca371d54 100644 --- a/tests/functional/event_handler/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/test_openapi_validation_middleware.py @@ -6,7 +6,7 @@ from pydantic import BaseModel -from aws_lambda_powertools.event_handler import APIGatewayRestResolver +from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response from aws_lambda_powertools.event_handler.openapi.params import Body from aws_lambda_powertools.shared.types import Annotated from tests.functional.utils import load_event @@ -330,3 +330,51 @@ def handler(user: Annotated[Model, Body(embed=True)]) -> Model: LOAD_GW_EVENT["body"] = json.dumps({"user": {"name": "John", "age": 30}}) result = app(LOAD_GW_EVENT, {}) assert result["statusCode"] == 200 + + +def test_validate_response_return(): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class Model(BaseModel): + name: str + age: int + + # WHEN a handler is defined with a body parameter + @app.post("/") + def handler(user: Model) -> Response[Model]: + return Response(body=user, status_code=200) + + LOAD_GW_EVENT["httpMethod"] = "POST" + LOAD_GW_EVENT["path"] = "/" + LOAD_GW_EVENT["body"] = json.dumps({"name": "John", "age": 30}) + + # THEN the handler should be invoked and return 200 + # THEN the body must be a dict + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 200 + assert result["body"] == {"name": "John", "age": 30} + + +def test_validate_response_invalid_return(): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class Model(BaseModel): + name: str + age: int + + # WHEN a handler is defined with a body parameter + @app.post("/") + def handler(user: Model) -> Response[Model]: + return Response(body=user, status_code=200) + + LOAD_GW_EVENT["httpMethod"] = "POST" + LOAD_GW_EVENT["path"] = "/" + LOAD_GW_EVENT["body"] = json.dumps({}) + + # THEN the handler should be invoked and return 422 + # THEN the body should have the word missing + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 422 + assert "missing" in result["body"]