Skip to content

Commit

Permalink
Fix the use of extendable model as query params into fastapi
Browse files Browse the repository at this point in the history
  • Loading branch information
lmignon committed Jan 8, 2025
1 parent 3512838 commit 3ef2db7
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 1 deletion.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ requires-python = ">=3.7"
test = [
"pytest",
"coverage[toml]",
"fastapi>=0.111",
"httpx",
]
mypy = [
"mypy>=1.4.1",
Expand Down
18 changes: 17 additions & 1 deletion src/extendable_pydantic/_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

def _resolve_model_fields_annotation(model_fields):
registry = context.extendable_registry.get()
if registry:
if registry and registry.ready:
for field in model_fields:
field_info = field.field_info
new_type = resolve_annotation(field_info.annotation)
Expand Down Expand Up @@ -77,3 +77,19 @@ def _create_response_field_wrapper(wrapped, instance, args, kwargs):
wrapt.wrap_function_wrapper(
utils, "create_model_field", _create_response_field_wrapper
)


@wrapt.when_imported("fastapi.dependencies.utils")
def hook_fastapi_dependencies_utils(utils):
def _analyze_param_wrapper(wrapped, instance, args, kwargs):
registry = context.extendable_registry.get()
if registry and registry.ready:
annotation = kwargs.get("annotation")
if annotation:
new_type = resolve_annotation(annotation)
if not all_identical(annotation, new_type):
kwargs["annotation"] = new_type
return wrapped(*args, **kwargs)

if hasattr(utils, "analyze_param"):
wrapt.wrap_function_wrapper(utils, "analyze_param", _analyze_param_wrapper)
59 changes: 59 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from extendable_pydantic import _patch # noqa: F401
import pytest
import sys
from extendable import context, main, registry
from fastapi import FastAPI, APIRouter
from fastapi.testclient import TestClient
from typing import Annotated

from fastapi import Depends
from extendable_pydantic import ExtendableBaseModel


skip_not_supported_version_for_generics = pytest.mark.skipif(
Expand All @@ -21,3 +28,55 @@ def test_registry() -> registry.ExtendableClassesRegistry:
finally:
main._extendable_class_defs_by_module = initial_class_defs
context.extendable_registry.reset(token)


@pytest.fixture
def test_fastapi(test_registry) -> TestClient:
app = FastAPI()
my_router = APIRouter()

class TestRequest(ExtendableBaseModel):
name: str = "rqst"

def get_type(self) -> str:
return "request"

class TestResponse(ExtendableBaseModel):
name: str = "resp"

def get_type(self) -> str:
return "response"

@my_router.get("/")
def get() -> TestResponse:
"""Get method."""
resp = TestResponse(name="World")
assert hasattr(resp, "id")
return resp

@my_router.post("/")
def post(rqst: TestRequest) -> TestResponse:
"""Post method."""
resp = TestResponse(**rqst.model_dump())
assert hasattr(resp, "id")
return resp

@my_router.get("/extended")
def get_with_params(rqst: Annotated[TestRequest, Depends()]) -> TestResponse:
"""Get method with parameters."""
resp = TestResponse(**rqst.model_dump())
assert hasattr(resp, "id")
return resp

class ExtendedTestRequest(TestRequest, extends=TestRequest):
id: int = 1

class ExtendedTestResponse(TestResponse, extends=TestResponse):
id: int = 2

test_registry.init_registry()

app.include_router(my_router)

with TestClient(app) as client:
yield client
57 changes: 57 additions & 0 deletions tests/test_fastapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""Test fastapi integration."""


def test_open_api_schema(test_fastapi):
client = test_fastapi
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
schema = response.json()
rqst_schema = schema["components"]["schemas"]["TestRequest"]
assert rqst_schema["properties"] == {
"name": {"title": "Name", "type": "string", "default": "rqst"},
"id": {"title": "Id", "type": "integer", "default": 1},
}
resp_schema = schema["components"]["schemas"]["TestResponse"]
assert resp_schema["properties"] == {
"name": {"title": "Name", "type": "string", "default": "resp"},
"id": {"title": "Id", "type": "integer", "default": 2},
}

extended_get_params = schema["paths"]["/extended"]["get"]["parameters"]
assert len(extended_get_params) == 2
assert extended_get_params[0] == {
"in": "query",
"name": "name",
"required": False,
"schema": {"title": "Name", "type": "string", "default": "rqst"},
}
assert extended_get_params[1] == {
"in": "query",
"name": "id",
"required": False,
"schema": {"title": "Id", "type": "integer", "default": 1},
}


def test_extended_response(test_fastapi):
"""Test extended pydantic model as response."""
client = test_fastapi
response = client.get("/")
assert response.status_code == 200
assert response.json() == {"name": "World", "id": 2}


def test_extended_request(test_fastapi):
"""Test extended pydantic model as json request."""
client = test_fastapi
response = client.post("/", json={"name": "Hello", "id": 3})
assert response.status_code == 200
assert response.json() == {"name": "Hello", "id": 3}


def test_extended_request_with_params(test_fastapi):
"""Test extended pydantic model as request with parameters."""
client = test_fastapi
response = client.get("/extended", params={"name": "echo", "id": 3})
assert response.status_code == 200
assert response.json() == {"name": "echo", "id": 3}

0 comments on commit 3ef2db7

Please sign in to comment.