From 4c6d1969effdbf3ecfb76036bcad95ceeaf3fd19 Mon Sep 17 00:00:00 2001 From: Aleksei Tcysin <24791800+tcysin@users.noreply.github.com> Date: Fri, 20 Dec 2024 00:04:12 +0100 Subject: [PATCH] feat(event_handler): mark API operation as deprecated for OpenAPI documentation (#5732) * Add deprecated parameter with default to BaseRouter.get * Add parameter with default to BaseRouter.route * Pass deprecated param from .get() into .route() * Add param and pass along for post, put, delete, patch, head * Add param and pass along for ApiGatewayRestResolver.route * Ditto for Route.__init__, use when creating operation metadata * Add param and pass along in ApiGatewayResolver.route * Add param and pass along in Router.route, workaround for include_router * Functional tests * Formatting * Refactor to use defaultdict * Move deprecated operation tests into separate test case * Simplify test case * Put 'deprecated' param before 'middlewares' * Remove workaround * Add test case for deprecated POST operation * Add 'deprecated' param to BedrockAgentResolver methods * Small changes + trigger pipeline --------- Co-authored-by: Leandro Damascena --- .../event_handler/api_gateway.py | 28 ++++++++++++- .../event_handler/bedrock_agent.py | 11 ++++- .../provider/cloudwatch_emf/cloudwatch.py | 3 -- .../parser/models/apigw_websocket.py | 1 + .../_openapi_customization_operations.md | 1 + .../_pydantic/test_openapi_params.py | 41 +++++++++++++++++++ tests/unit/metrics/test_functions.py | 18 ++++---- .../parser/_pydantic/test_apigw_websockets.py | 2 +- 8 files changed, 89 insertions(+), 16 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 001fcceac72..f4ef22019e5 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -309,6 +309,7 @@ def __init__( include_in_schema: bool = True, security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, + deprecated: bool = False, middlewares: list[Callable[..., Response]] | None = None, ): """ @@ -348,6 +349,8 @@ def __init__( The OpenAPI security for this route openapi_extensions: dict[str, Any], optional Additional OpenAPI extensions as a dictionary. + deprecated: bool + Whether or not to mark this route as deprecated in the OpenAPI schema middlewares: list[Callable[..., Response]] | None The list of route middlewares to be called in order. """ @@ -374,6 +377,7 @@ def __init__( self.openapi_extensions = openapi_extensions self.middlewares = middlewares or [] self.operation_id = operation_id or self._generate_operation_id() + self.deprecated = deprecated # _middleware_stack_built is used to ensure the middleware stack is only built once. self._middleware_stack_built = False @@ -670,6 +674,9 @@ def _openapi_operation_metadata(self, operation_ids: set[str]) -> dict[str, Any] operation_ids.add(self.operation_id) operation["operationId"] = self.operation_id + # Mark as deprecated if necessary + operation["deprecated"] = self.deprecated or None + return operation @staticmethod @@ -924,6 +931,7 @@ def route( include_in_schema: bool = True, security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, + deprecated: bool = False, middlewares: list[Callable[..., Any]] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: raise NotImplementedError() @@ -984,6 +992,7 @@ def get( include_in_schema: bool = True, security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, + deprecated: bool = False, middlewares: list[Callable[..., Any]] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: """Get route decorator with GET `method` @@ -1023,6 +1032,7 @@ def lambda_handler(event, context): include_in_schema, security, openapi_extensions, + deprecated, middlewares, ) @@ -1041,6 +1051,7 @@ def post( include_in_schema: bool = True, security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, + deprecated: bool = False, middlewares: list[Callable[..., Any]] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: """Post route decorator with POST `method` @@ -1081,6 +1092,7 @@ def lambda_handler(event, context): include_in_schema, security, openapi_extensions, + deprecated, middlewares, ) @@ -1099,6 +1111,7 @@ def put( include_in_schema: bool = True, security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, + deprecated: bool = False, middlewares: list[Callable[..., Any]] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: """Put route decorator with PUT `method` @@ -1139,6 +1152,7 @@ def lambda_handler(event, context): include_in_schema, security, openapi_extensions, + deprecated, middlewares, ) @@ -1157,6 +1171,7 @@ def delete( include_in_schema: bool = True, security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, + deprecated: bool = False, middlewares: list[Callable[..., Any]] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: """Delete route decorator with DELETE `method` @@ -1196,6 +1211,7 @@ def lambda_handler(event, context): include_in_schema, security, openapi_extensions, + deprecated, middlewares, ) @@ -1214,6 +1230,7 @@ def patch( include_in_schema: bool = True, security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, + deprecated: bool = False, middlewares: list[Callable] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: """Patch route decorator with PATCH `method` @@ -1256,6 +1273,7 @@ def lambda_handler(event, context): include_in_schema, security, openapi_extensions, + deprecated, middlewares, ) @@ -1274,6 +1292,7 @@ def head( include_in_schema: bool = True, security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, + deprecated: bool = False, middlewares: list[Callable] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: """Head route decorator with HEAD `method` @@ -1315,6 +1334,7 @@ def lambda_handler(event, context): include_in_schema, security, openapi_extensions, + deprecated, middlewares, ) @@ -1629,7 +1649,6 @@ def get_openapi_schema( # Add routes to the OpenAPI schema for route in all_routes: - if route.security and not _validate_openapi_security_parameters( security=route.security, security_schemes=security_schemes, @@ -1694,7 +1713,6 @@ def _get_openapi_security( @staticmethod def _determine_openapi_version(openapi_version: str): - # Pydantic V2 has no support for OpenAPI schema 3.0 if not openapi_version.startswith("3.1"): warnings.warn( @@ -1950,6 +1968,7 @@ def route( include_in_schema: bool = True, security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, + deprecated: bool = False, middlewares: list[Callable[..., Any]] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: """Route decorator includes parameter `method`""" @@ -1978,6 +1997,7 @@ def register_resolver(func: AnyCallableT) -> AnyCallableT: include_in_schema, security, openapi_extensions, + deprecated, middlewares, ) @@ -2492,6 +2512,7 @@ def route( include_in_schema: bool = True, security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, + deprecated: bool = False, middlewares: list[Callable[..., Any]] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: def register_route(func: AnyCallableT) -> AnyCallableT: @@ -2517,6 +2538,7 @@ def register_route(func: AnyCallableT) -> AnyCallableT: include_in_schema, frozen_security, fronzen_openapi_extensions, + deprecated, ) # Collate Middleware for routes @@ -2598,6 +2620,7 @@ def route( include_in_schema: bool = True, security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, + deprecated: bool = False, middlewares: list[Callable[..., Any]] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: # NOTE: see #1552 for more context. @@ -2616,6 +2639,7 @@ def route( include_in_schema, security, openapi_extensions, + deprecated, middlewares, ) diff --git a/aws_lambda_powertools/event_handler/bedrock_agent.py b/aws_lambda_powertools/event_handler/bedrock_agent.py index 8af5520a188..215199e0022 100644 --- a/aws_lambda_powertools/event_handler/bedrock_agent.py +++ b/aws_lambda_powertools/event_handler/bedrock_agent.py @@ -108,9 +108,9 @@ def get( # type: ignore[override] tags: list[str] | None = None, operation_id: str | None = None, include_in_schema: bool = True, + deprecated: bool = False, middlewares: list[Callable[..., Any]] | None = None, ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: - openapi_extensions = None security = None @@ -128,6 +128,7 @@ def get( # type: ignore[override] include_in_schema, security, openapi_extensions, + deprecated, middlewares, ) @@ -146,6 +147,7 @@ def post( # type: ignore[override] tags: list[str] | None = None, operation_id: str | None = None, include_in_schema: bool = True, + deprecated: bool = False, middlewares: list[Callable[..., Any]] | None = None, ): openapi_extensions = None @@ -165,6 +167,7 @@ def post( # type: ignore[override] include_in_schema, security, openapi_extensions, + deprecated, middlewares, ) @@ -183,6 +186,7 @@ def put( # type: ignore[override] tags: list[str] | None = None, operation_id: str | None = None, include_in_schema: bool = True, + deprecated: bool = False, middlewares: list[Callable[..., Any]] | None = None, ): openapi_extensions = None @@ -202,6 +206,7 @@ def put( # type: ignore[override] include_in_schema, security, openapi_extensions, + deprecated, middlewares, ) @@ -220,6 +225,7 @@ def patch( # type: ignore[override] tags: list[str] | None = None, operation_id: str | None = None, include_in_schema: bool = True, + deprecated: bool = False, middlewares: list[Callable] | None = None, ): openapi_extensions = None @@ -239,6 +245,7 @@ def patch( # type: ignore[override] include_in_schema, security, openapi_extensions, + deprecated, middlewares, ) @@ -257,6 +264,7 @@ def delete( # type: ignore[override] tags: list[str] | None = None, operation_id: str | None = None, include_in_schema: bool = True, + deprecated: bool = False, middlewares: list[Callable[..., Any]] | None = None, ): openapi_extensions = None @@ -276,6 +284,7 @@ def delete( # type: ignore[override] include_in_schema, security, openapi_extensions, + deprecated, middlewares, ) diff --git a/aws_lambda_powertools/metrics/provider/cloudwatch_emf/cloudwatch.py b/aws_lambda_powertools/metrics/provider/cloudwatch_emf/cloudwatch.py index 50ad1871953..cd9a90a0d19 100644 --- a/aws_lambda_powertools/metrics/provider/cloudwatch_emf/cloudwatch.py +++ b/aws_lambda_powertools/metrics/provider/cloudwatch_emf/cloudwatch.py @@ -24,7 +24,6 @@ from aws_lambda_powertools.shared.functions import resolve_env_var_choice from aws_lambda_powertools.warnings import PowertoolsUserWarning - if TYPE_CHECKING: from aws_lambda_powertools.metrics.provider.cloudwatch_emf.types import CloudWatchEMFOutput from aws_lambda_powertools.metrics.types import MetricNameUnitResolution @@ -295,8 +294,6 @@ def add_dimension(self, name: str, value: str) -> None: self.dimension_set[name] = value - - def add_metadata(self, key: str, value: Any) -> None: """Adds high cardinal metadata for metrics object diff --git a/aws_lambda_powertools/utilities/parser/models/apigw_websocket.py b/aws_lambda_powertools/utilities/parser/models/apigw_websocket.py index 0655825e776..b9e7ecd68c7 100644 --- a/aws_lambda_powertools/utilities/parser/models/apigw_websocket.py +++ b/aws_lambda_powertools/utilities/parser/models/apigw_websocket.py @@ -9,6 +9,7 @@ class APIGatewayWebSocketEventIdentity(BaseModel): source_ip: IPvAnyNetwork = Field(alias="sourceIp") user_agent: Optional[str] = Field(None, alias="userAgent") + class APIGatewayWebSocketEventRequestContextBase(BaseModel): extended_request_id: str = Field(alias="extendedRequestId") request_time: str = Field(alias="requestTime") diff --git a/docs/core/event_handler/_openapi_customization_operations.md b/docs/core/event_handler/_openapi_customization_operations.md index df842b2b7fc..0072ec1fae4 100644 --- a/docs/core/event_handler/_openapi_customization_operations.md +++ b/docs/core/event_handler/_openapi_customization_operations.md @@ -13,3 +13,4 @@ Here's a breakdown of various customizable fields: | `tags` | `List[str]` | Tags are a way to categorize and group endpoints within the API documentation. They can help organize the operations by resources or other heuristic. | | `operation_id` | `str` | A unique identifier for the operation, which can be used for referencing this operation in documentation or code. This ID must be unique across all operations described in the API. | | `include_in_schema` | `bool` | A boolean value that determines whether or not this operation should be included in the OpenAPI schema. Setting it to `False` can hide the endpoint from generated documentation and schema exports, which might be useful for private or experimental endpoints. | +| `deprecated` | `bool` | A boolean value that determines whether or not this operation should be marked as deprecated in the OpenAPI schema. | diff --git a/tests/functional/event_handler/_pydantic/test_openapi_params.py b/tests/functional/event_handler/_pydantic/test_openapi_params.py index 710627922f6..a57156db130 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_params.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_params.py @@ -44,6 +44,7 @@ def handler(): get = path.get assert get.summary == "GET /" assert get.operationId == "handler__get" + assert get.deprecated is None assert get.responses is not None assert 200 in get.responses.keys() @@ -388,6 +389,46 @@ def handler(user: Annotated[User, Body(description="This is a user")]): assert request_body.content[JSON_CONTENT_TYPE].schema_.description == "This is a user" +def test_openapi_with_deprecated_operations(): + app = APIGatewayRestResolver() + + @app.get("/", deprecated=True) + def _get(): + raise NotImplementedError() + + @app.post("/", deprecated=True) + def _post(): + raise NotImplementedError() + + schema = app.get_openapi_schema() + + get = schema.paths["/"].get + assert get.deprecated is True + + post = schema.paths["/"].post + assert post.deprecated is True + + +def test_openapi_without_deprecated_operations(): + app = APIGatewayRestResolver() + + @app.get("/") + def _get(): + raise NotImplementedError() + + @app.post("/", deprecated=False) + def _post(): + raise NotImplementedError() + + schema = app.get_openapi_schema() + + get = schema.paths["/"].get + assert get.deprecated is None + + post = schema.paths["/"].post + assert post.deprecated is None + + def test_openapi_with_excluded_operations(): app = APIGatewayRestResolver() diff --git a/tests/unit/metrics/test_functions.py b/tests/unit/metrics/test_functions.py index 142be729ae6..e7647852a49 100644 --- a/tests/unit/metrics/test_functions.py +++ b/tests/unit/metrics/test_functions.py @@ -1,6 +1,8 @@ -import pytest import warnings +import pytest + +from aws_lambda_powertools.metrics import Metrics from aws_lambda_powertools.metrics.functions import ( extract_cloudwatch_metric_resolution_value, extract_cloudwatch_metric_unit_value, @@ -10,9 +12,9 @@ MetricUnitError, ) from aws_lambda_powertools.metrics.provider.cloudwatch_emf.metric_properties import MetricResolution, MetricUnit -from aws_lambda_powertools.metrics import Metrics from aws_lambda_powertools.warnings import PowertoolsUserWarning + @pytest.fixture def warning_catcher(monkeypatch): caught_warnings = [] @@ -20,7 +22,7 @@ def warning_catcher(monkeypatch): def custom_warn(message, category=None, stacklevel=1, source=None): caught_warnings.append(PowertoolsUserWarning(message)) - monkeypatch.setattr(warnings, 'warn', custom_warn) + monkeypatch.setattr(warnings, "warn", custom_warn) return caught_warnings @@ -78,13 +80,13 @@ def test_extract_valid_cloudwatch_metric_unit_value(): def test_add_dimension_overwrite_warning(warning_catcher): """ - Adds a dimension and then tries to add another with the same name - but a different value. Verifies if the dimension is updated with - the new value and warning is issued when an existing dimension + Adds a dimension and then tries to add another with the same name + but a different value. Verifies if the dimension is updated with + the new value and warning is issued when an existing dimension is overwritten. """ metrics = Metrics(namespace="TestNamespace") - + # GIVEN default dimension dimension_name = "test-dimension" value1 = "test-value-1" @@ -100,5 +102,3 @@ def test_add_dimension_overwrite_warning(warning_catcher): # AND a warning should be issued with the exact message expected_warning = f"Dimension '{dimension_name}' has already been added. The previous value will be overwritten." assert any(str(w) == expected_warning for w in warning_catcher) - - diff --git a/tests/unit/parser/_pydantic/test_apigw_websockets.py b/tests/unit/parser/_pydantic/test_apigw_websockets.py index aea77217d93..7b8a3c9ba46 100644 --- a/tests/unit/parser/_pydantic/test_apigw_websockets.py +++ b/tests/unit/parser/_pydantic/test_apigw_websockets.py @@ -114,4 +114,4 @@ def test_apigw_websocket_disconnect_event(): assert parsed_event.is_base64_encoded == raw_event["isBase64Encoded"] assert parsed_event.headers == raw_event["headers"] - assert parsed_event.multi_value_headers == raw_event["multiValueHeaders"] \ No newline at end of file + assert parsed_event.multi_value_headers == raw_event["multiValueHeaders"]