Skip to content

Commit

Permalink
feat(event_handler): support Header parameter validation in OpenAPI s…
Browse files Browse the repository at this point in the history
…chema (#3687)

* Adding header - Initial commit

* Adding header - Fix VPC Lattice Payload

* Adding header - tests and final changes

* Making sonarqube happy

* Adding documentation

* Rafactoring to be complaint with RFC

* Adding tests

* Adding test with Uppercase variables

* Revert event changes

* Adding HTTP RFC

* Adding getter/setter to clean the code

* Adding getter/setter to clean the code

* Addressing Ruben's feedback
  • Loading branch information
leandrodamascena authored Feb 1, 2024
1 parent ced0a3d commit 33820d1
Show file tree
Hide file tree
Showing 18 changed files with 873 additions and 213 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,22 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
query_string,
)

# Normalize header values before validate this
headers = _normalize_multi_header_values_with_param(
app.current_event.resolved_headers_field,
route.dependant.header_params,
)

# Process header values
header_values, header_errors = _request_params_to_args(
route.dependant.header_params,
headers,
)

values.update(path_values)
values.update(query_values)
errors += path_errors + query_errors
values.update(header_values)
errors += path_errors + query_errors + header_errors

# Process the request body, if it exists
if route.dependant.body_params:
Expand Down Expand Up @@ -243,12 +256,14 @@ def _request_params_to_args(
errors = []

for field in required_params:
value = received_params.get(field.alias)

field_info = field.field_info

# To ensure early failure, we check if it's not an instance of Param.
if not isinstance(field_info, Param):
raise AssertionError(f"Expected Param field_info, got {field_info}")

value = received_params.get(field.alias)

loc = (field_info.in_.value, field.alias)

# If we don't have a value, see if it's required or has a default
Expand Down Expand Up @@ -377,3 +392,30 @@ def _normalize_multi_query_string_with_param(query_string: Optional[Dict[str, st
except KeyError:
pass
return query_string


def _normalize_multi_header_values_with_param(headers: Optional[Dict[str, str]], params: Sequence[ModelField]):
"""
Extract and normalize resolved_headers_field
Parameters
----------
headers: Dict
A dictionary containing the initial header parameters.
params: Sequence[ModelField]
A sequence of ModelField objects representing parameters.
Returns
-------
A dictionary containing the processed headers.
"""
if headers:
for param in filter(is_scalar_field, params):
try:
if len(headers[param.alias]) == 1:
# if the target parameter is a scalar and the list contains only 1 element
# we keep the first value of the headers regardless if there are more in the payload
headers[param.alias] = headers[param.alias][0]
except KeyError:
pass
return headers
27 changes: 16 additions & 11 deletions aws_lambda_powertools/event_handler/openapi/dependant.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
from aws_lambda_powertools.event_handler.openapi.params import (
Body,
Dependant,
Header,
Param,
ParamTypes,
Query,
_File,
_Form,
_Header,
analyze_param,
create_response_field,
get_flat_dependant,
Expand Down Expand Up @@ -59,16 +59,21 @@ def add_param_to_fields(
"""
field_info = cast(Param, field.field_info)
if field_info.in_ == ParamTypes.path:
dependant.path_params.append(field)
elif field_info.in_ == ParamTypes.query:
dependant.query_params.append(field)
elif field_info.in_ == ParamTypes.header:
dependant.header_params.append(field)

# Dictionary to map ParamTypes to their corresponding lists in dependant
param_type_map = {
ParamTypes.path: dependant.path_params,
ParamTypes.query: dependant.query_params,
ParamTypes.header: dependant.header_params,
ParamTypes.cookie: dependant.cookie_params,
}

# Check if field_info.in_ is a valid key in param_type_map and append the field to the corresponding list
# or raise an exception if it's not a valid key.
if field_info.in_ in param_type_map:
param_type_map[field_info.in_].append(field)
else:
if field_info.in_ != ParamTypes.cookie:
raise AssertionError(f"Unsupported param type: {field_info.in_}")
dependant.cookie_params.append(field)
raise AssertionError(f"Unsupported param type: {field_info.in_}")


def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
Expand Down Expand Up @@ -265,7 +270,7 @@ def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
return False
elif is_scalar_field(field=param_field):
return False
elif isinstance(param_field.field_info, (Query, _Header)) and is_scalar_sequence_field(param_field):
elif isinstance(param_field.field_info, (Query, Header)) and is_scalar_sequence_field(param_field):
return False
else:
if not isinstance(param_field.field_info, Body):
Expand Down
79 changes: 77 additions & 2 deletions aws_lambda_powertools/event_handler/openapi/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def __init__(
)


class _Header(Param):
class Header(Param):
"""
A class used internally to represent a header parameter in a path operation.
"""
Expand Down Expand Up @@ -527,12 +527,75 @@ def __init__(
json_schema_extra: Union[Dict[str, Any], None] = None,
**extra: Any,
):
"""
Constructs a new Query param.
Parameters
----------
default: Any
The default value of the parameter
default_factory: Callable[[], Any], optional
Callable that will be called when a default value is needed for this field
annotation: Any, optional
The type annotation of the parameter
alias: str, optional
The public name of the field
alias_priority: int, optional
Priority of the alias. This affects whether an alias generator is used
validation_alias: str | AliasPath | AliasChoices | None, optional
Alias to be used for validation only
serialization_alias: str | AliasPath | AliasChoices | None, optional
Alias to be used for serialization only
convert_underscores: bool
If true convert "_" to "-"
See RFC: https://www.rfc-editor.org/rfc/rfc9110.html#name-field-name-registry
title: str, optional
The title of the parameter
description: str, optional
The description of the parameter
gt: float, optional
Only applies to numbers, required the field to be "greater than"
ge: float, optional
Only applies to numbers, required the field to be "greater than or equal"
lt: float, optional
Only applies to numbers, required the field to be "less than"
le: float, optional
Only applies to numbers, required the field to be "less than or equal"
min_length: int, optional
Only applies to strings, required the field to have a minimum length
max_length: int, optional
Only applies to strings, required the field to have a maximum length
pattern: str, optional
Only applies to strings, requires the field match against a regular expression pattern string
discriminator: str, optional
Parameter field name for discriminating the type in a tagged union
strict: bool, optional
Enables Pydantic's strict mode for the field
multiple_of: float, optional
Only applies to numbers, requires the field to be a multiple of the given value
allow_inf_nan: bool, optional
Only applies to numbers, requires the field to allow infinity and NaN values
max_digits: int, optional
Only applies to Decimals, requires the field to have a maxmium number of digits within the decimal.
decimal_places: int, optional
Only applies to Decimals, requires the field to have at most a number of decimal places
examples: List[Any], optional
A list of examples for the parameter
deprecated: bool, optional
If `True`, the parameter will be marked as deprecated
include_in_schema: bool, optional
If `False`, the parameter will be excluded from the generated OpenAPI schema
json_schema_extra: Dict[str, Any], optional
Extra values to include in the generated OpenAPI schema
"""
self.convert_underscores = convert_underscores
self._alias = alias

super().__init__(
default=default,
default_factory=default_factory,
annotation=annotation,
alias=alias,
alias=self._alias,
alias_priority=alias_priority,
validation_alias=validation_alias,
serialization_alias=serialization_alias,
Expand All @@ -558,6 +621,18 @@ def __init__(
**extra,
)

@property
def alias(self):
return self._alias

@alias.setter
def alias(self, value: Optional[str] = None):
if value is not None:
# Headers are case-insensitive according to RFC 7540 (HTTP/2), so we lower the parameter name
# This ensures that customers can access headers with any casing, as per the RFC guidelines.
# Reference: https://www.rfc-editor.org/rfc/rfc7540#section-8.1.2
self._alias = value.lower()


class Body(FieldInfo):
"""
Expand Down
11 changes: 11 additions & 0 deletions aws_lambda_powertools/utilities/data_classes/alb_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@ def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]:

return self.query_string_parameters

@property
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
headers: Dict[str, Any] = {}

if self.multi_value_headers:
headers = self.multi_value_headers
else:
headers = self.headers

return {key.lower(): value for key, value in headers.items()}

@property
def multi_value_headers(self) -> Optional[Dict[str, List[str]]]:
return self.get("multiValueHeaders")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,17 @@ def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]:

return self.query_string_parameters

@property
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
headers: Dict[str, Any] = {}

if self.multi_value_headers:
headers = self.multi_value_headers
else:
headers = self.headers

return {key.lower(): value for key, value in headers.items()}

@property
def request_context(self) -> APIGatewayEventRequestContext:
return APIGatewayEventRequestContext(self._data)
Expand Down Expand Up @@ -316,3 +327,11 @@ def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]:
return query_string

return {}

@property
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
if self.headers is not None:
headers = {key.lower(): value.split(",") if "," in value else value for key, value in self.headers.items()}
return headers

return {}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional

from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent, DictWrapper

Expand Down Expand Up @@ -112,3 +112,7 @@ def query_string_parameters(self) -> Optional[Dict[str, str]]:
@property
def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]:
return self.query_string_parameters

@property
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
return {}
15 changes: 15 additions & 0 deletions aws_lambda_powertools/utilities/data_classes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,21 @@ def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]:
"""
return self.query_string_parameters

@property
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
"""
This property determines the appropriate header to be used
as a trusted source for validating OpenAPI.
This is necessary because different resolvers use different formats to encode
headers parameters.
Headers are case-insensitive according to RFC 7540 (HTTP/2), so we lower the header name
This ensures that customers can access headers with any casing, as per the RFC guidelines.
Reference: https://www.rfc-editor.org/rfc/rfc7540#section-8.1.2
"""
return self.headers

@property
def is_base64_encoded(self) -> Optional[bool]:
return self.get("isBase64Encoded")
Expand Down
15 changes: 15 additions & 0 deletions aws_lambda_powertools/utilities/data_classes/vpc_lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,14 @@ def query_string_parameters(self) -> Dict[str, str]:
def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]:
return self.query_string_parameters

@property
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
if self.headers is not None:
headers = {key.lower(): value.split(",") if "," in value else value for key, value in self.headers.items()}
return headers

return {}


class vpcLatticeEventV2Identity(DictWrapper):
@property
Expand Down Expand Up @@ -259,3 +267,10 @@ def query_string_parameters(self) -> Optional[Dict[str, str]]:
@property
def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]:
return self.query_string_parameters

@property
def resolved_headers_field(self) -> Optional[Dict[str, str]]:
if self.headers is not None:
return {key.lower(): value for key, value in self.headers.items()}

return {}
Loading

0 comments on commit 33820d1

Please sign in to comment.