Skip to content

Commit

Permalink
feat(event_handler): use custom serializer during openapi serializati…
Browse files Browse the repository at this point in the history
…on (#3900)

* feat(event_handler): use custom serializer during openapi serialization

* fix: comments
  • Loading branch information
rubenfonseca authored Mar 8, 2024
1 parent e79eef4 commit 8765206
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 4 deletions.
4 changes: 3 additions & 1 deletion aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -1401,7 +1401,9 @@ def __init__(
if self._enable_validation:
from aws_lambda_powertools.event_handler.middlewares.openapi_validation import OpenAPIValidationMiddleware

self.use([OpenAPIValidationMiddleware()])
# Note the serializer argument: only use custom serializer if provided by the caller
# Otherwise, fully rely on the internal Pydantic based mechanism to serialize responses for validation.
self.use([OpenAPIValidationMiddleware(validation_serializer=serializer)])

def get_openapi_schema(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import logging
from copy import deepcopy
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple

from pydantic import BaseModel

Expand Down Expand Up @@ -55,6 +55,18 @@ def get_todos(): List[Todo]:
```
"""

def __init__(self, validation_serializer: Optional[Callable[[Any], str]] = None):
"""
Initialize the OpenAPIValidationMiddleware.
Parameters
----------
validation_serializer : Callable, optional
Optional serializer to use when serializing the response for validation.
Use it when you have a custom type that cannot be serialized by the default jsonable_encoder.
"""
self._validation_serializer = validation_serializer

def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response:
logger.debug("OpenAPIValidationMiddleware handler")

Expand Down Expand Up @@ -181,10 +193,11 @@ def _serialize_response(
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_serializer=self._validation_serializer,
)
else:
# Just serialize the response content returned from the handler
return jsonable_encoder(response_content)
return jsonable_encoder(response_content, custom_serializer=self._validation_serializer)

def _prepare_response_content(
self,
Expand Down
9 changes: 8 additions & 1 deletion aws_lambda_powertools/event_handler/openapi/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def jsonable_encoder( # noqa: PLR0911
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
custom_serializer: Optional[Callable[[Any], str]] = None,
) -> Any:
"""
JSON encodes an arbitrary Python object into JSON serializable data types.
Expand All @@ -55,6 +56,8 @@ def jsonable_encoder( # noqa: PLR0911
by default False
exclude_none : bool, optional
Whether fields that are equal to None should be excluded, by default False
custom_serializer : Callable, optional
A custom serializer to use for encoding the object, when everything else fails.
Returns
-------
Expand Down Expand Up @@ -134,6 +137,10 @@ def jsonable_encoder( # noqa: PLR0911
if isinstance(obj, classes_tuple):
return encoder(obj)

# Use custom serializer if present
if custom_serializer:
return custom_serializer(obj)

# Default
return _dump_other(
obj=obj,
Expand Down Expand Up @@ -259,7 +266,7 @@ def _dump_other(
exclude_defaults: bool = False,
) -> Any:
"""
Dump an object to ah hashable object, using the same parameters as jsonable_encoder
Dump an object to a hashable object, using the same parameters as jsonable_encoder
"""
try:
data = dict(obj)
Expand Down
24 changes: 24 additions & 0 deletions tests/functional/event_handler/test_openapi_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,27 @@ def handler():

# THEN we should get a dictionary
assert isinstance(schema, Dict)


def test_openapi_serialize_other(gw_event):
# GIVEN a custom serializer
def serializer(_):
return "hello world"

# GIVEN APIGatewayRestResolver is initialized with enable_validation=True and the custom serializer
app = APIGatewayRestResolver(enable_validation=True, serializer=serializer)

# GIVEN a custom class
class CustomClass(object):
__slots__ = []

# GIVEN a handler that returns an instance of that class
@app.get("/my/path")
def handler():
return CustomClass()

# WHEN we invoke the handler
response = app(gw_event, {})

# THEN we the custom serializer should be used
assert response["body"] == "hello world"

0 comments on commit 8765206

Please sign in to comment.