Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-fastapi] Fix additionalProperties support & support pydantic v2 #19312

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -853,6 +853,8 @@ public Map<String, ModelsMap> postProcessAllModels(Map<String, ModelsMap> objs)
codegenModelMap.put(cm.classname, ModelUtils.getModelByName(entry.getKey(), objs));
}

propagateDiscriminatorValuesToProperties(processed);

// create circular import
for (String m : codegenModelMap.keySet()) {
createImportMapOfSet(m, codegenModelMap);
Expand Down Expand Up @@ -1046,6 +1048,52 @@ private ModelsMap postProcessModelsMap(ModelsMap objs) {
return objs;
}

private void propagateDiscriminatorValuesToProperties(Map<String, ModelsMap> objMap) {
HashMap<String, CodegenModel> modelMap = new HashMap<>();
for (Map.Entry<String, ModelsMap> entry : objMap.entrySet()) {
for (ModelMap m : entry.getValue().getModels()) {
modelMap.put("#/components/schemas/" + entry.getKey(), m.getModel());
}
}

for (Map.Entry<String, ModelsMap> entry : objMap.entrySet()) {
for (ModelMap m : entry.getValue().getModels()) {
CodegenModel model = m.getModel();
if (model.discriminator != null && !model.oneOf.isEmpty()) {
// Populate default, implicit discriminator values
for (String typeName : model.oneOf) {
ModelsMap obj = objMap.get(typeName);
if (obj == null) {
continue;
}
for (ModelMap m1 : obj.getModels()) {
for (CodegenProperty p : m1.getModel().vars) {
if (p.baseName.equals(model.discriminator.getPropertyBaseName())) {
p.isDiscriminator = true;
p.discriminatorValue = typeName;
}
}
}
}
// Populate explicit discriminator values from mapping, overwriting default values
if (model.discriminator.getMapping() != null) {
for (Map.Entry<String, String> discrEntry : model.discriminator.getMapping().entrySet()) {
CodegenModel resolved = modelMap.get(discrEntry.getValue());
if (resolved != null) {
for (CodegenProperty p : resolved.vars) {
if (p.baseName.equals(model.discriminator.getPropertyBaseName())) {
p.isDiscriminator = true;
p.discriminatorValue = discrEntry.getKey();
}
}
}
}
}
}
}
}
}


/*
* Gets the pydantic type given a Codegen Property
Expand Down Expand Up @@ -2160,7 +2208,16 @@ private PythonType getType(CodegenProperty cp) {
}

private String finalizeType(CodegenProperty cp, PythonType pt) {
if (!cp.required || cp.isNullable) {
if (cp.isDiscriminator && cp.discriminatorValue != null) {
moduleImports.add("typing", "Literal");
PythonType literal = new PythonType("Literal");
String literalValue = '"'+escapeText(cp.discriminatorValue)+'"';
PythonType valueType = new PythonType(literalValue);
literal.addTypeParam(valueType);
literal.setDefaultValue(literalValue);
cp.setDefaultValue(literalValue);
pt = literal;
} else if (!cp.required || cp.isNullable) {
moduleImports.add("typing", "Optional");
PythonType opt = new PythonType("Optional");
opt.addTypeParam(pt);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ public PythonFastAPIServerCodegen() {
.defaultValue(implPackage));
}

@Override
protected void addParentFromContainer(CodegenModel model, Schema schema) {
// we do not want to inherit simply because additionalProperties is set
}

@Override
public void processOpts() {
super.processOpts();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ for _, name, _ in pkgutil.iter_modules(ns_pkg.__path__, ns_pkg.__name__ + "."):
description = "{{.}}",
{{/description}}
response_model_by_alias=True,
response_model_exclude_unset=True,
)
async def {{operationId}}(
{{#allParams}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,174 +14,56 @@ import re # noqa: F401
{{/vendorExtensions.x-py-model-imports}}
from typing import Union, Any, List, TYPE_CHECKING, Optional, Dict
from typing_extensions import Literal
from pydantic import StrictStr, Field
from pydantic import StrictStr, Field, RootModel
try:
from typing import Self
except ImportError:
from typing_extensions import Self

{{#lambda.uppercase}}{{{classname}}}{{/lambda.uppercase}}_ANY_OF_SCHEMAS = [{{#anyOf}}"{{.}}"{{^-last}}, {{/-last}}{{/anyOf}}]

class {{classname}}({{#parent}}{{{.}}}{{/parent}}{{^parent}}BaseModel{{/parent}}):
class {{classname}}({{#parent}}{{{.}}}{{/parent}}{{^parent}}RootModel{{/parent}}):
"""
{{{description}}}{{^description}}{{{classname}}}{{/description}}
"""

{{#composedSchemas.anyOf}}
# data type: {{{dataType}}}
{{vendorExtensions.x-py-name}}: {{{vendorExtensions.x-py-typing}}}
{{/composedSchemas.anyOf}}
if TYPE_CHECKING:
actual_instance: Optional[Union[{{#anyOf}}{{{.}}}{{^-last}}, {{/-last}}{{/anyOf}}]] = None
else:
actual_instance: Any = None
any_of_schemas: List[str] = Literal[{{#lambda.uppercase}}{{{classname}}}{{/lambda.uppercase}}_ANY_OF_SCHEMAS]
root: Union[{{#anyOf}}{{{.}}}{{^-last}}, {{/-last}}{{/anyOf}}] = None

model_config = {
"validate_assignment": True,
"protected_namespaces": (),
}
{{#discriminator}}

discriminator_value_class_map: Dict[str, str] = {
{{#children}}
'{{^vendorExtensions.x-discriminator-value}}{{name}}{{/vendorExtensions.x-discriminator-value}}{{#vendorExtensions.x-discriminator-value}}{{{vendorExtensions.x-discriminator-value}}}{{/vendorExtensions.x-discriminator-value}}': '{{{classname}}}'{{^-last}},{{/-last}}
{{/children}}
}
{{/discriminator}}

def __init__(self, *args, **kwargs) -> None:
if args:
if len(args) > 1:
raise ValueError("If a position argument is used, only 1 is allowed to set `actual_instance`")
if kwargs:
raise ValueError("If a position argument is used, keyword arguments cannot be used.")
super().__init__(actual_instance=args[0])
else:
super().__init__(**kwargs)

@field_validator('actual_instance')
def actual_instance_must_validate_anyof(cls, v):
{{#isNullable}}
if v is None:
return v

{{/isNullable}}
instance = {{{classname}}}.model_construct()
error_messages = []
{{#composedSchemas.anyOf}}
# validate data type: {{{dataType}}}
{{#isContainer}}
try:
instance.{{vendorExtensions.x-py-name}} = v
return v
except (ValidationError, ValueError) as e:
error_messages.append(str(e))
{{/isContainer}}
{{^isContainer}}
{{#isPrimitiveType}}
try:
instance.{{vendorExtensions.x-py-name}} = v
return v
except (ValidationError, ValueError) as e:
error_messages.append(str(e))
{{/isPrimitiveType}}
{{^isPrimitiveType}}
if not isinstance(v, {{{dataType}}}):
error_messages.append(f"Error! Input type `{type(v)}` is not `{{{dataType}}}`")
else:
return v

{{/isPrimitiveType}}
{{/isContainer}}
{{/composedSchemas.anyOf}}
if error_messages:
# no match
raise ValueError("No match found when setting the actual_instance in {{{classname}}} with anyOf schemas: {{#anyOf}}{{{.}}}{{^-last}}, {{/-last}}{{/anyOf}}. Details: " + ", ".join(error_messages))
else:
return v

@classmethod
def from_dict(cls, obj: dict) -> Self:
return cls.from_json(json.dumps(obj))

@classmethod
def from_json(cls, json_str: str) -> Self:
"""Returns the object represented by the json string"""
instance = cls.model_construct()
{{#isNullable}}
if json_str is None:
return instance

{{/isNullable}}
error_messages = []
{{#composedSchemas.anyOf}}
{{#isContainer}}
# deserialize data into {{{dataType}}}
try:
# validation
instance.{{vendorExtensions.x-py-name}} = json.loads(json_str)
# assign value to actual_instance
instance.actual_instance = instance.{{vendorExtensions.x-py-name}}
return instance
except (ValidationError, ValueError) as e:
error_messages.append(str(e))
{{/isContainer}}
{{^isContainer}}
{{#isPrimitiveType}}
# deserialize data into {{{dataType}}}
try:
# validation
instance.{{vendorExtensions.x-py-name}} = json.loads(json_str)
# assign value to actual_instance
instance.actual_instance = instance.{{vendorExtensions.x-py-name}}
return instance
except (ValidationError, ValueError) as e:
error_messages.append(str(e))
{{/isPrimitiveType}}
{{^isPrimitiveType}}
# {{vendorExtensions.x-py-name}}: {{{vendorExtensions.x-py-typing}}}
try:
instance.actual_instance = {{{dataType}}}.from_json(json_str)
return instance
except (ValidationError, ValueError) as e:
error_messages.append(str(e))
{{/isPrimitiveType}}
{{/isContainer}}
{{/composedSchemas.anyOf}}

if error_messages:
# no match
raise ValueError("No match found when deserializing the JSON string into {{{classname}}} with anyOf schemas: {{#anyOf}}{{{.}}}{{^-last}}, {{/-last}}{{/anyOf}}. Details: " + ", ".join(error_messages))
else:
return instance
def to_str(self) -> str:
"""Returns the string representation of the model using alias"""
return pprint.pformat(self.model_dump(by_alias=True))

def to_json(self) -> str:
"""Returns the JSON representation of the actual instance"""
if self.actual_instance is None:
return "null"
"""Returns the JSON representation of the model using alias"""
return self.model_dump_json(by_alias=True, exclude_unset=True)

to_json = getattr(self.actual_instance, "to_json", None)
if callable(to_json):
return self.actual_instance.to_json()
@classmethod
def from_json(cls, json_str: str) -> {{^hasChildren}}Self{{/hasChildren}}{{#hasChildren}}{{#discriminator}}Union[{{#children}}Self{{^-last}}, {{/-last}}{{/children}}]{{/discriminator}}{{^discriminator}}Self{{/discriminator}}{{/hasChildren}}:
"""Create an instance of {{{classname}}} from a JSON string"""
return cls.from_dict(json.loads(json_str))

def to_dict(self) -> Dict[str, Any]:
"""Return the dictionary representation of the model using alias"""
to_dict = getattr(self.root, "to_dict", None)
if callable(to_dict):
return self.model_dump(by_alias=True, exclude_unset=True)
else:
return json.dumps(self.actual_instance)
# primitive type
return self.root

def to_dict(self) -> Dict:
"""Returns the dict representation of the actual instance"""
if self.actual_instance is None:
return "null"
@classmethod
def from_dict(cls, obj: Dict) -> {{^hasChildren}}Self{{/hasChildren}}{{#hasChildren}}{{#discriminator}}Union[{{#children}}Self{{^-last}}, {{/-last}}{{/children}}]{{/discriminator}}{{^discriminator}}Self{{/discriminator}}{{/hasChildren}}:
"""Create an instance of {{{classname}}} from a dict"""
if obj is None:
return None

to_json = getattr(self.actual_instance, "to_json", None)
if callable(to_json):
return self.actual_instance.to_dict()
else:
# primitive type
return self.actual_instance
if not isinstance(obj, dict):
return cls.model_validate(obj)

def to_str(self) -> str:
"""Returns the string representation of the actual instance"""
return pprint.pformat(self.model_dump())
return cls.parse_obj(obj)

{{#vendorExtensions.x-py-postponed-model-imports.size}}
{{#vendorExtensions.x-py-postponed-model-imports}}
Expand Down
Loading
Loading