Skip to content

Commit

Permalink
fix(event-handler): multi-value query string and validation of scalar…
Browse files Browse the repository at this point in the history
… parameters (#3795)
  • Loading branch information
rubenfonseca authored Feb 19, 2024
1 parent 770f023 commit 36905b5
Show file tree
Hide file tree
Showing 8 changed files with 287 additions and 198 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,10 @@ def _get_embed_body(
return received_body, field_alias_omitted


def _normalize_multi_query_string_with_param(query_string: Optional[Dict[str, str]], params: Sequence[ModelField]):
def _normalize_multi_query_string_with_param(
query_string: Dict[str, List[str]],
params: Sequence[ModelField],
) -> Dict[str, Any]:
"""
Extract and normalize resolved_query_string_parameters
Expand All @@ -383,15 +386,15 @@ def _normalize_multi_query_string_with_param(query_string: Optional[Dict[str, st
-------
A dictionary containing the processed multi_query_string_parameters.
"""
if query_string:
for param in filter(is_scalar_field, params):
try:
# if the target parameter is a scalar, we keep the first value of the query string
# regardless if there are more in the payload
query_string[param.alias] = query_string[param.alias][0]
except KeyError:
pass
return query_string
resolved_query_string: Dict[str, Any] = query_string
for param in filter(is_scalar_field, params):
try:
# if the target parameter is a scalar, we keep the first value of the query string
# regardless if there are more in the payload
resolved_query_string[param.alias] = query_string[param.alias][0]
except KeyError:
pass
return resolved_query_string


def _normalize_multi_header_values_with_param(headers: Optional[Dict[str, str]], params: Sequence[ModelField]):
Expand Down
4 changes: 2 additions & 2 deletions aws_lambda_powertools/utilities/data_classes/alb_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ def multi_value_query_string_parameters(self) -> Optional[Dict[str, List[str]]]:
return self.get("multiValueQueryStringParameters")

@property
def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]:
def resolved_query_string_parameters(self) -> Dict[str, List[str]]:
if self.multi_value_query_string_parameters:
return self.multi_value_query_string_parameters

return self.query_string_parameters
return super().resolved_query_string_parameters

@property
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,11 @@ def multi_value_query_string_parameters(self) -> Optional[Dict[str, List[str]]]:
return self.get("multiValueQueryStringParameters")

@property
def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]:
def resolved_query_string_parameters(self) -> Dict[str, List[str]]:
if self.multi_value_query_string_parameters:
return self.multi_value_query_string_parameters

return self.query_string_parameters
return super().resolved_query_string_parameters

@property
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
Expand Down Expand Up @@ -318,16 +318,6 @@ def http_method(self) -> str:
def header_serializer(self):
return HttpApiHeadersSerializer()

@property
def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]:
if self.query_string_parameters is not None:
query_string = {
key: value.split(",") if "," in value else value for key, value in self.query_string_parameters.items()
}
return query_string

return {}

@property
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
if self.headers is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,6 @@ def query_string_parameters(self) -> Optional[Dict[str, str]]:
# together with the other parameters. So we just return all parameters here.
return {x["name"]: x["value"] for x in self["parameters"]} if self.get("parameters") else None

@property
def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]:
return self.query_string_parameters

@property
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
return {}
14 changes: 8 additions & 6 deletions aws_lambda_powertools/utilities/data_classes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,19 @@ def query_string_parameters(self) -> Optional[Dict[str, str]]:
return self.get("queryStringParameters")

@property
def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]:
def resolved_query_string_parameters(self) -> Dict[str, List[str]]:
"""
This property determines the appropriate query string parameter to be used
as a trusted source for validating OpenAPI.
This is necessary because different resolvers use different formats to encode
multi query string parameters.
"""
return self.query_string_parameters
if self.query_string_parameters is not None:
query_string = {key: value.split(",") for key, value in self.query_string_parameters.items()}
return query_string

return {}

@property
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
Expand Down Expand Up @@ -186,17 +190,15 @@ def get_header_value(
name: str,
default_value: str,
case_sensitive: Optional[bool] = False,
) -> str:
...
) -> str: ...

@overload
def get_header_value(
self,
name: str,
default_value: Optional[str] = None,
case_sensitive: Optional[bool] = False,
) -> Optional[str]:
...
) -> Optional[str]: ...

def get_header_value(
self,
Expand Down
26 changes: 12 additions & 14 deletions aws_lambda_powertools/utilities/data_classes/vpc_lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,15 @@ def get_header_value(
name: str,
default_value: str,
case_sensitive: Optional[bool] = False,
) -> str:
...
) -> str: ...

@overload
def get_header_value(
self,
name: str,
default_value: Optional[str] = None,
case_sensitive: Optional[bool] = False,
) -> Optional[str]:
...
) -> Optional[str]: ...

def get_header_value(
self,
Expand Down Expand Up @@ -140,10 +138,6 @@ def query_string_parameters(self) -> Dict[str, str]:
"""The request query string parameters."""
return self["query_string_parameters"]

@property
def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]:
return self.query_string_parameters

@property
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
if self.headers is not None:
Expand Down Expand Up @@ -255,17 +249,21 @@ def path(self) -> str:

@property
def request_context(self) -> vpcLatticeEventV2RequestContext:
"""he VPC Lattice v2 Event request context."""
"""The VPC Lattice v2 Event request context."""
return vpcLatticeEventV2RequestContext(self["requestContext"])

@property
def query_string_parameters(self) -> Optional[Dict[str, str]]:
"""The request query string parameters."""
return self.get("queryStringParameters")
"""The request query string parameters.
@property
def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]:
return self.query_string_parameters
For VPC Lattice V2, the queryStringParameters will contain a Dict[str, List[str]]
so to keep compatibility with existing utilities, we merge all the values with a comma.
"""
params = self.get("queryStringParameters")
if params:
return {key: ",".join(value) for key, value in params.items()}
else:
return None

@property
def resolved_headers_field(self) -> Optional[Dict[str, str]]:
Expand Down
32 changes: 32 additions & 0 deletions tests/functional/event_handler/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import pytest

from tests.functional.utils import load_event


@pytest.fixture
def json_dump():
Expand Down Expand Up @@ -39,3 +41,33 @@ def validation_schema():
@pytest.fixture
def raw_event():
return {"message": "hello hello", "username": "blah blah"}


@pytest.fixture
def gw_event():
return load_event("apiGatewayProxyEvent.json")


@pytest.fixture
def gw_event_http():
return load_event("apiGatewayProxyV2Event.json")


@pytest.fixture
def gw_event_alb():
return load_event("albMultiValueQueryStringEvent.json")


@pytest.fixture
def gw_event_lambda_url():
return load_event("lambdaFunctionUrlEventWithHeaders.json")


@pytest.fixture
def gw_event_vpc_lattice():
return load_event("vpcLatticeV2EventWithHeaders.json")


@pytest.fixture
def gw_event_vpc_lattice_v1():
return load_event("vpcLatticeEvent.json")
Loading

0 comments on commit 36905b5

Please sign in to comment.