Skip to content

Commit

Permalink
feat(event_handler): Ensure Bedrock Agents resolver works with Pydant…
Browse files Browse the repository at this point in the history
…ic v2 (#5156)

Make sure Bedrock Agent works with Pydantic v2
  • Loading branch information
leandrodamascena authored Sep 11, 2024
1 parent dbfb0db commit 9acfee4
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 1 deletion.
107 changes: 107 additions & 0 deletions aws_lambda_powertools/event_handler/bedrock_agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Any, Callable

from typing_extensions import override
Expand All @@ -10,10 +11,12 @@
ProxyEventType,
ResponseBuilder,
)
from aws_lambda_powertools.event_handler.openapi.constants import DEFAULT_API_VERSION, DEFAULT_OPENAPI_VERSION

if TYPE_CHECKING:
from re import Match

from aws_lambda_powertools.event_handler.openapi.models import Contact, License, SecurityScheme, Server, Tag
from aws_lambda_powertools.event_handler.openapi.types import OpenAPIResponse
from aws_lambda_powertools.utilities.data_classes import BedrockAgentEvent

Expand Down Expand Up @@ -273,3 +276,107 @@ def _convert_matches_into_route_keys(self, match: Match) -> dict[str, str]:
if match.groupdict() and self.current_event.parameters:
parameters = {parameter["name"]: parameter["value"] for parameter in self.current_event.parameters}
return parameters

@override
def get_openapi_json_schema(
self,
*,
title: str = "Powertools API",
version: str = DEFAULT_API_VERSION,
openapi_version: str = DEFAULT_OPENAPI_VERSION,
summary: str | None = None,
description: str | None = None,
tags: list[Tag | str] | None = None,
servers: list[Server] | None = None,
terms_of_service: str | None = None,
contact: Contact | None = None,
license_info: License | None = None,
security_schemes: dict[str, SecurityScheme] | None = None,
security: list[dict[str, list[str]]] | None = None,
) -> str:
"""
Returns the OpenAPI schema as a JSON serializable dict.
Since Bedrock Agents only support OpenAPI 3.0.0, we convert OpenAPI 3.1.0 schemas
and enforce 3.0.0 compatibility for seamless integration.
Parameters
----------
title: str
The title of the application.
version: str
The version of the OpenAPI document (which is distinct from the OpenAPI Specification version or the API
openapi_version: str, default = "3.0.0"
The version of the OpenAPI Specification (which the document uses).
summary: str, optional
A short summary of what the application does.
description: str, optional
A verbose explanation of the application behavior.
tags: list[Tag, str], optional
A list of tags used by the specification with additional metadata.
servers: list[Server], optional
An array of Server Objects, which provide connectivity information to a target server.
terms_of_service: str, optional
A URL to the Terms of Service for the API. MUST be in the format of a URL.
contact: Contact, optional
The contact information for the exposed API.
license_info: License, optional
The license information for the exposed API.
security_schemes: dict[str, SecurityScheme]], optional
A declaration of the security schemes available to be used in the specification.
security: list[dict[str, list[str]]], optional
A declaration of which security mechanisms are applied globally across the API.
Returns
-------
str
The OpenAPI schema as a JSON serializable dict.
"""
from aws_lambda_powertools.event_handler.openapi.compat import model_json

schema = super().get_openapi_schema(
title=title,
version=version,
openapi_version=openapi_version,
summary=summary,
description=description,
tags=tags,
servers=servers,
terms_of_service=terms_of_service,
contact=contact,
license_info=license_info,
security_schemes=security_schemes,
security=security,
)
schema.openapi = "3.0.3"

# Transform OpenAPI 3.1 into 3.0
def inner(yaml_dict):
if isinstance(yaml_dict, dict):
if "anyOf" in yaml_dict and isinstance((anyOf := yaml_dict["anyOf"]), list):
for i, item in enumerate(anyOf):
if isinstance(item, dict) and item.get("type") == "null":
anyOf.pop(i)
yaml_dict["nullable"] = True
if "examples" in yaml_dict:
examples = yaml_dict["examples"]
del yaml_dict["examples"]
if isinstance(examples, list) and len(examples):
yaml_dict["example"] = examples[0]
for value in yaml_dict.values():
inner(value)
elif isinstance(yaml_dict, list):
for item in yaml_dict:
inner(item)

model = json.loads(
model_json(
schema,
by_alias=True,
exclude_none=True,
indent=2,
),
)

inner(model)

return json.dumps(model)
21 changes: 20 additions & 1 deletion tests/functional/event_handler/test_bedrock_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from typing import Any, Dict
from typing import Any, Dict, Optional

import pytest
from typing_extensions import Annotated

from aws_lambda_powertools.event_handler import BedrockAgentResolver, Response, content_types
Expand Down Expand Up @@ -181,3 +182,21 @@ def send_reminders(
# THEN return the correct result
body = result["response"]["responseBody"]["application/json"]["body"]
assert json.loads(body) is True


@pytest.mark.usefixtures("pydanticv2_only")
def test_openapi_schema_for_pydanticv2(openapi30_schema):
# GIVEN BedrockAgentResolver is initialized with enable_validation=True
app = BedrockAgentResolver(enable_validation=True)

# WHEN we have a simple handler
@app.get("/", description="Testing")
def handler() -> Optional[Dict]:
pass

# WHEN we get the schema
schema = json.loads(app.get_openapi_json_schema())

# THEN the schema must be a valid 3.0.3 version
assert openapi30_schema(schema)
assert schema.get("openapi") == "3.0.3"

0 comments on commit 9acfee4

Please sign in to comment.