Skip to content

Commit

Permalink
fix(event_handler): allow fine grained Response with data validation (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
rubenfonseca authored Nov 22, 2023
1 parent 9ce3ed1 commit 270060d
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 4 deletions.
5 changes: 3 additions & 2 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
_ROUTE_REGEX = "^{}$"

ResponseEventT = TypeVar("ResponseEventT", bound=BaseProxyEvent)
ResponseT = TypeVar("ResponseT")

if TYPE_CHECKING:
from aws_lambda_powertools.event_handler.openapi.compat import (
Expand Down Expand Up @@ -207,14 +208,14 @@ def to_dict(self, origin: Optional[str]) -> Dict[str, str]:
return headers


class Response:
class Response(Generic[ResponseT]):
"""Response data class that provides greater control over what is returned from the proxy event"""

def __init__(
self,
status_code: int,
content_type: Optional[str] = None,
body: Any = None,
body: Optional[ResponseT] = None,
headers: Optional[Dict[str, Union[str, List[str]]]] = None,
cookies: Optional[List[Cookie]] = None,
compress: Optional[bool] = None,
Expand Down
12 changes: 12 additions & 0 deletions aws_lambda_powertools/event_handler/openapi/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pydantic import BaseConfig
from pydantic.fields import FieldInfo

from aws_lambda_powertools.event_handler import Response
from aws_lambda_powertools.event_handler.openapi.compat import (
ModelField,
Required,
Expand Down Expand Up @@ -724,13 +725,24 @@ def get_field_info_and_type_annotation(annotation, value, is_path_param: bool) -
# If the annotation is an Annotated type, we need to extract the type annotation and the FieldInfo
if get_origin(annotation) is Annotated:
field_info, type_annotation = get_field_info_annotated_type(annotation, value, is_path_param)
# If the annotation is a Response type, we recursively call this function with the inner type
elif get_origin(annotation) is Response:
field_info, type_annotation = get_field_info_response_type(annotation, value)
# If the annotation is not an Annotated type, we use it as the type annotation
else:
type_annotation = annotation

return field_info, type_annotation


def get_field_info_response_type(annotation, value) -> Tuple[Optional[FieldInfo], Any]:
# Example: get_args(Response[inner_type]) == (inner_type,) # noqa: ERA001
(inner_type,) = get_args(annotation)

# Recursively resolve the inner type
return get_field_info_and_type_annotation(inner_type, value, False)


def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> Tuple[Optional[FieldInfo], Any]:
"""
Get the FieldInfo and type annotation from an Annotated type.
Expand Down
20 changes: 19 additions & 1 deletion tests/functional/event_handler/test_openapi_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pydantic import BaseModel

from aws_lambda_powertools.event_handler.api_gateway import APIGatewayRestResolver
from aws_lambda_powertools.event_handler.api_gateway import APIGatewayRestResolver, Response
from aws_lambda_powertools.event_handler.openapi.models import (
Example,
Parameter,
Expand Down Expand Up @@ -153,6 +153,24 @@ def handler() -> str:
assert response.schema_.type == "string"


def test_openapi_with_response_returns():
app = APIGatewayRestResolver()

@app.get("/")
def handler() -> Response[Annotated[str, Body(title="Response title")]]:
return Response(body="Hello, world", status_code=200)

schema = app.get_openapi_schema()
assert len(schema.paths.keys()) == 1

get = schema.paths["/"].get
assert get.parameters is None

response = get.responses[200].content[JSON_CONTENT_TYPE]
assert response.schema_.title == "Response title"
assert response.schema_.type == "string"


def test_openapi_with_omitted_param():
app = APIGatewayRestResolver()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from pydantic import BaseModel

from aws_lambda_powertools.event_handler import APIGatewayRestResolver
from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response
from aws_lambda_powertools.event_handler.openapi.params import Body
from aws_lambda_powertools.shared.types import Annotated
from tests.functional.utils import load_event
Expand Down Expand Up @@ -330,3 +330,51 @@ def handler(user: Annotated[Model, Body(embed=True)]) -> Model:
LOAD_GW_EVENT["body"] = json.dumps({"user": {"name": "John", "age": 30}})
result = app(LOAD_GW_EVENT, {})
assert result["statusCode"] == 200


def test_validate_response_return():
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

class Model(BaseModel):
name: str
age: int

# WHEN a handler is defined with a body parameter
@app.post("/")
def handler(user: Model) -> Response[Model]:
return Response(body=user, status_code=200)

LOAD_GW_EVENT["httpMethod"] = "POST"
LOAD_GW_EVENT["path"] = "/"
LOAD_GW_EVENT["body"] = json.dumps({"name": "John", "age": 30})

# THEN the handler should be invoked and return 200
# THEN the body must be a dict
result = app(LOAD_GW_EVENT, {})
assert result["statusCode"] == 200
assert result["body"] == {"name": "John", "age": 30}


def test_validate_response_invalid_return():
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

class Model(BaseModel):
name: str
age: int

# WHEN a handler is defined with a body parameter
@app.post("/")
def handler(user: Model) -> Response[Model]:
return Response(body=user, status_code=200)

LOAD_GW_EVENT["httpMethod"] = "POST"
LOAD_GW_EVENT["path"] = "/"
LOAD_GW_EVENT["body"] = json.dumps({})

# THEN the handler should be invoked and return 422
# THEN the body should have the word missing
result = app(LOAD_GW_EVENT, {})
assert result["statusCode"] == 422
assert "missing" in result["body"]

0 comments on commit 270060d

Please sign in to comment.