Skip to content

Commit

Permalink
Merge pull request #19 from lmignon/master-recursive-mode-rebuild
Browse files Browse the repository at this point in the history
Ensure recursive field annotation resolution
  • Loading branch information
lmignon authored Nov 23, 2023
2 parents 55cd466 + c954629 commit c2aeed0
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 8 deletions.
39 changes: 39 additions & 0 deletions news/19.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
Fix problem with unresolved annotated types in the aggregated model.

At the end of the registry building process, the registry contains the aggregated
model. Each aggregated model is the result of the build of a new class based on
a hierarchy of all the classes defined as 'extends' of the same base class. The
order of the classes hierarchy is defined by the order in which the classes are
loaded by the class loader or by a specific order defined by the developer when
the registry is built.

The last step of the build process is to resolve all the annotated types in the
aggregated model and rebuild the pydantic schema validator. This step is necessary
because when the developer defines a model, some fields can be annotated with a
type that refers to a class that is an extendable class. It's therefore necessary
to update the annotated type with the aggregated result of the specified
extendable class and rebuild the pydantic schema validator to take into account
the new annotated types.

Prior to this commit, the resolution of the annotated types was not done in a
recursive way and the rebuild of the pydantic schema validator was only done
just after the resolution of an aggregated class. This means that if a class A
is an extendable defining a fields annotated with a type that refers to a class
B, and if the class B is an extendable class defining a field of type C,
the annotated type of the field of the class A was resolved with the aggregated
model of the class B but we didn't resolve th annotated type of the field ot type
B with the aggregated model of the type C. Therefore when the pydantic schema
validator was rebuilt after the resolution of the class A, if the class B was
not yet resolved and therefore the pydantic schema validator was not rebuilt,
the new schema validator for the class A was not correct because it didn't take
into account the aggregated model of the class C nor the definition of extra
fields of the aggregated model of the class B.

This commit changes the resolution of the annotated types to be recursive. Therefore
when the pydantic schema validator is rebuilt, we are sure that all referenced
subtypes are resolved and defines a correct schema validator. In the
same time, when an aggregated class is resolved, it's marked as resolved to avoid
to resolve it again and rebuild the pydantic schema validator again for nothing.
In addition to resolve the initial problem, this commit also improves
the performance of the build process because schema validators rebuilds are
done only once per aggregated class.
12 changes: 6 additions & 6 deletions src/extendable_pydantic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@


class ExtendableModelMeta(ExtendableMeta, ModelMetaclass):
__xreg_fields_resolved__: bool = False

@no_type_check
@classmethod
def _build_original_class(metacls, name, bases, namespace, **kwargs):
Expand All @@ -45,6 +47,7 @@ def _prepare_namespace(
namespace = super()._prepare_namespace(
name=name, bases=bases, namespace=namespace, extends=extends, **kwargs
)
namespace["__xreg_fields_resolved__"] = False
return namespace

@no_type_check
Expand Down Expand Up @@ -93,6 +96,9 @@ def _resolve_submodel_fields(
"""Replace the original field type into the definition of the field by the one
from the registry."""
registry = registry if registry else context.extendable_registry.get()
if cls.__xreg_fields_resolved__:
return
cls.__xreg_fields_resolved__ = True
to_rebuild = False
if issubclass(cls, BaseModel):
for field_name, field_info in cast(BaseModel, cls).model_fields.items():
Expand All @@ -105,13 +111,11 @@ def _resolve_submodel_fields(
if to_rebuild:
delattr(cls, "__pydantic_core_schema__")
cast(BaseModel, cls).model_rebuild(force=True)
return


class RegistryListener(ExtendableRegistryListener):
def on_registry_initialized(self, registry: ExtendableClassesRegistry) -> None:
self.resolve_submodel_fields(registry)
self.rebuild_models(registry)

def before_init_registry(
self,
Expand All @@ -124,10 +128,6 @@ def before_init_registry(
if "extendable_pydantic" not in module_matchings:
module_matchings.insert(0, "extendable_pydantic.models")

def rebuild_models(self, registry: ExtendableClassesRegistry) -> None:
for cls in registry._extendable_classes.values():
cast(BaseModel, cls).model_rebuild(force=True)

def resolve_submodel_fields(self, registry: ExtendableClassesRegistry) -> None:
for cls in registry._extendable_classes.values():
if issubclass(type(cls), ExtendableModelMeta):
Expand Down
5 changes: 4 additions & 1 deletion src/extendable_pydantic/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,10 @@ def resolve_annotation(
# semantics as "typing" classes or generic aliases

if not origin_type and issubclass(type(type_), ExtendableMeta):
return type_._get_assembled_cls(registry)
final_type = type_._get_assembled_cls(registry)
if final_type is not type_:
final_type._resolve_submodel_fields(registry)
return final_type

# Handle special case for typehints that can have lists as arguments.
# `typing.Callable[[int, str], int]` is an example for this.
Expand Down
61 changes: 60 additions & 1 deletion tests/test_generics_inheritance.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Test generics model inheritance."""
from typing import Generic, List, TypeVar
from typing import Generic, List, TypeVar, Optional

try:
from typing import Literal
Expand All @@ -9,6 +9,7 @@
from pydantic.main import BaseModel

from extendable_pydantic import ExtendableModelMeta
from extendable_pydantic.models import ExtendableBaseModel

from .conftest import skip_not_supported_version_for_generics

Expand Down Expand Up @@ -117,6 +118,64 @@ class SearchResultExtended(SearchResult[T], Generic[T], extends=SearchResult[T])
}


@skip_not_supported_version_for_generics
def test_generic_with_nested_extended(test_registry):
T = TypeVar("T")

class SearchResult(ExtendableBaseModel, Generic[T]):
total: int
results: List[T]

class Level(ExtendableBaseModel):
val: int

class SearchLevelResult(SearchResult[Level]):
pass

class Level11(ExtendableBaseModel):
val: int

class Level1(ExtendableBaseModel):
val: int
level11: Optional[Level11]

class Level11Extended(Level11, extends=True):
name: str = "level11"

class Level1Extended(Level1, extends=True):
name: str = "level1"

class LevelExtended(Level, extends=True):
name: str = "level"
level1: Optional[Level1]

test_registry.init_registry()

assert Level11(val=3).model_dump() == {"val": 3, "name": "level11"}

item = SearchLevelResult(
total=0,
results=[Level(val=1, level1=Level1(val=2, level11=Level11(val=3)))],
)
assert item.model_dump() == {
"total": 0,
"results": [
{
"val": 1,
"level1": {
"val": 2,
"level11": {
"val": 3,
"name": "level11",
},
"name": "level1",
},
"name": "level",
}
],
}


@skip_not_supported_version_for_generics
def test_extended_generics_of_extended_model(test_registry):
"""In this test we check that the extension of a genrics of extended model
Expand Down

0 comments on commit c2aeed0

Please sign in to comment.