Skip to content

Commit

Permalink
fix(event_handler): apply serialization as the last operation for mid…
Browse files Browse the repository at this point in the history
…dlewares (#3392)
  • Loading branch information
rubenfonseca authored Nov 22, 2023
1 parent f94526f commit 6a47ee8
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 43 deletions.
59 changes: 33 additions & 26 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,8 +699,14 @@ def _generate_operation_id(self) -> str:
class ResponseBuilder(Generic[ResponseEventT]):
"""Internally used Response builder"""

def __init__(self, response: Response, route: Optional[Route] = None):
def __init__(
self,
response: Response,
serializer: Callable[[Any], str] = json.dumps,
route: Optional[Route] = None,
):
self.response = response
self.serializer = serializer
self.route = route

def _add_cors(self, event: ResponseEventT, cors: CORSConfig):
Expand Down Expand Up @@ -783,6 +789,11 @@ def build(self, event: ResponseEventT, cors: Optional[CORSConfig] = None) -> Dic
self.response.base64_encoded = True
self.response.body = base64.b64encode(self.response.body).decode()

# We only apply the serializer when the content type is JSON and the
# body is not a str, to avoid double encoding
elif self.response.is_json() and not isinstance(self.response.body, str):
self.response.body = self.serializer(self.response.body)

return {
"statusCode": self.response.status_code,
"body": self.response.body,
Expand Down Expand Up @@ -1332,14 +1343,6 @@ def __init__(

self.use([OpenAPIValidationMiddleware()])

# When using validation, we need to skip the serializer, as the middleware is doing it automatically.
# However, if the user is using a custom serializer, we need to abort.
if serializer:
raise ValueError("Cannot use a custom serializer when using validation")

# Install a dummy serializer
self._serializer = lambda args: args # type: ignore

def get_openapi_schema(
self,
*,
Expand Down Expand Up @@ -1717,7 +1720,7 @@ def resolve(self, event, context) -> Dict[str, Any]:
event = event.raw_event

if self._debug:
print(self._json_dump(event))
print(self._serializer(event))

# Populate router(s) dependencies without keeping a reference to each registered router
BaseRouter.current_event = self._to_proxy_event(event)
Expand Down Expand Up @@ -1881,19 +1884,23 @@ def _not_found(self, method: str) -> ResponseBuilder:
if method == "OPTIONS":
logger.debug("Pre-flight request detected. Returning CORS with null response")
headers["Access-Control-Allow-Methods"] = ",".join(sorted(self._cors_methods))
return ResponseBuilder(Response(status_code=204, content_type=None, headers=headers, body=""))
return ResponseBuilder(
response=Response(status_code=204, content_type=None, headers=headers, body=""),
serializer=self._serializer,
)

handler = self._lookup_exception_handler(NotFoundError)
if handler:
return self._response_builder_class(handler(NotFoundError()))
return self._response_builder_class(response=handler(NotFoundError()), serializer=self._serializer)

return self._response_builder_class(
Response(
response=Response(
status_code=HTTPStatus.NOT_FOUND.value,
content_type=content_types.APPLICATION_JSON,
headers=headers,
body=self._json_dump({"statusCode": HTTPStatus.NOT_FOUND.value, "message": "Not found"}),
body={"statusCode": HTTPStatus.NOT_FOUND.value, "message": "Not found"},
),
serializer=self._serializer,
)

def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> ResponseBuilder:
Expand All @@ -1903,10 +1910,11 @@ def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> Response
self._reset_processed_stack()

return self._response_builder_class(
self._to_response(
response=self._to_response(
route(router_middlewares=self._router_middlewares, app=self, route_arguments=route_arguments),
),
route,
serializer=self._serializer,
route=route,
)
except Exception as exc:
# If exception is handled then return the response builder to reduce noise
Expand All @@ -1920,12 +1928,13 @@ def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> Response
# we'll let the original exception propagate, so
# they get more information about what went wrong.
return self._response_builder_class(
Response(
response=Response(
status_code=500,
content_type=content_types.TEXT_PLAIN,
body="".join(traceback.format_exc()),
),
route,
serializer=self._serializer,
route=route,
)

raise
Expand Down Expand Up @@ -1958,18 +1967,19 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[Resp
handler = self._lookup_exception_handler(type(exp))
if handler:
try:
return self._response_builder_class(handler(exp), route)
return self._response_builder_class(response=handler(exp), serializer=self._serializer, route=route)
except ServiceError as service_error:
exp = service_error

if isinstance(exp, ServiceError):
return self._response_builder_class(
Response(
response=Response(
status_code=exp.status_code,
content_type=content_types.APPLICATION_JSON,
body=self._json_dump({"statusCode": exp.status_code, "message": exp.msg}),
body={"statusCode": exp.status_code, "message": exp.msg},
),
route,
serializer=self._serializer,
route=route,
)

return None
Expand All @@ -1995,12 +2005,9 @@ def _to_response(self, result: Union[Dict, Tuple, Response]) -> Response:
return Response(
status_code=status_code,
content_type=content_types.APPLICATION_JSON,
body=self._json_dump(result),
body=result,
)

def _json_dump(self, obj: Any) -> str:
return self._serializer(obj)

def include_router(self, router: "Router", prefix: Optional[str] = None) -> None:
"""Adds all routes and context defined in a router
Expand Down
6 changes: 5 additions & 1 deletion aws_lambda_powertools/event_handler/bedrock_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ def build(self, event: BedrockAgentEvent, *args) -> Dict[str, Any]:
"""Build the full response dict to be returned by the lambda"""
self._route(event, None)

body = self.response.body
if self.response.is_json() and not isinstance(self.response.body, str):
body = self.serializer(self.response.body)

return {
"messageVersion": "1.0",
"response": {
Expand All @@ -32,7 +36,7 @@ def build(self, event: BedrockAgentEvent, *args) -> Dict[str, Any]:
"httpStatusCode": self.response.status_code,
"responseBody": {
self.response.content_type: {
"body": self.response.body,
"body": body,
},
},
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ def _handle_response(self, *, route: Route, response: Response):
if response.body:
# Validate and serialize the response, if it's JSON
if response.is_json():
response.body = json.dumps(
self._serialize_response(field=route.dependant.return_param, response_content=response.body),
sort_keys=True,
response.body = self._serialize_response(
field=route.dependant.return_param,
response_content=response.body,
)

return response
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/event_handler/test_api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def handler(event, context):
# WHEN calling the event handler
result = handler(mock_event, None)

# THEN then the response is not compressed
# THEN the response is not compressed
assert result["isBase64Encoded"] is False
assert result["body"] == expected_value
assert result["multiValueHeaders"].get("Content-Encoding") is None
Expand Down
4 changes: 2 additions & 2 deletions tests/functional/event_handler/test_bedrock_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def claims() -> Dict[str, Any]:
assert result["response"]["httpStatusCode"] == 200

body = result["response"]["responseBody"]["application/json"]["body"]
assert body == json.dumps({"output": claims_response})
assert json.loads(body) == {"output": claims_response}


def test_bedrock_agent_with_path_params():
Expand Down Expand Up @@ -79,7 +79,7 @@ def claims():
assert result["response"]["httpStatusCode"] == 200

body = result["response"]["responseBody"]["application/json"]["body"]
assert body == json.dumps(output)
assert json.loads(body) == output


def test_bedrock_agent_event_with_no_matches():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from pathlib import PurePath
from typing import List, Tuple

import pytest
from pydantic import BaseModel

from aws_lambda_powertools.event_handler import APIGatewayRestResolver
Expand All @@ -15,11 +14,6 @@
LOAD_GW_EVENT = load_event("apiGatewayProxyEvent.json")


def test_validate_with_customn_serializer():
with pytest.raises(ValueError):
APIGatewayRestResolver(enable_validation=True, serializer=json.dumps)


def test_validate_scalars():
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)
Expand Down Expand Up @@ -128,7 +122,7 @@ def handler() -> List[int]:
# THEN the body must be [123, 234]
result = app(LOAD_GW_EVENT, {})
assert result["statusCode"] == 200
assert result["body"] == "[123, 234]"
assert json.loads(result["body"]) == [123, 234]


def test_validate_return_tuple():
Expand All @@ -148,7 +142,7 @@ def handler() -> Tuple:
# THEN the body must be a tuple
result = app(LOAD_GW_EVENT, {})
assert result["statusCode"] == 200
assert result["body"] == "[1, 2, 3]"
assert json.loads(result["body"]) == [1, 2, 3]


def test_validate_return_purepath():
Expand All @@ -169,7 +163,7 @@ def handler() -> str:
# THEN the body must be a string
result = app(LOAD_GW_EVENT, {})
assert result["statusCode"] == 200
assert result["body"] == json.dumps(sample_path.as_posix())
assert result["body"] == sample_path.as_posix()


def test_validate_return_enum():
Expand All @@ -190,7 +184,7 @@ def handler() -> Model:
# THEN the body must be a string
result = app(LOAD_GW_EVENT, {})
assert result["statusCode"] == 200
assert result["body"] == '"powertools"'
assert result["body"] == "powertools"


def test_validate_return_dataclass():
Expand Down

0 comments on commit 6a47ee8

Please sign in to comment.