diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py
index fd7507603de..54c48189282 100644
--- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py
+++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py
@@ -81,9 +81,22 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
query_string,
)
+ # Normalize header values before validate this
+ headers = _normalize_multi_header_values_with_param(
+ app.current_event.resolved_headers_field,
+ route.dependant.header_params,
+ )
+
+ # Process header values
+ header_values, header_errors = _request_params_to_args(
+ route.dependant.header_params,
+ headers,
+ )
+
values.update(path_values)
values.update(query_values)
- errors += path_errors + query_errors
+ values.update(header_values)
+ errors += path_errors + query_errors + header_errors
# Process the request body, if it exists
if route.dependant.body_params:
@@ -243,12 +256,14 @@ def _request_params_to_args(
errors = []
for field in required_params:
- value = received_params.get(field.alias)
-
field_info = field.field_info
+
+ # To ensure early failure, we check if it's not an instance of Param.
if not isinstance(field_info, Param):
raise AssertionError(f"Expected Param field_info, got {field_info}")
+ value = received_params.get(field.alias)
+
loc = (field_info.in_.value, field.alias)
# If we don't have a value, see if it's required or has a default
@@ -377,3 +392,30 @@ def _normalize_multi_query_string_with_param(query_string: Optional[Dict[str, st
except KeyError:
pass
return query_string
+
+
+def _normalize_multi_header_values_with_param(headers: Optional[Dict[str, str]], params: Sequence[ModelField]):
+ """
+ Extract and normalize resolved_headers_field
+
+ Parameters
+ ----------
+ headers: Dict
+ A dictionary containing the initial header parameters.
+ params: Sequence[ModelField]
+ A sequence of ModelField objects representing parameters.
+
+ Returns
+ -------
+ A dictionary containing the processed headers.
+ """
+ if headers:
+ for param in filter(is_scalar_field, params):
+ try:
+ if len(headers[param.alias]) == 1:
+ # if the target parameter is a scalar and the list contains only 1 element
+ # we keep the first value of the headers regardless if there are more in the payload
+ headers[param.alias] = headers[param.alias][0]
+ except KeyError:
+ pass
+ return headers
diff --git a/aws_lambda_powertools/event_handler/openapi/dependant.py b/aws_lambda_powertools/event_handler/openapi/dependant.py
index 418a86e083c..abcb91e90dd 100644
--- a/aws_lambda_powertools/event_handler/openapi/dependant.py
+++ b/aws_lambda_powertools/event_handler/openapi/dependant.py
@@ -14,12 +14,12 @@
from aws_lambda_powertools.event_handler.openapi.params import (
Body,
Dependant,
+ Header,
Param,
ParamTypes,
Query,
_File,
_Form,
- _Header,
analyze_param,
create_response_field,
get_flat_dependant,
@@ -59,16 +59,21 @@ def add_param_to_fields(
"""
field_info = cast(Param, field.field_info)
- if field_info.in_ == ParamTypes.path:
- dependant.path_params.append(field)
- elif field_info.in_ == ParamTypes.query:
- dependant.query_params.append(field)
- elif field_info.in_ == ParamTypes.header:
- dependant.header_params.append(field)
+
+ # Dictionary to map ParamTypes to their corresponding lists in dependant
+ param_type_map = {
+ ParamTypes.path: dependant.path_params,
+ ParamTypes.query: dependant.query_params,
+ ParamTypes.header: dependant.header_params,
+ ParamTypes.cookie: dependant.cookie_params,
+ }
+
+ # Check if field_info.in_ is a valid key in param_type_map and append the field to the corresponding list
+ # or raise an exception if it's not a valid key.
+ if field_info.in_ in param_type_map:
+ param_type_map[field_info.in_].append(field)
else:
- if field_info.in_ != ParamTypes.cookie:
- raise AssertionError(f"Unsupported param type: {field_info.in_}")
- dependant.cookie_params.append(field)
+ raise AssertionError(f"Unsupported param type: {field_info.in_}")
def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
@@ -265,7 +270,7 @@ def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
return False
elif is_scalar_field(field=param_field):
return False
- elif isinstance(param_field.field_info, (Query, _Header)) and is_scalar_sequence_field(param_field):
+ elif isinstance(param_field.field_info, (Query, Header)) and is_scalar_sequence_field(param_field):
return False
else:
if not isinstance(param_field.field_info, Body):
diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py
index 78426cbc7c9..d5665a48d30 100644
--- a/aws_lambda_powertools/event_handler/openapi/params.py
+++ b/aws_lambda_powertools/event_handler/openapi/params.py
@@ -486,7 +486,7 @@ def __init__(
)
-class _Header(Param):
+class Header(Param):
"""
A class used internally to represent a header parameter in a path operation.
"""
@@ -527,12 +527,75 @@ def __init__(
json_schema_extra: Union[Dict[str, Any], None] = None,
**extra: Any,
):
+ """
+ Constructs a new Query param.
+
+ Parameters
+ ----------
+ default: Any
+ The default value of the parameter
+ default_factory: Callable[[], Any], optional
+ Callable that will be called when a default value is needed for this field
+ annotation: Any, optional
+ The type annotation of the parameter
+ alias: str, optional
+ The public name of the field
+ alias_priority: int, optional
+ Priority of the alias. This affects whether an alias generator is used
+ validation_alias: str | AliasPath | AliasChoices | None, optional
+ Alias to be used for validation only
+ serialization_alias: str | AliasPath | AliasChoices | None, optional
+ Alias to be used for serialization only
+ convert_underscores: bool
+ If true convert "_" to "-"
+ See RFC: https://www.rfc-editor.org/rfc/rfc9110.html#name-field-name-registry
+ title: str, optional
+ The title of the parameter
+ description: str, optional
+ The description of the parameter
+ gt: float, optional
+ Only applies to numbers, required the field to be "greater than"
+ ge: float, optional
+ Only applies to numbers, required the field to be "greater than or equal"
+ lt: float, optional
+ Only applies to numbers, required the field to be "less than"
+ le: float, optional
+ Only applies to numbers, required the field to be "less than or equal"
+ min_length: int, optional
+ Only applies to strings, required the field to have a minimum length
+ max_length: int, optional
+ Only applies to strings, required the field to have a maximum length
+ pattern: str, optional
+ Only applies to strings, requires the field match against a regular expression pattern string
+ discriminator: str, optional
+ Parameter field name for discriminating the type in a tagged union
+ strict: bool, optional
+ Enables Pydantic's strict mode for the field
+ multiple_of: float, optional
+ Only applies to numbers, requires the field to be a multiple of the given value
+ allow_inf_nan: bool, optional
+ Only applies to numbers, requires the field to allow infinity and NaN values
+ max_digits: int, optional
+ Only applies to Decimals, requires the field to have a maxmium number of digits within the decimal.
+ decimal_places: int, optional
+ Only applies to Decimals, requires the field to have at most a number of decimal places
+ examples: List[Any], optional
+ A list of examples for the parameter
+ deprecated: bool, optional
+ If `True`, the parameter will be marked as deprecated
+ include_in_schema: bool, optional
+ If `False`, the parameter will be excluded from the generated OpenAPI schema
+ json_schema_extra: Dict[str, Any], optional
+ Extra values to include in the generated OpenAPI schema
+ """
self.convert_underscores = convert_underscores
+ self._alias = alias
+
super().__init__(
default=default,
default_factory=default_factory,
annotation=annotation,
- alias=alias,
+ alias=self._alias,
alias_priority=alias_priority,
validation_alias=validation_alias,
serialization_alias=serialization_alias,
@@ -558,6 +621,18 @@ def __init__(
**extra,
)
+ @property
+ def alias(self):
+ return self._alias
+
+ @alias.setter
+ def alias(self, value: Optional[str] = None):
+ if value is not None:
+ # Headers are case-insensitive according to RFC 7540 (HTTP/2), so we lower the parameter name
+ # This ensures that customers can access headers with any casing, as per the RFC guidelines.
+ # Reference: https://www.rfc-editor.org/rfc/rfc7540#section-8.1.2
+ self._alias = value.lower()
+
class Body(FieldInfo):
"""
diff --git a/aws_lambda_powertools/utilities/data_classes/alb_event.py b/aws_lambda_powertools/utilities/data_classes/alb_event.py
index 688c9567efa..98f37b4f415 100644
--- a/aws_lambda_powertools/utilities/data_classes/alb_event.py
+++ b/aws_lambda_powertools/utilities/data_classes/alb_event.py
@@ -42,6 +42,17 @@ def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]:
return self.query_string_parameters
+ @property
+ def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
+ headers: Dict[str, Any] = {}
+
+ if self.multi_value_headers:
+ headers = self.multi_value_headers
+ else:
+ headers = self.headers
+
+ return {key.lower(): value for key, value in headers.items()}
+
@property
def multi_value_headers(self) -> Optional[Dict[str, List[str]]]:
return self.get("multiValueHeaders")
diff --git a/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py b/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py
index 9e013eac038..c37bd22ca53 100644
--- a/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py
+++ b/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py
@@ -125,6 +125,17 @@ def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]:
return self.query_string_parameters
+ @property
+ def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
+ headers: Dict[str, Any] = {}
+
+ if self.multi_value_headers:
+ headers = self.multi_value_headers
+ else:
+ headers = self.headers
+
+ return {key.lower(): value for key, value in headers.items()}
+
@property
def request_context(self) -> APIGatewayEventRequestContext:
return APIGatewayEventRequestContext(self._data)
@@ -316,3 +327,11 @@ def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]:
return query_string
return {}
+
+ @property
+ def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
+ if self.headers is not None:
+ headers = {key.lower(): value.split(",") if "," in value else value for key, value in self.headers.items()}
+ return headers
+
+ return {}
diff --git a/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py b/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py
index d9b45242376..0fa97036a3e 100644
--- a/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py
+++ b/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py
@@ -1,4 +1,4 @@
-from typing import Dict, List, Optional
+from typing import Any, Dict, List, Optional
from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent, DictWrapper
@@ -112,3 +112,7 @@ def query_string_parameters(self) -> Optional[Dict[str, str]]:
@property
def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]:
return self.query_string_parameters
+
+ @property
+ def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
+ return {}
diff --git a/aws_lambda_powertools/utilities/data_classes/common.py b/aws_lambda_powertools/utilities/data_classes/common.py
index d2cf57d4af5..0560159ecc5 100644
--- a/aws_lambda_powertools/utilities/data_classes/common.py
+++ b/aws_lambda_powertools/utilities/data_classes/common.py
@@ -114,6 +114,21 @@ def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]:
"""
return self.query_string_parameters
+ @property
+ def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
+ """
+ This property determines the appropriate header to be used
+ as a trusted source for validating OpenAPI.
+
+ This is necessary because different resolvers use different formats to encode
+ headers parameters.
+
+ Headers are case-insensitive according to RFC 7540 (HTTP/2), so we lower the header name
+ This ensures that customers can access headers with any casing, as per the RFC guidelines.
+ Reference: https://www.rfc-editor.org/rfc/rfc7540#section-8.1.2
+ """
+ return self.headers
+
@property
def is_base64_encoded(self) -> Optional[bool]:
return self.get("isBase64Encoded")
diff --git a/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py b/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py
index 633ce068f6e..f12c53d841a 100644
--- a/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py
+++ b/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py
@@ -145,6 +145,14 @@ def query_string_parameters(self) -> Dict[str, str]:
def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]:
return self.query_string_parameters
+ @property
+ def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
+ if self.headers is not None:
+ headers = {key.lower(): value.split(",") if "," in value else value for key, value in self.headers.items()}
+ return headers
+
+ return {}
+
class vpcLatticeEventV2Identity(DictWrapper):
@property
@@ -259,3 +267,10 @@ def query_string_parameters(self) -> Optional[Dict[str, str]]:
@property
def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]:
return self.query_string_parameters
+
+ @property
+ def resolved_headers_field(self) -> Optional[Dict[str, str]]:
+ if self.headers is not None:
+ return {key.lower(): value for key, value in self.headers.items()}
+
+ return {}
diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md
index 86b97c87e4b..32631ac867e 100644
--- a/docs/core/event_handler/api_gateway.md
+++ b/docs/core/event_handler/api_gateway.md
@@ -368,13 +368,13 @@ We use the `Annotated` and OpenAPI `Body` type to instruct Event Handler that ou
!!! info "We will automatically validate and inject incoming query strings via type annotation."
-We use the `Annotated` type to tell Event Handler that a particular parameter is not only an optional string, but also a query string with constraints.
+We use the `Annotated` type to tell the Event Handler that a particular parameter is not only an optional string, but also a query string with constraints.
In the following example, we use a new `Query` OpenAPI type to add [one out of many possible constraints](#customizing-openapi-parameters), which should read as:
* `completed` is a query string with a `None` as its default value
* `completed`, when set, should have at minimum 4 characters
-* Doesn't match? Event Handler will return a validation error response
+* No match? Event Handler will return a validation error response
@@ -386,7 +386,7 @@ In the following example, we use a new `Query` OpenAPI type to add [one out of m
1. If you're not using Python 3.9 or higher, you can install and use [`typing_extensions`](https://pypi.org/project/typing-extensions/){target="_blank" rel="nofollow"} to the same effect
2. `Query` is a special OpenAPI type that can add constraints to a query string as well as document them
- 3. **First time seeing the `Annotated`?**
This special type uses the first argument as the actual type, and subsequent arguments are metadata.
At runtime, static checkers will also see the first argument, but anyone receiving them could inspect them to fetch their metadata.
+ 3. **First time seeing `Annotated`?**
This special type uses the first argument as the actual type, and subsequent arguments as metadata.
At runtime, static checkers will also see the first argument, but any receiver can inspect it to get the metadata.
=== "skip_validating_query_strings.py"
@@ -424,6 +424,40 @@ For example, we could validate that `` dynamic path should be no greate
1. `Path` is a special OpenAPI type that allows us to constrain todo_id to be less than 999.
+#### Validating headers
+
+We use the `Annotated` type to tell the Event Handler that a particular parameter is a header that needs to be validated.
+
+!!! info "We adhere to [HTTP RFC standards](https://www.rfc-editor.org/rfc/rfc7540#section-8.1.2){target="_blank" rel="nofollow"}, which means we treat HTTP headers as case-insensitive."
+
+In the following example, we use a new `Header` OpenAPI type to add [one out of many possible constraints](#customizing-openapi-parameters), which should read as:
+
+* `correlation_id` is a header that must be present in the request
+* `correlation_id` should have 16 characters
+* No match? Event Handler will return a validation error response
+
+
+
+=== "validating_headers.py"
+
+ ```python hl_lines="8 10 27"
+ --8<-- "examples/event_handler_rest/src/validating_headers.py"
+ ```
+
+ 1. If you're not using Python 3.9 or higher, you can install and use [`typing_extensions`](https://pypi.org/project/typing-extensions/){target="_blank" rel="nofollow"} to the same effect
+ 2. `Header` is a special OpenAPI type that can add constraints and documentation to a header
+ 3. **First time seeing `Annotated`?**
This special type uses the first argument as the actual type, and subsequent arguments as metadata.
At runtime, static checkers will also see the first argument, but any receiver can inspect it to get the metadata.
+
+=== "working_with_headers_multi_value.py"
+
+ You can handle multi-value headers by declaring it as a list of the desired type.
+
+ ```python hl_lines="23"
+ --8<-- "examples/event_handler_rest/src/working_with_headers_multi_value.py"
+ ```
+
+ 1. `cloudfront_viewer_country` is a list that must contain values from the `CountriesAllowed` enumeration.
+
### Accessing request details
Event Handler integrates with [Event Source Data Classes utilities](../../utilities/data_classes.md){target="_blank"}, and it exposes their respective resolver request details and convenient methods under `app.current_event`.
diff --git a/examples/event_handler_rest/src/validating_headers.py b/examples/event_handler_rest/src/validating_headers.py
new file mode 100644
index 00000000000..e830a49c38c
--- /dev/null
+++ b/examples/event_handler_rest/src/validating_headers.py
@@ -0,0 +1,39 @@
+from typing import List, Optional
+
+import requests
+from pydantic import BaseModel, Field
+
+from aws_lambda_powertools import Logger, Tracer
+from aws_lambda_powertools.event_handler import APIGatewayRestResolver
+from aws_lambda_powertools.event_handler.openapi.params import Header # (2)!
+from aws_lambda_powertools.logging import correlation_paths
+from aws_lambda_powertools.shared.types import Annotated # (1)!
+from aws_lambda_powertools.utilities.typing import LambdaContext
+
+tracer = Tracer()
+logger = Logger()
+app = APIGatewayRestResolver(enable_validation=True)
+
+
+class Todo(BaseModel):
+ userId: int
+ id_: Optional[int] = Field(alias="id", default=None)
+ title: str
+ completed: bool
+
+
+@app.get("/todos")
+@tracer.capture_method
+def get_todos(correlation_id: Annotated[str, Header(min_length=16, max_length=16)]) -> List[Todo]: # (3)!
+ url = "https://jsonplaceholder.typicode.com/todos"
+
+ todo = requests.get(url, headers={"correlation_id": correlation_id})
+ todo.raise_for_status()
+
+ return todo.json()
+
+
+@logger.inject_lambda_context(correlation_id_path=correlation_paths.API_GATEWAY_HTTP)
+@tracer.capture_lambda_handler
+def lambda_handler(event: dict, context: LambdaContext) -> dict:
+ return app.resolve(event, context)
diff --git a/examples/event_handler_rest/src/working_with_headers_multi_value.py b/examples/event_handler_rest/src/working_with_headers_multi_value.py
new file mode 100644
index 00000000000..956fd58b14d
--- /dev/null
+++ b/examples/event_handler_rest/src/working_with_headers_multi_value.py
@@ -0,0 +1,34 @@
+from enum import Enum
+from typing import List
+
+from aws_lambda_powertools.event_handler import APIGatewayRestResolver
+from aws_lambda_powertools.event_handler.openapi.params import Header
+from aws_lambda_powertools.shared.types import Annotated
+from aws_lambda_powertools.utilities.typing import LambdaContext
+
+app = APIGatewayRestResolver(enable_validation=True)
+
+
+class CountriesAllowed(Enum):
+ """Example of an Enum class."""
+
+ US = "US"
+ PT = "PT"
+ BR = "BR"
+
+
+@app.get("/hello")
+def get(
+ cloudfront_viewer_country: Annotated[
+ List[CountriesAllowed], # (1)!
+ Header(
+ description="This is multi value header parameter.",
+ ),
+ ],
+):
+ """Return validated multi-value header values."""
+ return cloudfront_viewer_country
+
+
+def lambda_handler(event: dict, context: LambdaContext) -> dict:
+ return app.resolve(event, context)
diff --git a/tests/events/albMultiValueQueryStringEvent.json b/tests/events/albMultiValueQueryStringEvent.json
index 4584ba7c477..d5cdf18f023 100644
--- a/tests/events/albMultiValueQueryStringEvent.json
+++ b/tests/events/albMultiValueQueryStringEvent.json
@@ -14,6 +14,13 @@
"accept": [
"*/*"
],
+ "header2": [
+ "value1",
+ "value2"
+ ],
+ "header1": [
+ "value1"
+ ],
"host": [
"alb-c-LoadB-14POFKYCLBNSF-1815800096.eu-central-1.elb.amazonaws.com"
],
diff --git a/tests/events/apiGatewayProxyEvent.json b/tests/events/apiGatewayProxyEvent.json
index 3f095e28e45..da814c91100 100644
--- a/tests/events/apiGatewayProxyEvent.json
+++ b/tests/events/apiGatewayProxyEvent.json
@@ -78,4 +78,4 @@
"stageVariables": null,
"body": "Hello from Lambda!",
"isBase64Encoded": false
-}
\ No newline at end of file
+}
diff --git a/tests/events/lambdaFunctionUrlEventWithHeaders.json b/tests/events/lambdaFunctionUrlEventWithHeaders.json
index e453690d9b3..d1cc50630a8 100644
--- a/tests/events/lambdaFunctionUrlEventWithHeaders.json
+++ b/tests/events/lambdaFunctionUrlEventWithHeaders.json
@@ -23,7 +23,9 @@
"cache-control":"max-age=0",
"accept-encoding":"gzip, deflate, br",
"sec-fetch-dest":"document",
- "user-agent":"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/102.0.0.0 Safari/537.36"
+ "user-agent":"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/102.0.0.0 Safari/537.36",
+ "header1": "value1",
+ "header2": "value1,value2"
},
"queryStringParameters": {
"parameter1": "value1,value2",
diff --git a/tests/events/vpcLatticeEvent.json b/tests/events/vpcLatticeEvent.json
index 936bfb22d1b..fa9031f7dc4 100644
--- a/tests/events/vpcLatticeEvent.json
+++ b/tests/events/vpcLatticeEvent.json
@@ -5,7 +5,9 @@
"user_agent": "curl/7.64.1",
"x-forwarded-for": "10.213.229.10",
"host": "test-lambda-service-3908sdf9u3u.dkfjd93.vpc-lattice-svcs.us-east-2.on.aws",
- "accept": "*/*"
+ "accept": "*/*",
+ "header1": "value1",
+ "header2": "value1,value2"
},
"query_string_parameters": {
"order-id": "1"
diff --git a/tests/events/vpcLatticeV2EventWithHeaders.json b/tests/events/vpcLatticeV2EventWithHeaders.json
index 11b36ef118b..fdaf7dc7891 100644
--- a/tests/events/vpcLatticeV2EventWithHeaders.json
+++ b/tests/events/vpcLatticeV2EventWithHeaders.json
@@ -2,12 +2,31 @@
"version": "2.0",
"path": "/newpath",
"method": "GET",
- "headers": {
- "user_agent": "curl/7.64.1",
- "x-forwarded-for": "10.213.229.10",
- "host": "test-lambda-service-3908sdf9u3u.dkfjd93.vpc-lattice-svcs.us-east-2.on.aws",
- "accept": "*/*"
- },
+ "headers":{
+ "user-agent":[
+ "curl/8.3.0"
+ ],
+ "accept":[
+ "*/*"
+ ],
+ "powertools":[
+ "a",
+ "b"
+ ],
+ "x-forwarded-for":[
+ "172.31.40.143"
+ ],
+ "host":[
+ "lattice-svc-027b423199122da5f.7d67968.vpc-lattice-svcs.us-east-1.on.aws"
+ ],
+ "Header1": [
+ "value1"
+ ],
+ "Header2": [
+ "value1",
+ "value2"
+ ]
+ },
"queryStringParameters": {
"parameter1": [
"value1",
diff --git a/tests/functional/event_handler/test_openapi_params.py b/tests/functional/event_handler/test_openapi_params.py
index 2f48f5aa534..38b0cbed307 100644
--- a/tests/functional/event_handler/test_openapi_params.py
+++ b/tests/functional/event_handler/test_openapi_params.py
@@ -13,11 +13,11 @@
)
from aws_lambda_powertools.event_handler.openapi.params import (
Body,
+ Header,
Param,
ParamTypes,
Query,
_create_model_field,
- _Header,
)
from aws_lambda_powertools.shared.types import Annotated
@@ -431,7 +431,7 @@ def handler():
def test_create_header():
- header = _Header(convert_underscores=True)
+ header = Header(convert_underscores=True)
assert header.convert_underscores is True
@@ -456,7 +456,7 @@ def test_create_model_field_with_empty_in():
# Tests that when we try to create a model field with convert_underscore, we convert the field name
def test_create_model_field_convert_underscore():
- field_info = _Header(alias=None, convert_underscores=True)
+ field_info = Header(alias=None, convert_underscores=True)
result = _create_model_field(field_info, int, "user_id", False)
assert result.alias == "user-id"
diff --git a/tests/functional/event_handler/test_openapi_validation_middleware.py b/tests/functional/event_handler/test_openapi_validation_middleware.py
index 23fa131ab9f..07e2a34ac42 100644
--- a/tests/functional/event_handler/test_openapi_validation_middleware.py
+++ b/tests/functional/event_handler/test_openapi_validation_middleware.py
@@ -4,6 +4,7 @@
from pathlib import PurePath
from typing import List, Tuple
+import pytest
from pydantic import BaseModel
from aws_lambda_powertools.event_handler import (
@@ -12,9 +13,10 @@
APIGatewayRestResolver,
LambdaFunctionUrlResolver,
Response,
+ VPCLatticeResolver,
VPCLatticeV2Resolver,
)
-from aws_lambda_powertools.event_handler.openapi.params import Body, Query
+from aws_lambda_powertools.event_handler.openapi.params import Body, Header, Query
from aws_lambda_powertools.shared.types import Annotated
from tests.functional.utils import load_event
@@ -23,6 +25,7 @@
LOAD_GW_EVENT_ALB = load_event("albMultiValueQueryStringEvent.json")
LOAD_GW_EVENT_LAMBDA_URL = load_event("lambdaFunctionUrlEventWithHeaders.json")
LOAD_GW_EVENT_VPC_LATTICE = load_event("vpcLatticeV2EventWithHeaders.json")
+LOAD_GW_EVENT_VPC_LATTICE_V1 = load_event("vpcLatticeEvent.json")
def test_validate_scalars():
@@ -417,267 +420,601 @@ def handler(user: Model) -> Response[Model]:
assert "missing" in result["body"]
-def test_validate_rest_api_resolver_with_multi_query_params():
- # GIVEN an APIGatewayRestResolver with validation enabled
+########### TEST WITH QUERY PARAMS
+@pytest.mark.parametrize(
+ "handler_func, expected_status_code, expected_error_text",
+ [
+ ("handler1_with_correct_params", 200, None),
+ ("handler2_with_wrong_params", 422, "['type_error.integer', 'int_parsing']"),
+ ("handler3_without_query_params", 200, None),
+ ],
+)
+def test_validation_query_string_with_api_rest_resolver(handler_func, expected_status_code, expected_error_text):
+ # GIVEN a APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)
- # WHEN a handler is defined with a default scalar parameter and a list
- @app.get("/users")
- def handler(parameter1: Annotated[List[str], Query()], parameter2: str):
- print(parameter2)
-
LOAD_GW_EVENT["httpMethod"] = "GET"
LOAD_GW_EVENT["path"] = "/users"
+ # WHEN a handler is defined with various parameters and routes
- # THEN the handler should be invoked and return 200
- result = app(LOAD_GW_EVENT, {})
- assert result["statusCode"] == 200
+ # Define handler1 with correct params
+ if handler_func == "handler1_with_correct_params":
+ @app.get("/users")
+ def handler1(parameter1: Annotated[List[str], Query()], parameter2: str):
+ print(parameter2)
-def test_validate_rest_api_resolver_with_multi_query_params_fail():
- # GIVEN an APIGatewayRestResolver with validation enabled
- app = APIGatewayRestResolver(enable_validation=True)
+ # Define handler2 with wrong params
+ if handler_func == "handler2_with_wrong_params":
- # WHEN a handler is defined with a default scalar parameter and a list with wrong type
- @app.get("/users")
- def handler(parameter1: Annotated[List[int], Query()], parameter2: str):
- print(parameter2)
+ @app.get("/users")
+ def handler2(parameter1: Annotated[List[int], Query()], parameter2: str):
+ print(parameter2)
- LOAD_GW_EVENT["httpMethod"] = "GET"
- LOAD_GW_EVENT["path"] = "/users"
+ # Define handler3 without params
+ if handler_func == "handler3_without_query_params":
+ LOAD_GW_EVENT["queryStringParameters"] = None
+ LOAD_GW_EVENT["multiValueQueryStringParameters"] = None
- # THEN the handler should be invoked and return 422
+ @app.get("/users")
+ def handler3():
+ return None
+
+ # THEN the handler should be invoked with the expected result
+ # AND the status code should match the expected_status_code
result = app(LOAD_GW_EVENT, {})
- assert result["statusCode"] == 422
- assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"])
+ assert result["statusCode"] == expected_status_code
+ # IF expected_error_text is provided, THEN check for its presence in the response body
+ if expected_error_text:
+ assert any(text in result["body"] for text in expected_error_text)
-def test_validate_rest_api_resolver_without_query_params():
- # GIVEN an APIGatewayRestResolver with validation enabled
- app = APIGatewayRestResolver(enable_validation=True)
- # WHEN a handler is defined with a default scalar parameter and a list with wrong type
- @app.get("/users")
- def handler():
- return None
+@pytest.mark.parametrize(
+ "handler_func, expected_status_code, expected_error_text",
+ [
+ ("handler1_with_correct_params", 200, None),
+ ("handler2_with_wrong_params", 422, "['type_error.integer', 'int_parsing']"),
+ ("handler3_without_query_params", 200, None),
+ ],
+)
+def test_validation_query_string_with_api_http_resolver(handler_func, expected_status_code, expected_error_text):
+ # GIVEN a APIGatewayHttpResolver with validation enabled
+ app = APIGatewayHttpResolver(enable_validation=True)
- LOAD_GW_EVENT["httpMethod"] = "GET"
- LOAD_GW_EVENT["path"] = "/users"
- LOAD_GW_EVENT["queryStringParameters"] = None
- LOAD_GW_EVENT["multiValueQueryStringParameters"] = None
+ LOAD_GW_EVENT_HTTP["rawPath"] = "/users"
+ LOAD_GW_EVENT_HTTP["requestContext"]["http"]["method"] = "GET"
+ LOAD_GW_EVENT_HTTP["requestContext"]["http"]["path"] = "/users"
+ # WHEN a handler is defined with various parameters and routes
- # THEN the handler should be invoked and return 422
- result = app(LOAD_GW_EVENT, {})
- assert result["statusCode"] == 200
+ # Define handler1 with correct params
+ if handler_func == "handler1_with_correct_params":
+ @app.get("/users")
+ def handler1(parameter1: Annotated[List[str], Query()], parameter2: str):
+ print(parameter2)
-def test_validate_http_resolver_with_multi_query_params():
- # GIVEN an APIGatewayHttpResolver with validation enabled
- app = APIGatewayHttpResolver(enable_validation=True)
+ # Define handler2 with wrong params
+ if handler_func == "handler2_with_wrong_params":
- # WHEN a handler is defined with a default scalar parameter and a list
- @app.get("/users")
- def handler(parameter1: Annotated[List[str], Query()], parameter2: str):
- print(parameter2)
+ @app.get("/users")
+ def handler2(parameter1: Annotated[List[int], Query()], parameter2: str):
+ print(parameter2)
- LOAD_GW_EVENT_HTTP["rawPath"] = "/users"
- LOAD_GW_EVENT_HTTP["requestContext"]["http"]["method"] = "GET"
- LOAD_GW_EVENT_HTTP["requestContext"]["http"]["path"] = "/users"
+ # Define handler3 without params
+ if handler_func == "handler3_without_query_params":
+ LOAD_GW_EVENT_HTTP["queryStringParameters"] = None
- # THEN the handler should be invoked and return 200
+ @app.get("/users")
+ def handler3():
+ return None
+
+ # THEN the handler should be invoked with the expected result
+ # AND the status code should match the expected_status_code
result = app(LOAD_GW_EVENT_HTTP, {})
- assert result["statusCode"] == 200
+ assert result["statusCode"] == expected_status_code
+ # IF expected_error_text is provided, THEN check for its presence in the response body
+ if expected_error_text:
+ assert any(text in result["body"] for text in expected_error_text)
-def test_validate_http_resolver_with_multi_query_values_fail():
- # GIVEN an APIGatewayHttpResolver with validation enabled
- app = APIGatewayHttpResolver(enable_validation=True)
- # WHEN a handler is defined with a default scalar parameter and a list with wrong type
- @app.get("/users")
- def handler(parameter1: Annotated[List[int], Query()], parameter2: str):
- print(parameter2)
+@pytest.mark.parametrize(
+ "handler_func, expected_status_code, expected_error_text",
+ [
+ ("handler1_with_correct_params", 200, None),
+ ("handler2_with_wrong_params", 422, "['type_error.integer', 'int_parsing']"),
+ ("handler3_without_query_params", 200, None),
+ ],
+)
+def test_validation_query_string_with_alb_resolver(handler_func, expected_status_code, expected_error_text):
+ # GIVEN a ALBResolver with validation enabled
+ app = ALBResolver(enable_validation=True)
- LOAD_GW_EVENT_HTTP["rawPath"] = "/users"
- LOAD_GW_EVENT_HTTP["requestContext"]["http"]["method"] = "GET"
- LOAD_GW_EVENT_HTTP["requestContext"]["http"]["path"] = "/users"
+ LOAD_GW_EVENT_ALB["path"] = "/users"
+ # WHEN a handler is defined with various parameters and routes
- # THEN the handler should be invoked and return 422
- result = app(LOAD_GW_EVENT_HTTP, {})
- assert result["statusCode"] == 422
- assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"])
+ # Define handler1 with correct params
+ if handler_func == "handler1_with_correct_params":
+ @app.get("/users")
+ def handler1(parameter1: Annotated[List[str], Query()], parameter2: str):
+ print(parameter2)
-def test_validate_http_resolver_without_query_params():
- # GIVEN an APIGatewayHttpResolver with validation enabled
- app = APIGatewayHttpResolver(enable_validation=True)
+ # Define handler2 with wrong params
+ if handler_func == "handler2_with_wrong_params":
+
+ @app.get("/users")
+ def handler2(parameter1: Annotated[List[int], Query()], parameter2: str):
+ print(parameter2)
+
+ # Define handler3 without params
+ if handler_func == "handler3_without_query_params":
+ LOAD_GW_EVENT_HTTP["multiValueQueryStringParameters"] = None
+
+ @app.get("/users")
+ def handler3():
+ return None
+
+ # THEN the handler should be invoked with the expected result
+ # AND the status code should match the expected_status_code
+ result = app(LOAD_GW_EVENT_ALB, {})
+ assert result["statusCode"] == expected_status_code
+
+ # IF expected_error_text is provided, THEN check for its presence in the response body
+ if expected_error_text:
+ assert any(text in result["body"] for text in expected_error_text)
+
+
+@pytest.mark.parametrize(
+ "handler_func, expected_status_code, expected_error_text",
+ [
+ ("handler1_with_correct_params", 200, None),
+ ("handler2_with_wrong_params", 422, "['type_error.integer', 'int_parsing']"),
+ ("handler3_without_query_params", 200, None),
+ ],
+)
+def test_validation_query_string_with_lambda_url_resolver(handler_func, expected_status_code, expected_error_text):
+ # GIVEN a LambdaFunctionUrlResolver with validation enabled
+ app = LambdaFunctionUrlResolver(enable_validation=True)
+
+ LOAD_GW_EVENT_LAMBDA_URL["rawPath"] = "/users"
+ LOAD_GW_EVENT_LAMBDA_URL["requestContext"]["http"]["method"] = "GET"
+ LOAD_GW_EVENT_LAMBDA_URL["requestContext"]["http"]["path"] = "/users"
+ # WHEN a handler is defined with various parameters and routes
+
+ # Define handler1 with correct params
+ if handler_func == "handler1_with_correct_params":
+
+ @app.get("/users")
+ def handler1(parameter1: Annotated[List[str], Query()], parameter2: str):
+ print(parameter2)
+
+ # Define handler2 with wrong params
+ if handler_func == "handler2_with_wrong_params":
+
+ @app.get("/users")
+ def handler2(parameter1: Annotated[List[int], Query()], parameter2: str):
+ print(parameter2)
+
+ # Define handler3 without params
+ if handler_func == "handler3_without_query_params":
+ LOAD_GW_EVENT_LAMBDA_URL["queryStringParameters"] = None
+
+ @app.get("/users")
+ def handler3():
+ return None
+
+ # THEN the handler should be invoked with the expected result
+ # AND the status code should match the expected_status_code
+ result = app(LOAD_GW_EVENT_LAMBDA_URL, {})
+ assert result["statusCode"] == expected_status_code
+
+ # IF expected_error_text is provided, THEN check for its presence in the response body
+ if expected_error_text:
+ assert any(text in result["body"] for text in expected_error_text)
+
+
+@pytest.mark.parametrize(
+ "handler_func, expected_status_code, expected_error_text",
+ [
+ ("handler1_with_correct_params", 200, None),
+ ("handler2_with_wrong_params", 422, "['type_error.integer', 'int_parsing']"),
+ ("handler3_without_query_params", 200, None),
+ ],
+)
+def test_validation_query_string_with_vpc_lattice_resolver(handler_func, expected_status_code, expected_error_text):
+ # GIVEN a VPCLatticeV2Resolver with validation enabled
+ app = VPCLatticeV2Resolver(enable_validation=True)
+
+ LOAD_GW_EVENT_VPC_LATTICE["path"] = "/users"
+
+ # WHEN a handler is defined with various parameters and routes
+
+ # Define handler1 with correct params
+ if handler_func == "handler1_with_correct_params":
+
+ @app.get("/users")
+ def handler1(parameter1: Annotated[List[str], Query()], parameter2: str):
+ print(parameter2)
+
+ # Define handler2 with wrong params
+ if handler_func == "handler2_with_wrong_params":
+
+ @app.get("/users")
+ def handler2(parameter1: Annotated[List[int], Query()], parameter2: str):
+ print(parameter2)
+
+ # Define handler3 without params
+ if handler_func == "handler3_without_query_params":
+ LOAD_GW_EVENT_VPC_LATTICE["queryStringParameters"] = None
+
+ @app.get("/users")
+ def handler3():
+ return None
+
+ # THEN the handler should be invoked with the expected result
+ # AND the status code should match the expected_status_code
+ result = app(LOAD_GW_EVENT_VPC_LATTICE, {})
+ assert result["statusCode"] == expected_status_code
+
+ # IF expected_error_text is provided, THEN check for its presence in the response body
+ if expected_error_text:
+ assert any(text in result["body"] for text in expected_error_text)
+
+
+########### TEST WITH HEADER PARAMS
+@pytest.mark.parametrize(
+ "handler_func, expected_status_code, expected_error_text",
+ [
+ ("handler1_with_correct_params", 200, None),
+ ("handler2_with_wrong_params", 422, "['type_error.integer', 'int_parsing']"),
+ ("handler3_with_uppercase_params", 200, None),
+ ("handler4_without_header_params", 200, None),
+ ],
+)
+def test_validation_header_with_api_rest_resolver(handler_func, expected_status_code, expected_error_text):
+ # GIVEN a APIGatewayRestResolver with validation enabled
+ app = APIGatewayRestResolver(enable_validation=True)
- # WHEN a handler is defined without any query params
- @app.get("/users")
- def handler():
- return None
+ LOAD_GW_EVENT["httpMethod"] = "GET"
+ LOAD_GW_EVENT["path"] = "/users"
+ # WHEN a handler is defined with various parameters and routes
+
+ # Define handler1 with correct params
+ if handler_func == "handler1_with_correct_params":
+
+ @app.get("/users")
+ def handler1(header2: Annotated[List[str], Header()], header1: Annotated[str, Header()]):
+ print(header2)
+
+ # Define handler2 with wrong params
+ if handler_func == "handler2_with_wrong_params":
+
+ @app.get("/users")
+ def handler2(header2: Annotated[List[int], Header()], header1: Annotated[str, Header()]):
+ print(header2)
+
+ # Define handler3 with uppercase parameters
+ if handler_func == "handler3_with_uppercase_params":
+
+ @app.get("/users")
+ def handler3(
+ header2: Annotated[List[str], Header(name="Header2")],
+ header1: Annotated[str, Header(name="Header1")],
+ ):
+ print(header2)
+
+ # Define handler4 without params
+ if handler_func == "handler4_without_header_params":
+ LOAD_GW_EVENT["headers"] = None
+ LOAD_GW_EVENT["multiValueHeaders"] = None
+
+ @app.get("/users")
+ def handler4():
+ return None
+
+ # THEN the handler should be invoked with the expected result
+ # AND the status code should match the expected_status_code
+ result = app(LOAD_GW_EVENT, {})
+ assert result["statusCode"] == expected_status_code
+
+ # IF expected_error_text is provided, THEN check for its presence in the response body
+ if expected_error_text:
+ assert any(text in result["body"] for text in expected_error_text)
+
+
+@pytest.mark.parametrize(
+ "handler_func, expected_status_code, expected_error_text",
+ [
+ ("handler1_with_correct_params", 200, None),
+ ("handler2_with_wrong_params", 422, "['type_error.integer', 'int_parsing']"),
+ ("handler3_with_uppercase_params", 200, None),
+ ("handler4_without_header_params", 200, None),
+ ],
+)
+def test_validation_header_with_http_rest_resolver(handler_func, expected_status_code, expected_error_text):
+ # GIVEN a APIGatewayHttpResolver with validation enabled
+ app = APIGatewayHttpResolver(enable_validation=True)
LOAD_GW_EVENT_HTTP["rawPath"] = "/users"
LOAD_GW_EVENT_HTTP["requestContext"]["http"]["method"] = "GET"
LOAD_GW_EVENT_HTTP["requestContext"]["http"]["path"] = "/users"
- LOAD_GW_EVENT_HTTP["queryStringParameters"] = None
+ # WHEN a handler is defined with various parameters and routes
- # THEN the handler should be invoked and return 200
- result = app(LOAD_GW_EVENT_HTTP, {})
- assert result["statusCode"] == 200
+ # Define handler1 with correct params
+ if handler_func == "handler1_with_correct_params":
+ @app.get("/users")
+ def handler1(header2: Annotated[List[str], Header()], header1: Annotated[str, Header()]):
+ print(header2)
-def test_validate_alb_resolver_with_multi_query_values():
- # GIVEN an ALBResolver with validation enabled
- app = ALBResolver(enable_validation=True)
+ # Define handler2 with wrong params
+ if handler_func == "handler2_with_wrong_params":
- # WHEN a handler is defined with a default scalar parameter and a list
- @app.get("/users")
- def handler(parameter1: Annotated[List[str], Query()], parameter2: str):
- print(parameter2)
+ @app.get("/users")
+ def handler2(header2: Annotated[List[int], Header()], header1: Annotated[str, Header()]):
+ print(header2)
- LOAD_GW_EVENT_ALB["path"] = "/users"
+ # Define handler3 with uppercase parameters
+ if handler_func == "handler3_with_uppercase_params":
- # THEN the handler should be invoked and return 200
- result = app(LOAD_GW_EVENT_ALB, {})
- assert result["statusCode"] == 200
+ @app.get("/users")
+ def handler3(
+ header2: Annotated[List[str], Header(name="Header2")],
+ header1: Annotated[str, Header(name="Header1")],
+ ):
+ print(header2)
+ # Define handler4 without params
+ if handler_func == "handler4_without_header_params":
+ LOAD_GW_EVENT_HTTP["headers"] = None
-def test_validate_alb_resolver_with_multi_query_values_fail():
- # GIVEN an ALBResolver with validation enabled
- app = ALBResolver(enable_validation=True)
-
- # WHEN a handler is defined with a default scalar parameter and a list with wrong type
- @app.get("/users")
- def handler(parameter1: Annotated[List[int], Query()], parameter2: str):
- print(parameter2)
+ @app.get("/users")
+ def handler4():
+ return None
- LOAD_GW_EVENT_ALB["path"] = "/users"
+ # THEN the handler should be invoked with the expected result
+ # AND the status code should match the expected_status_code
+ result = app(LOAD_GW_EVENT_HTTP, {})
+ assert result["statusCode"] == expected_status_code
- # THEN the handler should be invoked and return 422
- result = app(LOAD_GW_EVENT_ALB, {})
- assert result["statusCode"] == 422
- assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"])
+ # IF expected_error_text is provided, THEN check for its presence in the response body
+ if expected_error_text:
+ assert any(text in result["body"] for text in expected_error_text)
-def test_validate_alb_resolver_without_query_params():
- # GIVEN an ALBResolver with validation enabled
+@pytest.mark.parametrize(
+ "handler_func, expected_status_code, expected_error_text",
+ [
+ ("handler1_with_correct_params", 200, None),
+ ("handler2_with_wrong_params", 422, "['type_error.integer', 'int_parsing']"),
+ ("handler3_with_uppercase_params", 200, None),
+ ("handler4_without_header_params", 200, None),
+ ],
+)
+def test_validation_header_with_alb_resolver(handler_func, expected_status_code, expected_error_text):
+ # GIVEN a ALBResolver with validation enabled
app = ALBResolver(enable_validation=True)
- # WHEN a handler is defined without any query params
- @app.get("/users")
- def handler(parameter1: Annotated[List[str], Query()], parameter2: str):
- print(parameter2)
-
LOAD_GW_EVENT_ALB["path"] = "/users"
- LOAD_GW_EVENT_HTTP["multiValueQueryStringParameters"] = None
+ # WHEN a handler is defined with various parameters and routes
- # THEN the handler should be invoked and return 200
- result = app(LOAD_GW_EVENT_ALB, {})
- assert result["statusCode"] == 200
+ # Define handler1 with correct params
+ if handler_func == "handler1_with_correct_params":
+ @app.get("/users")
+ def handler1(header2: Annotated[List[str], Header()], header1: Annotated[str, Header()]):
+ print(header2)
-def test_validate_lambda_url_resolver_with_multi_query_params():
- # GIVEN an LambdaFunctionUrlResolver with validation enabled
- app = LambdaFunctionUrlResolver(enable_validation=True)
+ # Define handler2 with wrong params
+ if handler_func == "handler2_with_wrong_params":
- # WHEN a handler is defined with a default scalar parameter and a list
- @app.get("/users")
- def handler(parameter1: Annotated[List[str], Query()], parameter2: str):
- print(parameter2)
+ @app.get("/users")
+ def handler2(header2: Annotated[List[int], Header()], header1: Annotated[str, Header()]):
+ print(header2)
- LOAD_GW_EVENT_LAMBDA_URL["rawPath"] = "/users"
- LOAD_GW_EVENT_LAMBDA_URL["requestContext"]["http"]["method"] = "GET"
- LOAD_GW_EVENT_LAMBDA_URL["requestContext"]["http"]["path"] = "/users"
+ # Define handler3 with uppercase parameters
+ if handler_func == "handler3_with_uppercase_params":
- # THEN the handler should be invoked and return 200
- result = app(LOAD_GW_EVENT_LAMBDA_URL, {})
- assert result["statusCode"] == 200
+ @app.get("/users")
+ def handler3(
+ header2: Annotated[List[str], Header(name="Header2")],
+ header1: Annotated[str, Header(name="Header1")],
+ ):
+ print(header2)
+ # Define handler4 without params
+ if handler_func == "handler4_without_header_params":
+ LOAD_GW_EVENT_ALB["multiValueHeaders"] = None
-def test_validate_lambda_url_resolver_with_multi_query_params_fail():
- # GIVEN an LambdaFunctionUrlResolver with validation enabled
- app = LambdaFunctionUrlResolver(enable_validation=True)
+ @app.get("/users")
+ def handler4():
+ return None
- # WHEN a handler is defined with a default scalar parameter and a list with wrong type
- @app.get("/users")
- def handler(parameter1: Annotated[List[int], Query()], parameter2: str):
- print(parameter2)
+ # THEN the handler should be invoked with the expected result
+ # AND the status code should match the expected_status_code
+ result = app(LOAD_GW_EVENT_ALB, {})
+ assert result["statusCode"] == expected_status_code
- LOAD_GW_EVENT_LAMBDA_URL["rawPath"] = "/users"
- LOAD_GW_EVENT_LAMBDA_URL["requestContext"]["http"]["method"] = "GET"
- LOAD_GW_EVENT_LAMBDA_URL["requestContext"]["http"]["path"] = "/users"
+ # IF expected_error_text is provided, THEN check for its presence in the response body
+ if expected_error_text:
+ assert any(text in result["body"] for text in expected_error_text)
- # THEN the handler should be invoked and return 422
- result = app(LOAD_GW_EVENT_LAMBDA_URL, {})
- assert result["statusCode"] == 422
- assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"])
-
-def test_validate_lambda_url_resolver_without_query_params():
- # GIVEN an LambdaFunctionUrlResolver with validation enabled
+@pytest.mark.parametrize(
+ "handler_func, expected_status_code, expected_error_text",
+ [
+ ("handler1_with_correct_params", 200, None),
+ ("handler2_with_wrong_params", 422, "['type_error.integer', 'int_parsing']"),
+ ("handler3_with_uppercase_params", 200, None),
+ ("handler4_without_header_params", 200, None),
+ ],
+)
+def test_validation_header_with_lambda_url_resolver(handler_func, expected_status_code, expected_error_text):
+ # GIVEN a LambdaFunctionUrlResolver with validation enabled
app = LambdaFunctionUrlResolver(enable_validation=True)
- # WHEN a handler is defined without any query params
- @app.get("/users")
- def handler():
- return None
-
LOAD_GW_EVENT_LAMBDA_URL["rawPath"] = "/users"
LOAD_GW_EVENT_LAMBDA_URL["requestContext"]["http"]["method"] = "GET"
LOAD_GW_EVENT_LAMBDA_URL["requestContext"]["http"]["path"] = "/users"
- LOAD_GW_EVENT_LAMBDA_URL["queryStringParameters"] = None
+ # WHEN a handler is defined with various parameters and routes
- # THEN the handler should be invoked and return 200
- result = app(LOAD_GW_EVENT_LAMBDA_URL, {})
- assert result["statusCode"] == 200
+ # Define handler1 with correct params
+ if handler_func == "handler1_with_correct_params":
+ @app.get("/users")
+ def handler1(header2: Annotated[List[str], Header()], header1: Annotated[str, Header()]):
+ print(header2)
-def test_validate_vpc_lattice_resolver_with_multi_params_values():
- # GIVEN an VPCLatticeV2Resolver with validation enabled
- app = VPCLatticeV2Resolver(enable_validation=True)
+ # Define handler2 with wrong params
+ if handler_func == "handler2_with_wrong_params":
- # WHEN a handler is defined with a default scalar parameter and a list
- @app.get("/users")
- def handler(parameter1: Annotated[List[str], Query()], parameter2: str):
- print(parameter2)
+ @app.get("/users")
+ def handler2(header2: Annotated[List[int], Header()], header1: Annotated[str, Header()]):
+ print(header2)
- LOAD_GW_EVENT_VPC_LATTICE["path"] = "/users"
+ # Define handler3 with uppercase parameters
+ if handler_func == "handler3_with_uppercase_params":
- # THEN the handler should be invoked and return 200
- result = app(LOAD_GW_EVENT_VPC_LATTICE, {})
- assert result["statusCode"] == 200
+ @app.get("/users")
+ def handler3(
+ header2: Annotated[List[str], Header(name="Header2")],
+ header1: Annotated[str, Header(name="Header1")],
+ ):
+ print(header2)
+ # Define handler4 without params
+ if handler_func == "handler4_without_header_params":
+ LOAD_GW_EVENT_LAMBDA_URL["headers"] = None
-def test_validate_vpc_lattice_resolver_with_multi_query_params_fail():
- # GIVEN an VPCLatticeV2Resolver with validation enabled
- app = VPCLatticeV2Resolver(enable_validation=True)
+ @app.get("/users")
+ def handler4():
+ return None
+
+ # THEN the handler should be invoked with the expected result
+ # AND the status code should match the expected_status_code
+ result = app(LOAD_GW_EVENT_LAMBDA_URL, {})
+ assert result["statusCode"] == expected_status_code
+
+ # IF expected_error_text is provided, THEN check for its presence in the response body
+ if expected_error_text:
+ assert any(text in result["body"] for text in expected_error_text)
- # WHEN a handler is defined with a default scalar parameter and a list with wrong type
- @app.get("/users")
- def handler(parameter1: Annotated[List[int], Query()], parameter2: str):
- print(parameter2)
+
+@pytest.mark.parametrize(
+ "handler_func, expected_status_code, expected_error_text",
+ [
+ ("handler1_with_correct_params", 200, None),
+ ("handler2_with_wrong_params", 422, "['type_error.integer', 'int_parsing']"),
+ ("handler3_with_uppercase_params", 200, None),
+ ("handler4_without_header_params", 200, None),
+ ],
+)
+def test_validation_header_with_vpc_lattice_v1_resolver(handler_func, expected_status_code, expected_error_text):
+ # GIVEN a VPCLatticeResolver with validation enabled
+ app = VPCLatticeResolver(enable_validation=True)
+
+ LOAD_GW_EVENT_VPC_LATTICE_V1["raw_path"] = "/users"
+ LOAD_GW_EVENT_VPC_LATTICE_V1["method"] = "GET"
+ # WHEN a handler is defined with various parameters and routes
+
+ # Define handler1 with correct params
+ if handler_func == "handler1_with_correct_params":
+
+ @app.get("/users")
+ def handler1(header2: Annotated[List[str], Header()], header1: Annotated[str, Header()]):
+ print(header2)
+
+ # Define handler2 with wrong params
+ if handler_func == "handler2_with_wrong_params":
+
+ @app.get("/users")
+ def handler2(header2: Annotated[List[int], Header()], header1: Annotated[str, Header()]):
+ print(header2)
+
+ # Define handler3 with uppercase parameters
+ if handler_func == "handler3_with_uppercase_params":
+
+ @app.get("/users")
+ def handler3(
+ header2: Annotated[List[str], Header(name="Header2")],
+ header1: Annotated[str, Header(name="Header1")],
+ ):
+ print(header2)
+
+ # Define handler4 without params
+ if handler_func == "handler4_without_header_params":
+ LOAD_GW_EVENT_VPC_LATTICE_V1["headers"] = None
+
+ @app.get("/users")
+ def handler4():
+ return None
+
+ # THEN the handler should be invoked with the expected result
+ # AND the status code should match the expected_status_code
+ result = app(LOAD_GW_EVENT_VPC_LATTICE_V1, {})
+ assert result["statusCode"] == expected_status_code
+
+ # IF expected_error_text is provided, THEN check for its presence in the response body
+ if expected_error_text:
+ assert any(text in result["body"] for text in expected_error_text)
+
+
+@pytest.mark.parametrize(
+ "handler_func, expected_status_code, expected_error_text",
+ [
+ ("handler1_with_correct_params", 200, None),
+ ("handler2_with_wrong_params", 422, "['type_error.integer', 'int_parsing']"),
+ ("handler3_with_uppercase_params", 200, None),
+ ("handler4_without_header_params", 200, None),
+ ],
+)
+def test_validation_header_with_vpc_lattice_v2_resolver(handler_func, expected_status_code, expected_error_text):
+ # GIVEN a VPCLatticeV2Resolver with validation enabled
+ app = VPCLatticeV2Resolver(enable_validation=True)
LOAD_GW_EVENT_VPC_LATTICE["path"] = "/users"
+ LOAD_GW_EVENT_VPC_LATTICE["method"] = "GET"
+ # WHEN a handler is defined with various parameters and routes
- # THEN the handler should be invoked and return 422
- result = app(LOAD_GW_EVENT_VPC_LATTICE, {})
- assert result["statusCode"] == 422
- assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"])
+ # Define handler1 with correct params
+ if handler_func == "handler1_with_correct_params":
+ @app.get("/users")
+ def handler1(header2: Annotated[List[str], Header()], header1: Annotated[str, Header()]):
+ print(header2)
-def test_validate_vpc_lattice_resolver_without_query_params():
- # GIVEN an VPCLatticeV2Resolver with validation enabled
- app = VPCLatticeV2Resolver(enable_validation=True)
+ # Define handler2 with wrong params
+ if handler_func == "handler2_with_wrong_params":
- # WHEN a handler is defined without any query params
- @app.get("/users")
- def handler():
- return None
+ @app.get("/users")
+ def handler1(header2: Annotated[List[int], Header()], header1: Annotated[str, Header()]):
+ print(header2)
- LOAD_GW_EVENT_VPC_LATTICE["path"] = "/users"
- LOAD_GW_EVENT_VPC_LATTICE["queryStringParameters"] = None
+ # Define handler3 with uppercase parameters
+ if handler_func == "handler3_with_uppercase_params":
- # THEN the handler should be invoked and return 200
+ @app.get("/users")
+ def handler3(
+ header2: Annotated[List[str], Header(name="Header2")],
+ header1: Annotated[str, Header(name="Header1")],
+ ):
+ print(header2)
+
+ # Define handler4 without params
+ if handler_func == "handler4_without_header_params":
+ LOAD_GW_EVENT_VPC_LATTICE["headers"] = None
+
+ @app.get("/users")
+ def handler3():
+ return None
+
+ # THEN the handler should be invoked with the expected result
+ # AND the status code should match the expected_status_code
result = app(LOAD_GW_EVENT_VPC_LATTICE, {})
- assert result["statusCode"] == 200
+ assert result["statusCode"] == expected_status_code
+
+ # IF expected_error_text is provided, THEN check for its presence in the response body
+ if expected_error_text:
+ assert any(text in result["body"] for text in expected_error_text)