Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(event_handler): generate OpenAPI specifications and validate input/output #3109

Merged
merged 78 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
86bce91
feat: generate OpenAPI spec from event handler
rubenfonseca Sep 19, 2023
e74cee6
fix: resolver circular dependencies
rubenfonseca Oct 9, 2023
510ad25
fix: rebase
rubenfonseca Oct 9, 2023
a2c1c92
fix: document the new methods
rubenfonseca Sep 25, 2023
ba333d3
fix: linter
rubenfonseca Sep 26, 2023
303fb2e
fix: remove unneeded code
rubenfonseca Sep 26, 2023
d1be57b
fix: reduce duplication
rubenfonseca Sep 26, 2023
40fcca1
fix: types and sonarcube
rubenfonseca Sep 26, 2023
079f3d7
chore: refactor complex function
rubenfonseca Sep 26, 2023
44bc067
fix: typing extensions
rubenfonseca Sep 26, 2023
c11dda4
fix: tests
rubenfonseca Sep 27, 2023
9c7c37f
fix: mypy
rubenfonseca Sep 27, 2023
f02f189
fix: security baseline
rubenfonseca Sep 27, 2023
2d30443
feat: add simultaneous support for Pydantic v2
rubenfonseca Sep 27, 2023
3b0037f
fix: disable mypy and ruff on openapi compat
rubenfonseca Sep 27, 2023
633ceb4
chore: add explanation to imports
rubenfonseca Sep 27, 2023
b4fcde6
chore: add first test
rubenfonseca Sep 28, 2023
bca3c71
fix: test
rubenfonseca Sep 28, 2023
88ec111
fix: test
rubenfonseca Sep 28, 2023
ba2e8f0
fix: don't require pydantic to run normal things
rubenfonseca Sep 28, 2023
c97d016
chore: added first tests
rubenfonseca Sep 28, 2023
e4de16c
fix: refactored tests to remove code smell
rubenfonseca Sep 28, 2023
c92b8c0
fix: customize the handler methods
rubenfonseca Oct 2, 2023
a80f53b
fix: tests
rubenfonseca Oct 2, 2023
79ea082
feat: add a validation middleware
rubenfonseca Oct 9, 2023
9b8ce4a
fix: uniontype
rubenfonseca Oct 9, 2023
c3f25f8
fix: types
rubenfonseca Oct 9, 2023
13ccd5f
fix: ignore unused-ignore
rubenfonseca Oct 9, 2023
cf1b866
fix: moved things around
rubenfonseca Oct 9, 2023
f4d9446
fix: compatibility with pydantic v2
rubenfonseca Oct 9, 2023
24a9818
chore: add tests on the body request
rubenfonseca Oct 9, 2023
d17cc64
chore: add tests for validation middleware
rubenfonseca Oct 9, 2023
280abf5
fix: assorted fixes
rubenfonseca Oct 9, 2023
1bb73c6
fix: make tests pass in both pydantic versions
rubenfonseca Oct 9, 2023
a559ed6
fix: remove assert
rubenfonseca Oct 9, 2023
d7317ec
fix: complexity
rubenfonseca Oct 10, 2023
6b44575
fix: move Response class back
rubenfonseca Oct 10, 2023
eb90c56
fix: more fix
rubenfonseca Oct 10, 2023
31dca10
fix: more fix
rubenfonseca Oct 10, 2023
550528d
fix: one more fix
rubenfonseca Oct 10, 2023
cdfbfbf
fix: refactor OpenAPI validation middleware
rubenfonseca Oct 10, 2023
5ff491b
fix: refactor dependant.py
rubenfonseca Oct 10, 2023
bbb9c25
fix: beautify encoders
rubenfonseca Oct 10, 2023
5bd4a50
fix: move things around
rubenfonseca Oct 10, 2023
a3cef34
fix: costmetic changes
rubenfonseca Oct 10, 2023
de22a93
fix: add more comments
rubenfonseca Oct 10, 2023
e60f7df
fix: format
rubenfonseca Oct 10, 2023
0cd690e
fix: cyclomatic
rubenfonseca Oct 10, 2023
eebdc2f
fix: change method of generating operation id
rubenfonseca Oct 11, 2023
b308f63
fix: allow validation in all resolvers
rubenfonseca Oct 11, 2023
2c7367e
fix: use proper resolver in tests
rubenfonseca Oct 11, 2023
c87e47e
fix: move from flake8 to ruff
rubenfonseca Oct 11, 2023
9427ed6
fix: customizing responses
rubenfonseca Oct 11, 2023
2cb7c67
fix: add documentation to a method
rubenfonseca Oct 11, 2023
0a69582
fix: more explicit comments
rubenfonseca Oct 11, 2023
ab21cb3
fix: typo
rubenfonseca Oct 11, 2023
2fa15a4
fix: add extra comment
rubenfonseca Oct 11, 2023
efd339c
fix: comment
rubenfonseca Oct 11, 2023
0c2db13
fix: add comments
rubenfonseca Oct 11, 2023
c2d7bc3
fix: comments
rubenfonseca Oct 11, 2023
a0a9adc
fix: typo
rubenfonseca Oct 11, 2023
526d9f7
fix: remove leftover comment
rubenfonseca Oct 11, 2023
76f3a32
fix: addressing comments
rubenfonseca Oct 17, 2023
e243200
fix: pydantic2 models
rubenfonseca Oct 17, 2023
64c6192
fix: typing extension problems
rubenfonseca Oct 17, 2023
006f854
Adding more tests and fixing small things
leandrodamascena Oct 18, 2023
0e79b81
Adding more tests and fixing small things
leandrodamascena Oct 18, 2023
acf8928
Adding more tests and fixing small things
leandrodamascena Oct 18, 2023
4779d39
Removing flaky tests
leandrodamascena Oct 18, 2023
96e2d17
Merge branch 'develop' into rf/openapi-v2
leandrodamascena Oct 18, 2023
d116be2
Merge branch 'develop' into rf/openapi-v2
leandrodamascena Oct 23, 2023
2e115dd
fix: improve coverage of encoders
rubenfonseca Oct 23, 2023
bf0aaae
fix: mark test as pydantic v1 only
rubenfonseca Oct 23, 2023
80375e4
fix: make sonarcube happy
rubenfonseca Oct 23, 2023
18f4418
fix: improve coverage of params.py
rubenfonseca Oct 23, 2023
fdadd6b
fix: add codecov.yml file to ignore compat.py
rubenfonseca Oct 23, 2023
9f4672a
Increasing coverage
Oct 24, 2023
45573a0
Merge branch 'develop' into rf/openapi-v2
leandrodamascena Oct 24, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
794 changes: 769 additions & 25 deletions aws_lambda_powertools/event_handler/api_gateway.py

Large diffs are not rendered by default.

10 changes: 9 additions & 1 deletion aws_lambda_powertools/event_handler/lambda_function_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,13 @@ def __init__(
debug: Optional[bool] = None,
serializer: Optional[Callable[[Dict], str]] = None,
strip_prefixes: Optional[List[Union[str, Pattern]]] = None,
enable_validation: bool = False,
):
super().__init__(ProxyEventType.LambdaFunctionUrlEvent, cors, debug, serializer, strip_prefixes)
super().__init__(
ProxyEventType.LambdaFunctionUrlEvent,
cors,
debug,
serializer,
strip_prefixes,
enable_validation,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,342 @@
import dataclasses
import json
import logging
from copy import deepcopy
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple

from pydantic import BaseModel

from aws_lambda_powertools.event_handler import Response
from aws_lambda_powertools.event_handler.api_gateway import Route
from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware
from aws_lambda_powertools.event_handler.openapi.compat import (
ModelField,
_model_dump,
_normalize_errors,
_regenerate_error_with_loc,
get_missing_field_error,
)
from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError
from aws_lambda_powertools.event_handler.openapi.params import Param
from aws_lambda_powertools.event_handler.openapi.types import IncEx
from aws_lambda_powertools.event_handler.types import EventHandlerInstance

logger = logging.getLogger(__name__)


class OpenAPIValidationMiddleware(BaseMiddlewareHandler):
"""
OpenAPIValidationMiddleware is a middleware that validates the request against the OpenAPI schema defined by the
Lambda handler. It also validates the response against the OpenAPI schema defined by the Lambda handler. It
should not be used directly, but rather through the `enable_validation` parameter of the `ApiGatewayResolver`.

Examples
--------

```python
from typing import List

from pydantic import BaseModel

from aws_lambda_powertools.event_handler.api_gateway import (
APIGatewayRestResolver,
)

class Todo(BaseModel):
name: str

app = APIGatewayRestResolver(enable_validation=True)

@app.get("/todos")
def get_todos(): List[Todo]:
return [Todo(name="hello world")]
```
"""

def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response:
logger.debug("OpenAPIValidationMiddleware handler")

route: Route = app.context["_route"]
leandrodamascena marked this conversation as resolved.
Show resolved Hide resolved

values: Dict[str, Any] = {}
errors: List[Any] = []

try:
# Process path values, which can be found on the route_args
path_values, path_errors = _request_params_to_args(
route.dependant.path_params,
app.context["_route_args"],
)

# Process query values
query_values, query_errors = _request_params_to_args(
route.dependant.query_params,
app.current_event.query_string_parameters or {},
)

values.update(path_values)
values.update(query_values)
errors += path_errors + query_errors

# Process the request body, if it exists
if route.dependant.body_params:
(body_values, body_errors) = _request_body_to_args(
required_params=route.dependant.body_params,
received_body=self._get_body(app),
)
values.update(body_values)
errors.extend(body_errors)

if errors:
# Raise the validation errors
raise RequestValidationError(_normalize_errors(errors))
else:
# Re-write the route_args with the validated values, and call the next middleware
app.context["_route_args"] = values
leandrodamascena marked this conversation as resolved.
Show resolved Hide resolved
response = next_middleware(app)

# Process the response body if it exists
raw_response = jsonable_encoder(response.body)

# Validate and serialize the response
return self._serialize_response(field=route.dependant.return_param, response_content=raw_response)
except RequestValidationError as e:
return Response(
status_code=422,
content_type="application/json",
body=json.dumps({"detail": e.errors()}),
)

def _serialize_response(
self,
*,
field: Optional[ModelField] = None,
response_content: Any,
include: Optional[IncEx] = None,
exclude: Optional[IncEx] = None,
by_alias: bool = True,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> Any:
"""
Serialize the response content according to the field type.
"""
if field:
errors: List[Dict[str, Any]] = []
# MAINTENANCE: remove this when we drop pydantic v1
if not hasattr(field, "serializable"):
response_content = self._prepare_response_content(
response_content,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)

value = _validate_field(field=field, value=response_content, loc=("response",), existing_errors=errors)
if errors:
raise RequestValidationError(errors=_normalize_errors(errors), body=response_content)

if hasattr(field, "serialize"):
return field.serialize(
value,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)

return jsonable_encoder(
value,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
leandrodamascena marked this conversation as resolved.
Show resolved Hide resolved
else:
# Just serialize the response content returned from the handler
return jsonable_encoder(response_content)

def _prepare_response_content(
self,
res: Any,
*,
exclude_unset: bool,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> Any:
"""
Prepares the response content for serialization.
"""
if isinstance(res, BaseModel):
return _model_dump(
res,
by_alias=True,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
elif isinstance(res, list):
return [
self._prepare_response_content(item, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults)
for item in res
]
elif isinstance(res, dict):
return {
k: self._prepare_response_content(v, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults)
for k, v in res.items()
}
elif dataclasses.is_dataclass(res):
return dataclasses.asdict(res)
return res
leandrodamascena marked this conversation as resolved.
Show resolved Hide resolved

def _get_body(self, app: EventHandlerInstance) -> Dict[str, Any]:
"""
Get the request body from the event, and parse it as JSON.
"""

content_type_value = app.current_event.get_header_value("content-type")
if not content_type_value or content_type_value.startswith("application/json"):
try:
return app.current_event.json_body
except json.JSONDecodeError as e:
raise RequestValidationError(
[
{
"type": "json_invalid",
"loc": ("body", e.pos),
"msg": "JSON decode error",
"input": {},
"ctx": {"error": e.msg},
},
],
body=e.doc,
) from e
else:
raise NotImplementedError("Only JSON body is supported")


def _request_params_to_args(
required_params: Sequence[ModelField],
received_params: Mapping[str, Any],
) -> Tuple[Dict[str, Any], List[Any]]:
"""
Convert the request params to a dictionary of values using validation, and returns a list of errors.
"""
values = {}
errors = []

for field in required_params:
value = received_params.get(field.alias)

field_info = field.field_info
if not isinstance(field_info, Param):
raise AssertionError(f"Expected Param field_info, got {field_info}")

loc = (field_info.in_.value, field.alias)

# If we don't have a value, see if it's required or has a default
if value is None:
if field.required:
errors.append(get_missing_field_error(loc=loc))
else:
values[field.name] = deepcopy(field.default)
continue

# Finally, validate the value
values[field.name] = _validate_field(field=field, value=value, loc=loc, existing_errors=errors)

return values, errors


def _request_body_to_args(
required_params: List[ModelField],
received_body: Optional[Dict[str, Any]],
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
"""
Convert the request body to a dictionary of values using validation, and returns a list of errors.
"""
values: Dict[str, Any] = {}
errors: List[Dict[str, Any]] = []

received_body, field_alias_omitted = _get_embed_body(
field=required_params[0],
required_params=required_params,
received_body=received_body,
)

for field in required_params:
# This sets the location to:
# { "user": { object } } if field.alias == user
# { { object } if field_alias is omitted
loc: Tuple[str, ...] = ("body", field.alias)
if field_alias_omitted:
loc = ("body",)

value: Optional[Any] = None

# Now that we know what to look for, try to get the value from the received body
if received_body is not None:
try:
value = received_body.get(field.alias)
except AttributeError:
errors.append(get_missing_field_error(loc))
continue

# Determine if the field is required
if value is None:
if field.required:
errors.append(get_missing_field_error(loc))
else:
values[field.name] = deepcopy(field.default)
continue

# MAINTENANCE: Handle byte and file fields

# Finally, validate the value
values[field.name] = _validate_field(field=field, value=value, loc=loc, existing_errors=errors)

return values, errors


def _validate_field(
*,
field: ModelField,
value: Any,
loc: Tuple[str, ...],
existing_errors: List[Dict[str, Any]],
):
"""
Validate a field, and append any errors to the existing_errors list.
"""
validated_value, errors = field.validate(value, value, loc=loc)

if isinstance(errors, list):
processed_errors = _regenerate_error_with_loc(errors=errors, loc_prefix=())
existing_errors.extend(processed_errors)
elif errors:
existing_errors.append(errors)

return validated_value


def _get_embed_body(
*,
field: ModelField,
required_params: List[ModelField],
received_body: Optional[Dict[str, Any]],
) -> Tuple[Optional[Dict[str, Any]], bool]:
field_info = field.field_info
embed = getattr(field_info, "embed", None)

# If the field is an embed, and the field alias is omitted, we need to wrap the received body in the field alias.
field_alias_omitted = len(required_params) == 1 and not embed
if field_alias_omitted:
received_body = {field.alias: received_body}

return received_body, field_alias_omitted
Empty file.
Loading
Loading