diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 172b57de277..19082fd0955 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -80,6 +80,7 @@ License, OpenAPI, Server, + Tag, ) from aws_lambda_powertools.event_handler.openapi.params import Dependant from aws_lambda_powertools.event_handler.openapi.types import ( @@ -1360,7 +1361,7 @@ def get_openapi_schema( openapi_version: str = DEFAULT_OPENAPI_VERSION, summary: Optional[str] = None, description: Optional[str] = None, - tags: Optional[List[str]] = None, + tags: Optional[List[Union["Tag", str]]] = None, servers: Optional[List["Server"]] = None, terms_of_service: Optional[str] = None, contact: Optional["Contact"] = None, @@ -1381,7 +1382,7 @@ def get_openapi_schema( A short summary of what the application does. description: str, optional A verbose explanation of the application behavior. - tags: List[str], optional + tags: List[Tag | str], optional A list of tags used by the specification with additional metadata. servers: List[Server], optional An array of Server Objects, which provide connectivity information to a target server. @@ -1403,7 +1404,7 @@ def get_openapi_schema( get_compat_model_name_map, get_definitions, ) - from aws_lambda_powertools.event_handler.openapi.models import OpenAPI, PathItem, Server + from aws_lambda_powertools.event_handler.openapi.models import OpenAPI, PathItem, Server, Tag from aws_lambda_powertools.event_handler.openapi.types import ( COMPONENT_REF_TEMPLATE, ) @@ -1468,7 +1469,7 @@ def get_openapi_schema( if components: output["components"] = components if tags: - output["tags"] = [{"name": tag} for tag in tags] + output["tags"] = [Tag(name=tag) if isinstance(tag, str) else tag for tag in tags] output["paths"] = {k: PathItem(**v) for k, v in paths.items()} @@ -1482,7 +1483,7 @@ def get_openapi_json_schema( openapi_version: str = DEFAULT_OPENAPI_VERSION, summary: Optional[str] = None, description: Optional[str] = None, - tags: Optional[List[str]] = None, + tags: Optional[List[Union["Tag", str]]] = None, servers: Optional[List["Server"]] = None, terms_of_service: Optional[str] = None, contact: Optional["Contact"] = None, @@ -1503,7 +1504,7 @@ def get_openapi_json_schema( A short summary of what the application does. description: str, optional A verbose explanation of the application behavior. - tags: List[str], optional + tags: List[Tag, str], optional A list of tags used by the specification with additional metadata. servers: List[Server], optional An array of Server Objects, which provide connectivity information to a target server. @@ -1548,7 +1549,7 @@ def enable_swagger( openapi_version: str = DEFAULT_OPENAPI_VERSION, summary: Optional[str] = None, description: Optional[str] = None, - tags: Optional[List[str]] = None, + tags: Optional[List[Union["Tag", str]]] = None, servers: Optional[List["Server"]] = None, terms_of_service: Optional[str] = None, contact: Optional["Contact"] = None, @@ -1573,7 +1574,7 @@ def enable_swagger( A short summary of what the application does. description: str, optional A verbose explanation of the application behavior. - tags: List[str], optional + tags: List[Tag, str], optional A list of tags used by the specification with additional metadata. servers: List[Server], optional An array of Server Objects, which provide connectivity information to a target server. diff --git a/tests/functional/event_handler/test_openapi_params.py b/tests/functional/event_handler/test_openapi_params.py index 1fbdbef2bde..702bb937571 100644 --- a/tests/functional/event_handler/test_openapi_params.py +++ b/tests/functional/event_handler/test_openapi_params.py @@ -349,37 +349,6 @@ def handler(user: Annotated[User, Body(embed=True)]): assert body_post_handler_schema.properties["user"].ref == "#/components/schemas/User" -def test_openapi_with_tags(): - app = APIGatewayRestResolver() - - @app.get("/users") - def handler(): - raise NotImplementedError() - - schema = app.get_openapi_schema(tags=["Orders"]) - assert len(schema.tags) == 1 - - tag = schema.tags[0] - assert tag.name == "Orders" - - -def test_openapi_operation_with_tags(): - app = APIGatewayRestResolver() - - @app.get("/users", tags=["Users"]) - def handler(): - raise NotImplementedError() - - schema = app.get_openapi_schema() - assert len(schema.paths.keys()) == 1 - - get = schema.paths["/users"].get - assert len(get.tags) == 1 - - tag = get.tags[0] - assert tag == "Users" - - def test_openapi_with_excluded_operations(): app = APIGatewayRestResolver() diff --git a/tests/functional/event_handler/test_openapi_tags.py b/tests/functional/event_handler/test_openapi_tags.py new file mode 100644 index 00000000000..daa30b193ff --- /dev/null +++ b/tests/functional/event_handler/test_openapi_tags.py @@ -0,0 +1,53 @@ +from aws_lambda_powertools.event_handler import APIGatewayRestResolver +from aws_lambda_powertools.event_handler.openapi.models import Tag + + +def test_openapi_with_tags(): + app = APIGatewayRestResolver() + + @app.get("/users") + def handler(): + raise NotImplementedError() + + schema = app.get_openapi_schema(tags=["Orders"]) + assert schema.tags is not None + assert len(schema.tags) == 1 + + tag = schema.tags[0] + assert tag.name == "Orders" + + +def test_openapi_with_object_tags(): + app = APIGatewayRestResolver() + + @app.get("/users") + def handler(): + raise NotImplementedError() + + schema = app.get_openapi_schema(tags=[Tag(name="Orders", description="Order description tag")]) + assert schema.tags is not None + assert len(schema.tags) == 1 + + tag = schema.tags[0] + assert tag.name == "Orders" + assert tag.description == "Order description tag" + + +def test_openapi_operation_with_tags(): + app = APIGatewayRestResolver() + + @app.get("/users", tags=["Users"]) + def handler(): + raise NotImplementedError() + + schema = app.get_openapi_schema() + assert schema.paths is not None + assert len(schema.paths.keys()) == 1 + + get = schema.paths["/users"].get + assert get is not None + assert get.tags is not None + assert len(get.tags) == 1 + + tag = get.tags[0] + assert tag == "Users"