diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 5b7262e5d55..05831a2eea5 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -32,6 +32,7 @@ from aws_lambda_powertools.event_handler import content_types from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError from aws_lambda_powertools.event_handler.openapi.constants import DEFAULT_API_VERSION, DEFAULT_OPENAPI_VERSION +from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError from aws_lambda_powertools.event_handler.openapi.swagger_ui.html import generate_swagger_html from aws_lambda_powertools.event_handler.openapi.types import ( COMPONENT_REF_PREFIX, @@ -1972,6 +1973,17 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[Resp except ServiceError as service_error: exp = service_error + if isinstance(exp, RequestValidationError): + return self._response_builder_class( + response=Response( + status_code=HTTPStatus.UNPROCESSABLE_ENTITY, + content_type=content_types.APPLICATION_JSON, + body={"statusCode": HTTPStatus.UNPROCESSABLE_ENTITY, "message": exp.errors()}, + ), + serializer=self._serializer, + route=route, + ) + if isinstance(exp, ServiceError): return self._response_builder_class( response=Response( diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 131f9d267a3..34011b64384 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -62,50 +62,43 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> values: Dict[str, Any] = {} errors: List[Any] = [] - try: - # Process path values, which can be found on the route_args - path_values, path_errors = _request_params_to_args( - route.dependant.path_params, - app.context["_route_args"], + # Process path values, which can be found on the route_args + path_values, path_errors = _request_params_to_args( + route.dependant.path_params, + app.context["_route_args"], + ) + + # Process query values + query_values, query_errors = _request_params_to_args( + route.dependant.query_params, + app.current_event.query_string_parameters or {}, + ) + + values.update(path_values) + values.update(query_values) + errors += path_errors + query_errors + + # Process the request body, if it exists + if route.dependant.body_params: + (body_values, body_errors) = _request_body_to_args( + required_params=route.dependant.body_params, + received_body=self._get_body(app), ) + values.update(body_values) + errors.extend(body_errors) - # Process query values - query_values, query_errors = _request_params_to_args( - route.dependant.query_params, - app.current_event.query_string_parameters or {}, - ) - - values.update(path_values) - values.update(query_values) - errors += path_errors + query_errors + if errors: + # Raise the validation errors + raise RequestValidationError(_normalize_errors(errors)) + else: + # Re-write the route_args with the validated values, and call the next middleware + app.context["_route_args"] = values - # Process the request body, if it exists - if route.dependant.body_params: - (body_values, body_errors) = _request_body_to_args( - required_params=route.dependant.body_params, - received_body=self._get_body(app), - ) - values.update(body_values) - errors.extend(body_errors) + # Call the handler by calling the next middleware + response = next_middleware(app) - if errors: - # Raise the validation errors - raise RequestValidationError(_normalize_errors(errors)) - else: - # Re-write the route_args with the validated values, and call the next middleware - app.context["_route_args"] = values - - # Call the handler by calling the next middleware - response = next_middleware(app) - - # Process the response - return self._handle_response(route=route, response=response) - except RequestValidationError as e: - return Response( - status_code=422, - content_type="application/json", - body=json.dumps({"detail": e.errors()}), - ) + # Process the response + return self._handle_response(route=route, response=response) def _handle_response(self, *, route: Route, response: Response): # Process the response body if it exists diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index 570de9ec808..d4c88b541aa 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -30,6 +30,7 @@ ServiceError, UnauthorizedError, ) +from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError from aws_lambda_powertools.shared import constants from aws_lambda_powertools.shared.cookies import Cookie from aws_lambda_powertools.shared.json_encoder import Encoder @@ -1458,6 +1459,51 @@ def get_lambda() -> Response: assert result["body"] == "Foo!" +def test_exception_handler_with_data_validation(): + # GIVEN a resolver with an exception handler defined for RequestValidationError + app = ApiGatewayResolver(enable_validation=True) + + @app.exception_handler(RequestValidationError) + def handle_validation_error(ex: RequestValidationError): + print(f"request path is '{app.current_event.path}'") + return Response( + status_code=422, + content_type=content_types.TEXT_PLAIN, + body=f"Invalid data. Number of errors: {len(ex.errors())}", + ) + + @app.get("/my/path") + def get_lambda(param: int): + ... + + # WHEN calling the event handler + # AND a RequestValidationError is raised + result = app(LOAD_GW_EVENT, {}) + + # THEN call the exception_handler + assert result["statusCode"] == 422 + assert result["multiValueHeaders"]["Content-Type"] == [content_types.TEXT_PLAIN] + assert result["body"] == "Invalid data. Number of errors: 1" + + +def test_data_validation_error(): + # GIVEN a resolver without an exception handler + app = ApiGatewayResolver(enable_validation=True) + + @app.get("/my/path") + def get_lambda(param: int): + ... + + # WHEN calling the event handler + # AND a RequestValidationError is raised + result = app(LOAD_GW_EVENT, {}) + + # THEN call the exception_handler + assert result["statusCode"] == 422 + assert result["multiValueHeaders"]["Content-Type"] == [content_types.APPLICATION_JSON] + assert "missing" in result["body"] + + def test_exception_handler_service_error(): # GIVEN app = ApiGatewayResolver() diff --git a/tests/functional/event_handler/test_openapi_validation_middleware.py b/tests/functional/event_handler/test_openapi_validation_middleware.py index 9c7ca371d54..f558bd23ced 100644 --- a/tests/functional/event_handler/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/test_openapi_validation_middleware.py @@ -343,7 +343,7 @@ class Model(BaseModel): # WHEN a handler is defined with a body parameter @app.post("/") def handler(user: Model) -> Response[Model]: - return Response(body=user, status_code=200) + return Response(body=user, status_code=200, content_type="application/json") LOAD_GW_EVENT["httpMethod"] = "POST" LOAD_GW_EVENT["path"] = "/" @@ -353,7 +353,7 @@ def handler(user: Model) -> Response[Model]: # THEN the body must be a dict result = app(LOAD_GW_EVENT, {}) assert result["statusCode"] == 200 - assert result["body"] == {"name": "John", "age": 30} + assert json.loads(result["body"]) == {"name": "John", "age": 30} def test_validate_response_invalid_return():