Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[csharp][java] Fix enum discriminator default value #19614

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3080,6 +3080,7 @@ public CodegenModel fromModel(String name, Schema schema) {
listOLists.add(m.requiredVars);
listOLists.add(m.vars);
listOLists.add(m.allVars);
listOLists.add(m.readWriteVars);
for (List<CodegenProperty> theseVars : listOLists) {
for (CodegenProperty requiredVar : theseVars) {
if (discPropName.equals(requiredVar.baseName)) {
Expand Down Expand Up @@ -3113,6 +3114,63 @@ public CodegenModel fromModel(String name, Schema schema) {
return m;
}

/**
* Sets the default value for an enum discriminator property in the provided {@link CodegenModel}.
* <p>
* If the model's discriminator is defined, this method identifies the discriminator properties among the model's
* variables and assigns the default value to reflect the corresponding enum value for the model type.
* </p>
* <p>
* Example: If the discriminator is for type `Animal`, and the model is `Cat`, the default value
* will be set to `Animal.Cat` for the properties that have the same name as the discriminator.
* </p>
*
* @param model the {@link CodegenModel} whose discriminator property default value is to be set
*/
protected static void setEnumDiscriminatorDefaultValue(CodegenModel model) {
david-marconis marked this conversation as resolved.
Show resolved Hide resolved
if (model.discriminator == null) {
return;
}
String discPropName = model.discriminator.getPropertyBaseName();
Stream.of(model.requiredVars, model.vars, model.allVars, model.readWriteVars)
.flatMap(List::stream)
.filter(v -> discPropName.equals(v.baseName))
.forEach(v -> v.defaultValue = getEnumValueForProperty(model.schemaName, model.discriminator, v));
}

/**
* Retrieves the appropriate default value for an enum discriminator property based on the model name.
* <p>
* If the discriminator has a mapping defined, it attempts to find a mapping for the model name.
* Otherwise, it defaults to one of the allowable enum value associated with the property.
* If no suitable value is found, the original default value of the property is returned.
* </p>
*
* @param modelName the name of the model to determine the default value for
* @param discriminator the {@link CodegenDiscriminator} containing the mapping and enum details
* @param var the {@link CodegenProperty} representing the discriminator property
* @return the default value for the enum discriminator property, or its original default value if none is found
*/
protected static String getEnumValueForProperty(
david-marconis marked this conversation as resolved.
Show resolved Hide resolved
String modelName, CodegenDiscriminator discriminator, CodegenProperty var) {
if (!discriminator.getIsEnum() && !var.isEnum) {
return var.defaultValue;
}
Map<String, String> mapping = Optional.ofNullable(discriminator.getMapping()).orElseGet(Collections::emptyMap);
for (Map.Entry<String, String> e : mapping.entrySet()) {
String schemaName = e.getValue().indexOf('/') < 0 ? e.getValue() : ModelUtils.getSimpleRef(e.getValue());
if (modelName.equals(schemaName)) {
return e.getKey();
}
}
Object values = var.allowableValues.get("values");
if (!(values instanceof List<?>)) {
return var.defaultValue;
}
List<?> valueList = (List<?>) values;
return valueList.stream().filter(o -> o.equals(modelName)).map(o -> (String) o).findAny().orElse(var.defaultValue);
}

protected void SortModelPropertiesByRequiredFlag(CodegenModel model) {
Comparator<CodegenProperty> comparator = new Comparator<CodegenProperty>() {
@Override
Expand Down Expand Up @@ -3183,15 +3241,19 @@ protected void setAddProps(Schema schema, IJsonSchemaValidationProperties proper
* @param visitedSchemas A set of visited schema names
*/
private CodegenProperty discriminatorFound(String composedSchemaName, Schema sc, String discPropName, Set<String> visitedSchemas) {
if (visitedSchemas.contains(composedSchemaName)) { // recursive schema definition found
Schema refSchema = ModelUtils.getReferencedSchema(openAPI, sc);
String schemaName = Optional.ofNullable(composedSchemaName)
.or(() -> Optional.ofNullable(refSchema.getName()))
.or(() -> Optional.ofNullable(sc.get$ref()).map(ModelUtils::getSimpleRef))
.orElseGet(sc::toString);
if (visitedSchemas.contains(schemaName)) { // recursive schema definition found
return null;
} else {
visitedSchemas.add(composedSchemaName);
visitedSchemas.add(schemaName);
}

Schema refSchema = ModelUtils.getReferencedSchema(openAPI, sc);
if (refSchema.getProperties() != null && refSchema.getProperties().get(discPropName) != null) {
Schema discSchema = (Schema) refSchema.getProperties().get(discPropName);
Schema discSchema = ModelUtils.getReferencedSchema(openAPI, (Schema)refSchema.getProperties().get(discPropName));
CodegenProperty cp = new CodegenProperty();
if (ModelUtils.isStringSchema(discSchema)) {
cp.isString = true;
Expand All @@ -3200,14 +3262,16 @@ private CodegenProperty discriminatorFound(String composedSchemaName, Schema sc,
if (refSchema.getRequired() != null && refSchema.getRequired().contains(discPropName)) {
cp.setRequired(true);
}
cp.setIsEnum(discSchema.getEnum() != null && !discSchema.getEnum().isEmpty());
return cp;
}
if (ModelUtils.isComposedSchema(refSchema)) {
Schema composedSchema = refSchema;
if (composedSchema.getAllOf() != null) {
// If our discriminator is in one of the allOf schemas break when we find it
for (Object allOf : composedSchema.getAllOf()) {
CodegenProperty cp = discriminatorFound(composedSchemaName, (Schema) allOf, discPropName, visitedSchemas);
Schema allOfSchema = (Schema) allOf;
CodegenProperty cp = discriminatorFound(allOfSchema.getName(), allOfSchema, discPropName, visitedSchemas);
if (cp != null) {
return cp;
}
Expand All @@ -3217,8 +3281,11 @@ private CodegenProperty discriminatorFound(String composedSchemaName, Schema sc,
// All oneOf definitions must contain the discriminator
CodegenProperty cp = new CodegenProperty();
for (Object oneOf : composedSchema.getOneOf()) {
String modelName = ModelUtils.getSimpleRef(((Schema) oneOf).get$ref());
CodegenProperty thisCp = discriminatorFound(composedSchemaName, (Schema) oneOf, discPropName, visitedSchemas);
Schema oneOfSchema = (Schema) oneOf;
String modelName = ModelUtils.getSimpleRef((oneOfSchema).get$ref());
// Must use a copied set as the oneOf schemas can point to the same discriminator.
Set<String> visitedSchemasCopy = new TreeSet<>(visitedSchemas);
CodegenProperty thisCp = discriminatorFound(oneOfSchema.getName(), oneOfSchema, discPropName, visitedSchemasCopy);
if (thisCp == null) {
once(LOGGER).warn(
"'{}' defines discriminator '{}', but the referenced OneOf schema '{}' is missing {}",
Expand All @@ -3240,8 +3307,11 @@ private CodegenProperty discriminatorFound(String composedSchemaName, Schema sc,
// All anyOf definitions must contain the discriminator because a min of one must be selected
CodegenProperty cp = new CodegenProperty();
for (Object anyOf : composedSchema.getAnyOf()) {
String modelName = ModelUtils.getSimpleRef(((Schema) anyOf).get$ref());
CodegenProperty thisCp = discriminatorFound(composedSchemaName, (Schema) anyOf, discPropName, visitedSchemas);
Schema anyOfSchema = (Schema) anyOf;
String modelName = ModelUtils.getSimpleRef(anyOfSchema.get$ref());
// Must use a copied set as the anyOf schemas can point to the same discriminator.
Set<String> visitedSchemasCopy = new TreeSet<>(visitedSchemas);
CodegenProperty thisCp = discriminatorFound(anyOfSchema.getName(), anyOfSchema, discPropName, visitedSchemasCopy);
if (thisCp == null) {
once(LOGGER).warn(
"'{}' defines discriminator '{}', but the referenced AnyOf schema '{}' is missing {}",
Expand Down Expand Up @@ -3524,13 +3594,11 @@ protected CodegenDiscriminator createDiscriminator(String schemaName, Schema sch
discriminator.setPropertyType(propertyType);

// check to see if the discriminator property is an enum string
if (schema.getProperties() != null &&
schema.getProperties().get(discriminatorPropertyName) instanceof StringSchema) {
StringSchema s = (StringSchema) schema.getProperties().get(discriminatorPropertyName);
if (s.getEnum() != null && !s.getEnum().isEmpty()) { // it's an enum string
discriminator.setIsEnum(true);
}
}
boolean isEnum = Optional
.ofNullable(discriminatorFound(schemaName, schema, discriminatorPropertyName, new TreeSet<>()))
.map(CodegenProperty::getIsEnum)
.orElse(false);
discriminator.setIsEnum(isEnum);

discriminator.setMapping(sourceDiscriminator.getMapping());
List<MappedModel> uniqueDescendants = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1700,6 +1700,7 @@ public CodegenModel fromModel(String name, Schema model) {

// additional import for different cases
addAdditionalImports(codegenModel, codegenModel.getComposedSchemas());
setEnumDiscriminatorDefaultValue(codegenModel);
return codegenModel;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,7 @@ public String apiTestFileFolder() {
public CodegenModel fromModel(String name, Schema model) {
Map<String, Schema> allDefinitions = ModelUtils.getSchemas(this.openAPI);
CodegenModel codegenModel = super.fromModel(name, model);
setEnumDiscriminatorDefaultValue(codegenModel);
if (allDefinitions != null && codegenModel != null && codegenModel.parent != null) {
final Schema<?> parentModel = allDefinitions.get(toModelName(codegenModel.parent));
if (parentModel != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1595,6 +1595,12 @@ public static String getParentName(Schema composedSchema, Map<String, Schema> al
* @return the name of the parent model
*/
public static List<String> getAllParentsName(Schema composedSchema, Map<String, Schema> allSchemas, boolean includeAncestors) {
return getAllParentsName(composedSchema, allSchemas, includeAncestors, new HashSet<>());
}

// Use a set of seen names to avoid infinite recursion
private static List<String> getAllParentsName(
Schema composedSchema, Map<String, Schema> allSchemas, boolean includeAncestors, Set<String> seenNames) {
List<Schema> interfaces = getInterfaces(composedSchema);
List<String> names = new ArrayList<String>();

Expand All @@ -1603,6 +1609,10 @@ public static List<String> getAllParentsName(Schema composedSchema, Map<String,
// get the actual schema
if (StringUtils.isNotEmpty(schema.get$ref())) {
String parentName = getSimpleRef(schema.get$ref());
if (seenNames.contains(parentName)) {
continue;
}
seenNames.add(parentName);
Schema s = allSchemas.get(parentName);
if (s == null) {
LOGGER.error("Failed to obtain schema from {}", parentName);
Expand All @@ -1611,7 +1621,7 @@ public static List<String> getAllParentsName(Schema composedSchema, Map<String,
// discriminator.propertyName is used or x-parent is used
names.add(parentName);
if (includeAncestors && isComposedSchema(s)) {
names.addAll(getAllParentsName(s, allSchemas, true));
names.addAll(getAllParentsName(s, allSchemas, true, seenNames));
}
} else {
// not a parent since discriminator.propertyName is not set
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ public class {{classname}} {{#parent}}extends {{{.}}} {{/parent}}{{#vendorExtens
{{/parcelableModel}}
{{/parent}}
{{#discriminator}}
{{#discriminator.isEnum}}
{{#readWriteVars}}{{#isDiscriminator}}{{#defaultValue}}
this.{{name}} = {{defaultValue}};
{{/defaultValue}}{{/isDiscriminator}}{{/readWriteVars}}
{{/discriminator.isEnum}}
{{^discriminator.isEnum}}
this.{{{discriminatorName}}} = this.getClass().getSimpleName();
{{/discriminator.isEnum}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,4 +140,46 @@ public void test31specAdditionalPropertiesOfOneOf() throws IOException {
assertFileContains(modelFile.toPath(),
" Dictionary<string, ResponseResultsValue> results = default(Dictionary<string, ResponseResultsValue>");
}

@Test
public void testEnumDiscriminatorDefaultValueIsNotString() throws IOException {
File output = Files.createTempDirectory("test").toFile().getCanonicalFile();
output.deleteOnExit();
final OpenAPI openAPI = TestUtils.parseFlattenSpec(
"src/test/resources/3_0/enum_discriminator_inheritance.yaml");
final DefaultGenerator defaultGenerator = new DefaultGenerator();
final ClientOptInput clientOptInput = new ClientOptInput();
clientOptInput.openAPI(openAPI);
CSharpClientCodegen cSharpClientCodegen = new CSharpClientCodegen();
cSharpClientCodegen.setOutputDir(output.getAbsolutePath());
cSharpClientCodegen.setAutosetConstants(true);
clientOptInput.config(cSharpClientCodegen);
defaultGenerator.opts(clientOptInput);

Map<String, File> files = defaultGenerator.generate().stream()
.collect(Collectors.toMap(File::getPath, Function.identity()));

Map<String, String> expectedContents = Map.of(
"Cat", "PetTypeEnum petType = PetTypeEnum.Catty",
"Dog", "PetTypeEnum petType = PetTypeEnum.Dog",
"Gecko", "PetTypeEnum petType = PetTypeEnum.Gecko",
"Chameleon", "PetTypeEnum petType = PetTypeEnum.Camo",
"MiniVan", "CarType carType = CarType.MiniVan",
"CargoVan", "CarType carType = CarType.CargoVan",
"SUV", "CarType carType = CarType.SUV",
"Truck", "CarType carType = CarType.Truck",
"Sedan", "CarType carType = CarType.Sedan"

);
for (Map.Entry<String, String> e : expectedContents.entrySet()) {
String modelName = e.getKey();
String expectedContent = e.getValue();
File file = files.get(Paths
.get(output.getAbsolutePath(), "src", "Org.OpenAPITools", "Model", modelName + ".cs")
.toString()
);
assertNotNull(file, "Could not find file for model: " + modelName);
assertFileContains(file.toPath(), expectedContent);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2287,6 +2287,39 @@ public void testAllOfWithSinglePrimitiveTypeRef() {
assertNull(files.get("pom.xml"));
}

@Test
public void testEnumDiscriminatorDefaultValueIsNotString() {
final Path output = newTempFolder();
final OpenAPI openAPI = TestUtils.parseFlattenSpec(
"src/test/resources/3_0/enum_discriminator_inheritance.yaml");
JavaClientCodegen codegen = new JavaClientCodegen();
codegen.setOutputDir(output.toString());

Map<String, File> files = new DefaultGenerator().opts(new ClientOptInput().openAPI(openAPI).config(codegen))
.generate().stream().collect(Collectors.toMap(File::getName, Function.identity()));

Map<String, String> expectedContents = Map.of(
"Cat", "this.petType = PetTypeEnum.CATTY",
"Dog", "this.petType = PetTypeEnum.DOG",
"Gecko", "this.petType = PetTypeEnum.GECKO",
"Chameleon", "this.petType = PetTypeEnum.CAMO",
"MiniVan", "this.carType = CarType.MINI_VAN",
"CargoVan", "this.carType = CarType.CARGO_VAN",
"SUV", "this.carType = CarType.SUV",
"Truck", "this.carType = CarType.TRUCK",
"Sedan", "this.carType = CarType.SEDAN"

);
for (Map.Entry<String, String> e : expectedContents.entrySet()) {
String modelName = e.getKey();
String expectedContent = e.getValue();
File entityFile = files.get(modelName + ".java");
assertNotNull(entityFile);
assertThat(entityFile).content().doesNotContain("Type = this.getClass().getSimpleName();");
assertThat(entityFile).content().contains(expectedContent);
}
}

@Test
public void testRestTemplateHandleURIEnum() {
String[] expectedInnerEnumLines = new String[] {
Expand Down
Loading
Loading