From 1605ab82ddb5671155209455560d9aaf6a924b9c Mon Sep 17 00:00:00 2001 From: Max Goltzsche Date: Thu, 17 Oct 2024 00:39:13 +0200 Subject: [PATCH] [python-fastapi] support oneOf the pydantic v2 way * Support oneOf and anyOf schemas the pydantic v2 way by generating them as Unions. * Generate model constructor that forcefully sets the discriminator field to ensure it is included in the marshalled representation. --- .../languages/AbstractPythonCodegen.java | 59 +++++- .../python-fastapi/model_anyof.mustache | 174 +++------------- .../python-fastapi/model_generic.mustache | 7 + .../python-fastapi/model_oneof.mustache | 194 +++--------------- .../python/PythonFastapiCodegenTest.java | 81 +++++++- .../petstore/python-aiohttp/docs/BasquePig.md | 2 +- .../petstore/python-aiohttp/docs/DanishPig.md | 2 +- .../petstore_api/models/basque_pig.py | 4 +- .../petstore_api/models/danish_pig.py | 4 +- .../tests/test_deserialization.py | 4 +- .../client/petstore/python/docs/BasquePig.md | 2 +- .../client/petstore/python/docs/DanishPig.md | 2 +- .../python/petstore_api/models/basque_pig.py | 4 +- .../python/petstore_api/models/danish_pig.py | 4 +- .../python/tests/test_deserialization.py | 4 +- .../src/openapi_server/models/api_response.py | 2 + .../src/openapi_server/models/category.py | 2 + .../src/openapi_server/models/order.py | 2 + .../src/openapi_server/models/pet.py | 2 + .../src/openapi_server/models/tag.py | 2 + .../src/openapi_server/models/user.py | 2 + 21 files changed, 218 insertions(+), 341 deletions(-) diff --git a/modules/openapi-generator/src/main/java/org/openapitools/codegen/languages/AbstractPythonCodegen.java b/modules/openapi-generator/src/main/java/org/openapitools/codegen/languages/AbstractPythonCodegen.java index 8288294ffbdef..49075f099d16d 100644 --- a/modules/openapi-generator/src/main/java/org/openapitools/codegen/languages/AbstractPythonCodegen.java +++ b/modules/openapi-generator/src/main/java/org/openapitools/codegen/languages/AbstractPythonCodegen.java @@ -853,6 +853,8 @@ public Map postProcessAllModels(Map objs) codegenModelMap.put(cm.classname, ModelUtils.getModelByName(entry.getKey(), objs)); } + propagateDiscriminatorValuesToProperties(processed); + // create circular import for (String m : codegenModelMap.keySet()) { createImportMapOfSet(m, codegenModelMap); @@ -1046,6 +1048,52 @@ private ModelsMap postProcessModelsMap(ModelsMap objs) { return objs; } + private void propagateDiscriminatorValuesToProperties(Map objMap) { + HashMap modelMap = new HashMap<>(); + for (Map.Entry entry : objMap.entrySet()) { + for (ModelMap m : entry.getValue().getModels()) { + modelMap.put("#/components/schemas/" + entry.getKey(), m.getModel()); + } + } + + for (Map.Entry entry : objMap.entrySet()) { + for (ModelMap m : entry.getValue().getModels()) { + CodegenModel model = m.getModel(); + if (model.discriminator != null && !model.oneOf.isEmpty()) { + // Populate default, implicit discriminator values + for (String typeName : model.oneOf) { + ModelsMap obj = objMap.get(typeName); + if (obj == null) { + continue; + } + for (ModelMap m1 : obj.getModels()) { + for (CodegenProperty p : m1.getModel().vars) { + if (p.baseName.equals(model.discriminator.getPropertyBaseName())) { + p.isDiscriminator = true; + p.discriminatorValue = typeName; + } + } + } + } + // Populate explicit discriminator values from mapping, overwriting default values + if (model.discriminator.getMapping() != null) { + for (Map.Entry discrEntry : model.discriminator.getMapping().entrySet()) { + CodegenModel resolved = modelMap.get(discrEntry.getValue()); + if (resolved != null) { + for (CodegenProperty p : resolved.vars) { + if (p.baseName.equals(model.discriminator.getPropertyBaseName())) { + p.isDiscriminator = true; + p.discriminatorValue = discrEntry.getKey(); + } + } + } + } + } + } + } + } + } + /* * Gets the pydantic type given a Codegen Property @@ -2160,7 +2208,16 @@ private PythonType getType(CodegenProperty cp) { } private String finalizeType(CodegenProperty cp, PythonType pt) { - if (!cp.required || cp.isNullable) { + if (cp.isDiscriminator && cp.discriminatorValue != null) { + moduleImports.add("typing", "Literal"); + PythonType literal = new PythonType("Literal"); + String literalValue = '"'+escapeText(cp.discriminatorValue)+'"'; + PythonType valueType = new PythonType(literalValue); + literal.addTypeParam(valueType); + literal.setDefaultValue(literalValue); + cp.setDefaultValue(literalValue); + pt = literal; + } else if (!cp.required || cp.isNullable) { moduleImports.add("typing", "Optional"); PythonType opt = new PythonType("Optional"); opt.addTypeParam(pt); diff --git a/modules/openapi-generator/src/main/resources/python-fastapi/model_anyof.mustache b/modules/openapi-generator/src/main/resources/python-fastapi/model_anyof.mustache index b145f73ad13b8..dc945bdaf1ad1 100644 --- a/modules/openapi-generator/src/main/resources/python-fastapi/model_anyof.mustache +++ b/modules/openapi-generator/src/main/resources/python-fastapi/model_anyof.mustache @@ -14,174 +14,56 @@ import re # noqa: F401 {{/vendorExtensions.x-py-model-imports}} from typing import Union, Any, List, TYPE_CHECKING, Optional, Dict from typing_extensions import Literal -from pydantic import StrictStr, Field +from pydantic import StrictStr, Field, RootModel try: from typing import Self except ImportError: from typing_extensions import Self -{{#lambda.uppercase}}{{{classname}}}{{/lambda.uppercase}}_ANY_OF_SCHEMAS = [{{#anyOf}}"{{.}}"{{^-last}}, {{/-last}}{{/anyOf}}] - -class {{classname}}({{#parent}}{{{.}}}{{/parent}}{{^parent}}BaseModel{{/parent}}): +class {{classname}}({{#parent}}{{{.}}}{{/parent}}{{^parent}}RootModel{{/parent}}): """ {{{description}}}{{^description}}{{{classname}}}{{/description}} """ -{{#composedSchemas.anyOf}} - # data type: {{{dataType}}} - {{vendorExtensions.x-py-name}}: {{{vendorExtensions.x-py-typing}}} -{{/composedSchemas.anyOf}} - if TYPE_CHECKING: - actual_instance: Optional[Union[{{#anyOf}}{{{.}}}{{^-last}}, {{/-last}}{{/anyOf}}]] = None - else: - actual_instance: Any = None - any_of_schemas: List[str] = Literal[{{#lambda.uppercase}}{{{classname}}}{{/lambda.uppercase}}_ANY_OF_SCHEMAS] + root: Union[{{#anyOf}}{{{.}}}{{^-last}}, {{/-last}}{{/anyOf}}] = None model_config = { "validate_assignment": True, "protected_namespaces": (), } -{{#discriminator}} - - discriminator_value_class_map: Dict[str, str] = { -{{#children}} - '{{^vendorExtensions.x-discriminator-value}}{{name}}{{/vendorExtensions.x-discriminator-value}}{{#vendorExtensions.x-discriminator-value}}{{{vendorExtensions.x-discriminator-value}}}{{/vendorExtensions.x-discriminator-value}}': '{{{classname}}}'{{^-last}},{{/-last}} -{{/children}} - } -{{/discriminator}} - - def __init__(self, *args, **kwargs) -> None: - if args: - if len(args) > 1: - raise ValueError("If a position argument is used, only 1 is allowed to set `actual_instance`") - if kwargs: - raise ValueError("If a position argument is used, keyword arguments cannot be used.") - super().__init__(actual_instance=args[0]) - else: - super().__init__(**kwargs) - - @field_validator('actual_instance') - def actual_instance_must_validate_anyof(cls, v): - {{#isNullable}} - if v is None: - return v - - {{/isNullable}} - instance = {{{classname}}}.model_construct() - error_messages = [] - {{#composedSchemas.anyOf}} - # validate data type: {{{dataType}}} - {{#isContainer}} - try: - instance.{{vendorExtensions.x-py-name}} = v - return v - except (ValidationError, ValueError) as e: - error_messages.append(str(e)) - {{/isContainer}} - {{^isContainer}} - {{#isPrimitiveType}} - try: - instance.{{vendorExtensions.x-py-name}} = v - return v - except (ValidationError, ValueError) as e: - error_messages.append(str(e)) - {{/isPrimitiveType}} - {{^isPrimitiveType}} - if not isinstance(v, {{{dataType}}}): - error_messages.append(f"Error! Input type `{type(v)}` is not `{{{dataType}}}`") - else: - return v - - {{/isPrimitiveType}} - {{/isContainer}} - {{/composedSchemas.anyOf}} - if error_messages: - # no match - raise ValueError("No match found when setting the actual_instance in {{{classname}}} with anyOf schemas: {{#anyOf}}{{{.}}}{{^-last}}, {{/-last}}{{/anyOf}}. Details: " + ", ".join(error_messages)) - else: - return v - - @classmethod - def from_dict(cls, obj: dict) -> Self: - return cls.from_json(json.dumps(obj)) - @classmethod - def from_json(cls, json_str: str) -> Self: - """Returns the object represented by the json string""" - instance = cls.model_construct() - {{#isNullable}} - if json_str is None: - return instance - - {{/isNullable}} - error_messages = [] - {{#composedSchemas.anyOf}} - {{#isContainer}} - # deserialize data into {{{dataType}}} - try: - # validation - instance.{{vendorExtensions.x-py-name}} = json.loads(json_str) - # assign value to actual_instance - instance.actual_instance = instance.{{vendorExtensions.x-py-name}} - return instance - except (ValidationError, ValueError) as e: - error_messages.append(str(e)) - {{/isContainer}} - {{^isContainer}} - {{#isPrimitiveType}} - # deserialize data into {{{dataType}}} - try: - # validation - instance.{{vendorExtensions.x-py-name}} = json.loads(json_str) - # assign value to actual_instance - instance.actual_instance = instance.{{vendorExtensions.x-py-name}} - return instance - except (ValidationError, ValueError) as e: - error_messages.append(str(e)) - {{/isPrimitiveType}} - {{^isPrimitiveType}} - # {{vendorExtensions.x-py-name}}: {{{vendorExtensions.x-py-typing}}} - try: - instance.actual_instance = {{{dataType}}}.from_json(json_str) - return instance - except (ValidationError, ValueError) as e: - error_messages.append(str(e)) - {{/isPrimitiveType}} - {{/isContainer}} - {{/composedSchemas.anyOf}} - - if error_messages: - # no match - raise ValueError("No match found when deserializing the JSON string into {{{classname}}} with anyOf schemas: {{#anyOf}}{{{.}}}{{^-last}}, {{/-last}}{{/anyOf}}. Details: " + ", ".join(error_messages)) - else: - return instance + def to_str(self) -> str: + """Returns the string representation of the model using alias""" + return pprint.pformat(self.model_dump(by_alias=True)) def to_json(self) -> str: - """Returns the JSON representation of the actual instance""" - if self.actual_instance is None: - return "null" + """Returns the JSON representation of the model using alias""" + return self.model_dump_json(by_alias=True, exclude_unset=True) - to_json = getattr(self.actual_instance, "to_json", None) - if callable(to_json): - return self.actual_instance.to_json() + @classmethod + def from_json(cls, json_str: str) -> {{^hasChildren}}Self{{/hasChildren}}{{#hasChildren}}{{#discriminator}}Union[{{#children}}Self{{^-last}}, {{/-last}}{{/children}}]{{/discriminator}}{{^discriminator}}Self{{/discriminator}}{{/hasChildren}}: + """Create an instance of {{{classname}}} from a JSON string""" + return cls.from_dict(json.loads(json_str)) + + def to_dict(self) -> Dict[str, Any]: + """Return the dictionary representation of the model using alias""" + to_dict = getattr(self.root, "to_dict", None) + if callable(to_dict): + return self.model_dump(by_alias=True, exclude_unset=True) else: - return json.dumps(self.actual_instance) + # primitive type + return self.root - def to_dict(self) -> Dict: - """Returns the dict representation of the actual instance""" - if self.actual_instance is None: - return "null" + @classmethod + def from_dict(cls, obj: Dict) -> {{^hasChildren}}Self{{/hasChildren}}{{#hasChildren}}{{#discriminator}}Union[{{#children}}Self{{^-last}}, {{/-last}}{{/children}}]{{/discriminator}}{{^discriminator}}Self{{/discriminator}}{{/hasChildren}}: + """Create an instance of {{{classname}}} from a dict""" + if obj is None: + return None - to_json = getattr(self.actual_instance, "to_json", None) - if callable(to_json): - return self.actual_instance.to_dict() - else: - # primitive type - return self.actual_instance + if not isinstance(obj, dict): + return cls.model_validate(obj) - def to_str(self) -> str: - """Returns the string representation of the actual instance""" - return pprint.pformat(self.model_dump()) + return cls.parse_obj(obj) {{#vendorExtensions.x-py-postponed-model-imports.size}} {{#vendorExtensions.x-py-postponed-model-imports}} diff --git a/modules/openapi-generator/src/main/resources/python-fastapi/model_generic.mustache b/modules/openapi-generator/src/main/resources/python-fastapi/model_generic.mustache index 1b4e2c5be1ed5..0bf343806fe0a 100644 --- a/modules/openapi-generator/src/main/resources/python-fastapi/model_generic.mustache +++ b/modules/openapi-generator/src/main/resources/python-fastapi/model_generic.mustache @@ -86,6 +86,13 @@ class {{classname}}({{#parent}}{{{.}}}{{/parent}}{{^parent}}BaseModel{{/parent}} {{/isAdditionalPropertiesTrue}} } + def __init__(self, *a, **kw): + super().__init__(*a, **kw) + {{#vars}} + {{#isDiscriminator}} + self.{{name}} = self.{{name}} + {{/isDiscriminator}} + {{/vars}} def to_str(self) -> str: """Returns the string representation of the model using alias""" diff --git a/modules/openapi-generator/src/main/resources/python-fastapi/model_oneof.mustache b/modules/openapi-generator/src/main/resources/python-fastapi/model_oneof.mustache index b87c42cf2b9be..ed2166a627769 100644 --- a/modules/openapi-generator/src/main/resources/python-fastapi/model_oneof.mustache +++ b/modules/openapi-generator/src/main/resources/python-fastapi/model_oneof.mustache @@ -13,198 +13,52 @@ import re # noqa: F401 {{{.}}} {{/vendorExtensions.x-py-model-imports}} from typing import Union, Any, List, TYPE_CHECKING, Optional, Dict -from typing_extensions import Literal -from pydantic import StrictStr, Field +from typing_extensions import Annotated, Literal +from pydantic import StrictStr, Field, Discriminator, Tag, RootModel try: from typing import Self except ImportError: from typing_extensions import Self -{{#lambda.uppercase}}{{{classname}}}{{/lambda.uppercase}}_ONE_OF_SCHEMAS = [{{#oneOf}}"{{.}}"{{^-last}}, {{/-last}}{{/oneOf}}] - -class {{classname}}({{#parent}}{{{.}}}{{/parent}}{{^parent}}BaseModel{{/parent}}): +class {{classname}}({{#parent}}{{{.}}}{{/parent}}{{^parent}}RootModel{{/parent}}): """ {{{description}}}{{^description}}{{{classname}}}{{/description}} """ -{{#composedSchemas.oneOf}} - # data type: {{{dataType}}} - {{vendorExtensions.x-py-name}}: {{{vendorExtensions.x-py-typing}}} -{{/composedSchemas.oneOf}} - actual_instance: Optional[Union[{{#oneOf}}{{{.}}}{{^-last}}, {{/-last}}{{/oneOf}}]] = None - one_of_schemas: List[str] = Literal[{{#oneOf}}"{{.}}"{{^-last}}, {{/-last}}{{/oneOf}}] + + root: Union[{{#oneOf}}{{{.}}}{{^-last}}, {{/-last}}{{/oneOf}}] = Field({{#discriminator}}{{#mappedModels}}{{#-first}}discriminator='{{{propertyName}}}'{{/-first}}{{/mappedModels}}{{/discriminator}}) model_config = { "validate_assignment": True, "protected_namespaces": (), } -{{#discriminator}} - - discriminator_value_class_map: Dict[str, str] = { -{{#children}} - '{{^vendorExtensions.x-discriminator-value}}{{name}}{{/vendorExtensions.x-discriminator-value}}{{#vendorExtensions.x-discriminator-value}}{{{vendorExtensions.x-discriminator-value}}}{{/vendorExtensions.x-discriminator-value}}': '{{{classname}}}'{{^-last}},{{/-last}} -{{/children}} - } -{{/discriminator}} - - def __init__(self, *args, **kwargs) -> None: - if args: - if len(args) > 1: - raise ValueError("If a position argument is used, only 1 is allowed to set `actual_instance`") - if kwargs: - raise ValueError("If a position argument is used, keyword arguments cannot be used.") - super().__init__(actual_instance=args[0]) - else: - super().__init__(**kwargs) - - @field_validator('actual_instance') - def actual_instance_must_validate_oneof(cls, v): - {{#isNullable}} - if v is None: - return v - - {{/isNullable}} - instance = {{{classname}}}.model_construct() - error_messages = [] - match = 0 - {{#composedSchemas.oneOf}} - # validate data type: {{{dataType}}} - {{#isContainer}} - try: - instance.{{vendorExtensions.x-py-name}} = v - match += 1 - except (ValidationError, ValueError) as e: - error_messages.append(str(e)) - {{/isContainer}} - {{^isContainer}} - {{#isPrimitiveType}} - try: - instance.{{vendorExtensions.x-py-name}} = v - match += 1 - except (ValidationError, ValueError) as e: - error_messages.append(str(e)) - {{/isPrimitiveType}} - {{^isPrimitiveType}} - if not isinstance(v, {{{dataType}}}): - error_messages.append(f"Error! Input type `{type(v)}` is not `{{{dataType}}}`") - else: - match += 1 - {{/isPrimitiveType}} - {{/isContainer}} - {{/composedSchemas.oneOf}} - if match > 1: - # more than 1 match - raise ValueError("Multiple matches found when setting `actual_instance` in {{{classname}}} with oneOf schemas: {{#oneOf}}{{{.}}}{{^-last}}, {{/-last}}{{/oneOf}}. Details: " + ", ".join(error_messages)) - elif match == 0: - # no match - raise ValueError("No match found when setting `actual_instance` in {{{classname}}} with oneOf schemas: {{#oneOf}}{{{.}}}{{^-last}}, {{/-last}}{{/oneOf}}. Details: " + ", ".join(error_messages)) - else: - return v + def to_str(self) -> str: + """Returns the string representation of the model using alias""" + return pprint.pformat(self.model_dump(by_alias=True)) - @classmethod - def from_dict(cls, obj: dict) -> Self: - return cls.from_json(json.dumps(obj)) + def to_json(self) -> str: + """Returns the JSON representation of the model using alias""" + return self.model_dump_json(by_alias=True, exclude_unset=True) @classmethod - def from_json(cls, json_str: str) -> Self: - """Returns the object represented by the json string""" - instance = cls.model_construct() - {{#isNullable}} - if json_str is None: - return instance + def from_json(cls, json_str: str) -> {{^hasChildren}}Self{{/hasChildren}}{{#hasChildren}}{{#discriminator}}Union[{{#children}}Self{{^-last}}, {{/-last}}{{/children}}]{{/discriminator}}{{^discriminator}}Self{{/discriminator}}{{/hasChildren}}: + """Create an instance of {{{classname}}} from a JSON string""" + return cls.from_dict(json.loads(json_str)) - {{/isNullable}} - error_messages = [] - match = 0 + def to_dict(self) -> Dict[str, Any]: + """Return the dictionary representation of the model using alias""" + return self.model_dump(by_alias=True, exclude_unset=True) - {{#useOneOfDiscriminatorLookup}} - {{#discriminator}} - {{#mappedModels}} - {{#-first}} - # use oneOf discriminator to lookup the data type - _data_type = json.loads(json_str).get("{{{propertyBaseName}}}") - if not _data_type: - raise ValueError("Failed to lookup data type from the field `{{{propertyBaseName}}}` in the input.") - - {{/-first}} - # check if data type is `{{{modelName}}}` - if _data_type == "{{{mappingName}}}": - instance.actual_instance = {{{modelName}}}.from_json(json_str) - return instance - - {{/mappedModels}} - {{/discriminator}} - {{/useOneOfDiscriminatorLookup}} - {{#composedSchemas.oneOf}} - {{#isContainer}} - # deserialize data into {{{dataType}}} - try: - # validation - instance.{{vendorExtensions.x-py-name}} = json.loads(json_str) - # assign value to actual_instance - instance.actual_instance = instance.{{vendorExtensions.x-py-name}} - match += 1 - except (ValidationError, ValueError) as e: - error_messages.append(str(e)) - {{/isContainer}} - {{^isContainer}} - {{#isPrimitiveType}} - # deserialize data into {{{dataType}}} - try: - # validation - instance.{{vendorExtensions.x-py-name}} = json.loads(json_str) - # assign value to actual_instance - instance.actual_instance = instance.{{vendorExtensions.x-py-name}} - match += 1 - except (ValidationError, ValueError) as e: - error_messages.append(str(e)) - {{/isPrimitiveType}} - {{^isPrimitiveType}} - # deserialize data into {{{dataType}}} - try: - instance.actual_instance = {{{dataType}}}.from_json(json_str) - match += 1 - except (ValidationError, ValueError) as e: - error_messages.append(str(e)) - {{/isPrimitiveType}} - {{/isContainer}} - {{/composedSchemas.oneOf}} - - if match > 1: - # more than 1 match - raise ValueError("Multiple matches found when deserializing the JSON string into {{{classname}}} with oneOf schemas: {{#oneOf}}{{{.}}}{{^-last}}, {{/-last}}{{/oneOf}}. Details: " + ", ".join(error_messages)) - elif match == 0: - # no match - raise ValueError("No match found when deserializing the JSON string into {{{classname}}} with oneOf schemas: {{#oneOf}}{{{.}}}{{^-last}}, {{/-last}}{{/oneOf}}. Details: " + ", ".join(error_messages)) - else: - return instance - - def to_json(self) -> str: - """Returns the JSON representation of the actual instance""" - if self.actual_instance is None: - return "null" - - to_json = getattr(self.actual_instance, "to_json", None) - if callable(to_json): - return self.actual_instance.to_json() - else: - return json.dumps(self.actual_instance) - - def to_dict(self) -> Dict: - """Returns the dict representation of the actual instance""" - if self.actual_instance is None: + @classmethod + def from_dict(cls, obj: Dict) -> {{^hasChildren}}Self{{/hasChildren}}{{#hasChildren}}{{#discriminator}}Union[{{#children}}Self{{^-last}}, {{/-last}}{{/children}}]{{/discriminator}}{{^discriminator}}Self{{/discriminator}}{{/hasChildren}}: + """Create an instance of {{{classname}}} from a dict""" + if obj is None: return None - to_dict = getattr(self.actual_instance, "to_dict", None) - if callable(to_dict): - return self.actual_instance.to_dict() - else: - # primitive type - return self.actual_instance + if not isinstance(obj, dict): + return cls.model_validate(obj) - def to_str(self) -> str: - """Returns the string representation of the actual instance""" - return pprint.pformat(self.model_dump()) + return cls.parse_obj(obj) {{#vendorExtensions.x-py-postponed-model-imports.size}} {{#vendorExtensions.x-py-postponed-model-imports}} diff --git a/modules/openapi-generator/src/test/java/org/openapitools/codegen/python/PythonFastapiCodegenTest.java b/modules/openapi-generator/src/test/java/org/openapitools/codegen/python/PythonFastapiCodegenTest.java index a27450fea1868..533eacf09d925 100644 --- a/modules/openapi-generator/src/test/java/org/openapitools/codegen/python/PythonFastapiCodegenTest.java +++ b/modules/openapi-generator/src/test/java/org/openapitools/codegen/python/PythonFastapiCodegenTest.java @@ -1,12 +1,17 @@ package org.openapitools.codegen.python; import com.google.common.collect.Sets; +import io.swagger.v3.oas.models.Components; import io.swagger.v3.oas.models.OpenAPI; import io.swagger.v3.oas.models.media.ArraySchema; +import io.swagger.v3.oas.models.media.ComposedSchema; +import io.swagger.v3.oas.models.media.Discriminator; import io.swagger.v3.oas.models.media.Schema; import org.openapitools.codegen.*; import org.openapitools.codegen.config.CodegenConfigurator; -import org.openapitools.codegen.languages.PythonClientCodegen; +import org.openapitools.codegen.languages.PythonFastAPIServerCodegen; +import org.openapitools.codegen.model.ModelMap; +import org.openapitools.codegen.model.ModelsMap; import org.testng.Assert; import org.testng.annotations.Test; @@ -14,8 +19,9 @@ import java.io.IOException; import java.nio.file.Files; import java.nio.file.Paths; -import java.util.Collections; import java.util.List; +import java.util.Set; +import java.util.TreeMap; public class PythonFastapiCodegenTest { @Test @@ -60,22 +66,81 @@ public void testEndpointSpecsWithoutDescription() throws IOException { } @Test(description = "additionalProperties should not let container type inherit their type") - public void additionalPropertiesModelTest() { - final Schema model = new ArraySchema() - //.description() + public void testAdditionalProperties() { + Schema model = new ArraySchema() .items(new Schema().type("object").additionalProperties(new Schema().type("string"))) .description("model with additionalProperties"); - final DefaultCodegen codegen = new PythonClientCodegen(); + DefaultCodegen codegen = new PythonFastAPIServerCodegen(); OpenAPI openAPI = TestUtils.createOpenAPIWithOneSchema("sample", model); codegen.setOpenAPI(openAPI); - final CodegenModel cm = codegen.fromModel("sample", model); + CodegenModel cm = codegen.fromModel("sample", model); Assert.assertEquals(cm.name, "sample"); Assert.assertEquals(cm.classname, "Sample"); Assert.assertEquals(cm.description, "model with additionalProperties"); Assert.assertEquals(cm.vars.size(), 0); - Assert.assertEquals(cm.parent, "null"); + Assert.assertNull(cm.parent, null); Assert.assertEquals(cm.imports.size(), 0); Assert.assertEquals(Sets.intersection(cm.imports, Sets.newHashSet()).size(), 0); } + + @Test(description = "oneOf discriminator mapping values are propagated to vars") + public void testOneOfDiscriminator() { + TreeMap properties1 = new TreeMap<>(); + properties1.put("objectType", new Schema().type("string")); + TreeMap properties2 = new TreeMap<>(properties1); + properties1.put("someProp", new Schema().type("string")); + Schema typeA = new Schema().type("object").properties(properties1); + Schema typeB = new Schema().type("object").properties(properties2); + Schema typeC = new ComposedSchema().oneOf(List.of(typeA, typeB)) + .discriminator(new Discriminator() + .propertyName("objectType") + .mapping("type-a", "#/components/schemas/TypeA")); + Schema typeD = new Schema().type("object").properties(properties2); + + OpenAPI openAPI = TestUtils.createOpenAPI(); + openAPI.setComponents(new Components()); + openAPI.getComponents().addSchemas("TypeA", typeA); + openAPI.getComponents().addSchemas("TypeB", typeB); + openAPI.getComponents().addSchemas("TypeC", typeC); + openAPI.getComponents().addSchemas("TypeD", typeD); + + DefaultCodegen codegen = new PythonFastAPIServerCodegen(); + codegen.setOpenAPI(openAPI); + + TreeMap allModels = new TreeMap<>(); + String[] typeNames = new String[]{"TypeA", "TypeB", "TypeC", "TypeD"}; + CodegenModel[] models = new CodegenModel[]{null, null, null, null}; + for (int i = 0; i < typeNames.length; i++) { + String key = typeNames[i]; + CodegenModel cm = codegen.fromModel(key, openAPI.getComponents().getSchemas().get(key)); + if (key.equals("TypeC")) { + cm.oneOf = Set.of("TypeA", "TypeB"); + } + ModelMap mo = new ModelMap(); + mo.setModel(cm); + ModelsMap objs = new ModelsMap(); + objs.setModels(List.of(mo)); + allModels.put(key, objs); + models[i] = cm; + } + + codegen.postProcessAllModels(allModels); + + CodegenModel typeAModel = models[0]; + CodegenModel typeBModel = models[1]; + CodegenModel typeDModel = models[3]; + Assert.assertEquals(typeAModel.vars.size(), 2); + Assert.assertTrue(typeAModel.vars.get(0).isDiscriminator); + Assert.assertEquals(typeAModel.vars.get(0).discriminatorValue, "type-a"); // explicitly mapped value + Assert.assertTrue(typeBModel.vars.get(0).isDiscriminator); + Assert.assertEquals(typeBModel.vars.get(0).discriminatorValue, "TypeB"); // implicit value + Assert.assertNull(typeAModel.parent); + Assert.assertEquals(typeAModel.imports.size(), 0); + Assert.assertEquals(Sets.intersection(typeAModel.imports, Sets.newHashSet()).size(), 0); + + Assert.assertEquals(typeDModel.vars.size(), 1); + Assert.assertFalse(typeDModel.vars.get(0).isDiscriminator); + Assert.assertNull(typeDModel.vars.get(0).discriminatorValue); + } } diff --git a/samples/openapi3/client/petstore/python-aiohttp/docs/BasquePig.md b/samples/openapi3/client/petstore/python-aiohttp/docs/BasquePig.md index ee28d628722f8..ee2b2551e4f57 100644 --- a/samples/openapi3/client/petstore/python-aiohttp/docs/BasquePig.md +++ b/samples/openapi3/client/petstore/python-aiohttp/docs/BasquePig.md @@ -5,7 +5,7 @@ Name | Type | Description | Notes ------------ | ------------- | ------------- | ------------- -**class_name** | **str** | | +**class_name** | **str** | | [default to "BasquePig"] **color** | **str** | | ## Example diff --git a/samples/openapi3/client/petstore/python-aiohttp/docs/DanishPig.md b/samples/openapi3/client/petstore/python-aiohttp/docs/DanishPig.md index 16941388832a9..cd7666b41739c 100644 --- a/samples/openapi3/client/petstore/python-aiohttp/docs/DanishPig.md +++ b/samples/openapi3/client/petstore/python-aiohttp/docs/DanishPig.md @@ -5,7 +5,7 @@ Name | Type | Description | Notes ------------ | ------------- | ------------- | ------------- -**class_name** | **str** | | +**class_name** | **str** | | [default to "DanishPig"] **size** | **int** | | ## Example diff --git a/samples/openapi3/client/petstore/python-aiohttp/petstore_api/models/basque_pig.py b/samples/openapi3/client/petstore/python-aiohttp/petstore_api/models/basque_pig.py index a1f32a6edcfcd..7f557c368d940 100644 --- a/samples/openapi3/client/petstore/python-aiohttp/petstore_api/models/basque_pig.py +++ b/samples/openapi3/client/petstore/python-aiohttp/petstore_api/models/basque_pig.py @@ -18,7 +18,7 @@ import json from pydantic import BaseModel, ConfigDict, Field, StrictStr -from typing import Any, ClassVar, Dict, List +from typing import Any, ClassVar, Dict, List, Literal from typing import Optional, Set from typing_extensions import Self @@ -26,7 +26,7 @@ class BasquePig(BaseModel): """ BasquePig """ # noqa: E501 - class_name: StrictStr = Field(alias="className") + class_name: Literal["BasquePig"] = Field(default="BasquePig", alias="className") color: StrictStr __properties: ClassVar[List[str]] = ["className", "color"] diff --git a/samples/openapi3/client/petstore/python-aiohttp/petstore_api/models/danish_pig.py b/samples/openapi3/client/petstore/python-aiohttp/petstore_api/models/danish_pig.py index 061e16a486a5b..70bd0a2b6b49e 100644 --- a/samples/openapi3/client/petstore/python-aiohttp/petstore_api/models/danish_pig.py +++ b/samples/openapi3/client/petstore/python-aiohttp/petstore_api/models/danish_pig.py @@ -18,7 +18,7 @@ import json from pydantic import BaseModel, ConfigDict, Field, StrictInt, StrictStr -from typing import Any, ClassVar, Dict, List +from typing import Any, ClassVar, Dict, List, Literal from typing import Optional, Set from typing_extensions import Self @@ -26,7 +26,7 @@ class DanishPig(BaseModel): """ DanishPig """ # noqa: E501 - class_name: StrictStr = Field(alias="className") + class_name: Literal["DanishPig"] = Field(default="DanishPig", alias="className") size: StrictInt __properties: ClassVar[List[str]] = ["className", "size"] diff --git a/samples/openapi3/client/petstore/python-pydantic-v1/tests/test_deserialization.py b/samples/openapi3/client/petstore/python-pydantic-v1/tests/test_deserialization.py index c5fb68663821f..0f9677ebd327d 100644 --- a/samples/openapi3/client/petstore/python-pydantic-v1/tests/test_deserialization.py +++ b/samples/openapi3/client/petstore/python-pydantic-v1/tests/test_deserialization.py @@ -246,7 +246,7 @@ def test_deserialize_none(self): def test_deserialize_pig(self): """ deserialize pig (oneOf) """ data = { - "className": "BasqueBig", + "className": "BasquePig", "color": "white" } @@ -254,7 +254,7 @@ def test_deserialize_pig(self): deserialized = self.deserialize(response, "Pig") self.assertTrue(isinstance(deserialized.actual_instance, petstore_api.BasquePig)) - self.assertEqual(deserialized.actual_instance.class_name, "BasqueBig") + self.assertEqual(deserialized.actual_instance.class_name, "BasquePig") self.assertEqual(deserialized.actual_instance.color, "white") def test_deserialize_animal(self): diff --git a/samples/openapi3/client/petstore/python/docs/BasquePig.md b/samples/openapi3/client/petstore/python/docs/BasquePig.md index ee28d628722f8..ee2b2551e4f57 100644 --- a/samples/openapi3/client/petstore/python/docs/BasquePig.md +++ b/samples/openapi3/client/petstore/python/docs/BasquePig.md @@ -5,7 +5,7 @@ Name | Type | Description | Notes ------------ | ------------- | ------------- | ------------- -**class_name** | **str** | | +**class_name** | **str** | | [default to "BasquePig"] **color** | **str** | | ## Example diff --git a/samples/openapi3/client/petstore/python/docs/DanishPig.md b/samples/openapi3/client/petstore/python/docs/DanishPig.md index 16941388832a9..cd7666b41739c 100644 --- a/samples/openapi3/client/petstore/python/docs/DanishPig.md +++ b/samples/openapi3/client/petstore/python/docs/DanishPig.md @@ -5,7 +5,7 @@ Name | Type | Description | Notes ------------ | ------------- | ------------- | ------------- -**class_name** | **str** | | +**class_name** | **str** | | [default to "DanishPig"] **size** | **int** | | ## Example diff --git a/samples/openapi3/client/petstore/python/petstore_api/models/basque_pig.py b/samples/openapi3/client/petstore/python/petstore_api/models/basque_pig.py index 4a5b9e3bcb9dc..e5c9dbc4cd6e0 100644 --- a/samples/openapi3/client/petstore/python/petstore_api/models/basque_pig.py +++ b/samples/openapi3/client/petstore/python/petstore_api/models/basque_pig.py @@ -18,7 +18,7 @@ import json from pydantic import BaseModel, ConfigDict, Field, StrictStr -from typing import Any, ClassVar, Dict, List +from typing import Any, ClassVar, Dict, List, Literal from typing import Optional, Set from typing_extensions import Self @@ -26,7 +26,7 @@ class BasquePig(BaseModel): """ BasquePig """ # noqa: E501 - class_name: StrictStr = Field(alias="className") + class_name: Literal["BasquePig"] = Field(default="BasquePig", alias="className") color: StrictStr additional_properties: Dict[str, Any] = {} __properties: ClassVar[List[str]] = ["className", "color"] diff --git a/samples/openapi3/client/petstore/python/petstore_api/models/danish_pig.py b/samples/openapi3/client/petstore/python/petstore_api/models/danish_pig.py index df4a80d339086..f70bde48b0b0d 100644 --- a/samples/openapi3/client/petstore/python/petstore_api/models/danish_pig.py +++ b/samples/openapi3/client/petstore/python/petstore_api/models/danish_pig.py @@ -18,7 +18,7 @@ import json from pydantic import BaseModel, ConfigDict, Field, StrictInt, StrictStr -from typing import Any, ClassVar, Dict, List +from typing import Any, ClassVar, Dict, List, Literal from typing import Optional, Set from typing_extensions import Self @@ -26,7 +26,7 @@ class DanishPig(BaseModel): """ DanishPig """ # noqa: E501 - class_name: StrictStr = Field(alias="className") + class_name: Literal["DanishPig"] = Field(default="DanishPig", alias="className") size: StrictInt additional_properties: Dict[str, Any] = {} __properties: ClassVar[List[str]] = ["className", "size"] diff --git a/samples/openapi3/client/petstore/python/tests/test_deserialization.py b/samples/openapi3/client/petstore/python/tests/test_deserialization.py index 8db2929be3542..35a01d10010f1 100644 --- a/samples/openapi3/client/petstore/python/tests/test_deserialization.py +++ b/samples/openapi3/client/petstore/python/tests/test_deserialization.py @@ -254,7 +254,7 @@ def test_deserialize_none(self): def test_deserialize_pig(self): """ deserialize pig (oneOf) """ data = { - "className": "BasqueBig", + "className": "BasquePig", "color": "white" } @@ -262,7 +262,7 @@ def test_deserialize_pig(self): deserialized = self.deserialize(response, "Pig", 'application/json') self.assertTrue(isinstance(deserialized.actual_instance, petstore_api.BasquePig)) - self.assertEqual(deserialized.actual_instance.class_name, "BasqueBig") + self.assertEqual(deserialized.actual_instance.class_name, "BasquePig") self.assertEqual(deserialized.actual_instance.color, "white") def test_deserialize_animal(self): diff --git a/samples/server/petstore/python-fastapi/src/openapi_server/models/api_response.py b/samples/server/petstore/python-fastapi/src/openapi_server/models/api_response.py index 1d6fe59b60507..84bb9716530e7 100644 --- a/samples/server/petstore/python-fastapi/src/openapi_server/models/api_response.py +++ b/samples/server/petstore/python-fastapi/src/openapi_server/models/api_response.py @@ -42,6 +42,8 @@ class ApiResponse(BaseModel): "protected_namespaces": (), } + def __init__(self, *a, **kw): + super().__init__(*a, **kw) def to_str(self) -> str: """Returns the string representation of the model using alias""" diff --git a/samples/server/petstore/python-fastapi/src/openapi_server/models/category.py b/samples/server/petstore/python-fastapi/src/openapi_server/models/category.py index cab85b6fdec6c..57da112510838 100644 --- a/samples/server/petstore/python-fastapi/src/openapi_server/models/category.py +++ b/samples/server/petstore/python-fastapi/src/openapi_server/models/category.py @@ -52,6 +52,8 @@ def name_validate_regular_expression(cls, value): "protected_namespaces": (), } + def __init__(self, *a, **kw): + super().__init__(*a, **kw) def to_str(self) -> str: """Returns the string representation of the model using alias""" diff --git a/samples/server/petstore/python-fastapi/src/openapi_server/models/order.py b/samples/server/petstore/python-fastapi/src/openapi_server/models/order.py index 1aa18ee5ebd75..db5b809a037ba 100644 --- a/samples/server/petstore/python-fastapi/src/openapi_server/models/order.py +++ b/samples/server/petstore/python-fastapi/src/openapi_server/models/order.py @@ -56,6 +56,8 @@ def status_validate_enum(cls, value): "protected_namespaces": (), } + def __init__(self, *a, **kw): + super().__init__(*a, **kw) def to_str(self) -> str: """Returns the string representation of the model using alias""" diff --git a/samples/server/petstore/python-fastapi/src/openapi_server/models/pet.py b/samples/server/petstore/python-fastapi/src/openapi_server/models/pet.py index 5e74857ba5600..cdc860368906f 100644 --- a/samples/server/petstore/python-fastapi/src/openapi_server/models/pet.py +++ b/samples/server/petstore/python-fastapi/src/openapi_server/models/pet.py @@ -57,6 +57,8 @@ def status_validate_enum(cls, value): "protected_namespaces": (), } + def __init__(self, *a, **kw): + super().__init__(*a, **kw) def to_str(self) -> str: """Returns the string representation of the model using alias""" diff --git a/samples/server/petstore/python-fastapi/src/openapi_server/models/tag.py b/samples/server/petstore/python-fastapi/src/openapi_server/models/tag.py index 31b40d3dcf489..fc4b5a4a7fd4c 100644 --- a/samples/server/petstore/python-fastapi/src/openapi_server/models/tag.py +++ b/samples/server/petstore/python-fastapi/src/openapi_server/models/tag.py @@ -41,6 +41,8 @@ class Tag(BaseModel): "protected_namespaces": (), } + def __init__(self, *a, **kw): + super().__init__(*a, **kw) def to_str(self) -> str: """Returns the string representation of the model using alias""" diff --git a/samples/server/petstore/python-fastapi/src/openapi_server/models/user.py b/samples/server/petstore/python-fastapi/src/openapi_server/models/user.py index b0e8d044c6410..b361946973bd7 100644 --- a/samples/server/petstore/python-fastapi/src/openapi_server/models/user.py +++ b/samples/server/petstore/python-fastapi/src/openapi_server/models/user.py @@ -47,6 +47,8 @@ class User(BaseModel): "protected_namespaces": (), } + def __init__(self, *a, **kw): + super().__init__(*a, **kw) def to_str(self) -> str: """Returns the string representation of the model using alias"""