Skip to content

Commit

Permalink
feat(event_handler): allow customers to catch request validation erro…
Browse files Browse the repository at this point in the history
…rs (#3396)
  • Loading branch information
rubenfonseca authored Nov 22, 2023
1 parent 270060d commit 481905e
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 42 deletions.
12 changes: 12 additions & 0 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from aws_lambda_powertools.event_handler import content_types
from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError
from aws_lambda_powertools.event_handler.openapi.constants import DEFAULT_API_VERSION, DEFAULT_OPENAPI_VERSION
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError
from aws_lambda_powertools.event_handler.openapi.swagger_ui.html import generate_swagger_html
from aws_lambda_powertools.event_handler.openapi.types import (
COMPONENT_REF_PREFIX,
Expand Down Expand Up @@ -1972,6 +1973,17 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[Resp
except ServiceError as service_error:
exp = service_error

if isinstance(exp, RequestValidationError):
return self._response_builder_class(
response=Response(
status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
content_type=content_types.APPLICATION_JSON,
body={"statusCode": HTTPStatus.UNPROCESSABLE_ENTITY, "message": exp.errors()},
),
serializer=self._serializer,
route=route,
)

if isinstance(exp, ServiceError):
return self._response_builder_class(
response=Response(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,50 +62,43 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
values: Dict[str, Any] = {}
errors: List[Any] = []

try:
# Process path values, which can be found on the route_args
path_values, path_errors = _request_params_to_args(
route.dependant.path_params,
app.context["_route_args"],
# Process path values, which can be found on the route_args
path_values, path_errors = _request_params_to_args(
route.dependant.path_params,
app.context["_route_args"],
)

# Process query values
query_values, query_errors = _request_params_to_args(
route.dependant.query_params,
app.current_event.query_string_parameters or {},
)

values.update(path_values)
values.update(query_values)
errors += path_errors + query_errors

# Process the request body, if it exists
if route.dependant.body_params:
(body_values, body_errors) = _request_body_to_args(
required_params=route.dependant.body_params,
received_body=self._get_body(app),
)
values.update(body_values)
errors.extend(body_errors)

# Process query values
query_values, query_errors = _request_params_to_args(
route.dependant.query_params,
app.current_event.query_string_parameters or {},
)

values.update(path_values)
values.update(query_values)
errors += path_errors + query_errors
if errors:
# Raise the validation errors
raise RequestValidationError(_normalize_errors(errors))
else:
# Re-write the route_args with the validated values, and call the next middleware
app.context["_route_args"] = values

# Process the request body, if it exists
if route.dependant.body_params:
(body_values, body_errors) = _request_body_to_args(
required_params=route.dependant.body_params,
received_body=self._get_body(app),
)
values.update(body_values)
errors.extend(body_errors)
# Call the handler by calling the next middleware
response = next_middleware(app)

if errors:
# Raise the validation errors
raise RequestValidationError(_normalize_errors(errors))
else:
# Re-write the route_args with the validated values, and call the next middleware
app.context["_route_args"] = values

# Call the handler by calling the next middleware
response = next_middleware(app)

# Process the response
return self._handle_response(route=route, response=response)
except RequestValidationError as e:
return Response(
status_code=422,
content_type="application/json",
body=json.dumps({"detail": e.errors()}),
)
# Process the response
return self._handle_response(route=route, response=response)

def _handle_response(self, *, route: Route, response: Response):
# Process the response body if it exists
Expand Down
46 changes: 46 additions & 0 deletions tests/functional/event_handler/test_api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ServiceError,
UnauthorizedError,
)
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError
from aws_lambda_powertools.shared import constants
from aws_lambda_powertools.shared.cookies import Cookie
from aws_lambda_powertools.shared.json_encoder import Encoder
Expand Down Expand Up @@ -1458,6 +1459,51 @@ def get_lambda() -> Response:
assert result["body"] == "Foo!"


def test_exception_handler_with_data_validation():
# GIVEN a resolver with an exception handler defined for RequestValidationError
app = ApiGatewayResolver(enable_validation=True)

@app.exception_handler(RequestValidationError)
def handle_validation_error(ex: RequestValidationError):
print(f"request path is '{app.current_event.path}'")
return Response(
status_code=422,
content_type=content_types.TEXT_PLAIN,
body=f"Invalid data. Number of errors: {len(ex.errors())}",
)

@app.get("/my/path")
def get_lambda(param: int):
...

# WHEN calling the event handler
# AND a RequestValidationError is raised
result = app(LOAD_GW_EVENT, {})

# THEN call the exception_handler
assert result["statusCode"] == 422
assert result["multiValueHeaders"]["Content-Type"] == [content_types.TEXT_PLAIN]
assert result["body"] == "Invalid data. Number of errors: 1"


def test_data_validation_error():
# GIVEN a resolver without an exception handler
app = ApiGatewayResolver(enable_validation=True)

@app.get("/my/path")
def get_lambda(param: int):
...

# WHEN calling the event handler
# AND a RequestValidationError is raised
result = app(LOAD_GW_EVENT, {})

# THEN call the exception_handler
assert result["statusCode"] == 422
assert result["multiValueHeaders"]["Content-Type"] == [content_types.APPLICATION_JSON]
assert "missing" in result["body"]


def test_exception_handler_service_error():
# GIVEN
app = ApiGatewayResolver()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ class Model(BaseModel):
# WHEN a handler is defined with a body parameter
@app.post("/")
def handler(user: Model) -> Response[Model]:
return Response(body=user, status_code=200)
return Response(body=user, status_code=200, content_type="application/json")

LOAD_GW_EVENT["httpMethod"] = "POST"
LOAD_GW_EVENT["path"] = "/"
Expand All @@ -353,7 +353,7 @@ def handler(user: Model) -> Response[Model]:
# THEN the body must be a dict
result = app(LOAD_GW_EVENT, {})
assert result["statusCode"] == 200
assert result["body"] == {"name": "John", "age": 30}
assert json.loads(result["body"]) == {"name": "John", "age": 30}


def test_validate_response_invalid_return():
Expand Down

0 comments on commit 481905e

Please sign in to comment.