From 9cd21aa4edebf013526bd8787a79388cda3469d6 Mon Sep 17 00:00:00 2001 From: Tony Sherman <100969281+TonySherman@users.noreply.github.com> Date: Mon, 26 Feb 2024 07:14:05 -0500 Subject: [PATCH] feat(event-source): add function to get multi-value query string params by name (#3846) Co-authored-by: Heitor Lessa Co-authored-by: heitorlessa --- .../utilities/data_classes/alb_event.py | 4 +-- .../utilities/data_classes/common.py | 30 ++++++++++++++++++ .../data_classes/shared_functions.py | 31 +++++++++++++++++-- .../src/accessing_request_details.py | 5 ++- tests/unit/data_classes/test_alb_event.py | 4 ++- tests/unit/test_data_classes.py | 19 ++++++++++++ 6 files changed, 87 insertions(+), 6 deletions(-) diff --git a/aws_lambda_powertools/utilities/data_classes/alb_event.py b/aws_lambda_powertools/utilities/data_classes/alb_event.py index 1ec2535850b..a1ee3424a94 100644 --- a/aws_lambda_powertools/utilities/data_classes/alb_event.py +++ b/aws_lambda_powertools/utilities/data_classes/alb_event.py @@ -32,8 +32,8 @@ def request_context(self) -> ALBEventRequestContext: return ALBEventRequestContext(self._data) @property - def multi_value_query_string_parameters(self) -> Optional[Dict[str, List[str]]]: - return self.get("multiValueQueryStringParameters") + def multi_value_query_string_parameters(self) -> Dict[str, List[str]]: + return self.get("multiValueQueryStringParameters") or {} @property def resolved_query_string_parameters(self) -> Dict[str, List[str]]: diff --git a/aws_lambda_powertools/utilities/data_classes/common.py b/aws_lambda_powertools/utilities/data_classes/common.py index 067706140fd..41fbe5cd0de 100644 --- a/aws_lambda_powertools/utilities/data_classes/common.py +++ b/aws_lambda_powertools/utilities/data_classes/common.py @@ -7,6 +7,7 @@ from aws_lambda_powertools.shared.headers_serializer import BaseHeadersSerializer from aws_lambda_powertools.utilities.data_classes.shared_functions import ( get_header_value, + get_multi_value_query_string_values, get_query_string_value, ) @@ -103,6 +104,10 @@ def headers(self) -> Dict[str, str]: def query_string_parameters(self) -> Optional[Dict[str, str]]: return self.get("queryStringParameters") + @property + def multi_value_query_string_parameters(self) -> Dict[str, List[str]]: + return self.get("multiValueQueryStringParameters") or {} + @property def resolved_query_string_parameters(self) -> Dict[str, List[str]]: """ @@ -184,6 +189,31 @@ def get_query_string_value(self, name: str, default_value: Optional[str] = None) default_value=default_value, ) + def get_multi_value_query_string_values( + self, + name: str, + default_values: Optional[List[str]] = None, + ) ->List[str]: + """Get multi-value query string parameter values by name + + Parameters + ---------- + name: str + Multi-Value query string parameter name + default_values: List[str], optional + Default values is no values are found by name + Returns + ------- + List[str], optional + List of query string values + + """ + return get_multi_value_query_string_values( + multi_value_query_string_parameters=self.multi_value_query_string_parameters, + name=name, + default_values=default_values, + ) + @overload def get_header_value( self, diff --git a/aws_lambda_powertools/utilities/data_classes/shared_functions.py b/aws_lambda_powertools/utilities/data_classes/shared_functions.py index 594ea35bea7..43a3aad281b 100644 --- a/aws_lambda_powertools/utilities/data_classes/shared_functions.py +++ b/aws_lambda_powertools/utilities/data_classes/shared_functions.py @@ -1,7 +1,7 @@ from __future__ import annotations import base64 -from typing import Any +from typing import Any, Dict def base64_decode(value: str) -> str: @@ -63,7 +63,7 @@ def get_header_value( def get_query_string_value( - query_string_parameters: dict[str, str] | None, + query_string_parameters: Dict[str, str] | None, name: str, default_value: str | None = None, ) -> str | None: @@ -84,3 +84,30 @@ def get_query_string_value( """ params = query_string_parameters return default_value if params is None else params.get(name, default_value) + + +def get_multi_value_query_string_values( + multi_value_query_string_parameters: Dict[str, list[str]] | None, + name: str, + default_values: list[str] | None = None, +) -> list[str]: + """ + Retrieves the values of a multi-value string parameters specified by the given name. + + Parameters + ---------- + name: str + The name of the query string parameter to retrieve. + default_value: list[str], optional + The default value to return if the parameter is not found. Defaults to None. + + Returns + ------- + List[str]. optional + The values of the query string parameter if found, or the default values if not found. + """ + + default = default_values or [] + params = multi_value_query_string_parameters or {} + + return params.get(name) or default diff --git a/examples/event_handler_rest/src/accessing_request_details.py b/examples/event_handler_rest/src/accessing_request_details.py index 9929b601db0..037b76daa66 100644 --- a/examples/event_handler_rest/src/accessing_request_details.py +++ b/examples/event_handler_rest/src/accessing_request_details.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import List, Optional import requests from requests import Response @@ -20,6 +20,9 @@ def get_todos(): # alternatively _: Optional[str] = app.current_event.query_string_parameters.get("id") + # or multi-value query string parameters; ?category="red"&?category="blue" + _: List[str] = app.current_event.get_multi_value_query_string_values(name="category") + # Payload _: Optional[str] = app.current_event.body # raw str | None diff --git a/tests/unit/data_classes/test_alb_event.py b/tests/unit/data_classes/test_alb_event.py index c55af8e91ef..47048ab9407 100644 --- a/tests/unit/data_classes/test_alb_event.py +++ b/tests/unit/data_classes/test_alb_event.py @@ -11,7 +11,9 @@ def test_alb_event(): assert parsed_event.path == raw_event["path"] assert parsed_event.query_string_parameters == raw_event["queryStringParameters"] assert parsed_event.headers == raw_event["headers"] - assert parsed_event.multi_value_query_string_parameters == raw_event.get("multiValueQueryStringParameters") + + assert parsed_event.multi_value_query_string_parameters == raw_event.get("multiValueQueryStringParameters", {}) + assert parsed_event.multi_value_headers == raw_event.get("multiValueHeaders") assert parsed_event.body == raw_event["body"] assert parsed_event.is_base64_encoded == raw_event["isBase64Encoded"] diff --git a/tests/unit/test_data_classes.py b/tests/unit/test_data_classes.py index c810f653f5b..c8f0c1fc932 100644 --- a/tests/unit/test_data_classes.py +++ b/tests/unit/test_data_classes.py @@ -259,6 +259,25 @@ def test_base_proxy_event_get_query_string_value(): assert value is None +def test_base_proxy_event_get_multi_value_query_string_values(): + default_values = ["default_1", "default_2"] + set_values = ["value_1", "value_2"] + + event = BaseProxyEvent({}) + values = event.get_multi_value_query_string_values("test", default_values) + assert values == default_values + + event._data["multiValueQueryStringParameters"] = {"test": set_values} + values = event.get_multi_value_query_string_values("test", default_values) + assert values == set_values + + values = event.get_multi_value_query_string_values("unknown", default_values) + assert values == default_values + + values = event.get_multi_value_query_string_values("unknown") + assert values == [] + + def test_base_proxy_event_get_header_value(): default_value = "default" set_value = "value"