From ebb10ea5300a7ef0c96768229c7bf0f473aa10f4 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Fri, 15 Dec 2023 14:36:41 +0100 Subject: [PATCH] fix(event_handler): allow responses and metadata when using Router (#3514) Co-authored-by: Heitor Lessa --- .../event_handler/api_gateway.py | 9 +++-- aws_lambda_powertools/event_handler/util.py | 13 +++++++ .../event_handler/test_openapi_params.py | 34 ++++++++++++++++++- 3 files changed, 52 insertions(+), 4 deletions(-) create mode 100644 aws_lambda_powertools/event_handler/util.py diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 31536457344..f68f186c333 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -40,6 +40,7 @@ validation_error_definition, validation_error_response_definition, ) +from aws_lambda_powertools.event_handler.util import _FrozenDict from aws_lambda_powertools.shared.cookies import Cookie from aws_lambda_powertools.shared.functions import powertools_dev_is_set from aws_lambda_powertools.shared.json_encoder import Encoder @@ -2130,8 +2131,10 @@ def route( middlewares: Optional[List[Callable[..., Any]]] = None, ): def register_route(func: Callable): - # Convert methods to tuple. It needs to be hashable as its part of the self._routes dict key + # All dict keys needs to be hashable. So we'll need to do some conversions: methods = (method,) if isinstance(method, str) else tuple(method) + frozen_responses = _FrozenDict(responses) if responses else None + frozen_tags = frozenset(tags) if tags else None route_key = ( rule, @@ -2141,9 +2144,9 @@ def register_route(func: Callable): cache_control, summary, description, - responses, + frozen_responses, response_description, - tags, + frozen_tags, operation_id, include_in_schema, ) diff --git a/aws_lambda_powertools/event_handler/util.py b/aws_lambda_powertools/event_handler/util.py new file mode 100644 index 00000000000..2832f8102ee --- /dev/null +++ b/aws_lambda_powertools/event_handler/util.py @@ -0,0 +1,13 @@ +class _FrozenDict(dict): + """ + A dictionary that can be used as a key in another dictionary. + + This is needed because the default dict implementation is not hashable. + The only usage for this right now is to store dicts as part of the Router key. + The implementation only takes into consideration the keys of the dictionary. + + MAINTENANCE: this is a temporary solution until we refactor the route key into a class. + """ + + def __hash__(self): + return hash(frozenset(self.keys())) diff --git a/tests/functional/event_handler/test_openapi_params.py b/tests/functional/event_handler/test_openapi_params.py index 9209cb9decd..4ac425c9fdd 100644 --- a/tests/functional/event_handler/test_openapi_params.py +++ b/tests/functional/event_handler/test_openapi_params.py @@ -4,7 +4,7 @@ from pydantic import BaseModel -from aws_lambda_powertools.event_handler.api_gateway import APIGatewayRestResolver, Response +from aws_lambda_powertools.event_handler.api_gateway import APIGatewayRestResolver, Response, Router from aws_lambda_powertools.event_handler.openapi.models import ( Example, Parameter, @@ -392,6 +392,38 @@ def secret(): assert len(schema.paths.keys()) == 0 +def test_openapi_with_router_response(): + router = Router() + + @router.put("/example-resource", responses={200: {"description": "Custom response"}}) + def handler(): + pass + + app = APIGatewayRestResolver(enable_validation=True) + app.include_router(router) + + schema = app.get_openapi_schema() + put = schema.paths["/example-resource"].put + assert 200 in put.responses.keys() + assert put.responses[200].description == "Custom response" + + +def test_openapi_with_router_tags(): + router = Router() + + @router.put("/example-resource", tags=["Example"]) + def handler(): + pass + + app = APIGatewayRestResolver(enable_validation=True) + app.include_router(router) + + schema = app.get_openapi_schema() + tags = schema.paths["/example-resource"].put.tags + assert len(tags) == 1 + assert tags[0].name == "Example" + + def test_create_header(): header = _Header(convert_underscores=True) assert header.convert_underscores is True