From 8765206dc6e0ab374b31b867d4a838e1561c49e1 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Fri, 8 Mar 2024 10:37:47 +0100 Subject: [PATCH] feat(event_handler): use custom serializer during openapi serialization (#3900) * feat(event_handler): use custom serializer during openapi serialization * fix: comments --- .../event_handler/api_gateway.py | 4 +++- .../middlewares/openapi_validation.py | 17 +++++++++++-- .../event_handler/openapi/encoders.py | 9 ++++++- .../test_openapi_serialization.py | 24 +++++++++++++++++++ 4 files changed, 50 insertions(+), 4 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 342d6227dd3..e72d39ba821 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -1401,7 +1401,9 @@ def __init__( if self._enable_validation: from aws_lambda_powertools.event_handler.middlewares.openapi_validation import OpenAPIValidationMiddleware - self.use([OpenAPIValidationMiddleware()]) + # Note the serializer argument: only use custom serializer if provided by the caller + # Otherwise, fully rely on the internal Pydantic based mechanism to serialize responses for validation. + self.use([OpenAPIValidationMiddleware(validation_serializer=serializer)]) def get_openapi_schema( self, diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 241a9972953..a57560a3ad1 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -2,7 +2,7 @@ import json import logging from copy import deepcopy -from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple from pydantic import BaseModel @@ -55,6 +55,18 @@ def get_todos(): List[Todo]: ``` """ + def __init__(self, validation_serializer: Optional[Callable[[Any], str]] = None): + """ + Initialize the OpenAPIValidationMiddleware. + + Parameters + ---------- + validation_serializer : Callable, optional + Optional serializer to use when serializing the response for validation. + Use it when you have a custom type that cannot be serialized by the default jsonable_encoder. + """ + self._validation_serializer = validation_serializer + def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response: logger.debug("OpenAPIValidationMiddleware handler") @@ -181,10 +193,11 @@ def _serialize_response( exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, exclude_none=exclude_none, + custom_serializer=self._validation_serializer, ) else: # Just serialize the response content returned from the handler - return jsonable_encoder(response_content) + return jsonable_encoder(response_content, custom_serializer=self._validation_serializer) def _prepare_response_content( self, diff --git a/aws_lambda_powertools/event_handler/openapi/encoders.py b/aws_lambda_powertools/event_handler/openapi/encoders.py index 94c1cb5d659..c12aa0164e1 100644 --- a/aws_lambda_powertools/event_handler/openapi/encoders.py +++ b/aws_lambda_powertools/event_handler/openapi/encoders.py @@ -29,6 +29,7 @@ def jsonable_encoder( # noqa: PLR0911 exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, + custom_serializer: Optional[Callable[[Any], str]] = None, ) -> Any: """ JSON encodes an arbitrary Python object into JSON serializable data types. @@ -55,6 +56,8 @@ def jsonable_encoder( # noqa: PLR0911 by default False exclude_none : bool, optional Whether fields that are equal to None should be excluded, by default False + custom_serializer : Callable, optional + A custom serializer to use for encoding the object, when everything else fails. Returns ------- @@ -134,6 +137,10 @@ def jsonable_encoder( # noqa: PLR0911 if isinstance(obj, classes_tuple): return encoder(obj) + # Use custom serializer if present + if custom_serializer: + return custom_serializer(obj) + # Default return _dump_other( obj=obj, @@ -259,7 +266,7 @@ def _dump_other( exclude_defaults: bool = False, ) -> Any: """ - Dump an object to ah hashable object, using the same parameters as jsonable_encoder + Dump an object to a hashable object, using the same parameters as jsonable_encoder """ try: data = dict(obj) diff --git a/tests/functional/event_handler/test_openapi_serialization.py b/tests/functional/event_handler/test_openapi_serialization.py index 63f1c0e4f9d..91e345260e8 100644 --- a/tests/functional/event_handler/test_openapi_serialization.py +++ b/tests/functional/event_handler/test_openapi_serialization.py @@ -37,3 +37,27 @@ def handler(): # THEN we should get a dictionary assert isinstance(schema, Dict) + + +def test_openapi_serialize_other(gw_event): + # GIVEN a custom serializer + def serializer(_): + return "hello world" + + # GIVEN APIGatewayRestResolver is initialized with enable_validation=True and the custom serializer + app = APIGatewayRestResolver(enable_validation=True, serializer=serializer) + + # GIVEN a custom class + class CustomClass(object): + __slots__ = [] + + # GIVEN a handler that returns an instance of that class + @app.get("/my/path") + def handler(): + return CustomClass() + + # WHEN we invoke the handler + response = app(gw_event, {}) + + # THEN we the custom serializer should be used + assert response["body"] == "hello world"