diff --git a/src/ApiCodeGenerator.AsyncApi/DOM/AsyncApiDocument.cs b/src/ApiCodeGenerator.AsyncApi/DOM/AsyncApiDocument.cs index 940d204..ff12784 100644 --- a/src/ApiCodeGenerator.AsyncApi/DOM/AsyncApiDocument.cs +++ b/src/ApiCodeGenerator.AsyncApi/DOM/AsyncApiDocument.cs @@ -1,6 +1,7 @@ using Newtonsoft.Json; using Newtonsoft.Json.Linq; using NJsonSchema; +using NJsonSchema.CodeGeneration; using NJsonSchema.Generation; using NJsonSchema.Yaml; using YamlDotNet.Serialization; @@ -59,11 +60,13 @@ public static Task FromJsonAsync(string data) /// JSON text. /// Path to document. /// AsyncApi document object model. - public static Task FromJsonAsync(string data, string? documentPath) + public static async Task FromJsonAsync(string data, string? documentPath) { var document = JsonConvert.DeserializeObject(data, JSONSERIALIZERSETTINGS)!; document.DocumentPath = documentPath; - return UpdateSchemaReferencesAsync(document); + await UpdateSchemaReferencesAsync(document); + BuildAsyncApiDescriminatorMapping(document); + return document; } /// @@ -80,7 +83,7 @@ public static Task FromYamlAsync(string data) /// YAML text. /// Path to document. /// AsyncApi document object model. - public static Task FromYamlAsync(string data, string? documentPath) + public static async Task FromYamlAsync(string data, string? documentPath) { var deserializer = new DeserializerBuilder().Build(); using var reader = new StringReader(data); @@ -90,15 +93,40 @@ public static Task FromYamlAsync(string data, string? document var serializer = JsonSerializer.Create(JSONSERIALIZERSETTINGS); var doc = jObject.ToObject(serializer)!; doc.DocumentPath = documentPath; - return UpdateSchemaReferencesAsync(doc); + await UpdateSchemaReferencesAsync(doc); + BuildAsyncApiDescriminatorMapping(doc); + return doc; } - private static async Task UpdateSchemaReferencesAsync(AsyncApiDocument document) + private static Task UpdateSchemaReferencesAsync(AsyncApiDocument document) { - await JsonSchemaReferenceUtilities.UpdateSchemaReferencesAsync( + return JsonSchemaReferenceUtilities.UpdateSchemaReferencesAsync( document, new JsonAndYamlReferenceResolver(new AsyncApiSchemaResolver(document, new SystemTextJsonSchemaGeneratorSettings()))); - return document; + } + + private static void BuildAsyncApiDescriminatorMapping(AsyncApiDocument document) + { + foreach (var schema in document.Components?.Schemas.Values ?? []) + { + var discriminatorPropName = schema.DiscriminatorObject?.PropertyName; + if (discriminatorPropName != null) + { + var derivedSchemas = schema.GetDerivedSchemas(document); + foreach (var item in derivedSchemas) + { + var derivedSchema = item.Key; + if ((derivedSchema.Properties.TryGetValue(discriminatorPropName, out var discriminatorProp) + || derivedSchema.AllOf?.FirstOrDefault(i => i != schema && i.Properties.ContainsKey(discriminatorPropName))?.Properties.TryGetValue(discriminatorPropName, out discriminatorProp) == true) + && discriminatorProp.ExtensionData?.TryGetValue("const", out var constValue) == true) + { + var constValueStr = constValue!.ToString(); + discriminatorProp.ParentSchema!.Properties.Remove(discriminatorPropName); + schema.DiscriminatorObject!.Mapping.Add(constValueStr, derivedSchema); + } + } + } + } } } #pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. diff --git a/test/ApiCodeGenerator.AsyncApi.Tests/FunctionalTests.cs b/test/ApiCodeGenerator.AsyncApi.Tests/FunctionalTests.cs index 2820364..df379bd 100644 --- a/test/ApiCodeGenerator.AsyncApi.Tests/FunctionalTests.cs +++ b/test/ApiCodeGenerator.AsyncApi.Tests/FunctionalTests.cs @@ -169,6 +169,97 @@ public async Task GenerateMultipleClients() Assert.AreEqual(expected, actual); } + [Test] + public async Task GenerateDiscriminator() + { + var yaml = """ + asyncapi: 2.0 + info: { title: 'dd', version: '1.0' } + components: + schemas: + Pet: + additionalProperties: false + type: object + discriminator: petType + properties: + name: + type: string + petType: + type: string + required: + - name + - petType + Cat: + allOf: + - $ref: '#/components/schemas/Pet' + - type: object + properties: + huntingSkill: + type: string + required: + - huntingSkill + additionalProperties: false + StickInsect: + allOf: + - $ref: '#/components/schemas/Pet' + - type: object + properties: + petType: + const: StickBug + color: + type: string + required: + - color + additionalProperties: false + """; + var settingsJson = $$""" + { + "Namespace": "TestNS", + "GenerateDataAnnotations": false, + "GenerateClientClasses": false + } + """; + var generationContext = CreateContext(settingsJson, new StringReader(yaml)); + var generator = await CSharpClientContentGenerator.CreateAsync(generationContext); + var actual = generator.Generate(); + + var expectedDto = $$""" + [Newtonsoft.Json.JsonConverter(typeof(JsonInheritanceConverter), "petType")] + [JsonInheritanceAttribute("StickBug", typeof(StickInsect))] + [JsonInheritanceAttribute("Cat", typeof(Cat))] + {{GENERATED_CODE}} + public partial class Pet + { + [Newtonsoft.Json.JsonProperty("name", Required = Newtonsoft.Json.Required.Always)] + public string Name { get; set; } + + + } + + {{GENERATED_CODE}} + public partial class Cat : Pet + { + [Newtonsoft.Json.JsonProperty("huntingSkill", Required = Newtonsoft.Json.Required.Always)] + public string HuntingSkill { get; set; } + + + } + + {{GENERATED_CODE}} + public partial class StickInsect : Pet + { + [Newtonsoft.Json.JsonProperty("color", Required = Newtonsoft.Json.Required.Always)] + public string Color { get; set; } + + + } + + {{JSON_INHERITANCE_CONVERTER}} + """.Replace("\r", string.Empty); + var expected = GetExpectedCode(null, expectedDto); + Assert.AreEqual(expected, actual); + } + [TestCaseSource(nameof(TemplateDirectorySource))] public void TemplateDirectory(T settings) where T : CSharpGeneratorBaseSettings diff --git a/test/ApiCodeGenerator.AsyncApi.Tests/Infrastructure/TestHelpers.cs b/test/ApiCodeGenerator.AsyncApi.Tests/Infrastructure/TestHelpers.cs index 2496b61..70f475b 100644 --- a/test/ApiCodeGenerator.AsyncApi.Tests/Infrastructure/TestHelpers.cs +++ b/test/ApiCodeGenerator.AsyncApi.Tests/Infrastructure/TestHelpers.cs @@ -14,6 +14,146 @@ internal static partial class TestHelpers public static readonly string GENERATED_CODE = "[System.CodeDom.Compiler.GeneratedCode(\"NJsonSchema\", \"" + APICODEGEN_VERSION + "\")]"; public static readonly string GENERATED_CODE_ATTRIBUTE = "[System.CodeDom.Compiler.GeneratedCode(\"ApiCodeGenerator.AsyncApi\", \"" + APICODEGEN_VERSION + "\")]"; + public static readonly string JSON_INHERITANCE_CONVERTER = $$""" + {{GENERATED_CODE}} + [System.AttributeUsage(System.AttributeTargets.Class | System.AttributeTargets.Interface, AllowMultiple = true)] + internal class JsonInheritanceAttribute : System.Attribute + { + public JsonInheritanceAttribute(string key, System.Type type) + { + Key = key; + Type = type; + } + + public string Key { get; } + + public System.Type Type { get; } + } + + {{GENERATED_CODE}} + public class JsonInheritanceConverter : Newtonsoft.Json.JsonConverter + { + internal static readonly string DefaultDiscriminatorName = "discriminator"; + + private readonly string _discriminatorName; + + [System.ThreadStatic] + private static bool _isReading; + + [System.ThreadStatic] + private static bool _isWriting; + + public JsonInheritanceConverter() + { + _discriminatorName = DefaultDiscriminatorName; + } + + public JsonInheritanceConverter(string discriminatorName) + { + _discriminatorName = discriminatorName; + } + + public string DiscriminatorName { get { return _discriminatorName; } } + + public override void WriteJson(Newtonsoft.Json.JsonWriter writer, object value, Newtonsoft.Json.JsonSerializer serializer) + { + try + { + _isWriting = true; + + var jObject = Newtonsoft.Json.Linq.JObject.FromObject(value, serializer); + jObject.AddFirst(new Newtonsoft.Json.Linq.JProperty(_discriminatorName, GetSubtypeDiscriminator(value.GetType()))); + writer.WriteToken(jObject.CreateReader()); + } + finally + { + _isWriting = false; + } + } + + public override bool CanWrite + { + get + { + if (_isWriting) + { + _isWriting = false; + return false; + } + return true; + } + } + + public override bool CanRead + { + get + { + if (_isReading) + { + _isReading = false; + return false; + } + return true; + } + } + + public override bool CanConvert(System.Type objectType) + { + return true; + } + + public override object ReadJson(Newtonsoft.Json.JsonReader reader, System.Type objectType, object existingValue, Newtonsoft.Json.JsonSerializer serializer) + { + var jObject = serializer.Deserialize(reader); + if (jObject == null) + return null; + + var discriminatorValue = jObject.GetValue(_discriminatorName); + var discriminator = discriminatorValue != null ? Newtonsoft.Json.Linq.Extensions.Value(discriminatorValue) : null; + var subtype = GetObjectSubtype(objectType, discriminator); + + var objectContract = serializer.ContractResolver.ResolveContract(subtype) as Newtonsoft.Json.Serialization.JsonObjectContract; + if (objectContract == null || System.Linq.Enumerable.All(objectContract.Properties, p => p.PropertyName != _discriminatorName)) + { + jObject.Remove(_discriminatorName); + } + + try + { + _isReading = true; + return serializer.Deserialize(jObject.CreateReader(), subtype); + } + finally + { + _isReading = false; + } + } + + private System.Type GetObjectSubtype(System.Type objectType, string discriminator) + { + foreach (var attribute in System.Reflection.CustomAttributeExtensions.GetCustomAttributes(System.Reflection.IntrospectionExtensions.GetTypeInfo(objectType), true)) + { + if (attribute.Key == discriminator) + return attribute.Type; + } + + return objectType; + } + + private string GetSubtypeDiscriminator(System.Type objectType) + { + foreach (var attribute in System.Reflection.CustomAttributeExtensions.GetCustomAttributes(System.Reflection.IntrospectionExtensions.GetTypeInfo(objectType), true)) + { + if (attribute.Type == objectType) + return attribute.Key; + } + + return objectType.Name; + } + } + + """.Replace("\r", string.Empty); + public static string GetAsyncApiPath(string schemaFile) => Path.Combine("asyncApi", schemaFile); public static async Task LoadApiDocumentAsync(string fileName) @@ -41,10 +181,9 @@ public static async Task RunTest(CSharpClientGeneratorSettings settings, string Assert.That(actual, Is.EqualTo(expected)); } - public static GeneratorContext CreateContext(string settingsJson, string schemaFile, Core.ExtensionManager.Extensions? extensions = null) + public static GeneratorContext CreateContext(string settingsJson, TextReader docReader, Core.ExtensionManager.Extensions? extensions = null) { var jReader = new JsonTextReader(new StringReader(settingsJson)); - var docReader = File.OpenText(GetAsyncApiPath(schemaFile)); return new GeneratorContext( (t, s, v) => s!.Deserialize(jReader, t), @@ -55,6 +194,9 @@ public static GeneratorContext CreateContext(string settingsJson, string schemaF }; } + public static GeneratorContext CreateContext(string settingsJson, string schemaFile, Core.ExtensionManager.Extensions? extensions = null) + => CreateContext(settingsJson, File.OpenText(GetAsyncApiPath(schemaFile)), extensions); + public static string GetExpectedCode(string? expectedClientDeclartion, string? testOperResponseText, string @namespace = "TestNS", string? usings = null) { if (!string.IsNullOrWhiteSpace(expectedClientDeclartion) && !expectedClientDeclartion.Contains(GENERATED_CODE_ATTRIBUTE))