From 17ecd30e355d4e24450508e55644bb6abe0892fc Mon Sep 17 00:00:00 2001 From: Lucas Trzesniewski Date: Mon, 8 Jan 2024 19:21:29 +0100 Subject: [PATCH] Use new language features --- .../IntegrationTest.cs | 23 +- .../Build/GenerateZebusMessagesTask.cs | 117 +- .../IntegrationTest.cs | 23 +- .../MessageDslGeneratorTests.cs | 223 ++- .../Generator/Extensions.cs | 37 +- .../Generator/MessageDslDiagnostics.cs | 17 +- .../Generator/MessageDslGenerator.cs | 151 +- .../Support/CompilerServices.cs | 4 +- .../MessageDsl/CSharpGeneratorTests.cs | 1364 ++++++++--------- .../MessageDsl/CSharpSyntaxTests.cs | 63 +- .../MessageDsl/GeneratorTests.cs | 41 +- .../MessageDsl/ParsedContractsTests.cs | 1261 +++++++-------- .../MessageDsl/ProtoGeneratorTests.cs | 279 ++-- .../MessageDsl/TypeNameTests.cs | 445 +++--- .../TestTools/AssertionExtensions.cs | 93 +- .../Analysis/AstCreationVisitor.cs | 642 ++++---- .../Analysis/AstProcessor.cs | 191 ++- .../Analysis/AstValidator.cs | 324 ++-- .../Analysis/AttributeInterpreter.cs | 195 ++- .../Analysis/ContractsEnhancer.cs | 45 +- .../Analysis/KnownTypes.cs | 35 +- .../Analysis/SyntaxDebugHelper.cs | 31 +- .../Analysis/TextInterval.cs | 92 +- .../Ast/AccessModifier.cs | 11 +- src/Abc.Zebus.MessageDsl/Ast/AstNode.cs | 33 +- .../Ast/AttributeDefinition.cs | 67 +- src/Abc.Zebus.MessageDsl/Ast/AttributeSet.cs | 103 +- .../Ast/ContractOptions.cs | 9 +- .../Ast/EnumDefinition.cs | 108 +- .../Ast/EnumMemberDefinition.cs | 15 +- src/Abc.Zebus.MessageDsl/Ast/FieldRules.cs | 15 +- .../Ast/GenericConstraint.cs | 20 +- src/Abc.Zebus.MessageDsl/Ast/IClassNode.cs | 9 +- src/Abc.Zebus.MessageDsl/Ast/IMemberNode.cs | 11 +- src/Abc.Zebus.MessageDsl/Ast/INamedNode.cs | 9 +- .../Ast/InheritanceModifier.cs | 15 +- src/Abc.Zebus.MessageDsl/Ast/MemberOptions.cs | 32 +- .../Ast/MessageDefinition.cs | 72 +- src/Abc.Zebus.MessageDsl/Ast/MessageType.cs | 15 +- src/Abc.Zebus.MessageDsl/Ast/OptionsBase.cs | 69 +- .../Ast/ParameterDefinition.cs | 67 +- .../Ast/ParsedContracts.cs | 153 +- src/Abc.Zebus.MessageDsl/Ast/TypeName.cs | 319 ++-- .../Dsl/CollectingErrorListener.cs | 100 +- .../Dsl/MessageContracts.g4.lexer.cs | 9 +- .../Dsl/MessageContracts.g4.parser.cs | 123 +- src/Abc.Zebus.MessageDsl/Dsl/SyntaxError.cs | 47 +- .../Generator/CSharpGenerator.cs | 806 +++++----- .../Generator/CSharpSyntax.cs | 123 +- .../Generator/GeneratorBase.cs | 119 +- .../Generator/ProtoGenerator.cs | 257 ++-- .../Support/CodeAnalysis.cs | 18 + .../Support/Disposable.cs | 22 +- .../Support/Extensions.cs | 63 +- src/Abc.Zebus.MessageDsl/Support/Index.cs | 84 + 55 files changed, 4331 insertions(+), 4288 deletions(-) create mode 100644 src/Abc.Zebus.MessageDsl/Support/CodeAnalysis.cs create mode 100644 src/Abc.Zebus.MessageDsl/Support/Index.cs diff --git a/src/Abc.Zebus.MessageDsl.Build.Integration/IntegrationTest.cs b/src/Abc.Zebus.MessageDsl.Build.Integration/IntegrationTest.cs index f4d6fc2..f9a386d 100644 --- a/src/Abc.Zebus.MessageDsl.Build.Integration/IntegrationTest.cs +++ b/src/Abc.Zebus.MessageDsl.Build.Integration/IntegrationTest.cs @@ -2,20 +2,19 @@ using System; -namespace Abc.Zebus.MessageDsl.Build.Integration +namespace Abc.Zebus.MessageDsl.Build.Integration; + +public class IntegrationTest { - public class IntegrationTest + static IntegrationTest() { - static IntegrationTest() - { - GC.KeepAlive(typeof(SomeMessage)); - GC.KeepAlive(typeof(InnerNamespace.InnerMessage)); - GC.KeepAlive(typeof(Abc.Zebus.CustomNamespace.HasCustomNamespace)); - GC.KeepAlive(typeof(Abc.Zebus.CustomExplicitNamespace.HasCustomExplicitNamespace)); - GC.KeepAlive(typeof(global::HasEmptyNamespace)); - GC.KeepAlive(typeof(ExplicitItems.A.ExplicitlyDefinedMessage)); - GC.KeepAlive(typeof(ExplicitItems.B.ExplicitlyDefinedMessage)); - } + GC.KeepAlive(typeof(SomeMessage)); + GC.KeepAlive(typeof(InnerNamespace.InnerMessage)); + GC.KeepAlive(typeof(Abc.Zebus.CustomNamespace.HasCustomNamespace)); + GC.KeepAlive(typeof(Abc.Zebus.CustomExplicitNamespace.HasCustomExplicitNamespace)); + GC.KeepAlive(typeof(global::HasEmptyNamespace)); + GC.KeepAlive(typeof(ExplicitItems.A.ExplicitlyDefinedMessage)); + GC.KeepAlive(typeof(ExplicitItems.B.ExplicitlyDefinedMessage)); } } diff --git a/src/Abc.Zebus.MessageDsl.Build/Build/GenerateZebusMessagesTask.cs b/src/Abc.Zebus.MessageDsl.Build/Build/GenerateZebusMessagesTask.cs index ebd35a3..ce81fe9 100644 --- a/src/Abc.Zebus.MessageDsl.Build/Build/GenerateZebusMessagesTask.cs +++ b/src/Abc.Zebus.MessageDsl.Build/Build/GenerateZebusMessagesTask.cs @@ -8,85 +8,84 @@ #nullable enable -namespace Abc.Zebus.MessageDsl.Build +namespace Abc.Zebus.MessageDsl.Build; + +[UsedImplicitly(ImplicitUseTargetFlags.WithMembers)] +public class GenerateZebusMessagesTask : Task { - [UsedImplicitly(ImplicitUseTargetFlags.WithMembers)] - public class GenerateZebusMessagesTask : Task - { - private const string _logSubcategory = "Zebus.MessageDsl"; + private const string _logSubcategory = "Zebus.MessageDsl"; - [Required] - public ITaskItem[] InputFiles { get; set; } = default!; + [Required] + public ITaskItem[] InputFiles { get; set; } = default!; - public override bool Execute() + public override bool Execute() + { + foreach (var inputFile in InputFiles) { - foreach (var inputFile in InputFiles) + try { - try - { - TranslateFile(inputFile); - } - catch (Exception ex) - { - LogError(inputFile, $"Error translating file: {ex}"); - } + TranslateFile(inputFile); + } + catch (Exception ex) + { + LogError(inputFile, $"Error translating file: {ex}"); } - - return !Log.HasLoggedErrors; } - private void TranslateFile(ITaskItem inputFile) - { - var fileContents = File.ReadAllText(inputFile.ItemSpec); - var defaultNamespace = inputFile.GetMetadata("CustomToolNamespace")?.Trim() ?? string.Empty; - var contracts = ParsedContracts.Parse(fileContents, defaultNamespace); + return !Log.HasLoggedErrors; + } - if (!contracts.IsValid) - { - foreach (var error in contracts.Errors) - LogError(inputFile, error.Message, error.LineNumber, error.CharacterInLine); + private void TranslateFile(ITaskItem inputFile) + { + var fileContents = File.ReadAllText(inputFile.ItemSpec); + var defaultNamespace = inputFile.GetMetadata("CustomToolNamespace")?.Trim() ?? string.Empty; + var contracts = ParsedContracts.Parse(fileContents, defaultNamespace); - return; - } + if (!contracts.IsValid) + { + foreach (var error in contracts.Errors) + LogError(inputFile, error.Message, error.LineNumber, error.CharacterInLine); - GenerateCSharpOutput(inputFile, contracts); - GenerateProtoOutput(inputFile, contracts); + return; } - private void GenerateCSharpOutput(ITaskItem inputFile, ParsedContracts contracts) - { - var targetPath = GetValidTargetFilePath(inputFile); - - var output = CSharpGenerator.Generate(contracts); - File.WriteAllText(targetPath, output); + GenerateCSharpOutput(inputFile, contracts); + GenerateProtoOutput(inputFile, contracts); + } - LogDebug($"{inputFile.ItemSpec}: Translated {contracts.Messages.Count} message{(contracts.Messages.Count > 1 ? "s" : "")}"); - } + private void GenerateCSharpOutput(ITaskItem inputFile, ParsedContracts contracts) + { + var targetPath = GetValidTargetFilePath(inputFile); - private void GenerateProtoOutput(ITaskItem inputFile, ParsedContracts contracts) - { - if (!ProtoGenerator.HasProtoOutput(contracts)) - return; + var output = CSharpGenerator.Generate(contracts); + File.WriteAllText(targetPath, output); - var targetPath = Path.ChangeExtension(GetValidTargetFilePath(inputFile), "proto") ?? throw new InvalidOperationException("Invalid target path"); + LogDebug($"{inputFile.ItemSpec}: Translated {contracts.Messages.Count} message{(contracts.Messages.Count > 1 ? "s" : "")}"); + } - var output = ProtoGenerator.Generate(contracts); - File.WriteAllText(targetPath, output); + private void GenerateProtoOutput(ITaskItem inputFile, ParsedContracts contracts) + { + if (!ProtoGenerator.HasProtoOutput(contracts)) + return; - LogDebug($"{inputFile.ItemSpec}: Generated proto file"); - } + var targetPath = Path.ChangeExtension(GetValidTargetFilePath(inputFile), "proto") ?? throw new InvalidOperationException("Invalid target path"); - private static string GetValidTargetFilePath(ITaskItem inputFile) - { - var targetPath = inputFile.GetMetadata("GeneratorTargetPath") ?? throw new InvalidOperationException("No target path specified"); - Directory.CreateDirectory(Path.GetDirectoryName(targetPath) ?? throw new InvalidOperationException("Invalid target directory")); - return targetPath; - } + var output = ProtoGenerator.Generate(contracts); + File.WriteAllText(targetPath, output); - private void LogDebug(string message) - => Log.LogMessage(_logSubcategory, null, null, null, 0, 0, 0, 0, MessageImportance.Low, message, null); + LogDebug($"{inputFile.ItemSpec}: Generated proto file"); + } - private void LogError(ITaskItem? inputFile, string message, int lineNumber = 0, int columnNumber = 0) - => Log.LogError(_logSubcategory, null, null, inputFile?.ItemSpec, lineNumber, columnNumber, 0, 0, message, null); + private static string GetValidTargetFilePath(ITaskItem inputFile) + { + var targetPath = inputFile.GetMetadata("GeneratorTargetPath") ?? throw new InvalidOperationException("No target path specified"); + Directory.CreateDirectory(Path.GetDirectoryName(targetPath) ?? throw new InvalidOperationException("Invalid target directory")); + return targetPath; } + + private void LogDebug(string message) + => Log.LogMessage(_logSubcategory, null, null, null, 0, 0, 0, 0, MessageImportance.Low, message, null); + + private void LogError(ITaskItem? inputFile, string message, int lineNumber = 0, int columnNumber = 0) + => Log.LogError(_logSubcategory, null, null, inputFile?.ItemSpec, lineNumber, columnNumber, 0, 0, message, null); } diff --git a/src/Abc.Zebus.MessageDsl.Generator.Integration/IntegrationTest.cs b/src/Abc.Zebus.MessageDsl.Generator.Integration/IntegrationTest.cs index 57c9bf4..8a8e926 100644 --- a/src/Abc.Zebus.MessageDsl.Generator.Integration/IntegrationTest.cs +++ b/src/Abc.Zebus.MessageDsl.Generator.Integration/IntegrationTest.cs @@ -2,20 +2,19 @@ using System; -namespace Abc.Zebus.MessageDsl.Generator.Integration +namespace Abc.Zebus.MessageDsl.Generator.Integration; + +public class IntegrationTest { - public class IntegrationTest + static IntegrationTest() { - static IntegrationTest() - { - GC.KeepAlive(typeof(SomeMessage)); - GC.KeepAlive(typeof(InnerNamespace.InnerMessage)); - GC.KeepAlive(typeof(Abc.Zebus.CustomNamespace.HasCustomNamespace)); - GC.KeepAlive(typeof(Abc.Zebus.CustomExplicitNamespace.HasCustomExplicitNamespace)); - GC.KeepAlive(typeof(global::HasEmptyNamespace)); - GC.KeepAlive(typeof(ExplicitItems.A.ExplicitlyDefinedMessage)); - GC.KeepAlive(typeof(ExplicitItems.B.ExplicitlyDefinedMessage)); - } + GC.KeepAlive(typeof(SomeMessage)); + GC.KeepAlive(typeof(InnerNamespace.InnerMessage)); + GC.KeepAlive(typeof(Abc.Zebus.CustomNamespace.HasCustomNamespace)); + GC.KeepAlive(typeof(Abc.Zebus.CustomExplicitNamespace.HasCustomExplicitNamespace)); + GC.KeepAlive(typeof(global::HasEmptyNamespace)); + GC.KeepAlive(typeof(ExplicitItems.A.ExplicitlyDefinedMessage)); + GC.KeepAlive(typeof(ExplicitItems.B.ExplicitlyDefinedMessage)); } } diff --git a/src/Abc.Zebus.MessageDsl.Generator.Tests/MessageDslGeneratorTests.cs b/src/Abc.Zebus.MessageDsl.Generator.Tests/MessageDslGeneratorTests.cs index 9965234..8e489a1 100644 --- a/src/Abc.Zebus.MessageDsl.Generator.Tests/MessageDslGeneratorTests.cs +++ b/src/Abc.Zebus.MessageDsl.Generator.Tests/MessageDslGeneratorTests.cs @@ -8,134 +8,133 @@ using System.Linq; using System.Threading; -namespace Abc.Zebus.MessageDsl.Generator.Tests +namespace Abc.Zebus.MessageDsl.Generator.Tests; + +[TestFixture] +public class MessageDslGeneratorTests { - [TestFixture] - public class MessageDslGeneratorTests + [Test] + public void Should_generate_message_class_for_simple_dsl_file() { - [Test] - public void Should_generate_message_class_for_simple_dsl_file() - { - // Arrange - var additionalTextMock = CreateAdditionalTextMock(@"Dsl\Messages.msg", @"DoSomethingCommand(int foo);"); - var optionsProviderMock = CreateOptionProviderMock(new[] { additionalTextMock }, ("build_metadata.AdditionalFiles.ZebusMessageDslNamespace", "Abc.Zebus.TestNamespace")); - - // Act - var runResults = CSharpGeneratorDriver.Create(new MessageDslGenerator()) - .AddAdditionalTexts(ImmutableArray.Create(additionalTextMock.Object)) - .WithUpdatedAnalyzerConfigOptions(optionsProviderMock.Object) - .RunGenerators(CSharpCompilation.Create("Tests")) - .GetRunResult(); - - // Assert - var generatedSource = runResults.Results[0].GeneratedSources[0]; - Assert.That(generatedSource.HintName, Is.EqualTo("Dsl_Messages.msg.g.cs")); - var sourceText = generatedSource.SourceText.ToString(); - Assert.That(sourceText, Does.Contain(@"public sealed partial class DoSomethingCommand : ICommand")); - } - - [Test] - public void Should_generate_message_class_for_multiple_additional_files_with_conflicting_names() - { - // Arrange - var additionalTextMock1 = CreateAdditionalTextMock(@"Dsl\Messages.msg", @"DoSomethingCommand(int foo);"); - var additionalTextMock2 = CreateAdditionalTextMock(@"Dsl\Messages.msg", @"DoSomethingCommand(int foo);"); - - var optionsProviderMock1 = CreateOptionProviderMock(new[] { additionalTextMock1 }, ("build_metadata.AdditionalFiles.ZebusMessageDslNamespace", "Abc.Zebus.TestNamespace"), ("build_metadata.AdditionalFiles.ZebusMessageDslRelativePath", "Dsl/Messages1.msg")); - var optionsProviderMock2 = CreateOptionProviderMock(new[] { additionalTextMock2 }, ("build_metadata.AdditionalFiles.ZebusMessageDslNamespace", "Abc.Zebus.TestNamespace"), ("build_metadata.AdditionalFiles.ZebusMessageDslRelativePath", "Dsl/Messages2.msg")); - - // Act - var runResults = CSharpGeneratorDriver.Create(new MessageDslGenerator()) - .AddAdditionalTexts(ImmutableArray.Create(additionalTextMock1.Object, additionalTextMock2.Object)) - .WithUpdatedAnalyzerConfigOptions(CombineOptionProviderMocks(optionsProviderMock1.Object, optionsProviderMock2.Object).Object) - .RunGenerators(CSharpCompilation.Create("Tests")) - .GetRunResult(); - - // Assert - var generatedSource1 = runResults.Results[0].GeneratedSources.Single(x => x.HintName == "Dsl_Messages1.msg.g.cs"); - AssertMessageSourceIsCorrect(generatedSource1, "Dsl_Messages1.msg.g.cs", "public sealed partial class DoSomethingCommand : ICommand"); - var generatedSource2 = runResults.Results[0].GeneratedSources.Single(x => x.HintName == "Dsl_Messages2.msg.g.cs"); - AssertMessageSourceIsCorrect(generatedSource2, "Dsl_Messages2.msg.g.cs", "public sealed partial class DoSomethingCommand : ICommand"); - } - - [Test] - public void Should_not_generate_message_class_for_non_message_additional_files() - { - // Arrange - var additionalTextMock = CreateAdditionalTextMock(@"Dsl\Messages.notamessage", @"DoSomethingCommand(int foo);"); - var optionsProviderMock = CreateOptionProviderMock(new[] { additionalTextMock }, ("build_metadata.AdditionalFiles.ZebusMessageDslNamespace", "Abc.Zebus.TestNamespace")); - - // Act - var runResults = CSharpGeneratorDriver.Create(new MessageDslGenerator()) - .AddAdditionalTexts(ImmutableArray.Create(additionalTextMock.Object)) - .WithUpdatedAnalyzerConfigOptions(optionsProviderMock.Object) - .RunGenerators(CSharpCompilation.Create("Tests")) - .GetRunResult(); - - // Assert - Assert.That(runResults.Results[0].GeneratedSources, Is.Empty); - } + // Arrange + var additionalTextMock = CreateAdditionalTextMock(@"Dsl\Messages.msg", "DoSomethingCommand(int foo);"); + var optionsProviderMock = CreateOptionProviderMock([additionalTextMock], ("build_metadata.AdditionalFiles.ZebusMessageDslNamespace", "Abc.Zebus.TestNamespace")); + + // Act + var runResults = CSharpGeneratorDriver.Create(new MessageDslGenerator()) + .AddAdditionalTexts(ImmutableArray.Create(additionalTextMock.Object)) + .WithUpdatedAnalyzerConfigOptions(optionsProviderMock.Object) + .RunGenerators(CSharpCompilation.Create("Tests")) + .GetRunResult(); + + // Assert + var generatedSource = runResults.Results[0].GeneratedSources[0]; + Assert.That(generatedSource.HintName, Is.EqualTo("Dsl_Messages.msg.g.cs")); + var sourceText = generatedSource.SourceText.ToString(); + Assert.That(sourceText, Does.Contain("public sealed partial class DoSomethingCommand : ICommand")); + } - [Test] - public void Should_not_generate_message_class_if_ZebusMessageDslNamespace_option_is_not_set() - { - // Arrange - var additionalTextMock = CreateAdditionalTextMock(@"Dsl\Messages.msg", @"DoSomethingCommand(int foo);"); + [Test] + public void Should_generate_message_class_for_multiple_additional_files_with_conflicting_names() + { + // Arrange + var additionalTextMock1 = CreateAdditionalTextMock(@"Dsl\Messages.msg", "DoSomethingCommand(int foo);"); + var additionalTextMock2 = CreateAdditionalTextMock(@"Dsl\Messages.msg", "DoSomethingCommand(int foo);"); + + var optionsProviderMock1 = CreateOptionProviderMock([additionalTextMock1], ("build_metadata.AdditionalFiles.ZebusMessageDslNamespace", "Abc.Zebus.TestNamespace"), ("build_metadata.AdditionalFiles.ZebusMessageDslRelativePath", "Dsl/Messages1.msg")); + var optionsProviderMock2 = CreateOptionProviderMock([additionalTextMock2], ("build_metadata.AdditionalFiles.ZebusMessageDslNamespace", "Abc.Zebus.TestNamespace"), ("build_metadata.AdditionalFiles.ZebusMessageDslRelativePath", "Dsl/Messages2.msg")); + + // Act + var runResults = CSharpGeneratorDriver.Create(new MessageDslGenerator()) + .AddAdditionalTexts(ImmutableArray.Create(additionalTextMock1.Object, additionalTextMock2.Object)) + .WithUpdatedAnalyzerConfigOptions(CombineOptionProviderMocks(optionsProviderMock1.Object, optionsProviderMock2.Object).Object) + .RunGenerators(CSharpCompilation.Create("Tests")) + .GetRunResult(); + + // Assert + var generatedSource1 = runResults.Results[0].GeneratedSources.Single(x => x.HintName == "Dsl_Messages1.msg.g.cs"); + AssertMessageSourceIsCorrect(generatedSource1, "Dsl_Messages1.msg.g.cs", "public sealed partial class DoSomethingCommand : ICommand"); + var generatedSource2 = runResults.Results[0].GeneratedSources.Single(x => x.HintName == "Dsl_Messages2.msg.g.cs"); + AssertMessageSourceIsCorrect(generatedSource2, "Dsl_Messages2.msg.g.cs", "public sealed partial class DoSomethingCommand : ICommand"); + } - // Act - var runResults = CSharpGeneratorDriver.Create(new MessageDslGenerator()) - .AddAdditionalTexts(ImmutableArray.Create(additionalTextMock.Object)) - .RunGenerators(CSharpCompilation.Create("Tests")) - .GetRunResult(); + [Test] + public void Should_not_generate_message_class_for_non_message_additional_files() + { + // Arrange + var additionalTextMock = CreateAdditionalTextMock(@"Dsl\Messages.notamessage", "DoSomethingCommand(int foo);"); + var optionsProviderMock = CreateOptionProviderMock([additionalTextMock], ("build_metadata.AdditionalFiles.ZebusMessageDslNamespace", "Abc.Zebus.TestNamespace")); + + // Act + var runResults = CSharpGeneratorDriver.Create(new MessageDslGenerator()) + .AddAdditionalTexts(ImmutableArray.Create(additionalTextMock.Object)) + .WithUpdatedAnalyzerConfigOptions(optionsProviderMock.Object) + .RunGenerators(CSharpCompilation.Create("Tests")) + .GetRunResult(); + + // Assert + Assert.That(runResults.Results[0].GeneratedSources, Is.Empty); + } - // Assert - Assert.That(runResults.Results[0].GeneratedSources, Is.Empty); - } + [Test] + public void Should_not_generate_message_class_if_ZebusMessageDslNamespace_option_is_not_set() + { + // Arrange + var additionalTextMock = CreateAdditionalTextMock(@"Dsl\Messages.msg", "DoSomethingCommand(int foo);"); - private static Mock CreateOptionProviderMock(Mock[] additionalTextMocks, params (string key, string value)[] options) - { - var optionsProviderMock = new Mock(); - var optionsMock = new Mock(); + // Act + var runResults = CSharpGeneratorDriver.Create(new MessageDslGenerator()) + .AddAdditionalTexts(ImmutableArray.Create(additionalTextMock.Object)) + .RunGenerators(CSharpCompilation.Create("Tests")) + .GetRunResult(); - foreach (var option in options) - { - var fileNamespace = option.value; - optionsMock.Setup(x => x.TryGetValue(option.key, out fileNamespace)).Returns(true); - } + // Assert + Assert.That(runResults.Results[0].GeneratedSources, Is.Empty); + } - foreach (var additionalTextMock in additionalTextMocks) - { - optionsProviderMock.Setup(x => x.GetOptions(additionalTextMock.Object)).Returns(optionsMock.Object); - } + private static Mock CreateOptionProviderMock(Mock[] additionalTextMocks, params (string key, string value)[] options) + { + var optionsProviderMock = new Mock(); + var optionsMock = new Mock(); - return optionsProviderMock; + foreach (var option in options) + { + var fileNamespace = option.value; + optionsMock.Setup(x => x.TryGetValue(option.key, out fileNamespace)).Returns(true); } - private static Mock CombineOptionProviderMocks(params AnalyzerConfigOptionsProvider[] providers) + foreach (var additionalTextMock in additionalTextMocks) { - var optionsProviderMock = new Mock(); + optionsProviderMock.Setup(x => x.GetOptions(additionalTextMock.Object)).Returns(optionsMock.Object); + } - optionsProviderMock.Setup(x => x.GetOptions(It.IsAny())) - .Returns((AdditionalText additionalText) => providers.Select(p => p.GetOptions(additionalText)).FirstOrDefault(i => !ReferenceEquals(i, null))!); + return optionsProviderMock; + } - return optionsProviderMock; - } + private static Mock CombineOptionProviderMocks(params AnalyzerConfigOptionsProvider[] providers) + { + var optionsProviderMock = new Mock(); - private static Mock CreateAdditionalTextMock(string path, string source) - { - var additionalTextMock = new Mock(); - additionalTextMock.SetupGet(x => x.Path).Returns(path); - additionalTextMock.Setup(x => x.GetText(It.IsAny())) - .Returns(SourceText.From(source)); + optionsProviderMock.Setup(x => x.GetOptions(It.IsAny())) + .Returns((AdditionalText additionalText) => providers.Select(p => p.GetOptions(additionalText)).FirstOrDefault(i => !ReferenceEquals(i, null))!); - return additionalTextMock; - } + return optionsProviderMock; + } - private static void AssertMessageSourceIsCorrect(GeneratedSourceResult generatedSource, string expectedHintName, string expectedSourceFragment) - { - Assert.That(generatedSource.HintName, Is.EqualTo(expectedHintName)); - var sourceText = generatedSource.SourceText.ToString(); - Assert.That(sourceText, Does.Contain(expectedSourceFragment)); - } + private static Mock CreateAdditionalTextMock(string path, string source) + { + var additionalTextMock = new Mock(); + additionalTextMock.SetupGet(x => x.Path).Returns(path); + additionalTextMock.Setup(x => x.GetText(It.IsAny())) + .Returns(SourceText.From(source)); + + return additionalTextMock; + } + + private static void AssertMessageSourceIsCorrect(GeneratedSourceResult generatedSource, string expectedHintName, string expectedSourceFragment) + { + Assert.That(generatedSource.HintName, Is.EqualTo(expectedHintName)); + var sourceText = generatedSource.SourceText.ToString(); + Assert.That(sourceText, Does.Contain(expectedSourceFragment)); } } diff --git a/src/Abc.Zebus.MessageDsl.Generator/Generator/Extensions.cs b/src/Abc.Zebus.MessageDsl.Generator/Generator/Extensions.cs index 0f11c86..f9e2e04 100644 --- a/src/Abc.Zebus.MessageDsl.Generator/Generator/Extensions.cs +++ b/src/Abc.Zebus.MessageDsl.Generator/Generator/Extensions.cs @@ -3,30 +3,29 @@ #nullable enable -namespace Abc.Zebus.MessageDsl.Generator +namespace Abc.Zebus.MessageDsl.Generator; + +internal static class Extensions { - internal static class Extensions - { - private static readonly char[] _crlf = { '\r', '\n' }; + private static readonly char[] _crlf = ['\r', '\n']; - public static LinePositionSpan ToLinePositionSpan(this SyntaxError error) - { - if (error.LineNumber <= 0 || error.CharacterInLine <= 0) - return default; + public static LinePositionSpan ToLinePositionSpan(this SyntaxError error) + { + if (error.LineNumber <= 0 || error.CharacterInLine <= 0) + return default; - var length = 0; + var length = 0; - if (error.Token is not null) - { - length = error.Token.IndexOfAny(_crlf); - if (length < 0) - length = error.Token.Length; - } + if (error.Token is not null) + { + length = error.Token.IndexOfAny(_crlf); + if (length < 0) + length = error.Token.Length; + } - var startPosition = new LinePosition(error.LineNumber - 1, error.CharacterInLine - 1); - var endPosition = new LinePosition(startPosition.Line, startPosition.Character + length); + var startPosition = new LinePosition(error.LineNumber - 1, error.CharacterInLine - 1); + var endPosition = new LinePosition(startPosition.Line, startPosition.Character + length); - return new LinePositionSpan(startPosition, endPosition); - } + return new LinePositionSpan(startPosition, endPosition); } } diff --git a/src/Abc.Zebus.MessageDsl.Generator/Generator/MessageDslDiagnostics.cs b/src/Abc.Zebus.MessageDsl.Generator/Generator/MessageDslDiagnostics.cs index 746adcd..0620ed3 100644 --- a/src/Abc.Zebus.MessageDsl.Generator/Generator/MessageDslDiagnostics.cs +++ b/src/Abc.Zebus.MessageDsl.Generator/Generator/MessageDslDiagnostics.cs @@ -2,15 +2,14 @@ #nullable enable -namespace Abc.Zebus.MessageDsl.Generator +namespace Abc.Zebus.MessageDsl.Generator; + +internal static class MessageDslDiagnostics { - internal static class MessageDslDiagnostics - { - public static DiagnosticDescriptor MessageDslError { get; } = Error(1, "MessageDsl error", "{0}"); - public static DiagnosticDescriptor UnexpectedError { get; } = Error(2, "Unexpected error", "Unexpected error: {0}"); - public static DiagnosticDescriptor CouldNotReadFileContents { get; } = Error(3, "Could not read file contents", "Could not read file contents"); + public static DiagnosticDescriptor MessageDslError { get; } = Error(1, "MessageDsl error", "{0}"); + public static DiagnosticDescriptor UnexpectedError { get; } = Error(2, "Unexpected error", "Unexpected error: {0}"); + public static DiagnosticDescriptor CouldNotReadFileContents { get; } = Error(3, "Could not read file contents", "Could not read file contents"); - private static DiagnosticDescriptor Error(int index, string title, string messageFormat) - => new($"MessageDsl{index:D3}", title, messageFormat, "MessageDsl", DiagnosticSeverity.Error, true); - } + private static DiagnosticDescriptor Error(int index, string title, string messageFormat) + => new($"MessageDsl{index:D3}", title, messageFormat, "MessageDsl", DiagnosticSeverity.Error, true); } diff --git a/src/Abc.Zebus.MessageDsl.Generator/Generator/MessageDslGenerator.cs b/src/Abc.Zebus.MessageDsl.Generator/Generator/MessageDslGenerator.cs index 482a101..7a81fb9 100644 --- a/src/Abc.Zebus.MessageDsl.Generator/Generator/MessageDslGenerator.cs +++ b/src/Abc.Zebus.MessageDsl.Generator/Generator/MessageDslGenerator.cs @@ -8,105 +8,104 @@ #nullable enable -namespace Abc.Zebus.MessageDsl.Generator +namespace Abc.Zebus.MessageDsl.Generator; + +[Generator] +public class MessageDslGenerator : IIncrementalGenerator { - [Generator] - public class MessageDslGenerator : IIncrementalGenerator + private static readonly Regex _sanitizePathRegex = new(@"[:\\/]+", RegexOptions.Compiled); + + public void Initialize(IncrementalGeneratorInitializationContext context) { - private static readonly Regex _sanitizePathRegex = new(@"[:\\/]+", RegexOptions.Compiled); + var additionalTextsWithNamespaces = context.AdditionalTextsProvider + .Where(x => x.Path.EndsWith(".msg")) + .Combine(context.AnalyzerConfigOptionsProvider) + .Select((input, _) => + { + var (additionalText, config) = input; + var fileNamespace = config.GetOptions(additionalText).TryGetValue("build_metadata.AdditionalFiles.ZebusMessageDslNamespace", out var ns) ? ns : null; + var relativePath = config.GetOptions(additionalText).TryGetValue("build_metadata.AdditionalFiles.ZebusMessageDslRelativePath", out var dir) ? dir : null; + return new SourceGenerationInput(additionalText, fileNamespace, relativePath); + }) + .Where(x => x.FileNamespace != null) + .Select(GenerateCode) + .Collect(); + + context.RegisterSourceOutput(additionalTextsWithNamespaces, GenerateFiles); + } - public void Initialize(IncrementalGeneratorInitializationContext context) - { - var additionalTextsWithNamespaces = context.AdditionalTextsProvider - .Where(x => x.Path.EndsWith(".msg")) - .Combine(context.AnalyzerConfigOptionsProvider) - .Select((input, _) => - { - var (additionalText, config) = input; - var fileNamespace = config.GetOptions(additionalText).TryGetValue("build_metadata.AdditionalFiles.ZebusMessageDslNamespace", out var ns) ? ns : null; - var relativePath = config.GetOptions(additionalText).TryGetValue("build_metadata.AdditionalFiles.ZebusMessageDslRelativePath", out var dir) ? dir : null; - return new SourceGenerationInput(additionalText, fileNamespace, relativePath); - }) - .Where(x => x.FileNamespace != null) - .Select(GenerateCode) - .Collect(); - - context.RegisterSourceOutput(additionalTextsWithNamespaces, GenerateFiles); - } + private static SourceGenerationResult GenerateCode(SourceGenerationInput input, CancellationToken cancellationToken) + { + var (file, fileNamespace, relativePath) = input; - private static SourceGenerationResult GenerateCode(SourceGenerationInput input, CancellationToken cancellationToken) - { - var (file, fileNamespace, relativePath) = input; + var fileContents = file.GetText(cancellationToken)?.ToString(); - var fileContents = file.GetText(cancellationToken)?.ToString(); + if (fileContents is null) + return SourceGenerationResult.Error(Diagnostic.Create(MessageDslDiagnostics.CouldNotReadFileContents, Location.Create(file.Path, default, default))); - if (fileContents is null) - return SourceGenerationResult.Error(Diagnostic.Create(MessageDslDiagnostics.CouldNotReadFileContents, Location.Create(file.Path, default, default))); + try + { + var contracts = ParsedContracts.Parse(fileContents, fileNamespace!.Trim()); - try + if (!contracts.IsValid) { - var contracts = ParsedContracts.Parse(fileContents, fileNamespace!.Trim()); - - if (!contracts.IsValid) + var diagnostics = new List(); + foreach (var error in contracts.Errors) { - var diagnostics = new List(); - foreach (var error in contracts.Errors) - { - var location = Location.Create(file.Path, default, error.ToLinePositionSpan()); - diagnostics.Add(Diagnostic.Create(MessageDslDiagnostics.MessageDslError, location, error.Message)); - } - - return SourceGenerationResult.Error(diagnostics.ToArray()); + var location = Location.Create(file.Path, default, error.ToLinePositionSpan()); + diagnostics.Add(Diagnostic.Create(MessageDslDiagnostics.MessageDslError, location, error.Message)); } - var output = CSharpGenerator.Generate(contracts); - - return SourceGenerationResult.Success(file, output, relativePath); - } - catch (Exception ex) when (ex is not OperationCanceledException) - { - return SourceGenerationResult.Error(Diagnostic.Create(MessageDslDiagnostics.UnexpectedError, Location.None, ex.ToString())); + return SourceGenerationResult.Error(diagnostics.ToArray()); } + + var output = CSharpGenerator.Generate(contracts); + + return SourceGenerationResult.Success(file, output, relativePath); } + catch (Exception ex) when (ex is not OperationCanceledException) + { + return SourceGenerationResult.Error(Diagnostic.Create(MessageDslDiagnostics.UnexpectedError, Location.None, ex.ToString())); + } + } - private static void GenerateFiles(SourceProductionContext context, ImmutableArray results) + private static void GenerateFiles(SourceProductionContext context, ImmutableArray results) + { + foreach (var result in results) { - foreach (var result in results) - { - foreach (var diagnostic in result.Diagnostics) - context.ReportDiagnostic(diagnostic); + foreach (var diagnostic in result.Diagnostics) + context.ReportDiagnostic(diagnostic); - if (result.AdditionalText == null || result.GeneratedSource == null) - continue; + if (result.AdditionalText == null || result.GeneratedSource == null) + continue; - var hintName = _sanitizePathRegex.Replace(result.RelativePath ?? result.AdditionalText.Path, "_") + ".g.cs"; + var hintName = _sanitizePathRegex.Replace(result.RelativePath ?? result.AdditionalText.Path, "_") + ".g.cs"; - context.AddSource(hintName, result.GeneratedSource); - } + context.AddSource(hintName, result.GeneratedSource); } + } - private record SourceGenerationInput(AdditionalText AdditionalText, string? FileNamespace, string? RelativePath); + private record SourceGenerationInput(AdditionalText AdditionalText, string? FileNamespace, string? RelativePath); - public class SourceGenerationResult - { - public IList Diagnostics { get; } - public string? GeneratedSource { get; set; } - public AdditionalText? AdditionalText { get; } - public string? RelativePath { get; } + public class SourceGenerationResult + { + public IList Diagnostics { get; } + public string? GeneratedSource { get; set; } + public AdditionalText? AdditionalText { get; } + public string? RelativePath { get; } - public SourceGenerationResult(IList diagnostics, string? generatedSource, AdditionalText? additionalText, string? relativePath) - { - Diagnostics = diagnostics; - GeneratedSource = generatedSource; - AdditionalText = additionalText; - RelativePath = relativePath; - } + private SourceGenerationResult(IList diagnostics, string? generatedSource, AdditionalText? additionalText, string? relativePath) + { + Diagnostics = diagnostics; + GeneratedSource = generatedSource; + AdditionalText = additionalText; + RelativePath = relativePath; + } - public static SourceGenerationResult Error(params Diagnostic[] diagnostics) - => new(diagnostics, null, null, null); + public static SourceGenerationResult Error(params Diagnostic[] diagnostics) + => new(diagnostics, null, null, null); - public static SourceGenerationResult Success(AdditionalText? additionalText, string generatedSource, string? relativePath) - => new(Array.Empty(), generatedSource, additionalText, relativePath); - } + public static SourceGenerationResult Success(AdditionalText? additionalText, string generatedSource, string? relativePath) + => new(Array.Empty(), generatedSource, additionalText, relativePath); } } diff --git a/src/Abc.Zebus.MessageDsl.Generator/Support/CompilerServices.cs b/src/Abc.Zebus.MessageDsl.Generator/Support/CompilerServices.cs index 8d89869..f72d845 100644 --- a/src/Abc.Zebus.MessageDsl.Generator/Support/CompilerServices.cs +++ b/src/Abc.Zebus.MessageDsl.Generator/Support/CompilerServices.cs @@ -1,6 +1,4 @@ // ReSharper disable once CheckNamespace namespace System.Runtime.CompilerServices; -internal static class IsExternalInit -{ -} +internal static class IsExternalInit; diff --git a/src/Abc.Zebus.MessageDsl.Tests/MessageDsl/CSharpGeneratorTests.cs b/src/Abc.Zebus.MessageDsl.Tests/MessageDsl/CSharpGeneratorTests.cs index 2d8e47e..4a5e41c 100644 --- a/src/Abc.Zebus.MessageDsl.Tests/MessageDsl/CSharpGeneratorTests.cs +++ b/src/Abc.Zebus.MessageDsl.Tests/MessageDsl/CSharpGeneratorTests.cs @@ -4,893 +4,893 @@ using Abc.Zebus.MessageDsl.Tests.TestTools; using NUnit.Framework; -namespace Abc.Zebus.MessageDsl.Tests.MessageDsl +namespace Abc.Zebus.MessageDsl.Tests.MessageDsl; + +[TestFixture] +public class CSharpGeneratorTests : GeneratorTests { - [TestFixture] - public class CSharpGeneratorTests : GeneratorTests + [Test] + public void should_generate_code() { - [Test] - public void should_generate_code() + var code = Generate(new MessageDefinition { - var code = Generate(new MessageDefinition + Name = "FooExecuted", + Parameters = { - Name = "FooExecuted", - Parameters = - { - new ParameterDefinition("System.Int32", "foo"), - new ParameterDefinition("string", "bar") - } - }); + new ParameterDefinition("System.Int32", "foo"), + new ParameterDefinition("string", "bar") + } + }); - code.ShouldContain("public sealed partial class FooExecuted : IEvent"); - code.ShouldContain("public FooExecuted(int foo, string bar)"); - } + code.ShouldContain("public sealed partial class FooExecuted : IEvent"); + code.ShouldContain("public FooExecuted(int foo, string bar)"); + } - [Test] - public void should_generate_default_values() + [Test] + public void should_generate_default_values() + { + var code = Generate(new MessageDefinition { - var code = Generate(new MessageDefinition + Name = "FooExecuted", + Parameters = { - Name = "FooExecuted", - Parameters = - { - new ParameterDefinition("int", "foo") { DefaultValue = "42" } - } - }); + new ParameterDefinition("int", "foo") { DefaultValue = "42" } + } + }); + + code.ShouldContain("public FooExecuted(int foo = 42)"); + } - code.ShouldContain("public FooExecuted(int foo = 42)"); - } + [Test] + public void should_generate_packed_members() + { + var code = Generate(new MessageDefinition + { + Name = "FooExecuted", + Parameters = + { + new ParameterDefinition("System.Int32[]", "foo"), + new ParameterDefinition("LolType[]", "bar"), + new ParameterDefinition("List", "fooList"), + new ParameterDefinition("List", "barList"), + } + }); + + code.ShouldContainIgnoreIndent("[ProtoMember(1, IsRequired = false, IsPacked = true)]\npublic int[] Foo { get; private set; }"); + code.ShouldContainIgnoreIndent("[ProtoMember(2, IsRequired = false)]\npublic LolType[] Bar { get; private set; }"); + code.ShouldContainIgnoreIndent("[ProtoMember(3, IsRequired = false, IsPacked = true)]\npublic List FooList { get; private set; }"); + code.ShouldContainIgnoreIndent("[ProtoMember(4, IsRequired = false)]\npublic List BarList { get; private set; }"); + code.ShouldContain("Foo = Array.Empty();"); + code.ShouldContain("Bar = Array.Empty();"); + code.ShouldContain("FooList = new List();"); + code.ShouldContain("BarList = new List();"); + code.ShouldContain("using System.Collections.Generic;"); + } - [Test] - public void should_generate_packed_members() + [Test] + public void should_call_constructor_for_Dictionary() + { + var code = Generate(new MessageDefinition { - var code = Generate(new MessageDefinition + Name = "FooExecuted", + Parameters = { - Name = "FooExecuted", - Parameters = - { - new ParameterDefinition("System.Int32[]", "foo"), - new ParameterDefinition("LolType[]", "bar"), - new ParameterDefinition("List", "fooList"), - new ParameterDefinition("List", "barList"), - } - }); - - code.ShouldContainIgnoreIndent("[ProtoMember(1, IsRequired = false, IsPacked = true)]\npublic int[] Foo { get; private set; }"); - code.ShouldContainIgnoreIndent("[ProtoMember(2, IsRequired = false)]\npublic LolType[] Bar { get; private set; }"); - code.ShouldContainIgnoreIndent("[ProtoMember(3, IsRequired = false, IsPacked = true)]\npublic List FooList { get; private set; }"); - code.ShouldContainIgnoreIndent("[ProtoMember(4, IsRequired = false)]\npublic List BarList { get; private set; }"); - code.ShouldContain("Foo = Array.Empty();"); - code.ShouldContain("Bar = Array.Empty();"); - code.ShouldContain("FooList = new List();"); - code.ShouldContain("BarList = new List();"); - code.ShouldContain("using System.Collections.Generic;"); - } - - [Test] - public void should_call_constructor_for_Dictionary() - { - var code = Generate(new MessageDefinition - { - Name = "FooExecuted", - Parameters = - { - new ParameterDefinition("Dictionary", "fooDico"), - } - }); + new ParameterDefinition("Dictionary", "fooDico"), + } + }); - code.ShouldContainIgnoreIndent("[ProtoMember(1, IsRequired = true)]\n[ProtoMap(DisableMap = true)]\npublic Dictionary FooDico { get; private set; }"); - code.ShouldContain("using System.Collections.Generic;"); - code.ShouldContain("FooDico = new Dictionary();"); - } + code.ShouldContainIgnoreIndent("[ProtoMember(1, IsRequired = true)]\n[ProtoMap(DisableMap = true)]\npublic Dictionary FooDico { get; private set; }"); + code.ShouldContain("using System.Collections.Generic;"); + code.ShouldContain("FooDico = new Dictionary();"); + } - [Test] - public void should_call_constructor_for_HashSet() + [Test] + public void should_call_constructor_for_HashSet() + { + var code = Generate(new MessageDefinition { - var code = Generate(new MessageDefinition + Name = "FooExecuted", + Parameters = { - Name = "FooExecuted", - Parameters = - { - new ParameterDefinition("HashSet", "fooHashSet"), - } - }); + new ParameterDefinition("HashSet", "fooHashSet"), + } + }); - code.ShouldContainIgnoreIndent("[ProtoMember(1, IsRequired = false)]\npublic HashSet FooHashSet { get; private set; }"); - code.ShouldContain("using System.Collections.Generic;"); - code.ShouldContain("FooHashSet = new HashSet();"); - } + code.ShouldContainIgnoreIndent("[ProtoMember(1, IsRequired = false)]\npublic HashSet FooHashSet { get; private set; }"); + code.ShouldContain("using System.Collections.Generic;"); + code.ShouldContain("FooHashSet = new HashSet();"); + } - [Test] - public void should_generate_attributes() + [Test] + public void should_generate_attributes() + { + var code = Generate(new MessageDefinition { - var code = Generate(new MessageDefinition + Name = "FooExecuted", + Attributes = { - Name = "FooExecuted", - Attributes = - { - new AttributeDefinition("Transient") - }, - Parameters = + new AttributeDefinition("Transient") + }, + Parameters = + { + new ParameterDefinition("int", "foo") { - new ParameterDefinition("int", "foo") + Attributes = { - Attributes = - { - new AttributeDefinition("LolAttribute", "LolParam = 42") - } + new AttributeDefinition("LolAttribute", "LolParam = 42") } } - }); + } + }); - code.ShouldContain("[Transient]"); - code.ShouldContain("Lol(LolParam = 42)"); - } + code.ShouldContain("[Transient]"); + code.ShouldContain("Lol(LolParam = 42)"); + } - [Test] - public void should_generate_mutable_properties() - { - var contract = new ParsedContracts(); + [Test] + public void should_generate_mutable_properties() + { + var contract = new ParsedContracts(); - contract.Messages.Add(new MessageDefinition + contract.Messages.Add(new MessageDefinition + { + Name = "FooExecuted", + Options = { Mutable = true }, + Parameters = { - Name = "FooExecuted", - Options = { Mutable = true }, - Parameters = - { - new ParameterDefinition("System.Int32", "foo"), - new ParameterDefinition("string", "bar") - } - }); - var code = Generate(contract); + new ParameterDefinition("System.Int32", "foo"), + new ParameterDefinition("string", "bar") + } + }); + var code = Generate(contract); - code.ShouldContain("public int Foo { get; set; }"); - } + code.ShouldContain("public int Foo { get; set; }"); + } - [Test] - public void should_generate_properties() - { - var contract = new ParsedContracts(); + [Test] + public void should_generate_properties() + { + var contract = new ParsedContracts(); - contract.Messages.Add(new MessageDefinition + contract.Messages.Add(new MessageDefinition + { + Name = "FooExecuted", + Parameters = { - Name = "FooExecuted", - Parameters = - { - new ParameterDefinition("System.Int32", "foo"), - new ParameterDefinition("string", "bar") - } - }); - var code = Generate(contract); + new ParameterDefinition("System.Int32", "foo"), + new ParameterDefinition("string", "bar") + } + }); + var code = Generate(contract); - code.ShouldContain("public int Foo { get; private set; }"); - } + code.ShouldContain("public int Foo { get; private set; }"); + } + + [Test] + public void should_generate_generic_messages() + { + var contract = new ParsedContracts(); - [Test] - public void should_generate_generic_messages() + contract.Messages.Add(new MessageDefinition { - var contract = new ParsedContracts(); + Name = "FooExecuted", + GenericParameters = { "TFoo", "TBar" } + }); + var code = Generate(contract); - contract.Messages.Add(new MessageDefinition - { - Name = "FooExecuted", - GenericParameters = { "TFoo", "TBar" } - }); - var code = Generate(contract); + code.ShouldContain("class FooExecuted : IEvent"); + } - code.ShouldContain("class FooExecuted : IEvent"); - } + [Test] + public void should_generate_generic_constraints() + { + var contract = new ParsedContracts(); - [Test] - public void should_generate_generic_constraints() + contract.Messages.Add(new MessageDefinition { - var contract = new ParsedContracts(); - - contract.Messages.Add(new MessageDefinition + Name = "FooExecuted", + GenericParameters = { "TFoo", "TBar" }, + GenericConstraints = { - Name = "FooExecuted", - GenericParameters = { "TFoo", "TBar" }, - GenericConstraints = + new GenericConstraint { - new GenericConstraint - { - GenericParameterName = "TFoo", - IsClass = true, - HasDefaultConstructor = true, - Types = { "IDisposable" } - }, - new GenericConstraint - { - GenericParameterName = "TBar", - IsStruct = true, - HasDefaultConstructor = true, - Types = { "IDisposable" } - }, - } - }); - var code = Generate(contract); + GenericParameterName = "TFoo", + IsClass = true, + HasDefaultConstructor = true, + Types = { "IDisposable" } + }, + new GenericConstraint + { + GenericParameterName = "TBar", + IsStruct = true, + HasDefaultConstructor = true, + Types = { "IDisposable" } + }, + } + }); + var code = Generate(contract); - code.ShouldContain("where TFoo : class, IDisposable, new()"); - code.ShouldContain("where TBar : struct, IDisposable"); - } + code.ShouldContain("where TFoo : class, IDisposable, new()"); + code.ShouldContain("where TBar : struct, IDisposable"); + } - [Test] - public void should_not_generate_constructor_when_there_are_no_parameters() + [Test] + public void should_not_generate_constructor_when_there_are_no_parameters() + { + var code = Generate(new MessageDefinition { - var code = Generate(new MessageDefinition - { - Name = "FooExecuted" - }); + Name = "FooExecuted" + }); - code.ShouldNotContain("FooExecuted("); - } + code.ShouldNotContain("FooExecuted("); + } - [Test] - public void should_handle_escaped_keywords() + [Test] + public void should_handle_escaped_keywords() + { + var code = Generate(new MessageDefinition { - var code = Generate(new MessageDefinition + Name = "void", + Parameters = { - Name = "void", - Parameters = - { - new ParameterDefinition("int", "double"), - new ParameterDefinition("if", "else") - }, - GenericParameters = { "float" }, - GenericConstraints = + new ParameterDefinition("int", "double"), + new ParameterDefinition("if", "else") + }, + GenericParameters = { "float" }, + GenericConstraints = + { + new GenericConstraint { - new GenericConstraint - { - GenericParameterName = "float", - Types = { "volatile" } - } + GenericParameterName = "float", + Types = { "volatile" } } - }); - - code.ShouldContain("class @void<@float>"); - code.ShouldContain("where @float : @volatile"); - code.ShouldContain("public @void(int @double, @if @else)"); - code.ShouldContain("Double = @double;"); - code.ShouldContain("Else = @else;"); - code.ShouldContain("where @float : @volatile"); - } + } + }); + + code.ShouldContain("class @void<@float>"); + code.ShouldContain("where @float : @volatile"); + code.ShouldContain("public @void(int @double, @if @else)"); + code.ShouldContain("Double = @double;"); + code.ShouldContain("Else = @else;"); + code.ShouldContain("where @float : @volatile"); + } - [Test] - public void should_generate_simple_enums() + [Test] + public void should_generate_simple_enums() + { + var code = Generate(new ParsedContracts { - var code = Generate(new ParsedContracts + Enums = { - Enums = + new EnumDefinition { - new EnumDefinition + Name = "Foo", + Attributes = + { + new AttributeDefinition("EnumAttr") + }, + Members = { - Name = "Foo", - Attributes = + new EnumMemberDefinition { - new AttributeDefinition("EnumAttr") + Name = "Default" }, - Members = + new EnumMemberDefinition { - new EnumMemberDefinition - { - Name = "Default" - }, - new EnumMemberDefinition - { - Name = "Bar", - Value = "-2" - }, - new EnumMemberDefinition - { - Name = "Baz" - } + Name = "Bar", + Value = "-2" + }, + new EnumMemberDefinition + { + Name = "Baz" } } } - }); + } + }); - code.ShouldContain("public enum Foo"); - } + code.ShouldContain("public enum Foo"); + } - [Test] - public void should_generate_enums() + [Test] + public void should_generate_enums() + { + var code = Generate(new ParsedContracts { - var code = Generate(new ParsedContracts + Enums = { - Enums = + new EnumDefinition { - new EnumDefinition + Name = "Foo", + UnderlyingType = "short", + Attributes = + { + new AttributeDefinition("EnumAttr") + }, + Members = { - Name = "Foo", - UnderlyingType = "short", - Attributes = + new EnumMemberDefinition { - new AttributeDefinition("EnumAttr") + Name = "Default" }, - Members = + new EnumMemberDefinition { - new EnumMemberDefinition + Name = "Bar", + Value = "-2", + Attributes = { - Name = "Default" + new AttributeDefinition("Description(\"Beer!\")") }, - new EnumMemberDefinition + }, + new EnumMemberDefinition + { + Name = "Baz", + Value = "Bar", + Attributes = { - Name = "Bar", - Value = "-2", - Attributes = - { - new AttributeDefinition("Description(\"Beer!\")") - }, + new AttributeDefinition("EnumValueAttr") }, - new EnumMemberDefinition - { - Name = "Baz", - Value = "Bar", - Attributes = - { - new AttributeDefinition("EnumValueAttr") - }, - } } } - }, - Messages = + } + }, + Messages = + { + new MessageDefinition { - new MessageDefinition - { - Name = "Test" - } + Name = "Test" } - }); - - code.ShouldContain("public enum Foo : short"); - code.ShouldContain("Default,"); - code.ShouldContain("Bar = -2,"); - code.ShouldContain("Baz = Bar"); - code.ShouldContain("[EnumAttr]"); - code.ShouldContain("[EnumValueAttr]"); - } + } + }); + + code.ShouldContain("public enum Foo : short"); + code.ShouldContain("Default,"); + code.ShouldContain("Bar = -2,"); + code.ShouldContain("Baz = Bar"); + code.ShouldContain("[EnumAttr]"); + code.ShouldContain("[EnumValueAttr]"); + } - [Test] - public void should_handle_obsolete_attribute() + [Test] + public void should_handle_obsolete_attribute() + { + var code = Generate(new MessageDefinition { - var code = Generate(new MessageDefinition + Name = "FooExecuted", + Parameters = { - Name = "FooExecuted", - Parameters = - { - new ParameterDefinition("int", "foo") - } - }); + new ParameterDefinition("int", "foo") + } + }); - code.ShouldNotContain("#pragma warning disable 612"); + code.ShouldNotContain("#pragma warning disable 612"); - code = Generate(new MessageDefinition + code = Generate(new MessageDefinition + { + Name = "FooExecuted", + Parameters = { - Name = "FooExecuted", - Parameters = + new ParameterDefinition("int", "foo") { - new ParameterDefinition("int", "foo") - { - Attributes = { new AttributeDefinition("Obsolete") } - } + Attributes = { new AttributeDefinition("Obsolete") } } - }); + } + }); - code.ShouldContain("#pragma warning disable 612"); + code.ShouldContain("#pragma warning disable 612"); - code = Generate(new MessageDefinition - { - Name = "FooExecuted", - Attributes = { new AttributeDefinition("Obsolete") } - }); + code = Generate(new MessageDefinition + { + Name = "FooExecuted", + Attributes = { new AttributeDefinition("Obsolete") } + }); - code.ShouldContain("#pragma warning disable 612"); - } + code.ShouldContain("#pragma warning disable 612"); + } - [Test] - public void should_handle_custom_contract_attribute() + [Test] + public void should_handle_custom_contract_attribute() + { + var code = Generate(new MessageDefinition { - var code = Generate(new MessageDefinition - { - Name = "FooExecuted", - Attributes = { new AttributeDefinition("ProtoContract", "EnumPassthru = true") } - }); + Name = "FooExecuted", + Attributes = { new AttributeDefinition("ProtoContract", "EnumPassthru = true") } + }); - code.ShouldContain("[ProtoContract(EnumPassthru = true)]"); - code.ShouldNotContain("[ProtoContract]"); - } + code.ShouldContain("[ProtoContract(EnumPassthru = true)]"); + code.ShouldNotContain("[ProtoContract]"); + } - [Test] - public void should_handle_custom_member_attribute() + [Test] + public void should_handle_custom_member_attribute() + { + var code = Generate(new MessageDefinition { - var code = Generate(new MessageDefinition + Name = "FooExecuted", + Parameters = { - Name = "FooExecuted", - Parameters = + new ParameterDefinition("string", "foo") { - new ParameterDefinition("string", "foo") - { - Attributes = { new AttributeDefinition("ProtoMember", "42, AsReference = true") } - } + Attributes = { new AttributeDefinition("ProtoMember", "42, AsReference = true") } } - }); + } + }); - code.ShouldContain("ProtoMember(42, AsReference = true)"); - code.ShouldNotContain("ProtoMember(1"); - } + code.ShouldContain("ProtoMember(42, AsReference = true)"); + code.ShouldNotContain("ProtoMember(1"); + } - [Test] - public void should_handle_custom_contract_attribute_on_enums() + [Test] + public void should_handle_custom_contract_attribute_on_enums() + { + var code = Generate(new ParsedContracts { - var code = Generate(new ParsedContracts + Enums = { - Enums = + new EnumDefinition { - new EnumDefinition + Name = "Foo", + Attributes = { - Name = "Foo", - Attributes = - { - new AttributeDefinition("ProtoContract", "EnumPassthru = true") - }, - Members = + new AttributeDefinition("ProtoContract", "EnumPassthru = true") + }, + Members = + { + new EnumMemberDefinition { - new EnumMemberDefinition - { - Name = "Default" - } + Name = "Default" } } } - }); + } + }); - code.ShouldContain("[ProtoContract(EnumPassthru = true)]"); - code.ShouldNotContain("[ProtoContract]"); - } + code.ShouldContain("[ProtoContract(EnumPassthru = true)]"); + code.ShouldNotContain("[ProtoContract]"); + } - [Test] - public void should_add_protomap_attribute_to_dictionaries() + [Test] + public void should_add_protomap_attribute_to_dictionaries() + { + var code = Generate(new MessageDefinition { - var code = Generate(new MessageDefinition + Name = "FooExecuted", + Parameters = { - Name = "FooExecuted", - Parameters = - { - new ParameterDefinition("Dictionary", "foo") - } - }); + new ParameterDefinition("Dictionary", "foo") + } + }); - code.ShouldContain("ProtoMap(DisableMap = true)"); - } + code.ShouldContain("ProtoMap(DisableMap = true)"); + } - [Test] - public void should_leave_supplied_protomap() + [Test] + public void should_leave_supplied_protomap() + { + var code = Generate(new MessageDefinition { - var code = Generate(new MessageDefinition + Name = "FooExecuted", + Parameters = { - Name = "FooExecuted", - Parameters = + new ParameterDefinition("Dictionary", "foo") { - new ParameterDefinition("Dictionary", "foo") + Attributes = { - Attributes = - { - new AttributeDefinition("ProtoMap", "Foo = lol") - } + new AttributeDefinition("ProtoMap", "Foo = lol") } } - }); + } + }); - code.ShouldContain("ProtoMap(Foo = lol)"); - code.ShouldNotContain("ProtoMap(DisableMap = true)"); - } + code.ShouldContain("ProtoMap(Foo = lol)"); + code.ShouldNotContain("ProtoMap(DisableMap = true)"); + } - [Test] - public void should_generate_two_classes_with_same_name_and_different_arity() + [Test] + public void should_generate_two_classes_with_same_name_and_different_arity() + { + var code = Generate(new ParsedContracts { - var code = Generate(new ParsedContracts + Messages = { - Messages = - { - new MessageDefinition { Name = "FooExecuted" }, - new MessageDefinition { Name = "FooExecuted", GenericParameters = { "T" } } - } - }); + new MessageDefinition { Name = "FooExecuted" }, + new MessageDefinition { Name = "FooExecuted", GenericParameters = { "T" } } + } + }); - code.ShouldContain("public sealed partial class FooExecuted : IEvent"); - code.ShouldContain("public sealed partial class FooExecuted : IEvent"); - } + code.ShouldContain("public sealed partial class FooExecuted : IEvent"); + code.ShouldContain("public sealed partial class FooExecuted : IEvent"); + } - [Test] - public void should_generate_internal_messages() + [Test] + public void should_generate_internal_messages() + { + var code = Generate(new MessageDefinition { - var code = Generate(new MessageDefinition - { - Name = "FooExecuted", - AccessModifier = AccessModifier.Internal - }); + Name = "FooExecuted", + AccessModifier = AccessModifier.Internal + }); - code.ShouldContain("internal sealed partial class FooExecuted : IEvent"); - } + code.ShouldContain("internal sealed partial class FooExecuted : IEvent"); + } - [Test] - public void should_generate_internal_enums() + [Test] + public void should_generate_internal_enums() + { + var code = Generate(new ParsedContracts { - var code = Generate(new ParsedContracts + Enums = { - Enums = + new EnumDefinition { - new EnumDefinition + Name = "Foo", + Members = { - Name = "Foo", - Members = + new EnumMemberDefinition { - new EnumMemberDefinition - { - Name = "Default" - } - }, - AccessModifier = AccessModifier.Internal - } + Name = "Default" + } + }, + AccessModifier = AccessModifier.Internal } - }); + } + }); - code.ShouldContain("internal enum Foo"); - } + code.ShouldContain("internal enum Foo"); + } - [Test] - public void should_generate_unsealed_messages() + [Test] + public void should_generate_unsealed_messages() + { + var code = Generate(new MessageDefinition { - var code = Generate(new MessageDefinition - { - Name = "FooExecuted", - InheritanceModifier = InheritanceModifier.None - }); + Name = "FooExecuted", + InheritanceModifier = InheritanceModifier.None + }); - code.ShouldContain("public partial class FooExecuted : IEvent"); - } + code.ShouldContain("public partial class FooExecuted : IEvent"); + } - [Test] - public void should_generate_sealed_messages() + [Test] + public void should_generate_sealed_messages() + { + var code = Generate(new MessageDefinition { - var code = Generate(new MessageDefinition - { - Name = "FooExecuted", - InheritanceModifier = InheritanceModifier.Sealed - }); + Name = "FooExecuted", + InheritanceModifier = InheritanceModifier.Sealed + }); - code.ShouldContain("public sealed partial class FooExecuted : IEvent"); - } + code.ShouldContain("public sealed partial class FooExecuted : IEvent"); + } - [Test] - public void should_generate_abstract_messages() + [Test] + public void should_generate_abstract_messages() + { + var code = Generate(new MessageDefinition { - var code = Generate(new MessageDefinition + Name = "FooExecuted", + InheritanceModifier = InheritanceModifier.Abstract, + Parameters = { - Name = "FooExecuted", - InheritanceModifier = InheritanceModifier.Abstract, - Parameters = - { - new ParameterDefinition("int", "foo"), - new ParameterDefinition("int", "bar") - } - }); + new ParameterDefinition("int", "foo"), + new ParameterDefinition("int", "bar") + } + }); - code.ShouldContain("public abstract partial class FooExecuted : IEvent"); - code.ShouldContain("protected FooExecuted("); - code.ShouldNotContain("private FooExecuted("); - code.ShouldNotContain("public FooExecuted("); - } + code.ShouldContain("public abstract partial class FooExecuted : IEvent"); + code.ShouldContain("protected FooExecuted("); + code.ShouldNotContain("private FooExecuted("); + code.ShouldNotContain("public FooExecuted("); + } - [Test] - public void should_handle_nullable_reference_types() + [Test] + public void should_handle_nullable_reference_types() + { + var code = Generate(new ParsedContracts { - var code = Generate(new ParsedContracts + Messages = { - Messages = - { - new MessageDefinition { Name = "FooMessage" }, - new MessageDefinition { Name = "BarMessage", Options = { Nullable = true } }, - new MessageDefinition { Name = "BazMessage" } - } - }); + new MessageDefinition { Name = "FooMessage" }, + new MessageDefinition { Name = "BarMessage", Options = { Nullable = true } }, + new MessageDefinition { Name = "BazMessage" } + } + }); - var fooIndex = code.IndexOf("FooMessage", StringComparison.Ordinal); - var barIndex = code.IndexOf("BarMessage", StringComparison.Ordinal); - var bazIndex = code.IndexOf("BazMessage", StringComparison.Ordinal); + var fooIndex = code.IndexOf("FooMessage", StringComparison.Ordinal); + var barIndex = code.IndexOf("BarMessage", StringComparison.Ordinal); + var bazIndex = code.IndexOf("BazMessage", StringComparison.Ordinal); - var nullableEnableIndex = code.IndexOf("#nullable enable", StringComparison.Ordinal); - var nullableDisableIndex = code.IndexOf("#nullable disable", StringComparison.Ordinal); + var nullableEnableIndex = code.IndexOf("#nullable enable", StringComparison.Ordinal); + var nullableDisableIndex = code.IndexOf("#nullable disable", StringComparison.Ordinal); - foreach (var index in new[] { fooIndex, barIndex, bazIndex, nullableEnableIndex, nullableDisableIndex }) - index.ShouldBeGreaterThan(0); - - nullableEnableIndex.ShouldBeBetween(fooIndex, barIndex); - nullableDisableIndex.ShouldBeBetween(barIndex, bazIndex); - } - - [Test] - public void should_generate_initializers_for_nullable_reference_types() - { - var code = Generate(new MessageDefinition - { - Name = "FooExecuted", - Parameters = - { - new ParameterDefinition("string", "strNotNull"), - new ParameterDefinition("string?", "strNull"), - new ParameterDefinition("int[]", "arrayNotNull"), - new ParameterDefinition("int[]?", "arrayNull") - }, - Options = { Nullable = true } - }); + foreach (var index in new[] { fooIndex, barIndex, bazIndex, nullableEnableIndex, nullableDisableIndex }) + index.ShouldBeGreaterThan(0); - code.ShouldContain("StrNotNull = default!;"); - code.ShouldNotContain("StrNull = default!;"); - code.ShouldContain("ArrayNotNull = Array.Empty();"); - code.ShouldNotContain("ArrayNull = Array.Empty();"); - } + nullableEnableIndex.ShouldBeBetween(fooIndex, barIndex); + nullableDisableIndex.ShouldBeBetween(barIndex, bazIndex); + } - [Test] - public void should_not_generate_initializers_for_known_nullable_value_types() - { - var code = Generate(new MessageDefinition - { - Name = "FooExecuted", - Parameters = - { - new ParameterDefinition("int", "intNotNull"), - new ParameterDefinition("int?", "intNull"), - new ParameterDefinition("int[]", "arrayNotNull"), - new ParameterDefinition("int[]?", "arrayNull") - }, - Options = { Nullable = true } - }); + [Test] + public void should_generate_initializers_for_nullable_reference_types() + { + var code = Generate(new MessageDefinition + { + Name = "FooExecuted", + Parameters = + { + new ParameterDefinition("string", "strNotNull"), + new ParameterDefinition("string?", "strNull"), + new ParameterDefinition("int[]", "arrayNotNull"), + new ParameterDefinition("int[]?", "arrayNull") + }, + Options = { Nullable = true } + }); + + code.ShouldContain("StrNotNull = default!;"); + code.ShouldNotContain("StrNull = default!;"); + code.ShouldContain("ArrayNotNull = Array.Empty();"); + code.ShouldNotContain("ArrayNull = Array.Empty();"); + } - code.ShouldNotContain("IntNotNull = default!;"); - code.ShouldNotContain("IntNull = default!;"); - code.ShouldContain("ArrayNotNull = Array.Empty();"); - code.ShouldNotContain("ArrayNull = Array.Empty();"); - } + [Test] + public void should_not_generate_initializers_for_known_nullable_value_types() + { + var code = Generate(new MessageDefinition + { + Name = "FooExecuted", + Parameters = + { + new ParameterDefinition("int", "intNotNull"), + new ParameterDefinition("int?", "intNull"), + new ParameterDefinition("int[]", "arrayNotNull"), + new ParameterDefinition("int[]?", "arrayNull") + }, + Options = { Nullable = true } + }); + + code.ShouldNotContain("IntNotNull = default!;"); + code.ShouldNotContain("IntNull = default!;"); + code.ShouldContain("ArrayNotNull = Array.Empty();"); + code.ShouldNotContain("ArrayNull = Array.Empty();"); + } - [Test] - public void should_not_reorder_base_types() + [Test] + public void should_not_reorder_base_types() + { + var code = Generate(new MessageDefinition { - var code = Generate(new MessageDefinition - { - Name = "FooExecuted", - BaseTypes = { "BType", "AType" } - }); + Name = "FooExecuted", + BaseTypes = { "BType", "AType" } + }); - code.ShouldContain("FooExecuted : BType, AType"); - } + code.ShouldContain("FooExecuted : BType, AType"); + } - [Test] - public void should_forward_base_type_parameters() + [Test] + public void should_forward_base_type_parameters() + { + var code = Generate(new ParsedContracts { - var code = Generate(new ParsedContracts + Messages = { - Messages = + new MessageDefinition { - new MessageDefinition + Name = "FooMessage", + BaseTypes = { "BarMessage" }, + Parameters = { - Name = "FooMessage", - BaseTypes = { "BarMessage" }, - Parameters = - { - new ParameterDefinition("int", "fooA"), - new ParameterDefinition("int", "barA") { DefaultValue = "10" } - } - }, - new MessageDefinition + new ParameterDefinition("int", "fooA"), + new ParameterDefinition("int", "barA") { DefaultValue = "10" } + } + }, + new MessageDefinition + { + Name = "BarMessage", + BaseTypes = { "BazMessage" }, + InheritanceModifier = InheritanceModifier.Abstract, + Parameters = { - Name = "BarMessage", - BaseTypes = { "BazMessage" }, - InheritanceModifier = InheritanceModifier.Abstract, - Parameters = - { - new ParameterDefinition("int", "fooB") { DefaultValue = "20" }, - new ParameterDefinition("int", "barB") { DefaultValue = "21" } - } - }, - new MessageDefinition + new ParameterDefinition("int", "fooB") { DefaultValue = "20" }, + new ParameterDefinition("int", "barB") { DefaultValue = "21" } + } + }, + new MessageDefinition + { + Name = "BazMessage", + InheritanceModifier = InheritanceModifier.Abstract, + Parameters = { - Name = "BazMessage", - InheritanceModifier = InheritanceModifier.Abstract, - Parameters = - { - new ParameterDefinition("int", "fooC"), - new ParameterDefinition("int", "barC") { DefaultValue = "30" } - } + new ParameterDefinition("int", "fooC"), + new ParameterDefinition("int", "barC") { DefaultValue = "30" } } } - }); + } + }); - code.ShouldContain("FooMessage(int fooC, int barC, int fooB, int barB, int fooA, int barA = 10)"); - code.ShouldContain(": base(fooC, barC, fooB, barB)"); + code.ShouldContain("FooMessage(int fooC, int barC, int fooB, int barB, int fooA, int barA = 10)"); + code.ShouldContain(": base(fooC, barC, fooB, barB)"); - code.ShouldContain("BarMessage(int fooC, int barC = 30, int fooB = 20, int barB = 21)"); - code.ShouldContain(": base(fooC, barC)"); + code.ShouldContain("BarMessage(int fooC, int barC = 30, int fooB = 20, int barB = 21)"); + code.ShouldContain(": base(fooC, barC)"); - code.ShouldContain("BazMessage(int fooC, int barC = 30)"); - } + code.ShouldContain("BazMessage(int fooC, int barC = 30)"); + } - [Test] - public void should_not_forward_base_type_parameters_for_mutable_base_types_1() + [Test] + public void should_not_forward_base_type_parameters_for_mutable_base_types_1() + { + var code = Generate(new ParsedContracts { - var code = Generate(new ParsedContracts + Messages = { - Messages = + new MessageDefinition { - new MessageDefinition + Name = "FooMessage", + BaseTypes = { "BarMessage" }, + Parameters = { - Name = "FooMessage", - BaseTypes = { "BarMessage" }, - Parameters = - { - new ParameterDefinition("int", "fooA") - } - }, - new MessageDefinition + new ParameterDefinition("int", "fooA") + } + }, + new MessageDefinition + { + Name = "BarMessage", + BaseTypes = { "BazMessage" }, + InheritanceModifier = InheritanceModifier.Abstract, + Options = { Mutable = true }, + Parameters = { - Name = "BarMessage", - BaseTypes = { "BazMessage" }, - InheritanceModifier = InheritanceModifier.Abstract, - Options = { Mutable = true }, - Parameters = - { - new ParameterDefinition("int", "fooB") - } - }, - new MessageDefinition + new ParameterDefinition("int", "fooB") + } + }, + new MessageDefinition + { + Name = "BazMessage", + InheritanceModifier = InheritanceModifier.Abstract, + Options = { Mutable = true }, + Parameters = { - Name = "BazMessage", - InheritanceModifier = InheritanceModifier.Abstract, - Options = { Mutable = true }, - Parameters = - { - new ParameterDefinition("int", "fooC") - } + new ParameterDefinition("int", "fooC") } } - }); + } + }); - code.ShouldContain("FooMessage(int fooA)"); - code.ShouldContain("BarMessage(int fooB)"); - code.ShouldContain("BazMessage(int fooC)"); - code.ShouldNotContain("base("); - } + code.ShouldContain("FooMessage(int fooA)"); + code.ShouldContain("BarMessage(int fooB)"); + code.ShouldContain("BazMessage(int fooC)"); + code.ShouldNotContain("base("); + } - [Test] - public void should_not_forward_base_type_parameters_for_mutable_base_types_2() + [Test] + public void should_not_forward_base_type_parameters_for_mutable_base_types_2() + { + var code = Generate(new ParsedContracts { - var code = Generate(new ParsedContracts + Messages = { - Messages = + new MessageDefinition { - new MessageDefinition + Name = "FooMessage", + BaseTypes = { "BarMessage" }, + Parameters = { - Name = "FooMessage", - BaseTypes = { "BarMessage" }, - Parameters = - { - new ParameterDefinition("int", "fooA") - } - }, - new MessageDefinition + new ParameterDefinition("int", "fooA") + } + }, + new MessageDefinition + { + Name = "BarMessage", + BaseTypes = { "BazMessage" }, + InheritanceModifier = InheritanceModifier.Abstract, + Options = { Mutable = true }, + Parameters = { - Name = "BarMessage", - BaseTypes = { "BazMessage" }, - InheritanceModifier = InheritanceModifier.Abstract, - Options = { Mutable = true }, - Parameters = - { - new ParameterDefinition("int", "fooB") - } - }, - new MessageDefinition + new ParameterDefinition("int", "fooB") + } + }, + new MessageDefinition + { + Name = "BazMessage", + InheritanceModifier = InheritanceModifier.Abstract, + Parameters = { - Name = "BazMessage", - InheritanceModifier = InheritanceModifier.Abstract, - Parameters = - { - new ParameterDefinition("int", "fooC") - } + new ParameterDefinition("int", "fooC") } } - }); + } + }); - code.ShouldContain("public FooMessage(int fooC, int fooA)"); - code.ShouldNotContain("protected FooMessage"); - code.ShouldContain(": base(fooC)"); + code.ShouldContain("public FooMessage(int fooC, int fooA)"); + code.ShouldNotContain("protected FooMessage"); + code.ShouldContain(": base(fooC)"); - code.ShouldContain("protected BarMessage(int fooC, int fooB)"); - code.ShouldContain("protected BarMessage(int fooC)"); + code.ShouldContain("protected BarMessage(int fooC, int fooB)"); + code.ShouldContain("protected BarMessage(int fooC)"); - code.ShouldContain("protected BazMessage(int fooC)"); - } + code.ShouldContain("protected BazMessage(int fooC)"); + } - [Test] - public void should_not_forward_base_type_parameters_for_mutable_base_types_3() + [Test] + public void should_not_forward_base_type_parameters_for_mutable_base_types_3() + { + var code = Generate(new ParsedContracts { - var code = Generate(new ParsedContracts + Messages = { - Messages = + new MessageDefinition { - new MessageDefinition + Name = "FooMessage", + BaseTypes = { "BarMessage" }, + Parameters = { - Name = "FooMessage", - BaseTypes = { "BarMessage" }, - Parameters = - { - new ParameterDefinition("int", "fooA") - } - }, - new MessageDefinition + new ParameterDefinition("int", "fooA") + } + }, + new MessageDefinition + { + Name = "BarMessage", + BaseTypes = { "BazMessage" }, + InheritanceModifier = InheritanceModifier.Abstract, + Parameters = { - Name = "BarMessage", - BaseTypes = { "BazMessage" }, - InheritanceModifier = InheritanceModifier.Abstract, - Parameters = - { - new ParameterDefinition("int", "fooB") - } - }, - new MessageDefinition + new ParameterDefinition("int", "fooB") + } + }, + new MessageDefinition + { + Name = "BazMessage", + InheritanceModifier = InheritanceModifier.Abstract, + Options = { Mutable = true }, + Parameters = { - Name = "BazMessage", - InheritanceModifier = InheritanceModifier.Abstract, - Options = { Mutable = true }, - Parameters = - { - new ParameterDefinition("int", "fooC") - } + new ParameterDefinition("int", "fooC") } } - }); - - code.ShouldContain("FooMessage(int fooB, int fooA)"); - code.ShouldContain(": base(fooB)"); + } + }); - code.ShouldContain("BarMessage(int fooB)"); - code.ShouldContain("BazMessage(int fooC)"); - } + code.ShouldContain("FooMessage(int fooB, int fooA)"); + code.ShouldContain(": base(fooB)"); - [Test] - public void should_handle_nested_classes() - { - var code = Generate(new MessageDefinition - { - Name = "Baz", - ContainingClasses = { "Foo", "Bar" } - }); - - code.ShouldContainIgnoreIndent("partial class Foo\n{\npartial class Bar\n{\n[ProtoContract]"); - } + code.ShouldContain("BarMessage(int fooB)"); + code.ShouldContain("BazMessage(int fooC)"); + } - [Test] - public void should_coalesce_array_to_non_null() + [Test] + public void should_handle_nested_classes() + { + var code = Generate(new MessageDefinition { - var code = Generate(new MessageDefinition - { - Name = "Foo", - Parameters = - { - new ParameterDefinition("int[]", "bar"), - new ParameterDefinition("int?[]", "bar2"), - new ParameterDefinition("string[]?", "baz"), - new ParameterDefinition("string?[]", "baz2") - } - }); + Name = "Baz", + ContainingClasses = { "Foo", "Bar" } + }); - code.ShouldContain("Bar = bar ?? Array.Empty();"); - code.ShouldContain("Bar2 = bar2 ?? Array.Empty();"); - code.ShouldNotContain("Array.Empty();"); - code.ShouldContain("Baz = baz;"); - code.ShouldContain("Baz2 = baz2 ?? Array.Empty();"); - } + code.ShouldContainIgnoreIndent("partial class Foo\n{\npartial class Bar\n{\n[ProtoContract]"); + } - protected override string GenerateRaw(ParsedContracts contracts) => CSharpGenerator.Generate(contracts); + [Test] + public void should_coalesce_array_to_non_null() + { + var code = Generate(new MessageDefinition + { + Name = "Foo", + Parameters = + { + new ParameterDefinition("int[]", "bar"), + new ParameterDefinition("int?[]", "bar2"), + new ParameterDefinition("string[]?", "baz"), + new ParameterDefinition("string?[]", "baz2") + } + }); + + code.ShouldContain("Bar = bar ?? Array.Empty();"); + code.ShouldContain("Bar2 = bar2 ?? Array.Empty();"); + code.ShouldNotContain("Array.Empty();"); + code.ShouldContain("Baz = baz;"); + code.ShouldContain("Baz2 = baz2 ?? Array.Empty();"); } + + protected override string GenerateRaw(ParsedContracts contracts) + => CSharpGenerator.Generate(contracts); } diff --git a/src/Abc.Zebus.MessageDsl.Tests/MessageDsl/CSharpSyntaxTests.cs b/src/Abc.Zebus.MessageDsl.Tests/MessageDsl/CSharpSyntaxTests.cs index 5abd2f2..9659f5b 100644 --- a/src/Abc.Zebus.MessageDsl.Tests/MessageDsl/CSharpSyntaxTests.cs +++ b/src/Abc.Zebus.MessageDsl.Tests/MessageDsl/CSharpSyntaxTests.cs @@ -1,37 +1,38 @@ using Abc.Zebus.MessageDsl.Generator; using NUnit.Framework; -namespace Abc.Zebus.MessageDsl.Tests.MessageDsl +namespace Abc.Zebus.MessageDsl.Tests.MessageDsl; + +[TestFixture] +public class CSharpSyntaxTests { - [TestFixture] - public class CSharpSyntaxTests - { - [Test] - [TestCase(null, ExpectedResult = false)] - [TestCase("", ExpectedResult = false)] - [TestCase("foo", ExpectedResult = true)] - [TestCase("int", ExpectedResult = false)] - [TestCase("_int", ExpectedResult = true)] - [TestCase("@int", ExpectedResult = true)] - [TestCase("42lol", ExpectedResult = false)] - [TestCase("lol_42", ExpectedResult = true)] - [TestCase("lol.foo", ExpectedResult = false)] - [TestCase("l\\u006f", ExpectedResult = true)] - [TestCase("l\\U0000006f", ExpectedResult = true)] - [TestCase("l\\U42", ExpectedResult = false)] - [TestCase("in\\u0074", ExpectedResult = false)] - public bool should_validate_identifiers(string value) => CSharpSyntax.IsValidIdentifier(value); + [Test] + [TestCase(null, ExpectedResult = false)] + [TestCase("", ExpectedResult = false)] + [TestCase("foo", ExpectedResult = true)] + [TestCase("int", ExpectedResult = false)] + [TestCase("_int", ExpectedResult = true)] + [TestCase("@int", ExpectedResult = true)] + [TestCase("42lol", ExpectedResult = false)] + [TestCase("lol_42", ExpectedResult = true)] + [TestCase("lol.foo", ExpectedResult = false)] + [TestCase("l\\u006f", ExpectedResult = true)] + [TestCase("l\\U0000006f", ExpectedResult = true)] + [TestCase("l\\U42", ExpectedResult = false)] + [TestCase("in\\u0074", ExpectedResult = false)] + public bool should_validate_identifiers(string value) + => CSharpSyntax.IsValidIdentifier(value); - [Test] - [TestCase(null, ExpectedResult = false)] - [TestCase("", ExpectedResult = false)] - [TestCase("foo", ExpectedResult = true)] - [TestCase("foo.foo", ExpectedResult = true)] - [TestCase("foo.int", ExpectedResult = false)] - [TestCase("foo.@int", ExpectedResult = true)] - [TestCase("foo..foo", ExpectedResult = false)] - [TestCase(".foo.foo", ExpectedResult = false)] - [TestCase("foo.foo.", ExpectedResult = false)] - public bool should_validate_namespaces(string value) => CSharpSyntax.IsValidNamespace(value); - } + [Test] + [TestCase(null, ExpectedResult = false)] + [TestCase("", ExpectedResult = false)] + [TestCase("foo", ExpectedResult = true)] + [TestCase("foo.foo", ExpectedResult = true)] + [TestCase("foo.int", ExpectedResult = false)] + [TestCase("foo.@int", ExpectedResult = true)] + [TestCase("foo..foo", ExpectedResult = false)] + [TestCase(".foo.foo", ExpectedResult = false)] + [TestCase("foo.foo.", ExpectedResult = false)] + public bool should_validate_namespaces(string value) + => CSharpSyntax.IsValidNamespace(value); } diff --git a/src/Abc.Zebus.MessageDsl.Tests/MessageDsl/GeneratorTests.cs b/src/Abc.Zebus.MessageDsl.Tests/MessageDsl/GeneratorTests.cs index b2410b1..5b2375d 100644 --- a/src/Abc.Zebus.MessageDsl.Tests/MessageDsl/GeneratorTests.cs +++ b/src/Abc.Zebus.MessageDsl.Tests/MessageDsl/GeneratorTests.cs @@ -2,33 +2,32 @@ using Abc.Zebus.MessageDsl.Ast; using Abc.Zebus.MessageDsl.Tests.TestTools; -namespace Abc.Zebus.MessageDsl.Tests.MessageDsl +namespace Abc.Zebus.MessageDsl.Tests.MessageDsl; + +public abstract class GeneratorTests { - public abstract class GeneratorTests - { - protected abstract string GenerateRaw(ParsedContracts contracts); + protected abstract string GenerateRaw(ParsedContracts contracts); - protected string Generate(MessageDefinition message) - { - var contracts = new ParsedContracts(); - contracts.Messages.Add(message); - return Generate(contracts); - } + protected string Generate(MessageDefinition message) + { + var contracts = new ParsedContracts(); + contracts.Messages.Add(message); + return Generate(contracts); + } - protected string Generate(ParsedContracts contracts) - { - if (!contracts.ImportedNamespaces.Contains("System")) - contracts.Process(); + protected string Generate(ParsedContracts contracts) + { + if (!contracts.ImportedNamespaces.Contains("System")) + contracts.Process(); - var result = GenerateRaw(contracts); + var result = GenerateRaw(contracts); - Console.WriteLine("----- START -----"); - Console.WriteLine(result); - Console.WriteLine("----- END -----"); + Console.WriteLine("----- START -----"); + Console.WriteLine(result); + Console.WriteLine("----- END -----"); - contracts.Errors.ShouldBeEmpty(); + contracts.Errors.ShouldBeEmpty(); - return result; - } + return result; } } diff --git a/src/Abc.Zebus.MessageDsl.Tests/MessageDsl/ParsedContractsTests.cs b/src/Abc.Zebus.MessageDsl.Tests/MessageDsl/ParsedContractsTests.cs index 7ca5822..e59a8e9 100644 --- a/src/Abc.Zebus.MessageDsl.Tests/MessageDsl/ParsedContractsTests.cs +++ b/src/Abc.Zebus.MessageDsl.Tests/MessageDsl/ParsedContractsTests.cs @@ -6,732 +6,733 @@ using Abc.Zebus.MessageDsl.Tests.TestTools; using NUnit.Framework; -namespace Abc.Zebus.MessageDsl.Tests.MessageDsl +namespace Abc.Zebus.MessageDsl.Tests.MessageDsl; + +[TestFixture] +[SuppressMessage("ReSharper", "RedundantCast")] +public class ParsedContractsTests { - [TestFixture] - [SuppressMessage("ReSharper", "RedundantCast")] - public class ParsedContractsTests + [Test] + public void should_parse_simple_contracts() { - [Test] - public void should_parse_simple_contracts() - { - var contracts = ParseValid(@"FooCommand(int id); FooExecuted(int id, bool success = true);"); - - contracts.Errors.ShouldBeEmpty(); - contracts.Namespace.ShouldEqual("Some.Namespace"); - contracts.ImportedNamespaces.ShouldContain("System"); - contracts.ImportedNamespaces.ShouldContain("ProtoBuf"); - contracts.ImportedNamespaces.ShouldContain("Abc.Zebus"); - - contracts.Messages.Count.ShouldEqual(2); - - var cmd = contracts.Messages[0]; - var evt = contracts.Messages[1]; - - cmd.Name.ShouldEqual("FooCommand"); - cmd.IsCustom.ShouldBeFalse(); - cmd.Type.ShouldEqual(MessageType.Command); - cmd.Parameters.Count.ShouldEqual(1); - cmd.Parameters[0].Name.ShouldEqual("id"); - cmd.Parameters[0].Type.ShouldEqual(new TypeName("int")); - - evt.Name.ShouldEqual("FooExecuted"); - evt.IsCustom.ShouldBeFalse(); - evt.Type.ShouldEqual(MessageType.Event); - evt.Parameters.Count.ShouldEqual(2); - evt.Parameters[0].Name.ShouldEqual("id"); - evt.Parameters[0].Type.ShouldEqual(new TypeName("int")); - evt.Parameters[1].Name.ShouldEqual("success"); - evt.Parameters[1].Type.ShouldEqual(new TypeName("bool")); - evt.Parameters[1].DefaultValue.ShouldEqual("true"); - } + var contracts = ParseValid("FooCommand(int id); FooExecuted(int id, bool success = true);"); + + contracts.Errors.ShouldBeEmpty(); + contracts.Namespace.ShouldEqual("Some.Namespace"); + contracts.ImportedNamespaces.ShouldContain("System"); + contracts.ImportedNamespaces.ShouldContain("ProtoBuf"); + contracts.ImportedNamespaces.ShouldContain("Abc.Zebus"); + + contracts.Messages.Count.ShouldEqual(2); + + var cmd = contracts.Messages[0]; + var evt = contracts.Messages[1]; + + cmd.Name.ShouldEqual("FooCommand"); + cmd.IsCustom.ShouldBeFalse(); + cmd.Type.ShouldEqual(MessageType.Command); + cmd.Parameters.Count.ShouldEqual(1); + cmd.Parameters[0].Name.ShouldEqual("id"); + cmd.Parameters[0].Type.ShouldEqual(new TypeName("int")); + + evt.Name.ShouldEqual("FooExecuted"); + evt.IsCustom.ShouldBeFalse(); + evt.Type.ShouldEqual(MessageType.Event); + evt.Parameters.Count.ShouldEqual(2); + evt.Parameters[0].Name.ShouldEqual("id"); + evt.Parameters[0].Type.ShouldEqual(new TypeName("int")); + evt.Parameters[1].Name.ShouldEqual("success"); + evt.Parameters[1].Type.ShouldEqual(new TypeName("bool")); + evt.Parameters[1].DefaultValue.ShouldEqual("true"); + } - [Test] - public void should_detect_duplicate_names() - { - ParseInvalid("Foo(); Foo();"); - ParseInvalid("Foo(); enum Foo();"); - ParseInvalid("enum Foo(); enum Foo();"); - } + [Test] + public void should_detect_duplicate_names() + { + ParseInvalid("Foo(); Foo();"); + ParseInvalid("Foo(); enum Foo();"); + ParseInvalid("enum Foo(); enum Foo();"); + } - [Test] - public void should_handle_separators() - { - ParseValid("Foo()"); - ParseValid("Foo();"); - ParseValid(";;;;"); - ParseValid("Foo();Bar();"); - ParseValid("Foo()\r\nBar()"); - ParseValid("\r\nFoo()\r\n\r\nBar()\r\n"); - ParseInvalid("Foo() Bar()"); - } + [Test] + public void should_handle_separators() + { + ParseValid("Foo()"); + ParseValid("Foo();"); + ParseValid(";;;;"); + ParseValid("Foo();Bar();"); + ParseValid("Foo()\r\nBar()"); + ParseValid("\r\nFoo()\r\n\r\nBar()\r\n"); + ParseInvalid("Foo() Bar()"); + } - [Test] - public void should_handle_usings() - { - var contracts = ParseValid("using Foo.Bar; Foo();"); - contracts.ImportedNamespaces.ShouldContain("Foo.Bar"); - } + [Test] + public void should_handle_usings() + { + var contracts = ParseValid("using Foo.Bar; Foo();"); + contracts.ImportedNamespaces.ShouldContain("Foo.Bar"); + } - [Test] - public void should_disallow_usings_after_messages() - { - var contracts = ParseInvalid("Foo(); using Foo.Bar;"); - ShouldContainError(contracts, "top of the file"); - } + [Test] + public void should_disallow_usings_after_messages() + { + var contracts = ParseInvalid("Foo(); using Foo.Bar;"); + ShouldContainError(contracts, "top of the file"); + } - [Test] - public void should_disallow_usings_after_enums() - { - var contracts = ParseInvalid("enum Foo { Bar }; using Foo.Bar;"); - ShouldContainError(contracts, "top of the file"); - } + [Test] + public void should_disallow_usings_after_enums() + { + var contracts = ParseInvalid("enum Foo { Bar }; using Foo.Bar;"); + ShouldContainError(contracts, "top of the file"); + } - [Test] - public void should_handle_explicit_namespace() - { - var contracts = ParseValid("namespace Foo.Bar; Foo();"); - contracts.ExplicitNamespace.ShouldBeTrue(); - contracts.Namespace.ShouldEqual("Foo.Bar"); - } + [Test] + public void should_handle_explicit_namespace() + { + var contracts = ParseValid("namespace Foo.Bar; Foo();"); + contracts.ExplicitNamespace.ShouldBeTrue(); + contracts.Namespace.ShouldEqual("Foo.Bar"); + } - [Test] - public void should_handle_default_namespace() - { - var contracts = ParseValid("Foo();"); - contracts.ExplicitNamespace.ShouldBeFalse(); - contracts.Namespace.ShouldEqual("Some.Namespace"); - } + [Test] + public void should_handle_default_namespace() + { + var contracts = ParseValid("Foo();"); + contracts.ExplicitNamespace.ShouldBeFalse(); + contracts.Namespace.ShouldEqual("Some.Namespace"); + } - [Test] - public void should_error_on_missing_namespace() - { - var contracts = ParseInvalid("namespace;"); - ShouldContainError(contracts, "Mismatched input ';'"); - } + [Test] + public void should_error_on_missing_namespace() + { + var contracts = ParseInvalid("namespace;"); + ShouldContainError(contracts, "Mismatched input ';'"); + } - [Test] - public void should_disallow_namespace_clause_after_messages() - { - var contracts = ParseInvalid("Foo(); namespace Foo.Bar;"); - ShouldContainError(contracts, "top of the file"); - } + [Test] + public void should_disallow_namespace_clause_after_messages() + { + var contracts = ParseInvalid("Foo(); namespace Foo.Bar;"); + ShouldContainError(contracts, "top of the file"); + } - [Test] - public void should_disallow_namespace_clause_after_enums() - { - var contracts = ParseInvalid("enum Foo { Bar }; namespace Foo.Bar;"); - ShouldContainError(contracts, "top of the file"); - } + [Test] + public void should_disallow_namespace_clause_after_enums() + { + var contracts = ParseInvalid("enum Foo { Bar }; namespace Foo.Bar;"); + ShouldContainError(contracts, "top of the file"); + } - [Test] - public void should_handle_attributes() - { - var contracts = ParseValid(@"[Transient, System.ObsoleteAttribute(""No good"")] FooExecuted([LolAttribute] int id);"); - - var msg = contracts.Messages.ExpectedSingle(); - msg.Attributes.Count.ShouldEqual(2); - msg.Attributes[0].TypeName.ShouldEqual(new TypeName("Transient")); - msg.Attributes[1].TypeName.ShouldEqual(new TypeName("Obsolete")); - msg.Attributes[1].Parameters.ShouldEqual("\"No good\""); - msg.Parameters[0].Attributes.Count.ShouldEqual(1); - msg.Parameters[0].Attributes[0].TypeName.ShouldEqual(new TypeName("Lol")); - } + [Test] + public void should_handle_attributes() + { + var contracts = ParseValid("""[Transient, System.ObsoleteAttribute("No good")] FooExecuted([LolAttribute] int id);"""); + + var msg = contracts.Messages.ExpectedSingle(); + msg.Attributes.Count.ShouldEqual(2); + msg.Attributes[0].TypeName.ShouldEqual(new TypeName("Transient")); + msg.Attributes[1].TypeName.ShouldEqual(new TypeName("Obsolete")); + msg.Attributes[1].Parameters.ShouldEqual("\"No good\""); + msg.Parameters[0].Attributes.Count.ShouldEqual(1); + msg.Parameters[0].Attributes[0].TypeName.ShouldEqual(new TypeName("Lol")); + } - [Test] - public void should_set_message_as_transient() - { - var contracts = ParseValid(@"[Transient] FooExecuted(int id);"); + [Test] + public void should_set_message_as_transient() + { + var contracts = ParseValid("[Transient] FooExecuted(int id);"); - var msg = contracts.Messages.ExpectedSingle(); - msg.Attributes.Count.ShouldEqual(1); - msg.Attributes[0].TypeName.ShouldEqual(new TypeName("Transient")); - msg.IsTransient.ShouldBeTrue(); - } + var msg = contracts.Messages.ExpectedSingle(); + msg.Attributes.Count.ShouldEqual(1); + msg.Attributes[0].TypeName.ShouldEqual(new TypeName("Transient")); + msg.IsTransient.ShouldBeTrue(); + } - [Test] - public void should_set_message_as_routable() - { - var contracts = ParseValid(@"[Routable] FooExecuted([RoutingPosition(1)] int id, [RoutingPosition(2)] int id2);"); + [Test] + public void should_set_message_as_routable() + { + var contracts = ParseValid("[Routable] FooExecuted([RoutingPosition(1)] int id, [RoutingPosition(2)] int id2);"); - var msg = contracts.Messages.ExpectedSingle(); - msg.IsRoutable.ShouldBeTrue(); - msg.Parameters[0].RoutingPosition.ShouldEqual(1); - msg.Parameters[1].RoutingPosition.ShouldEqual(2); - } + var msg = contracts.Messages.ExpectedSingle(); + msg.IsRoutable.ShouldBeTrue(); + msg.Parameters[0].RoutingPosition.ShouldEqual(1); + msg.Parameters[1].RoutingPosition.ShouldEqual(2); + } - [Test] - public void should_detect_duplicated_routing_position() - { - ParseInvalid(@"[Routable] FooExecuted([RoutingPosition(1)] int id, [RoutingPosition(1)] int id2);"); - } + [Test] + public void should_detect_duplicated_routing_position() + { + ParseInvalid("[Routable] FooExecuted([RoutingPosition(1)] int id, [RoutingPosition(1)] int id2);"); + } - [Test] - public void should_detect_zero_based_routing_position() - { - ParseInvalid(@"[Routable] FooExecuted([RoutingPosition(0)] int id, [RoutingPosition(1)] int id2);"); - } + [Test] + public void should_detect_zero_based_routing_position() + { + ParseInvalid("[Routable] FooExecuted([RoutingPosition(0)] int id, [RoutingPosition(1)] int id2);"); + } - [Test] - public void should_detect_multiple_routing_position() - { - ParseInvalid(@"[Routable] FooExecuted([RoutingPosition(1), RoutingPosition(2)] int id, int id2);"); - } + [Test] + public void should_detect_multiple_routing_position() + { + ParseInvalid("[Routable] FooExecuted([RoutingPosition(1), RoutingPosition(2)] int id, int id2);"); + } - [Test] - public void should_detect_bad_routing_position() - { - ParseInvalid(@"[Routable] FooExecuted([RoutingPosition(""pouet"")] int id, int id2);"); - } + [Test] + public void should_detect_bad_routing_position() + { + ParseInvalid("""[Routable] FooExecuted([RoutingPosition("pouet")] int id, int id2);"""); + } - [Test] - public void should_detect_non_consecutive_routing_position() - { - ParseInvalid(@"[Routable] FooExecuted([RoutingPosition(1)] int id, [RoutingPosition(3)] int id2);"); - } + [Test] + public void should_detect_non_consecutive_routing_position() + { + ParseInvalid("[Routable] FooExecuted([RoutingPosition(1)] int id, [RoutingPosition(3)] int id2);"); + } - [Test] - public void should_detect_routing_positions_on_non_routable_messages() - { - ParseInvalid(@"FooExecuted([RoutingPosition(1)] int id, int id2);"); - } + [Test] + public void should_detect_routing_positions_on_non_routable_messages() + { + ParseInvalid("FooExecuted([RoutingPosition(1)] int id, int id2);"); + } - [Test] - public void should_handle_explicit_tags() - { - var contracts = ParseValid(@"FooExecuted([42] int id, int other);"); + [Test] + public void should_handle_explicit_tags() + { + var contracts = ParseValid("FooExecuted([42] int id, int other);"); - var msg = contracts.Messages.ExpectedSingle(); - msg.Parameters[0].Attributes.ShouldBeEmpty(); - msg.Parameters[0].Tag.ShouldEqual(42); - msg.Parameters[1].Tag.ShouldEqual(43); - } + var msg = contracts.Messages.ExpectedSingle(); + msg.Parameters[0].Attributes.ShouldBeEmpty(); + msg.Parameters[0].Tag.ShouldEqual(42); + msg.Parameters[1].Tag.ShouldEqual(43); + } - [Test] - public void should_detect_invalid_tags() - { - ParseInvalid(@"FooExecuted([42] int a, int b, [42] int c);"); - ParseInvalid(@"FooExecuted([42] int a, int b, [ProtoMember(42)] int c);"); - ParseInvalid(@"FooExecuted([42] int a, int b, int c, [43] int d);"); - ParseInvalid(@"FooExecuted([42, 43] int a);"); - ParseInvalid(@"FooExecuted([42.10] int a);"); - ParseInvalid(@"FooExecuted([-42] int a);"); - ParseInvalid(@"FooExecuted([0] int a);"); - ParseInvalid(@"FooExecuted([19500] int a);"); - } + [Test] + public void should_detect_invalid_tags() + { + ParseInvalid("FooExecuted([42] int a, int b, [42] int c);"); + ParseInvalid("FooExecuted([42] int a, int b, [ProtoMember(42)] int c);"); + ParseInvalid("FooExecuted([42] int a, int b, int c, [43] int d);"); + ParseInvalid("FooExecuted([42, 43] int a);"); + ParseInvalid("FooExecuted([42.10] int a);"); + ParseInvalid("FooExecuted([-42] int a);"); + ParseInvalid("FooExecuted([0] int a);"); + ParseInvalid("FooExecuted([19500] int a);"); + } - [Test] - public void should_detect_invalid_parameters() - { - ParseInvalid(@"FooExecuted(int a, int a)"); - ParseInvalid(@"FooExecuted([0] int a)"); - ParseInvalid(@"FooExecuted([ProtoMember(0)] int a)"); - ParseInvalid(@"FooExecuted([19042] int a)"); - ParseInvalid(@"FooExecuted([ProtoMember(19042)] int a)"); - ParseInvalid(@"FooExecuted([42] int a, [42] int b)"); - ParseInvalid(@"FooExecuted([42] int a, [ProtoMember(42)] int b)"); - } + [Test] + public void should_detect_invalid_parameters() + { + ParseInvalid("FooExecuted(int a, int a)"); + ParseInvalid("FooExecuted([0] int a)"); + ParseInvalid("FooExecuted([ProtoMember(0)] int a)"); + ParseInvalid("FooExecuted([19042] int a)"); + ParseInvalid("FooExecuted([ProtoMember(19042)] int a)"); + ParseInvalid("FooExecuted([42] int a, [42] int b)"); + ParseInvalid("FooExecuted([42] int a, [ProtoMember(42)] int b)"); + } - [Test] - public void should_validate_tags_on_proto_include() - { - ParseInvalid(@"[ProtoInclude(1, typeof(MsgB))] MsgA(int a);"); - ParseValid(@"[ProtoInclude(1, typeof(MsgB))] MsgA();"); - ParseValid(@"[ProtoInclude(2, typeof(MsgB))] MsgA(int a);"); + [Test] + public void should_validate_tags_on_proto_include() + { + ParseInvalid("[ProtoInclude(1, typeof(MsgB))] MsgA(int a);"); + ParseValid("[ProtoInclude(1, typeof(MsgB))] MsgA();"); + ParseValid("[ProtoInclude(2, typeof(MsgB))] MsgA(int a);"); - ParseInvalid(@"[ProtoInclude(42, typeof(MsgB)), ProtoInclude(42, typeof(MsgC))] MsgA(int a);"); - ParseValid(@"[ProtoInclude(42, typeof(MsgB)), ProtoInclude(43, typeof(MsgC))] MsgA(int a);"); + ParseInvalid("[ProtoInclude(42, typeof(MsgB)), ProtoInclude(42, typeof(MsgC))] MsgA(int a);"); + ParseValid("[ProtoInclude(42, typeof(MsgB)), ProtoInclude(43, typeof(MsgC))] MsgA(int a);"); - ParseInvalid(@"[ProtoInclude(19500, typeof(MsgB))] MsgA();"); - ParseInvalid(@"[ProtoInclude(""foo"")] MsgA();"); - ParseInvalid(@"[ProtoInclude] MsgA();"); - } + ParseInvalid("[ProtoInclude(19500, typeof(MsgB))] MsgA();"); + ParseInvalid("""[ProtoInclude("foo")] MsgA();"""); + ParseInvalid("[ProtoInclude] MsgA();"); + } - [Test] - public void should_parse_boolean_options() - { - var contracts = ParseValid("#pragma Proto \r\nFoo()"); - contracts.Messages.First().Options.Proto.ShouldBeTrue(); + [Test] + public void should_parse_boolean_options() + { + var contracts = ParseValid("#pragma Proto \r\nFoo()"); + contracts.Messages.First().Options.Proto.ShouldBeTrue(); - contracts = ParseValid("#pragma proto \r\nFoo()"); - contracts.Messages.First().Options.Proto.ShouldBeTrue(); + contracts = ParseValid("#pragma proto \r\nFoo()"); + contracts.Messages.First().Options.Proto.ShouldBeTrue(); - contracts = ParseValid("#pragma Proto true \r\nFoo()"); - contracts.Messages.First().Options.Proto.ShouldBeTrue(); + contracts = ParseValid("#pragma Proto true \r\nFoo()"); + contracts.Messages.First().Options.Proto.ShouldBeTrue(); - contracts = ParseValid("#pragma Proto false \r\nFoo()"); - contracts.Messages.First().Options.Proto.ShouldBeFalse(); + contracts = ParseValid("#pragma Proto false \r\nFoo()"); + contracts.Messages.First().Options.Proto.ShouldBeFalse(); - contracts = ParseValid("#pragma !proto \r\nFoo()"); - contracts.Messages.First().Options.Proto.ShouldBeFalse(); + contracts = ParseValid("#pragma !proto \r\nFoo()"); + contracts.Messages.First().Options.Proto.ShouldBeFalse(); - contracts = ParseValid("#pragma Proto = true \r\nFoo()"); - contracts.Messages.First().Options.Proto.ShouldBeTrue(); + contracts = ParseValid("#pragma Proto = true \r\nFoo()"); + contracts.Messages.First().Options.Proto.ShouldBeTrue(); - contracts = ParseValid(" # pragma Proto \r\nFoo()"); - contracts.Messages.First().Options.Proto.ShouldBeTrue(); - } + contracts = ParseValid(" # pragma Proto \r\nFoo()"); + contracts.Messages.First().Options.Proto.ShouldBeTrue(); + } - [Test] - public void should_detect_invalid_options() - { - ParseInvalid("#pragma !proto true \r\nFoo()"); - ParseInvalid("#pragma \r\n mutable"); - ParseInvalid("#pragma ! \r\n mutable"); - ParseInvalid("#pragma mutable \r\n false"); - } + [Test] + public void should_detect_invalid_options() + { + ParseInvalid("#pragma !proto true \r\nFoo()"); + ParseInvalid("#pragma \r\n mutable"); + ParseInvalid("#pragma ! \r\n mutable"); + ParseInvalid("#pragma mutable \r\n false"); + } - [Test] - public void should_not_allow_anything_after_pragma_on_the_same_line() - { - ParseInvalid("#pragma proto; Foo()"); - ParseInvalid("#pragma proto Foo()"); - ParseInvalid("#pragma proto foo bar"); - } + [Test] + public void should_not_allow_anything_after_pragma_on_the_same_line() + { + ParseInvalid("#pragma proto; Foo()"); + ParseInvalid("#pragma proto Foo()"); + ParseInvalid("#pragma proto foo bar"); + } - [Test] - public void should_parse_pragma_internal_and_public() - { - var contracts = ParseValid("#pragma internal\r\nMsgA()\r\n#pragma public\r\nMsgB()"); - contracts.Messages[0].Options.Internal.ShouldBeTrue(); - contracts.Messages[0].Options.Public.ShouldBeFalse(); - contracts.Messages[1].Options.Internal.ShouldBeFalse(); - contracts.Messages[1].Options.Public.ShouldBeTrue(); - } + [Test] + public void should_parse_pragma_internal_and_public() + { + var contracts = ParseValid("#pragma internal\r\nMsgA()\r\n#pragma public\r\nMsgB()"); + contracts.Messages[0].Options.Internal.ShouldBeTrue(); + contracts.Messages[0].Options.Public.ShouldBeFalse(); + contracts.Messages[1].Options.Internal.ShouldBeFalse(); + contracts.Messages[1].Options.Public.ShouldBeTrue(); + } - [Test] - public void should_parse_pragma_nullable() - { - var contracts = ParseValid("#pragma nullable\r\nMsgA()\r\n#pragma !nullable\r\nMsgB()"); - contracts.Messages[0].Options.Nullable.ShouldBeTrue(); - contracts.Messages[1].Options.Nullable.ShouldBeFalse(); - } + [Test] + public void should_parse_pragma_nullable() + { + var contracts = ParseValid("#pragma nullable\r\nMsgA()\r\n#pragma !nullable\r\nMsgB()"); + contracts.Messages[0].Options.Nullable.ShouldBeTrue(); + contracts.Messages[1].Options.Nullable.ShouldBeFalse(); + } - [Test] - public void should_parse_access_modifiers() - { - var contracts = ParseValid("public MsgA(); internal MsgB();"); - contracts.Messages[0].AccessModifier.ShouldEqual(AccessModifier.Public); - contracts.Messages[1].AccessModifier.ShouldEqual(AccessModifier.Internal); - } + [Test] + public void should_parse_access_modifiers() + { + var contracts = ParseValid("public MsgA(); internal MsgB();"); + contracts.Messages[0].AccessModifier.ShouldEqual(AccessModifier.Public); + contracts.Messages[1].AccessModifier.ShouldEqual(AccessModifier.Internal); + } - [Test] - public void should_parse_access_modifiers_in_internal_scope() - { - var contracts = ParseValid("#pragma internal\r\npublic MsgA(); internal MsgB();"); - contracts.Messages[0].AccessModifier.ShouldEqual(AccessModifier.Public); - contracts.Messages[1].AccessModifier.ShouldEqual(AccessModifier.Internal); - } + [Test] + public void should_parse_access_modifiers_in_internal_scope() + { + var contracts = ParseValid("#pragma internal\r\npublic MsgA(); internal MsgB();"); + contracts.Messages[0].AccessModifier.ShouldEqual(AccessModifier.Public); + contracts.Messages[1].AccessModifier.ShouldEqual(AccessModifier.Internal); + } - [Test] - public void should_parse_inheritance_modifiers() - { - var contracts = ParseValid("sealed MsgA(); abstract MsgB();"); - contracts.Messages[0].InheritanceModifier.ShouldEqual(InheritanceModifier.Sealed); - contracts.Messages[1].InheritanceModifier.ShouldEqual(InheritanceModifier.Abstract); - } + [Test] + public void should_parse_inheritance_modifiers() + { + var contracts = ParseValid("sealed MsgA(); abstract MsgB();"); + contracts.Messages[0].InheritanceModifier.ShouldEqual(InheritanceModifier.Sealed); + contracts.Messages[1].InheritanceModifier.ShouldEqual(InheritanceModifier.Abstract); + } - [Test] - public void should_default_to_sealed_message_classes() - { - var contracts = ParseValid("MsgA();"); - contracts.Messages[0].InheritanceModifier.ShouldEqual(InheritanceModifier.Sealed); - } + [Test] + public void should_default_to_sealed_message_classes() + { + var contracts = ParseValid("MsgA();"); + contracts.Messages[0].InheritanceModifier.ShouldEqual(InheritanceModifier.Sealed); + } - [Test] - public void should_not_mark_inherited_messages_as_sealed() - { - var contracts = ParseValid("[ProtoInclude(10, typeof(MsgB))] MsgA(); MsgB() : MsgA;"); - contracts.Messages[0].InheritanceModifier.ShouldEqual(InheritanceModifier.None); - } + [Test] + public void should_not_mark_inherited_messages_as_sealed() + { + var contracts = ParseValid("[ProtoInclude(10, typeof(MsgB))] MsgA(); MsgB() : MsgA;"); + contracts.Messages[0].InheritanceModifier.ShouldEqual(InheritanceModifier.None); + } - [Test] - public void should_parse_custom_types() - { - var contracts = ParseValid(@"FooType!(int id);"); + [Test] + public void should_parse_custom_types() + { + var contracts = ParseValid("FooType!(int id);"); - var msg = contracts.Messages.ExpectedSingle(); - msg.IsCustom.ShouldBeTrue(); - } + var msg = contracts.Messages.ExpectedSingle(); + msg.IsCustom.ShouldBeTrue(); + } - [Test] - public void should_handle_generic_types() - { - var contracts = ParseValid(@"FooEvent(IDictionary foo, IList >bar)"); - var msg = contracts.Messages.ExpectedSingle(); - msg.Parameters[0].Type.NetType.ShouldEqual("IDictionary"); - msg.Parameters[1].Type.NetType.ShouldEqual("IList>"); - - contracts = ParseValid(@"FooEvent(SomeGenericStruct? foo)"); - msg = contracts.Messages.ExpectedSingle(); - msg.Parameters[0].Type.NetType.ShouldEqual("SomeGenericStruct?"); - } + [Test] + public void should_handle_generic_types() + { + var contracts = ParseValid("FooEvent(IDictionary foo, IList >bar)"); + var msg = contracts.Messages.ExpectedSingle(); + msg.Parameters[0].Type.NetType.ShouldEqual("IDictionary"); + msg.Parameters[1].Type.NetType.ShouldEqual("IList>"); + + contracts = ParseValid("FooEvent(SomeGenericStruct? foo)"); + msg = contracts.Messages.ExpectedSingle(); + msg.Parameters[0].Type.NetType.ShouldEqual("SomeGenericStruct?"); + } - [Test] - public void should_handle_array_types() - { - ParseValid(@"FooEvent(int[] foo)"); - ParseValid(@"FooEvent(int[,] foo)"); - ParseValid(@"FooEvent(int[,][,,] foo)"); - ParseValid(@"FooEvent(int[,][,,][,,,] foo)"); - ParseValid(@"FooEvent(int?[,][,,][,,,] foo)"); - ParseValid(@"FooEvent(List[] foo)"); - ParseValid(@"FooEvent(List[] foo)"); - } + [Test] + public void should_handle_array_types() + { + ParseValid("FooEvent(int[] foo)"); + ParseValid("FooEvent(int[,] foo)"); + ParseValid("FooEvent(int[,][,,] foo)"); + ParseValid("FooEvent(int[,][,,][,,,] foo)"); + ParseValid("FooEvent(int?[,][,,][,,,] foo)"); + ParseValid("FooEvent(List[] foo)"); + ParseValid("FooEvent(List[] foo)"); + } - [Test] - public void should_reject_invalid_types() - { - ParseInvalid("FooEvent(int?? id)"); - ParseInvalid("FooEvent(int? ? id)"); - ParseInvalid("FooEvent(int[] id)"); - ParseInvalid("FooEvent(int[]? id)"); - ParseInvalid("FooEvent(List id)"); - ParseInvalid("FooEvent(List? id)"); - ParseInvalid("FooEvent(List? id)"); - } + [Test] + public void should_reject_invalid_types() + { + ParseInvalid("FooEvent(int?? id)"); + ParseInvalid("FooEvent(int? ? id)"); + ParseInvalid("FooEvent(int[] id)"); + ParseInvalid("FooEvent(int[]? id)"); + ParseInvalid("FooEvent(List id)"); + ParseInvalid("FooEvent(List? id)"); + ParseInvalid("FooEvent(List? id)"); + } - [Test] - public void should_parse_nullable_reference_types() - { - ParseValid(@"FooEvent(string?[] foo)"); - ParseValid(@"FooEvent(string[]? foo)"); - ParseValid(@"FooEvent(string?[]? foo)"); - ParseValid(@"FooEvent(string?[]?[] foo)"); - ParseValid(@"FooEvent(string?[]?[]? foo)"); - ParseValid(@"FooEvent(List[] foo)"); - ParseValid(@"FooEvent(List?[] id)"); - ParseValid(@"FooEvent(List[]? foo)"); - ParseValid(@"FooEvent(List?[]? foo)"); - } + [Test] + public void should_parse_nullable_reference_types() + { + ParseValid("FooEvent(string?[] foo)"); + ParseValid("FooEvent(string[]? foo)"); + ParseValid("FooEvent(string?[]? foo)"); + ParseValid("FooEvent(string?[]?[] foo)"); + ParseValid("FooEvent(string?[]?[]? foo)"); + ParseValid("FooEvent(List[] foo)"); + ParseValid("FooEvent(List?[] id)"); + ParseValid("FooEvent(List[]? foo)"); + ParseValid("FooEvent(List?[]? foo)"); + } - [Test] - public void should_handle_namespaces() - { - var contracts = ParseValid("FooEvent(global::System.Collection.Generic.List foo)"); - var msg = contracts.Messages.ExpectedSingle(); - msg.Parameters.ExpectedSingle().Type.NetType.ShouldEqual("global::System.Collection.Generic.List"); + [Test] + public void should_handle_namespaces() + { + var contracts = ParseValid("FooEvent(global::System.Collection.Generic.List foo)"); + var msg = contracts.Messages.ExpectedSingle(); + msg.Parameters.ExpectedSingle().Type.NetType.ShouldEqual("global::System.Collection.Generic.List"); - contracts = ParseValid("FooEvent(global::System.Int32 id)"); - msg = contracts.Messages.ExpectedSingle(); - msg.Parameters.ExpectedSingle().Type.NetType.ShouldEqual("int"); - } + contracts = ParseValid("FooEvent(global::System.Int32 id)"); + msg = contracts.Messages.ExpectedSingle(); + msg.Parameters.ExpectedSingle().Type.NetType.ShouldEqual("int"); + } - [Test] - public void should_handle_generic_messages() - { - ParseValid("Foo()"); - ParseValid("Foo()"); - ParseValid("Foo()"); - - ParseInvalid("Foo<>()"); - ParseInvalid("Foo<42>()"); - ParseInvalid("Foo()"); - ParseInvalid("Foo()"); - ParseInvalid("Foo()"); - ParseInvalid("Foo<,>()"); - ParseInvalid("Foo()"); - ParseInvalid("Foo()"); - } + [Test] + public void should_handle_generic_messages() + { + ParseValid("Foo()"); + ParseValid("Foo()"); + ParseValid("Foo()"); + + ParseInvalid("Foo<>()"); + ParseInvalid("Foo<42>()"); + ParseInvalid("Foo()"); + ParseInvalid("Foo()"); + ParseInvalid("Foo()"); + ParseInvalid("Foo<,>()"); + ParseInvalid("Foo()"); + ParseInvalid("Foo()"); + } - [Test] - public void should_handle_generic_constraints() - { - ParseValid("Foo() where T : class"); - ParseValid("Foo() where T : struct"); - ParseValid("Foo() where T : new()"); - ParseValid("Foo() where T : IDisposable"); - ParseValid("Foo() where T : class, IDisposable"); - ParseValid("Foo() where T : IDisposable, new()"); - ParseValid("Foo() where T : class, IDisposable, new()"); - ParseValid("Foo() where T : new(), IDisposable, class"); - ParseValid("Foo() where T : struct, new()"); - ParseValid("Foo() where U : class"); - ParseValid("Foo() where T : class where U : class"); - ParseValid("Foo() where T : ISomething"); - ParseValid("Foo() where T : @class, struct"); - - ParseInvalid("Foo() where T : class, struct"); - ParseInvalid("Foo() where T : class, class"); - ParseInvalid("Foo() where T : struct, struct"); - ParseInvalid("Foo() where T : new(), new()"); - ParseInvalid("Foo() where T : @new()"); - ParseInvalid("Foo() where T : IDisposable, IDisposable"); - ParseInvalid("Foo() where U : class"); - ParseInvalid("Foo() where T : class where T : IDisposable"); - ParseInvalid("Foo() where T : ISomething"); - ParseInvalid("Foo() where T : ISomething"); - } + [Test] + public void should_handle_generic_constraints() + { + ParseValid("Foo() where T : class"); + ParseValid("Foo() where T : struct"); + ParseValid("Foo() where T : new()"); + ParseValid("Foo() where T : IDisposable"); + ParseValid("Foo() where T : class, IDisposable"); + ParseValid("Foo() where T : IDisposable, new()"); + ParseValid("Foo() where T : class, IDisposable, new()"); + ParseValid("Foo() where T : new(), IDisposable, class"); + ParseValid("Foo() where T : struct, new()"); + ParseValid("Foo() where U : class"); + ParseValid("Foo() where T : class where U : class"); + ParseValid("Foo() where T : ISomething"); + ParseValid("Foo() where T : @class, struct"); + + ParseInvalid("Foo() where T : class, struct"); + ParseInvalid("Foo() where T : class, class"); + ParseInvalid("Foo() where T : struct, struct"); + ParseInvalid("Foo() where T : new(), new()"); + ParseInvalid("Foo() where T : @new()"); + ParseInvalid("Foo() where T : IDisposable, IDisposable"); + ParseInvalid("Foo() where U : class"); + ParseInvalid("Foo() where T : class where T : IDisposable"); + ParseInvalid("Foo() where T : ISomething"); + ParseInvalid("Foo() where T : ISomething"); + } - [Test] - public void should_disallow_unhandled_features_in_proto() - { - ParseInvalid("#pragma proto\r\nFoo()"); - ParseInvalid("#pragma proto\r\nFoo#v0()"); - } + [Test] + public void should_disallow_unhandled_features_in_proto() + { + ParseInvalid("#pragma proto\r\nFoo()"); + ParseInvalid("#pragma proto\r\nFoo#v0()"); + } - [Test] - public void should_handle_keywords() - { - ParseValid("@Foo(@if @void)"); - ParseValid("@null(@true @false, @public @static)"); + [Test] + public void should_handle_keywords() + { + ParseValid("@Foo(@if @void)"); + ParseValid("@null(@true @false, @public @static)"); - ParseInvalid("Foo(@ if test)"); - ParseInvalid("Foo(if void)"); - ParseInvalid("Foo(null true)"); - } + ParseInvalid("Foo(@ if test)"); + ParseInvalid("Foo(if void)"); + ParseInvalid("Foo(null true)"); + } - [Test] - public void should_handle_additional_interfaces() - { - ParseValid("Foo(int id) : ITest").Messages.Single().BaseTypes.ShouldContain((TypeName)"ITest"); - ParseValid("Foo(int id) : IGeneric").Messages.Single().BaseTypes.ShouldContain((TypeName)"IGeneric"); - } + [Test] + public void should_handle_additional_interfaces() + { + ParseValid("Foo(int id) : ITest").Messages.Single().BaseTypes.ShouldContain((TypeName)"ITest"); + ParseValid("Foo(int id) : IGeneric").Messages.Single().BaseTypes.ShouldContain((TypeName)"IGeneric"); + } - [Test] - public void should_override_message_type() - { - var msg = ParseValid("Foo(int id)").Messages.Single(); - msg.Type.ShouldEqual(MessageType.Event); - msg.BaseTypes.ShouldContain((TypeName)"IEvent"); - msg.BaseTypes.ShouldNotContain((TypeName)"ICommand"); - - msg = ParseValid("Foo(int id) : ICommand").Messages.Single(); - msg.Type.ShouldEqual(MessageType.Command); - msg.BaseTypes.ShouldContain((TypeName)"ICommand"); - msg.BaseTypes.ShouldNotContain((TypeName)"IEvent"); - } + [Test] + public void should_override_message_type() + { + var msg = ParseValid("Foo(int id)").Messages.Single(); + msg.Type.ShouldEqual(MessageType.Event); + msg.BaseTypes.ShouldContain((TypeName)"IEvent"); + msg.BaseTypes.ShouldNotContain((TypeName)"ICommand"); + + msg = ParseValid("Foo(int id) : ICommand").Messages.Single(); + msg.Type.ShouldEqual(MessageType.Command); + msg.BaseTypes.ShouldContain((TypeName)"ICommand"); + msg.BaseTypes.ShouldNotContain((TypeName)"IEvent"); + } - [Test] - public void should_parse_enums() - { - var msg = ParseValid("[EnumAttr] enum Foo { Red, Green = 4, [MemberAttr] Blue = Red }; Test();"); - var fooEnum = msg.Enums.ExpectedSingle(); + [Test] + public void should_parse_enums() + { + var msg = ParseValid("[EnumAttr] enum Foo { Red, Green = 4, [MemberAttr] Blue = Red }; Test();"); + var fooEnum = msg.Enums.ExpectedSingle(); - fooEnum.Name.ShouldEqual("Foo"); - fooEnum.UnderlyingType.ShouldEqual((TypeName)"int"); - fooEnum.Attributes.ExpectedSingle().TypeName.ShouldEqual((TypeName)"EnumAttr"); + fooEnum.Name.ShouldEqual("Foo"); + fooEnum.UnderlyingType.ShouldEqual((TypeName)"int"); + fooEnum.Attributes.ExpectedSingle().TypeName.ShouldEqual((TypeName)"EnumAttr"); - fooEnum.Members.Count.ShouldEqual(3); - fooEnum.Members[0].Name.ShouldEqual("Red"); - fooEnum.Members[0].Value.ShouldBeNull(); - fooEnum.Members[0].Attributes.ShouldBeEmpty(); + fooEnum.Members.Count.ShouldEqual(3); + fooEnum.Members[0].Name.ShouldEqual("Red"); + fooEnum.Members[0].Value.ShouldBeNull(); + fooEnum.Members[0].Attributes.ShouldBeEmpty(); - fooEnum.Members[1].Name.ShouldEqual("Green"); - fooEnum.Members[1].Value.ShouldEqual("4"); - fooEnum.Members[1].Attributes.ShouldBeEmpty(); + fooEnum.Members[1].Name.ShouldEqual("Green"); + fooEnum.Members[1].Value.ShouldEqual("4"); + fooEnum.Members[1].Attributes.ShouldBeEmpty(); - fooEnum.Members[2].Name.ShouldEqual("Blue"); - fooEnum.Members[2].Value.ShouldEqual("Red"); - fooEnum.Members[2].Attributes.ExpectedSingle().TypeName.ShouldEqual((TypeName)"MemberAttr"); - } + fooEnum.Members[2].Name.ShouldEqual("Blue"); + fooEnum.Members[2].Value.ShouldEqual("Red"); + fooEnum.Members[2].Attributes.ExpectedSingle().TypeName.ShouldEqual((TypeName)"MemberAttr"); + } - [Test] - public void should_parse_complex_enums() - { - var msg = ParseValid(@" - [Flags] - public enum Metrics - { - None = 0, - HitOrdersProportion = 1 << 0, - HitRatio = 1 << 1, - IocExecution = 1 << 2, - Latency = 1 << 3, - Lifetime = 1 << 4, - MessageCounter = 1 << 5, - CancelReject = 1 << 6, - ExcessiveMessageRatioNasdaq = 1 << 7, - All = ~0, - AllExceptOrderContextMetrics = All & ~HitOrdersProportion & ~HitRatio & ~ExcessiveMessageRatioNasdaq, - Test = Lifetime | (Latency) - } -"); - - var enumDef = msg.Enums.ExpectedSingle(); - enumDef.Members.Count.ShouldEqual(12); - enumDef.Members.ExpectedSingle(i => i.Name == "Test").Value.ShouldEqual("Lifetime | (Latency)"); - } + [Test] + public void should_parse_complex_enums() + { + var msg = ParseValid( + """ + [Flags] + public enum Metrics + { + None = 0, + HitOrdersProportion = 1 << 0, + HitRatio = 1 << 1, + IocExecution = 1 << 2, + Latency = 1 << 3, + Lifetime = 1 << 4, + MessageCounter = 1 << 5, + CancelReject = 1 << 6, + ExcessiveMessageRatioNasdaq = 1 << 7, + All = ~0, + AllExceptOrderContextMetrics = All & ~HitOrdersProportion & ~HitRatio & ~ExcessiveMessageRatioNasdaq, + Test = Lifetime | (Latency) + } + """ + ); - [Test] - public void should_handle_double_angled_brackets() - { - ParseValid(@"enum Foo { Bar = 1 << 4 };"); - ParseValid(@"enum Foo { Bar = 1 >> 0 };"); - ParseInvalid(@"enum Foo { Bar = 1 < < 4 };"); - ParseInvalid(@"enum Foo { Bar = 1 > > 0 };"); + var enumDef = msg.Enums.ExpectedSingle(); + enumDef.Members.Count.ShouldEqual(12); + enumDef.Members.ExpectedSingle(i => i.Name == "Test").Value.ShouldEqual("Lifetime | (Latency)"); + } - ParseValid(@"Foo(Dictionary> bar)"); - ParseValid(@"Foo(Dictionary > bar)"); - } + [Test] + public void should_handle_double_angled_brackets() + { + ParseValid("enum Foo { Bar = 1 << 4 };"); + ParseValid("enum Foo { Bar = 1 >> 0 };"); + ParseInvalid("enum Foo { Bar = 1 < < 4 };"); + ParseInvalid("enum Foo { Bar = 1 > > 0 };"); - [Test] - public void should_allow_public_keyword_for_enums() - { - var msg = ParseValid("[EnumAttr] public enum Foo { Red, Green = 4, [MemberAttr] Blue = Red }; Test();"); - msg.Enums.ExpectedSingle().Name.ShouldEqual("Foo"); - } + ParseValid("Foo(Dictionary> bar)"); + ParseValid("Foo(Dictionary > bar)"); + } - [Test] - public void should_parse_enum_separators() - { - ParseValid("enum Foo { }; enum Bar { A }; enum Baz { A, }"); - ParseInvalid("enum Foo { , }"); - } + [Test] + public void should_allow_public_keyword_for_enums() + { + var msg = ParseValid("[EnumAttr] public enum Foo { Red, Green = 4, [MemberAttr] Blue = Red }; Test();"); + msg.Enums.ExpectedSingle().Name.ShouldEqual("Foo"); + } - [Test] - public void should_detect_duplicate_enum_members() - { - ParseValid("enum Foo { Bar, bar }"); - ParseInvalid("enum Foo { Bar, Bar }"); - } + [Test] + public void should_parse_enum_separators() + { + ParseValid("enum Foo { }; enum Bar { A }; enum Baz { A, }"); + ParseInvalid("enum Foo { , }"); + } - [Test] - public void should_detect_invalid_underlying_enum_types() - { - ParseValid("enum Foo { Bar }"); - ParseValid("enum Foo : byte { Bar }"); - ParseValid("enum Foo : sbyte { Bar }"); - ParseValid("enum Foo : short { Bar }"); - ParseValid("enum Foo : ushort { Bar }"); - ParseValid("enum Foo : int { Bar }"); - ParseValid("enum Foo : uint { Bar }"); - ParseValid("enum Foo : long { Bar }"); - ParseValid("enum Foo : ulong { Bar }"); - - ParseInvalid("enum Foo : string { Bar }"); - ParseInvalid("enum Foo : Baz { Bar }"); - - ParseInvalid("#pragma proto\nenum Foo : short { Bar }"); - } + [Test] + public void should_detect_duplicate_enum_members() + { + ParseValid("enum Foo { Bar, bar }"); + ParseInvalid("enum Foo { Bar, Bar }"); + } - [Test] - public void should_detect_invalid_custom_protomember_attributes() - { - ParseValid("Foo([ProtoMember(42)] int bar)"); - ParseValid("Foo([ProtoMember( 42 )] int bar)"); - ParseValid("Foo([ProtoMember( 42, IsRequired = false )] int bar)"); - - ParseInvalid("Foo([10] [ProtoMember(42)] int bar)"); - ParseInvalid("Foo([ProtoMember()] int bar)"); - ParseInvalid("Foo([ProtoMember] int bar)"); - ParseInvalid("Foo([ProtoMember('a')] int bar)"); - } + [Test] + public void should_detect_invalid_underlying_enum_types() + { + ParseValid("enum Foo { Bar }"); + ParseValid("enum Foo : byte { Bar }"); + ParseValid("enum Foo : sbyte { Bar }"); + ParseValid("enum Foo : short { Bar }"); + ParseValid("enum Foo : ushort { Bar }"); + ParseValid("enum Foo : int { Bar }"); + ParseValid("enum Foo : uint { Bar }"); + ParseValid("enum Foo : long { Bar }"); + ParseValid("enum Foo : ulong { Bar }"); + + ParseInvalid("enum Foo : string { Bar }"); + ParseInvalid("enum Foo : Baz { Bar }"); + + ParseInvalid("#pragma proto\nenum Foo : short { Bar }"); + } - [Test] - public void should_return_source_text() - { - var contracts = ParseValid(@" [Attr(42)] Foo ( int bar ); "); + [Test] + public void should_detect_invalid_custom_protomember_attributes() + { + ParseValid("Foo([ProtoMember(42)] int bar)"); + ParseValid("Foo([ProtoMember( 42 )] int bar)"); + ParseValid("Foo([ProtoMember( 42, IsRequired = false )] int bar)"); + + ParseInvalid("Foo([10] [ProtoMember(42)] int bar)"); + ParseInvalid("Foo([ProtoMember()] int bar)"); + ParseInvalid("Foo([ProtoMember] int bar)"); + ParseInvalid("Foo([ProtoMember('a')] int bar)"); + } - var foo = contracts.Messages.ExpectedSingle(); - foo.GetSourceText().ShouldEqual("[Attr(42)] Foo ( int bar )"); - foo.GetSourceTextInterval().ShouldEqual(new TextInterval(2, 28)); + [Test] + public void should_return_source_text() + { + var contracts = ParseValid(" [Attr(42)] Foo ( int bar ); "); - var bar = foo.Parameters.ExpectedSingle(); - bar.GetSourceText().ShouldEqual("int bar"); - bar.GetSourceTextInterval().ShouldEqual(new TextInterval(19, 26)); + var foo = contracts.Messages.ExpectedSingle(); + foo.GetSourceText().ShouldEqual("[Attr(42)] Foo ( int bar )"); + foo.GetSourceTextInterval().ShouldEqual(new TextInterval(2, 28)); - var attr = foo.Attributes.ExpectedSingle(); - attr.GetSourceText().ShouldEqual("Attr(42)"); - attr.GetSourceTextInterval().ShouldEqual(new TextInterval(3, 11)); - } + var bar = foo.Parameters.ExpectedSingle(); + bar.GetSourceText().ShouldEqual("int bar"); + bar.GetSourceTextInterval().ShouldEqual(new TextInterval(19, 26)); - [Test] - public void should_return_two_messages_with_same_name_but_different_arity() - { - var contracts = ParseValid("Foo(); Foo();"); + var attr = foo.Attributes.ExpectedSingle(); + attr.GetSourceText().ShouldEqual("Attr(42)"); + attr.GetSourceTextInterval().ShouldEqual(new TextInterval(3, 11)); + } - contracts.Messages.Count.ShouldEqual(2); - contracts.Messages[0].Name.ShouldEqual("Foo"); - contracts.Messages[1].Name.ShouldEqual("Foo"); - contracts.Messages[1].GenericParameters.ExpectedSingle().ShouldEqual("T"); - } + [Test] + public void should_return_two_messages_with_same_name_but_different_arity() + { + var contracts = ParseValid("Foo(); Foo();"); - [Test] - public void should_parse_default_values() - { - ParseValid(@"Foo(int i = 42);"); - ParseValid(@"Foo(int i = default);"); - ParseValid(@"Foo(int i = default(int));"); - ParseValid(@"Foo(double d = 42.42);"); - ParseValid(@"Foo(decimal d = 42.42m);"); - ParseValid(@"Foo(bool b = true);"); - ParseValid(@"Foo(bool b = false);"); - ParseValid(@"Foo(string s = null);"); - ParseValid(@"Foo(string s = ""foo"");"); - ParseValid(@"Foo(char c = 'c');"); - ParseValid(@"Foo(Type t = typeof(string));"); // Invalid protobuf, but accepted in the DSL - ParseValid(@"Foo(DayOfWeek d = DayOfWeek.Friday);"); - } + contracts.Messages.Count.ShouldEqual(2); + contracts.Messages[0].Name.ShouldEqual("Foo"); + contracts.Messages[1].Name.ShouldEqual("Foo"); + contracts.Messages[1].GenericParameters.ExpectedSingle().ShouldEqual("T"); + } - [Test] - public void should_parse_nested_classes() - { - var contracts = ParseValid("Foo.Bar.Baz();"); - var message = contracts.Messages.ExpectedSingle(); - message.Name.ShouldEqual("Baz"); - message.ContainingClasses.ShouldEqual(new TypeName[] { "Foo", "Bar" }); - } + [Test] + public void should_parse_default_values() + { + ParseValid("Foo(int i = 42);"); + ParseValid("Foo(int i = default);"); + ParseValid("Foo(int i = default(int));"); + ParseValid("Foo(double d = 42.42);"); + ParseValid("Foo(decimal d = 42.42m);"); + ParseValid("Foo(bool b = true);"); + ParseValid("Foo(bool b = false);"); + ParseValid("Foo(string s = null);"); + ParseValid("""Foo(string s = "foo");"""); + ParseValid("Foo(char c = 'c');"); + ParseValid("Foo(Type t = typeof(string));"); // Invalid protobuf, but accepted in the DSL + ParseValid("Foo(DayOfWeek d = DayOfWeek.Friday);"); + } - [Test] - public void should_detect_inheritance_loops() - { - ParseInvalid(@"Foo() : Foo;"); - ParseInvalid(@"Foo() : Bar; [ProtoInclude(1, typeof(Foo))] Bar() : Foo;"); - ParseInvalid(@"Foo() : Bar; [ProtoInclude(1, typeof(Foo))] Bar() : Baz; [ProtoInclude(1, typeof(Bar))] Baz() : Foo;"); + [Test] + public void should_parse_nested_classes() + { + var contracts = ParseValid("Foo.Bar.Baz();"); + var message = contracts.Messages.ExpectedSingle(); + message.Name.ShouldEqual("Baz"); + message.ContainingClasses.ShouldEqual(new TypeName[] { "Foo", "Bar" }); + } - ParseValid(@"Foo() : Bar; [ProtoInclude(1, typeof(Foo))] Bar();"); - } + [Test] + public void should_detect_inheritance_loops() + { + ParseInvalid("Foo() : Foo;"); + ParseInvalid("Foo() : Bar; [ProtoInclude(1, typeof(Foo))] Bar() : Foo;"); + ParseInvalid("Foo() : Bar; [ProtoInclude(1, typeof(Foo))] Bar() : Baz; [ProtoInclude(1, typeof(Bar))] Baz() : Foo;"); - [Test] - public void should_detect_misplaced_optional_parameters() - { - ParseInvalid("Foo(int a, int b = 42, int c);"); - ParseValid("Foo(int a, int b = 42, int c = 10);"); - } + ParseValid("Foo() : Bar; [ProtoInclude(1, typeof(Foo))] Bar();"); + } - private static ParsedContracts ParseValid(string definitionText) - { - var contracts = Parse(definitionText); + [Test] + public void should_detect_misplaced_optional_parameters() + { + ParseInvalid("Foo(int a, int b = 42, int c);"); + ParseValid("Foo(int a, int b = 42, int c = 10);"); + } - if (contracts.Errors.Any()) - { - SyntaxDebugHelper.DumpParseTree(contracts); - Assert.Fail("There are unexpected errors"); - } + private static ParsedContracts ParseValid(string definitionText) + { + var contracts = Parse(definitionText); - return contracts; + if (contracts.Errors.Any()) + { + SyntaxDebugHelper.DumpParseTree(contracts); + Assert.Fail("There are unexpected errors"); } - private static ParsedContracts ParseInvalid(string definitionText) - { - var contracts = Parse(definitionText); + return contracts; + } - if (!contracts.Errors.Any()) - { - SyntaxDebugHelper.DumpParseTree(contracts); - Assert.Fail("Errors were expected"); - } + private static ParsedContracts ParseInvalid(string definitionText) + { + var contracts = Parse(definitionText); - return contracts; + if (!contracts.Errors.Any()) + { + SyntaxDebugHelper.DumpParseTree(contracts); + Assert.Fail("Errors were expected"); } - private static ParsedContracts Parse(string definitionText) - { - Console.WriteLine("PARSE: {0}", definitionText); - var contracts = ParsedContracts.Parse(definitionText, "Some.Namespace"); + return contracts; + } - foreach (var error in contracts.Errors) - Console.WriteLine("ERROR: {0}", error); + private static ParsedContracts Parse(string definitionText) + { + Console.WriteLine("PARSE: {0}", definitionText); + var contracts = ParsedContracts.Parse(definitionText, "Some.Namespace"); - return contracts; - } + foreach (var error in contracts.Errors) + Console.WriteLine("ERROR: {0}", error); - private static void ShouldContainError(ParsedContracts contracts, string expectedMessage) - { - var containsError = contracts.Errors.Any(err => err.Message.IndexOf(expectedMessage, StringComparison.OrdinalIgnoreCase) >= 0); - if (!containsError) - Assert.Fail($"Expected error: {expectedMessage}"); - } + return contracts; + } + + private static void ShouldContainError(ParsedContracts contracts, string expectedMessage) + { + var containsError = contracts.Errors.Any(err => err.Message.IndexOf(expectedMessage, StringComparison.OrdinalIgnoreCase) >= 0); + if (!containsError) + Assert.Fail($"Expected error: {expectedMessage}"); } } diff --git a/src/Abc.Zebus.MessageDsl.Tests/MessageDsl/ProtoGeneratorTests.cs b/src/Abc.Zebus.MessageDsl.Tests/MessageDsl/ProtoGeneratorTests.cs index 3e44e73..7902e93 100644 --- a/src/Abc.Zebus.MessageDsl.Tests/MessageDsl/ProtoGeneratorTests.cs +++ b/src/Abc.Zebus.MessageDsl.Tests/MessageDsl/ProtoGeneratorTests.cs @@ -3,191 +3,188 @@ using Abc.Zebus.MessageDsl.Tests.TestTools; using NUnit.Framework; -namespace Abc.Zebus.MessageDsl.Tests.MessageDsl +namespace Abc.Zebus.MessageDsl.Tests.MessageDsl; + +[TestFixture] +public class ProtoGeneratorTests : GeneratorTests { - [TestFixture] - public class ProtoGeneratorTests : GeneratorTests + [Test] + public void should_generate_code() { - [Test] - public void should_generate_code() + var code = Generate(new MessageDefinition { - var code = Generate(new MessageDefinition + Name = "FooExecuted", + Parameters = { - Name = "FooExecuted", - Parameters = - { - new ParameterDefinition("int?", "foo"), - new ParameterDefinition("string", "bar"), - new ParameterDefinition("BarBaz", "baz") - }, - Options = - { - Proto = true - } - }); + new ParameterDefinition("int?", "foo"), + new ParameterDefinition("string", "bar"), + new ParameterDefinition("BarBaz", "baz") + }, + Options = + { + Proto = true + } + }); - code.ShouldContain("optional int32 Foo = 1;"); - code.ShouldContain("required string Bar = 2;"); - code.ShouldContain("required BarBaz Baz = 3;"); - } + code.ShouldContain("optional int32 Foo = 1;"); + code.ShouldContain("required string Bar = 2;"); + code.ShouldContain("required BarBaz Baz = 3;"); + } - [Test] - public void should_handle_deprecated_fields() + [Test] + public void should_handle_deprecated_fields() + { + var code = Generate(new MessageDefinition { - var code = Generate(new MessageDefinition + Name = "FooExecuted", + Parameters = { - Name = "FooExecuted", - Parameters = - { - new ParameterDefinition("int", "foo") - { - Attributes = { new AttributeDefinition("Obsolete") } - } - }, - Options = + new ParameterDefinition("int", "foo") { - Proto = true + Attributes = { new AttributeDefinition("Obsolete") } } - }); + }, + Options = + { + Proto = true + } + }); - code.ShouldContain("required int32 Foo = 1 [deprecated = true];"); - } + code.ShouldContain("required int32 Foo = 1 [deprecated = true];"); + } - [Test] - public void should_generate_packed_members() + [Test] + public void should_generate_packed_members() + { + var code = Generate(new MessageDefinition { - var code = Generate(new MessageDefinition + Name = "FooExecuted", + Parameters = { - Name = "FooExecuted", - Parameters = - { - new ParameterDefinition("System.Int32[]", "foo"), - new ParameterDefinition("LolType[]", "bar"), - new ParameterDefinition("List", "fooList"), - new ParameterDefinition("System.Collections.Generic.List", "barList"), - }, - Options = - { - Proto = true - } - }); + new ParameterDefinition("System.Int32[]", "foo"), + new ParameterDefinition("LolType[]", "bar"), + new ParameterDefinition("List", "fooList"), + new ParameterDefinition("System.Collections.Generic.List", "barList"), + }, + Options = + { + Proto = true + } + }); - code.ShouldContain("repeated int32 Foo = 1 [packed = true];"); - code.ShouldContain("repeated LolType Bar = 2;"); - code.ShouldContain("repeated int32 FooList = 3 [packed = true];"); - code.ShouldContain("repeated LolType BarList = 4;"); - } + code.ShouldContain("repeated int32 Foo = 1 [packed = true];"); + code.ShouldContain("repeated LolType Bar = 2;"); + code.ShouldContain("repeated int32 FooList = 3 [packed = true];"); + code.ShouldContain("repeated LolType BarList = 4;"); + } - [Test] - public void should_generate_simple_enums() + [Test] + public void should_generate_simple_enums() + { + var code = Generate(new ParsedContracts { - var code = Generate(new ParsedContracts + Enums = { - Enums = + new EnumDefinition { - new EnumDefinition + Name = "Foo", + Members = { - Name = "Foo", - Members = + new EnumMemberDefinition { - new EnumMemberDefinition - { - Name = "Default" - }, - new EnumMemberDefinition - { - Name = "Bar", - Value = "-2" - } + Name = "Default" }, - Options = + new EnumMemberDefinition { - Proto = true + Name = "Bar", + Value = "-2" } + }, + Options = + { + Proto = true } } - }); + } + }); - code.ShouldContain("enum Foo {"); - code.ShouldNotContain("option allow_alias = true;"); - code.ShouldContain("Default = 0;"); - code.ShouldContain("Bar = -2;"); - } + code.ShouldContain("enum Foo {"); + code.ShouldNotContain("option allow_alias = true;"); + code.ShouldContain("Default = 0;"); + code.ShouldContain("Bar = -2;"); + } - [Test] - public void should_generate_enums() + [Test] + public void should_generate_enums() + { + var code = Generate(new ParsedContracts { - var code = Generate(new ParsedContracts + Enums = { - Enums = + new EnumDefinition { - new EnumDefinition + Name = "Foo", + Attributes = + { + new AttributeDefinition("EnumAttr") + }, + Members = { - Name = "Foo", - Attributes = + new EnumMemberDefinition { - new AttributeDefinition("EnumAttr") + Name = "Default" }, - Members = + new EnumMemberDefinition { - new EnumMemberDefinition - { - Name = "Default" - }, - new EnumMemberDefinition - { - Name = "Bar", - Value = "-2" - }, - new EnumMemberDefinition - { - Name = "Baz" - }, - new EnumMemberDefinition - { - Name = "Alias" - } + Name = "Bar", + Value = "-2" }, - Options = + new EnumMemberDefinition { - Proto = true + Name = "Baz" + }, + new EnumMemberDefinition + { + Name = "Alias" } + }, + Options = + { + Proto = true } } - }); + } + }); - code.ShouldContain("enum Foo {"); - code.ShouldContain("option allow_alias = true;"); - code.ShouldContain("Default = 0;"); - code.ShouldContain("Bar = -2;"); - code.ShouldContain("Baz = -1;"); - code.ShouldContain("Alias = 0;"); - } + code.ShouldContain("enum Foo {"); + code.ShouldContain("option allow_alias = true;"); + code.ShouldContain("Default = 0;"); + code.ShouldContain("Bar = -2;"); + code.ShouldContain("Baz = -1;"); + code.ShouldContain("Alias = 0;"); + } - [Test] - public void should_handle_message_inheritance() + [Test] + public void should_handle_message_inheritance() + { + var code = Generate(new MessageDefinition { - var code = Generate(new MessageDefinition + Name = "MsgA", + Attributes = { - Name = "MsgA", - Attributes = - { - new AttributeDefinition("ProtoInclude", "10, typeof(MsgB)"), - new AttributeDefinition("ProtoInclude", "11, typeof(MsgC)") - }, - Options = - { - Proto = true - } - }); - - code.ShouldContain("optional MsgB _subTypeMsgB = 10;"); - code.ShouldContain("optional MsgC _subTypeMsgC = 11;"); - } + new AttributeDefinition("ProtoInclude", "10, typeof(MsgB)"), + new AttributeDefinition("ProtoInclude", "11, typeof(MsgC)") + }, + Options = + { + Proto = true + } + }); - protected override string GenerateRaw(ParsedContracts contracts) - { - return ProtoGenerator.Generate(contracts); - } + code.ShouldContain("optional MsgB _subTypeMsgB = 10;"); + code.ShouldContain("optional MsgC _subTypeMsgC = 11;"); } + + protected override string GenerateRaw(ParsedContracts contracts) + => ProtoGenerator.Generate(contracts); } diff --git a/src/Abc.Zebus.MessageDsl.Tests/MessageDsl/TypeNameTests.cs b/src/Abc.Zebus.MessageDsl.Tests/MessageDsl/TypeNameTests.cs index af34a66..5427787 100644 --- a/src/Abc.Zebus.MessageDsl.Tests/MessageDsl/TypeNameTests.cs +++ b/src/Abc.Zebus.MessageDsl.Tests/MessageDsl/TypeNameTests.cs @@ -2,244 +2,243 @@ using Abc.Zebus.MessageDsl.Tests.TestTools; using NUnit.Framework; -namespace Abc.Zebus.MessageDsl.Tests.MessageDsl +namespace Abc.Zebus.MessageDsl.Tests.MessageDsl; + +[TestFixture] +public class TypeNameTests { - [TestFixture] - public class TypeNameTests + [Test] + public void should_normalize_type() + { + var type = new TypeName("System.Int64"); + type.NetType.ShouldEqual("long"); + } + + [Test] + public void should_remove_system_namespace() { - [Test] - public void should_normalize_type() - { - var type = new TypeName("System.Int64"); - type.NetType.ShouldEqual("long"); - } - - [Test] - public void should_remove_system_namespace() - { - var type = new TypeName("System.DateTime?"); - type.NetType.ShouldEqual("DateTime?"); - } - - [Test] - public void should_not_remove_system_namespace() - { - var type = new TypeName("System.Diagnostics.DebuggerNonUserCode"); - type.NetType.ShouldEqual("System.Diagnostics.DebuggerNonUserCode"); - } - - [Test] - public void should_detect_repeated_types() - { - var type = new TypeName("int"); - type.IsArray.ShouldBeFalse(); - type.IsList.ShouldBeFalse(); - type.IsHashSet.ShouldBeFalse(); - type.IsDictionary.ShouldBeFalse(); - type.IsRepeated.ShouldBeFalse(); - type.IsPackable.ShouldBeFalse(); - - type = new TypeName("int[]"); - type.IsArray.ShouldBeTrue(); - type.IsList.ShouldBeFalse(); - type.IsHashSet.ShouldBeFalse(); - type.IsDictionary.ShouldBeFalse(); - type.IsRepeated.ShouldBeTrue(); - type.IsPackable.ShouldBeTrue(); - - type = new TypeName("int[]?"); - type.IsArray.ShouldBeTrue(); - type.IsList.ShouldBeFalse(); - type.IsHashSet.ShouldBeFalse(); - type.IsDictionary.ShouldBeFalse(); - type.IsRepeated.ShouldBeTrue(); - type.IsPackable.ShouldBeTrue(); - - type = new TypeName("int?[]"); - type.IsArray.ShouldBeTrue(); - type.IsList.ShouldBeFalse(); - type.IsHashSet.ShouldBeFalse(); - type.IsDictionary.ShouldBeFalse(); - type.IsRepeated.ShouldBeTrue(); - type.IsPackable.ShouldBeFalse(); - - type = new TypeName("List"); - type.IsArray.ShouldBeFalse(); - type.IsList.ShouldBeTrue(); - type.IsHashSet.ShouldBeFalse(); - type.IsDictionary.ShouldBeFalse(); - type.IsRepeated.ShouldBeTrue(); - type.IsPackable.ShouldBeTrue(); - - type = new TypeName("List?"); - type.IsArray.ShouldBeFalse(); - type.IsList.ShouldBeTrue(); - type.IsHashSet.ShouldBeFalse(); - type.IsDictionary.ShouldBeFalse(); - type.IsRepeated.ShouldBeTrue(); - type.IsPackable.ShouldBeTrue(); - - type = new TypeName("HashSet"); - type.IsArray.ShouldBeFalse(); - type.IsList.ShouldBeFalse(); - type.IsHashSet.ShouldBeTrue(); - type.IsDictionary.ShouldBeFalse(); - type.IsRepeated.ShouldBeTrue(); - type.IsPackable.ShouldBeTrue(); - - type = new TypeName("HashSet?"); - type.IsArray.ShouldBeFalse(); - type.IsList.ShouldBeFalse(); - type.IsHashSet.ShouldBeTrue(); - type.IsDictionary.ShouldBeFalse(); - type.IsRepeated.ShouldBeTrue(); - type.IsPackable.ShouldBeTrue(); - - type = new TypeName("System.Collections.Generic.List"); - type.IsArray.ShouldBeFalse(); - type.IsList.ShouldBeTrue(); - type.IsHashSet.ShouldBeFalse(); - type.IsDictionary.ShouldBeFalse(); - type.IsRepeated.ShouldBeTrue(); - type.IsPackable.ShouldBeTrue(); - - type = new TypeName("System.Collections.Generic.List?"); - type.IsArray.ShouldBeFalse(); - type.IsList.ShouldBeTrue(); - type.IsHashSet.ShouldBeFalse(); - type.IsDictionary.ShouldBeFalse(); - type.IsRepeated.ShouldBeTrue(); - type.IsPackable.ShouldBeTrue(); - } - - [Test] - public void should_detect_nullables() - { - var type = new TypeName("int"); - type.IsNullable.ShouldBeFalse(); - - type = new TypeName("int?"); - type.IsNullable.ShouldBeTrue(); - - type = new TypeName("string"); - type.IsNullable.ShouldBeFalse(); - - type = new TypeName("string?"); - type.IsNullable.ShouldBeTrue(); - } - - [Test] - public void should_map_to_protobuf_type() - { - var type = new TypeName("int"); - type.ProtoBufType.ShouldEqual("int32"); - type.IsPackable.ShouldBeFalse(); - - type = new TypeName("int[]"); - type.ProtoBufType.ShouldEqual("int32"); - type.IsPackable.ShouldBeTrue(); - - type = new TypeName("int[]?"); - type.ProtoBufType.ShouldEqual("int32"); - type.IsPackable.ShouldBeTrue(); - - type = new TypeName("System.String"); - type.ProtoBufType.ShouldEqual("string"); - type.IsPackable.ShouldBeFalse(); - - type = new TypeName("System.String?"); - type.ProtoBufType.ShouldEqual("string"); - type.IsPackable.ShouldBeFalse(); - - type = new TypeName("System.DateTime"); - type.ProtoBufType.ShouldEqual("bcl.DateTime"); - - type = new TypeName("DateTime"); - type.ProtoBufType.ShouldEqual("bcl.DateTime"); - - type = new TypeName("global::System.DateTime"); - type.ProtoBufType.ShouldEqual("bcl.DateTime"); - - type = new TypeName("System.Decimal"); - type.ProtoBufType.ShouldEqual("bcl.Decimal"); - - type = new TypeName("Decimal"); - type.ProtoBufType.ShouldEqual("bcl.Decimal"); - - type = new TypeName("decimal"); - type.ProtoBufType.ShouldEqual("bcl.Decimal"); - - type = new TypeName("Decimal[]"); - type.ProtoBufType.ShouldEqual("bcl.Decimal"); - - type = new TypeName("decimal[]"); - type.ProtoBufType.ShouldEqual("bcl.Decimal"); - - type = new TypeName("decimal?"); - type.ProtoBufType.ShouldEqual("bcl.Decimal"); - - type = new TypeName("Decimal?"); - type.ProtoBufType.ShouldEqual("bcl.Decimal"); - - type = new TypeName("System.Decimal?"); - type.ProtoBufType.ShouldEqual("bcl.Decimal"); - - type = new TypeName("global::System.Decimal?"); - type.ProtoBufType.ShouldEqual("bcl.Decimal"); + var type = new TypeName("System.DateTime?"); + type.NetType.ShouldEqual("DateTime?"); + } + + [Test] + public void should_not_remove_system_namespace() + { + var type = new TypeName("System.Diagnostics.DebuggerNonUserCode"); + type.NetType.ShouldEqual("System.Diagnostics.DebuggerNonUserCode"); + } + + [Test] + public void should_detect_repeated_types() + { + var type = new TypeName("int"); + type.IsArray.ShouldBeFalse(); + type.IsList.ShouldBeFalse(); + type.IsHashSet.ShouldBeFalse(); + type.IsDictionary.ShouldBeFalse(); + type.IsRepeated.ShouldBeFalse(); + type.IsPackable.ShouldBeFalse(); + + type = new TypeName("int[]"); + type.IsArray.ShouldBeTrue(); + type.IsList.ShouldBeFalse(); + type.IsHashSet.ShouldBeFalse(); + type.IsDictionary.ShouldBeFalse(); + type.IsRepeated.ShouldBeTrue(); + type.IsPackable.ShouldBeTrue(); + + type = new TypeName("int[]?"); + type.IsArray.ShouldBeTrue(); + type.IsList.ShouldBeFalse(); + type.IsHashSet.ShouldBeFalse(); + type.IsDictionary.ShouldBeFalse(); + type.IsRepeated.ShouldBeTrue(); + type.IsPackable.ShouldBeTrue(); + + type = new TypeName("int?[]"); + type.IsArray.ShouldBeTrue(); + type.IsList.ShouldBeFalse(); + type.IsHashSet.ShouldBeFalse(); + type.IsDictionary.ShouldBeFalse(); + type.IsRepeated.ShouldBeTrue(); + type.IsPackable.ShouldBeFalse(); + + type = new TypeName("List"); + type.IsArray.ShouldBeFalse(); + type.IsList.ShouldBeTrue(); + type.IsHashSet.ShouldBeFalse(); + type.IsDictionary.ShouldBeFalse(); + type.IsRepeated.ShouldBeTrue(); + type.IsPackable.ShouldBeTrue(); + + type = new TypeName("List?"); + type.IsArray.ShouldBeFalse(); + type.IsList.ShouldBeTrue(); + type.IsHashSet.ShouldBeFalse(); + type.IsDictionary.ShouldBeFalse(); + type.IsRepeated.ShouldBeTrue(); + type.IsPackable.ShouldBeTrue(); + + type = new TypeName("HashSet"); + type.IsArray.ShouldBeFalse(); + type.IsList.ShouldBeFalse(); + type.IsHashSet.ShouldBeTrue(); + type.IsDictionary.ShouldBeFalse(); + type.IsRepeated.ShouldBeTrue(); + type.IsPackable.ShouldBeTrue(); + + type = new TypeName("HashSet?"); + type.IsArray.ShouldBeFalse(); + type.IsList.ShouldBeFalse(); + type.IsHashSet.ShouldBeTrue(); + type.IsDictionary.ShouldBeFalse(); + type.IsRepeated.ShouldBeTrue(); + type.IsPackable.ShouldBeTrue(); + + type = new TypeName("System.Collections.Generic.List"); + type.IsArray.ShouldBeFalse(); + type.IsList.ShouldBeTrue(); + type.IsHashSet.ShouldBeFalse(); + type.IsDictionary.ShouldBeFalse(); + type.IsRepeated.ShouldBeTrue(); + type.IsPackable.ShouldBeTrue(); + + type = new TypeName("System.Collections.Generic.List?"); + type.IsArray.ShouldBeFalse(); + type.IsList.ShouldBeTrue(); + type.IsHashSet.ShouldBeFalse(); + type.IsDictionary.ShouldBeFalse(); + type.IsRepeated.ShouldBeTrue(); + type.IsPackable.ShouldBeTrue(); + } + + [Test] + public void should_detect_nullables() + { + var type = new TypeName("int"); + type.IsNullable.ShouldBeFalse(); + + type = new TypeName("int?"); + type.IsNullable.ShouldBeTrue(); + + type = new TypeName("string"); + type.IsNullable.ShouldBeFalse(); + + type = new TypeName("string?"); + type.IsNullable.ShouldBeTrue(); + } + + [Test] + public void should_map_to_protobuf_type() + { + var type = new TypeName("int"); + type.ProtoBufType.ShouldEqual("int32"); + type.IsPackable.ShouldBeFalse(); + + type = new TypeName("int[]"); + type.ProtoBufType.ShouldEqual("int32"); + type.IsPackable.ShouldBeTrue(); + + type = new TypeName("int[]?"); + type.ProtoBufType.ShouldEqual("int32"); + type.IsPackable.ShouldBeTrue(); - type = new TypeName("foo::System.Decimal"); - type.ProtoBufType.ShouldEqual("foo.System.Decimal"); - } + type = new TypeName("System.String"); + type.ProtoBufType.ShouldEqual("string"); + type.IsPackable.ShouldBeFalse(); - [Test] - public void should_escape_csharp_identifiers() - { - TypeName type = "new"; - type.NetType.ShouldEqual("@new"); + type = new TypeName("System.String?"); + type.ProtoBufType.ShouldEqual("string"); + type.IsPackable.ShouldBeFalse(); - type = "int"; - type.NetType.ShouldEqual("int"); + type = new TypeName("System.DateTime"); + type.ProtoBufType.ShouldEqual("bcl.DateTime"); - type = "void"; - type.NetType.ShouldEqual("@void"); + type = new TypeName("DateTime"); + type.ProtoBufType.ShouldEqual("bcl.DateTime"); - type = "@void"; - type.NetType.ShouldEqual("@void"); + type = new TypeName("global::System.DateTime"); + type.ProtoBufType.ShouldEqual("bcl.DateTime"); - type = "@int"; - type.NetType.ShouldEqual("int"); + type = new TypeName("System.Decimal"); + type.ProtoBufType.ShouldEqual("bcl.Decimal"); - type = "@foo"; - type.NetType.ShouldEqual("foo"); - } + type = new TypeName("Decimal"); + type.ProtoBufType.ShouldEqual("bcl.Decimal"); - [Test] - public void should_normalize_spaces() - { - TypeName type = " IFoo < Bar ? ,@Baz,Hello [, , ], @World < Tanks>, Int32>"; - type.NetType.ShouldEqual("IFoo, int>"); - } + type = new TypeName("decimal"); + type.ProtoBufType.ShouldEqual("bcl.Decimal"); - [Test] - public void should_provide_repeated_item_type() - { - var type = new TypeName("int[]"); - type.GetRepeatedItemType().ShouldNotBeNull().NetType.ShouldEqual("int"); + type = new TypeName("Decimal[]"); + type.ProtoBufType.ShouldEqual("bcl.Decimal"); + + type = new TypeName("decimal[]"); + type.ProtoBufType.ShouldEqual("bcl.Decimal"); + + type = new TypeName("decimal?"); + type.ProtoBufType.ShouldEqual("bcl.Decimal"); + + type = new TypeName("Decimal?"); + type.ProtoBufType.ShouldEqual("bcl.Decimal"); + + type = new TypeName("System.Decimal?"); + type.ProtoBufType.ShouldEqual("bcl.Decimal"); + + type = new TypeName("global::System.Decimal?"); + type.ProtoBufType.ShouldEqual("bcl.Decimal"); + + type = new TypeName("foo::System.Decimal"); + type.ProtoBufType.ShouldEqual("foo.System.Decimal"); + } + + [Test] + public void should_escape_csharp_identifiers() + { + TypeName type = "new"; + type.NetType.ShouldEqual("@new"); + + type = "int"; + type.NetType.ShouldEqual("int"); + + type = "void"; + type.NetType.ShouldEqual("@void"); + + type = "@void"; + type.NetType.ShouldEqual("@void"); + + type = "@int"; + type.NetType.ShouldEqual("int"); + + type = "@foo"; + type.NetType.ShouldEqual("foo"); + } + + [Test] + public void should_normalize_spaces() + { + TypeName type = " IFoo < Bar ? ,@Baz,Hello [, , ], @World < Tanks>, Int32>"; + type.NetType.ShouldEqual("IFoo, int>"); + } + + [Test] + public void should_provide_repeated_item_type() + { + var type = new TypeName("int[]"); + type.GetRepeatedItemType().ShouldNotBeNull().NetType.ShouldEqual("int"); - type = new TypeName("int[]?"); - type.GetRepeatedItemType().ShouldNotBeNull().NetType.ShouldEqual("int"); + type = new TypeName("int[]?"); + type.GetRepeatedItemType().ShouldNotBeNull().NetType.ShouldEqual("int"); - type = new TypeName("List"); - type.GetRepeatedItemType().ShouldNotBeNull().NetType.ShouldEqual("int"); + type = new TypeName("List"); + type.GetRepeatedItemType().ShouldNotBeNull().NetType.ShouldEqual("int"); - type = new TypeName("List?"); - type.GetRepeatedItemType().ShouldNotBeNull().NetType.ShouldEqual("int"); + type = new TypeName("List?"); + type.GetRepeatedItemType().ShouldNotBeNull().NetType.ShouldEqual("int"); - type = new TypeName("HashSet"); - type.GetRepeatedItemType().ShouldNotBeNull().NetType.ShouldEqual("int"); + type = new TypeName("HashSet"); + type.GetRepeatedItemType().ShouldNotBeNull().NetType.ShouldEqual("int"); - type = new TypeName("HashSet?"); - type.GetRepeatedItemType().ShouldNotBeNull().NetType.ShouldEqual("int"); - } + type = new TypeName("HashSet?"); + type.GetRepeatedItemType().ShouldNotBeNull().NetType.ShouldEqual("int"); } } diff --git a/src/Abc.Zebus.MessageDsl.Tests/TestTools/AssertionExtensions.cs b/src/Abc.Zebus.MessageDsl.Tests/TestTools/AssertionExtensions.cs index 22606fb..eab5591 100644 --- a/src/Abc.Zebus.MessageDsl.Tests/TestTools/AssertionExtensions.cs +++ b/src/Abc.Zebus.MessageDsl.Tests/TestTools/AssertionExtensions.cs @@ -6,71 +6,70 @@ #nullable enable -namespace Abc.Zebus.MessageDsl.Tests.TestTools +namespace Abc.Zebus.MessageDsl.Tests.TestTools; + +public static class AssertionExtensions { - public static class AssertionExtensions - { - public static void ShouldBeTrue(this bool actual) - => Assert.That(actual, Is.True); + public static void ShouldBeTrue(this bool actual) + => Assert.That(actual, Is.True); - public static void ShouldBeFalse(this bool actual) - => Assert.That(actual, Is.False); + public static void ShouldBeFalse(this bool actual) + => Assert.That(actual, Is.False); - public static void ShouldBeNull(this object? actual) - => Assert.That(actual, Is.Null); + public static void ShouldBeNull(this object? actual) + => Assert.That(actual, Is.Null); - public static T ShouldNotBeNull(this T? actual) - where T : class - { - Assert.That(actual, Is.Not.Null); - return actual!; - } + public static T ShouldNotBeNull(this T? actual) + where T : class + { + Assert.That(actual, Is.Not.Null); + return actual!; + } - public static void ShouldContain(this string actual, string expected) - => Assert.That(actual, Contains.Substring(expected)); + public static void ShouldContain(this string actual, string expected) + => Assert.That(actual, Contains.Substring(expected)); - public static void ShouldContain(this IEnumerable actual, T expected) - => Assert.That(actual, Contains.Item(expected)); + public static void ShouldContain(this IEnumerable actual, T expected) + => Assert.That(actual, Contains.Item(expected)); - public static void ShouldContainIgnoreIndent(this string actual, string expected) - => Assert.That(Regex.Replace(actual, @"^[ ]+|\r", string.Empty, RegexOptions.CultureInvariant | RegexOptions.Multiline), Contains.Substring(expected)); + public static void ShouldContainIgnoreIndent(this string actual, string expected) + => Assert.That(Regex.Replace(actual, @"^[ ]+|\r", string.Empty, RegexOptions.CultureInvariant | RegexOptions.Multiline), Contains.Substring(expected)); - public static void ShouldNotContain(this string actual, string unexpected) - => Assert.That(actual, Does.Not.Contain(unexpected)); + public static void ShouldNotContain(this string actual, string unexpected) + => Assert.That(actual, Does.Not.Contain(unexpected)); - public static void ShouldNotContain(this IEnumerable actual, T unexpected) - => Assert.That(actual, Does.Not.Contain(unexpected)); + public static void ShouldNotContain(this IEnumerable actual, T unexpected) + => Assert.That(actual, Does.Not.Contain(unexpected)); - public static void ShouldBeEmpty(this object actual) - => Assert.That(actual, Is.Empty); + public static void ShouldBeEmpty(this object actual) + => Assert.That(actual, Is.Empty); - public static void ShouldEqual(this T? actual, T? expected) - => Assert.That(actual, Is.EqualTo(expected)); + public static void ShouldEqual(this T? actual, T? expected) + => Assert.That(actual, Is.EqualTo(expected)); - public static void ShouldBeGreaterThan(this int actual, int value) - => Assert.That(actual, Is.GreaterThan(value)); + public static void ShouldBeGreaterThan(this int actual, int value) + => Assert.That(actual, Is.GreaterThan(value)); - public static void ShouldBeBetween(this int actual, int min, int max) - => Assert.That(actual, Is.InRange(min, max)); + public static void ShouldBeBetween(this int actual, int min, int max) + => Assert.That(actual, Is.InRange(min, max)); - public static T ExpectedSingle(this IEnumerable actual) - { - var list = actual.ToList(); + public static T ExpectedSingle(this IEnumerable actual) + { + var list = actual.ToList(); - if (list.Count != 1) - Assert.Fail($"Sequence should contain a single element, but had {list.Count}"); + if (list.Count != 1) + Assert.Fail($"Sequence should contain a single element, but had {list.Count}"); - return list[0]; - } + return list[0]; + } - public static T ExpectedSingle(this IEnumerable actual, Func predicate) - { - var list = actual.Where(predicate).ToList(); + public static T ExpectedSingle(this IEnumerable actual, Func predicate) + { + var list = actual.Where(predicate).ToList(); - if (list.Count != 1) - Assert.Fail($"Sequence should contain a single element matching the predicate, but had {list.Count}"); + if (list.Count != 1) + Assert.Fail($"Sequence should contain a single element matching the predicate, but had {list.Count}"); - return list[0]; - } + return list[0]; } } diff --git a/src/Abc.Zebus.MessageDsl/Analysis/AstCreationVisitor.cs b/src/Abc.Zebus.MessageDsl/Analysis/AstCreationVisitor.cs index 31a65f3..2e14e84 100644 --- a/src/Abc.Zebus.MessageDsl/Analysis/AstCreationVisitor.cs +++ b/src/Abc.Zebus.MessageDsl/Analysis/AstCreationVisitor.cs @@ -8,434 +8,432 @@ using Antlr4.Runtime; using static Abc.Zebus.MessageDsl.Dsl.MessageContractsParser; -namespace Abc.Zebus.MessageDsl.Analysis +namespace Abc.Zebus.MessageDsl.Analysis; + +[SuppressMessage("ReSharper", "ReturnTypeCanBeNotNullable")] +internal class AstCreationVisitor : MessageContractsBaseVisitor { - [SuppressMessage("ReSharper", "ReturnTypeCanBeNotNullable")] - internal class AstCreationVisitor : MessageContractsBaseVisitor + private readonly ParsedContracts _contracts; + private readonly HashSet _definedContractOptions = new(StringComparer.OrdinalIgnoreCase); + + private bool _hasDefinitions; + private MessageDefinition? _currentMessage; + private ParameterDefinition? _currentParameter; + private AttributeSet? _currentAttributeSet; + private MemberOptions _currentMemberOptions; + + public AstCreationVisitor(ParsedContracts contracts) { - private readonly ParsedContracts _contracts; - private readonly HashSet _definedContractOptions = new(StringComparer.OrdinalIgnoreCase); + _contracts = contracts; + _currentMemberOptions = new MemberOptions(); + } - private bool _hasDefinitions; - private MessageDefinition? _currentMessage; - private ParameterDefinition? _currentParameter; - private AttributeSet? _currentAttributeSet; - private MemberOptions _currentMemberOptions; + public override AstNode? VisitPragmaDefinition(PragmaDefinitionContext context) + { + var pragmaName = context.name?.token?.Text; - public AstCreationVisitor(ParsedContracts contracts) + if (string.IsNullOrEmpty(pragmaName)) { - _contracts = contracts; - _currentMemberOptions = new MemberOptions(); + _contracts.AddError(context, "Missing pragma name"); + return null; } - public override AstNode? VisitPragmaDefinition(PragmaDefinitionContext context) + var optionDescriptor = _contracts.Options.GetOptionDescriptor(pragmaName); + if (optionDescriptor != null) { - var pragmaName = context.name?.token?.Text; - - if (string.IsNullOrEmpty(pragmaName)) + if (!_definedContractOptions.Add(pragmaName)) { - _contracts.AddError(context, "Missing pragma name"); + _contracts.AddError(context.name, "Duplicate file-level pragma: {0}", pragmaName); return null; } - var optionDescriptor = _contracts.Options.GetOptionDescriptor(pragmaName); - if (optionDescriptor != null) - { - if (!_definedContractOptions.Add(pragmaName!)) - { - _contracts.AddError(context.name, "Duplicate file-level pragma: {0}", pragmaName); - return null; - } - - if (_hasDefinitions) - _contracts.AddError(context, "File-level pragma {0} should be set at the top of the file", pragmaName); - } - else - { - _currentMemberOptions = _currentMemberOptions.Clone(); - optionDescriptor = _currentMemberOptions.GetOptionDescriptor(pragmaName); - - if (optionDescriptor == null) - { - _contracts.AddError(context.name, "Unknown pragma: '{0}'", pragmaName); - return null; - } - } - - pragmaName = optionDescriptor.Name; - - var value = context._valueTokens.Count > 0 - ? context._valueTokens.First().token.GetFullTextUntil(context._valueTokens.Last().token) - : null; + if (_hasDefinitions) + _contracts.AddError(context, "File-level pragma {0} should be set at the top of the file", pragmaName); + } + else + { + _currentMemberOptions = _currentMemberOptions.Clone(); + optionDescriptor = _currentMemberOptions.GetOptionDescriptor(pragmaName); - if (value == null) + if (optionDescriptor == null) { - if (optionDescriptor.IsBoolean) - { - value = context.not == null ? "true" : "false"; - } - else - { - _contracts.AddError(context, "Pragma {0} expects a value", pragmaName); - return null; - } + _contracts.AddError(context.name, "Unknown pragma: '{0}'", pragmaName); + return null; } - - if (!optionDescriptor.SetValue(value)) - _contracts.AddError(context, "Invalid option value: '{0}' for {1}", value, pragmaName); - - return null; } - public override AstNode? VisitUsingDefinition(UsingDefinitionContext context) - { - if (_hasDefinitions) - _contracts.AddError(context, "using clauses should be set at the top of the file"); + pragmaName = optionDescriptor.Name; - var ns = context.@namespace().GetText(); - _contracts.ImportedNamespaces.Add(ns); - return null; - } + var value = context._valueTokens.Count > 0 + ? context._valueTokens.First().token.GetFullTextUntil(context._valueTokens.Last().token) + : null; - public override AstNode? VisitNamespaceDefinition(NamespaceDefinitionContext context) + if (value == null) { - if (_contracts.ExplicitNamespace) + if (optionDescriptor.IsBoolean) { - _contracts.AddError(context, "The namespace has already been set"); + value = context.not == null ? "true" : "false"; + } + else + { + _contracts.AddError(context, "Pragma {0} expects a value", pragmaName); return null; } + } - if (_hasDefinitions) - _contracts.AddError(context, "The namespace should be set at the top of the file"); + if (!optionDescriptor.SetValue(value)) + _contracts.AddError(context, "Invalid option value: '{0}' for {1}", value, pragmaName); - var ns = context.name?.GetText(); + return null; + } - if (string.IsNullOrEmpty(ns)) - { - _contracts.AddError(context, "Missing namespace name"); - return null; - } + public override AstNode? VisitUsingDefinition(UsingDefinitionContext context) + { + if (_hasDefinitions) + _contracts.AddError(context, "using clauses should be set at the top of the file"); - _contracts.Namespace = ns!; - _contracts.ExplicitNamespace = true; + var ns = context.@namespace().GetText(); + _contracts.ImportedNamespaces.Add(ns); + return null; + } + + public override AstNode? VisitNamespaceDefinition(NamespaceDefinitionContext context) + { + if (_contracts.ExplicitNamespace) + { + _contracts.AddError(context, "The namespace has already been set"); return null; } - public override AstNode? VisitEnumDefinition(EnumDefinitionContext context) - { - _hasDefinitions = true; + if (_hasDefinitions) + _contracts.AddError(context, "The namespace should be set at the top of the file"); - var enumDef = new EnumDefinition - { - ParseContext = context, - Name = GetId(context.name), - Options = _currentMemberOptions - }; + var ns = context.name?.GetText(); - var accessModifier = context.accessModifier(); - if (accessModifier != null) - ProcessTypeModifiers(enumDef, new[] { accessModifier.type }); + if (string.IsNullOrEmpty(ns)) + { + _contracts.AddError(context, "Missing namespace name"); + return null; + } - ProcessAttributes(enumDef.Attributes, context.attributes()); + _contracts.Namespace = ns; + _contracts.ExplicitNamespace = true; + return null; + } - if (context.underlyingType != null) - enumDef.UnderlyingType = context.underlyingType.GetText(); + public override AstNode? VisitEnumDefinition(EnumDefinitionContext context) + { + _hasDefinitions = true; - foreach (var enumMemberContext in context.enumMember()) - { - var memberDef = new EnumMemberDefinition - { - ParseContext = enumMemberContext, - Name = GetId(enumMemberContext.name), - Value = enumMemberContext.value?.GetFullText() - }; + var enumDef = new EnumDefinition + { + ParseContext = context, + Name = GetId(context.name), + Options = _currentMemberOptions + }; - ProcessAttributes(memberDef.Attributes, enumMemberContext.attributes()); + var accessModifier = context.accessModifier(); + if (accessModifier != null) + ProcessTypeModifiers(enumDef, [accessModifier.type]); - enumDef.Members.Add(memberDef); - } + ProcessAttributes(enumDef.Attributes, context.attributes()); - _contracts.Enums.Add(enumDef); - return enumDef; - } + if (context.underlyingType != null) + enumDef.UnderlyingType = context.underlyingType.GetText(); - public override AstNode? VisitMessageDefinition(MessageDefinitionContext context) + foreach (var enumMemberContext in context.enumMember()) { - var message = new MessageDefinition + var memberDef = new EnumMemberDefinition { - IsCustom = context.customModifier != null + ParseContext = enumMemberContext, + Name = GetId(enumMemberContext.name), + Value = enumMemberContext.value?.GetFullText() }; - ProcessMessage(message, context); - return message; - } - - public override AstNode? VisitParameterList(ParameterListContext context) - { - foreach (var param in context.parameterDefinition().Select(Visit).OfType()) - _currentMessage!.Parameters.Add(param); + ProcessAttributes(memberDef.Attributes, enumMemberContext.attributes()); - return null; + enumDef.Members.Add(memberDef); } - public override AstNode? VisitParameterDefinition(ParameterDefinitionContext context) - { - try - { - _currentParameter = new ParameterDefinition - { - Name = GetId(context.paramName), - Type = context.typeName().GetText(), - IsMarkedOptional = context.optionalModifier != null, - DefaultValue = context.defaultValue?.GetText(), - ParseContext = context - }; - - ProcessAttributes(_currentParameter.Attributes, context.attributes()); - - return _currentParameter; - } - finally - { - _currentParameter = null; - } - } + _contracts.Enums.Add(enumDef); + return enumDef; + } - public override AstNode? VisitCustomAttribute(CustomAttributeContext context) + public override AstNode? VisitMessageDefinition(MessageDefinitionContext context) + { + var message = new MessageDefinition { - if (_currentAttributeSet == null) - return null; + IsCustom = context.customModifier != null + }; - var attrParameters = context.attributeParameters(); - var attrParametersText = attrParameters?.GetFullText(); + ProcessMessage(message, context); + return message; + } - var attr = new AttributeDefinition(context.attributeType.GetText(), attrParametersText) + public override AstNode? VisitParameterList(ParameterListContext context) + { + foreach (var param in context.parameterDefinition().Select(Visit).OfType()) + _currentMessage!.Parameters.Add(param); + + return null; + } + + public override AstNode? VisitParameterDefinition(ParameterDefinitionContext context) + { + try + { + _currentParameter = new ParameterDefinition { + Name = GetId(context.paramName), + Type = context.typeName().GetText(), + IsMarkedOptional = context.optionalModifier != null, + DefaultValue = context.defaultValue?.GetText(), ParseContext = context }; - _currentAttributeSet.Add(attr); - return attr; - } + ProcessAttributes(_currentParameter.Attributes, context.attributes()); - public override AstNode? VisitExplicitTag(ExplicitTagContext context) + return _currentParameter; + } + finally { - if (_currentParameter == null) - { - _contracts.AddError(context, "Tags can be defined only inside parameters"); - return null; - } + _currentParameter = null; + } + } - if (!int.TryParse(context.tagNumber?.Text, out var tag)) - { - _contracts.AddError(context, "Invalid tag value for parameter '{0}': {1}", _currentParameter.Name, context.tagNumber?.Text); - return null; - } + public override AstNode? VisitCustomAttribute(CustomAttributeContext context) + { + if (_currentAttributeSet == null) + return null; - if (!AstValidator.IsValidTag(tag)) - { - _contracts.AddError(context, "Tag for parameter '{0}' is not within the valid range ({1})", _currentParameter.Name, context.tagNumber?.Text); - return null; - } + var attrParameters = context.attributeParameters(); + var attrParametersText = attrParameters?.GetFullText(); - if (_currentParameter.Tag != 0) - { - _contracts.AddError(context, "The parameter '{0}' already has an explicit tag ({1})", _currentParameter.Name, _currentParameter.Tag); - return null; - } + var attr = new AttributeDefinition(context.attributeType.GetText(), attrParametersText) + { + ParseContext = context + }; + + _currentAttributeSet.Add(attr); + return attr; + } - _currentParameter.Tag = tag; + public override AstNode? VisitExplicitTag(ExplicitTagContext context) + { + if (_currentParameter == null) + { + _contracts.AddError(context, "Tags can be defined only inside parameters"); return null; } - public override AstNode? VisitTypeParamConstraintList(TypeParamConstraintListContext context) + if (!int.TryParse(context.tagNumber?.Text, out var tag)) { - foreach (var constraintContext in context.typeParamConstraint()) - { - if (Visit(constraintContext) is GenericConstraint constraint) - _currentMessage!.GenericConstraints.Add(constraint); - } - + _contracts.AddError(context, "Invalid tag value for parameter '{0}': {1}", _currentParameter.Name, context.tagNumber?.Text); return null; } - public override AstNode? VisitBaseTypeList(BaseTypeListContext context) + if (!AstValidator.IsValidTag(tag)) { - _currentMessage!.BaseTypes.AddRange( - context.GetRuleContexts() - .Select(typeContext => new TypeName(typeContext.GetText()))); - + _contracts.AddError(context, "Tag for parameter '{0}' is not within the valid range ({1})", _currentParameter.Name, context.tagNumber?.Text); return null; } - public override AstNode? VisitTypeParamConstraint(TypeParamConstraintContext context) + if (_currentParameter.Tag != 0) { - var constraint = new GenericConstraint - { - GenericParameterName = GetId(context.name), - ParseContext = context - }; + _contracts.AddError(context, "The parameter '{0}' already has an explicit tag ({1})", _currentParameter.Name, _currentParameter.Tag); + return null; + } - foreach (var clause in context.typeParamConstraintClause()) - { - if (clause is TypeParamConstraintClauseClassContext) - { - if (constraint.IsClass) - _contracts.AddError(clause, "Duplicate class constraint"); - - constraint.IsClass = true; - continue; - } - - if (clause is TypeParamConstraintClauseStructContext) - { - if (constraint.IsStruct) - _contracts.AddError(clause, "Duplicate struct constraint"); - - constraint.IsStruct = true; - continue; - } - - if (clause is TypeParamConstraintClauseNewContext) - { - if (constraint.HasDefaultConstructor) - _contracts.AddError(clause, "Duplicate new() constraint"); - - constraint.HasDefaultConstructor = true; - continue; - } - - if (clause is TypeParamConstraintClauseTypeContext typeClause) - { - var typeName = new TypeName(typeClause.typeName().GetText()); - if (!constraint.Types.Add(typeName)) - _contracts.AddError(clause, "Duplicate type constraint: '{0}'", typeName); - } - } + _currentParameter.Tag = tag; + return null; + } - return constraint; + public override AstNode? VisitTypeParamConstraintList(TypeParamConstraintListContext context) + { + foreach (var constraintContext in context.typeParamConstraint()) + { + if (Visit(constraintContext) is GenericConstraint constraint) + _currentMessage!.GenericConstraints.Add(constraint); } - private void ProcessMessage(MessageDefinition message, MessageDefinitionContext context) - { - try - { - _hasDefinitions = true; - _currentMessage = message; - _currentMessage.ParseContext = context; - _currentMessage.Options = _currentMemberOptions; + return null; + } - var nameContext = context.GetRuleContext(0); - message.Name = GetId(nameContext.name); + public override AstNode? VisitBaseTypeList(BaseTypeListContext context) + { + _currentMessage!.BaseTypes.AddRange( + context.GetRuleContexts() + .Select(typeContext => new TypeName(typeContext.GetText())) + ); - message.ContainingClasses.AddRange( - nameContext._containingTypes.Select(name => new TypeName(name.GetText())) - ); + return null; + } + + public override AstNode? VisitTypeParamConstraint(TypeParamConstraintContext context) + { + var constraint = new GenericConstraint + { + GenericParameterName = GetId(context.name), + ParseContext = context + }; - ProcessTypeModifiers(message, context.typeModifier().Select(i => i.type)); + foreach (var clause in context.typeParamConstraintClause()) + { + if (clause is TypeParamConstraintClauseClassContext) + { + if (constraint.IsClass) + _contracts.AddError(clause, "Duplicate class constraint"); - foreach (var typeParamToken in nameContext._typeParams) - { - var paramId = GetId(typeParamToken); + constraint.IsClass = true; + continue; + } - if (message.GenericParameters.Contains(paramId)) - _contracts.AddError(typeParamToken, "Duplicate generic parameter: '{0}'", paramId); + if (clause is TypeParamConstraintClauseStructContext) + { + if (constraint.IsStruct) + _contracts.AddError(clause, "Duplicate struct constraint"); - message.GenericParameters.Add(paramId); - } + constraint.IsStruct = true; + continue; + } - ProcessAttributes(message.Attributes, context.GetRuleContext(0)); - Visit(context.GetRuleContext(0)); - Visit(context.GetRuleContext(0)); - Visit(context.GetRuleContext(0)); + if (clause is TypeParamConstraintClauseNewContext) + { + if (constraint.HasDefaultConstructor) + _contracts.AddError(clause, "Duplicate new() constraint"); - _contracts.Messages.Add(_currentMessage); + constraint.HasDefaultConstructor = true; + continue; } - finally + + if (clause is TypeParamConstraintClauseTypeContext typeClause) { - _currentMessage = null; + var typeName = new TypeName(typeClause.typeName().GetText()); + if (!constraint.Types.Add(typeName)) + _contracts.AddError(clause, "Duplicate type constraint: '{0}'", typeName); } } - private void ProcessTypeModifiers(IMemberNode member, IEnumerable modifiers) + return constraint; + } + + private void ProcessMessage(MessageDefinition message, MessageDefinitionContext context) + { + try { - AccessModifier? accessModifier = null; - InheritanceModifier? inheritanceModifier = null; + _hasDefinitions = true; + _currentMessage = message; + _currentMessage.ParseContext = context; + _currentMessage.Options = _currentMemberOptions; - foreach (var modifier in modifiers) - { - switch (modifier.Type) - { - case MessageContractsLexer.KW_PUBLIC: - case MessageContractsLexer.KW_INTERNAL: - if (accessModifier == null) - { - accessModifier = modifier.Type switch - { - MessageContractsLexer.KW_PUBLIC => AccessModifier.Public, - MessageContractsLexer.KW_INTERNAL => AccessModifier.Internal, - _ => throw new InvalidOperationException($"Cannot map access modifier: {modifier.Text}") - }; - } - else - { - _contracts.AddError(modifier, "An access modifier has already been provided"); - } + var nameContext = context.GetRuleContext(0); + message.Name = GetId(nameContext.name); - break; + message.ContainingClasses.AddRange( + nameContext._containingTypes.Select(name => new TypeName(name.GetText())) + ); - case MessageContractsLexer.KW_SEALED: - case MessageContractsLexer.KW_ABSTRACT: - if (inheritanceModifier == null) - { - if (!(member is IClassNode)) - _contracts.AddError(modifier, "Cannot apply inheritance modifier to a non-class type"); - - inheritanceModifier = modifier.Type switch - { - MessageContractsLexer.KW_SEALED => InheritanceModifier.Sealed, - MessageContractsLexer.KW_ABSTRACT => InheritanceModifier.Abstract, - _ => throw new InvalidOperationException($"Cannot map inheritance modifier: {modifier.Text}") - }; - } - else - { - _contracts.AddError(modifier, "An inheritance modifier has already been provided"); - } + ProcessTypeModifiers(message, context.typeModifier().Select(i => i.type)); + + foreach (var typeParamToken in nameContext._typeParams) + { + var paramId = GetId(typeParamToken); - break; - } + if (message.GenericParameters.Contains(paramId)) + _contracts.AddError(typeParamToken, "Duplicate generic parameter: '{0}'", paramId); + + message.GenericParameters.Add(paramId); } - member.AccessModifier = accessModifier ?? member.Options.GetAccessModifier(); + ProcessAttributes(message.Attributes, context.GetRuleContext(0)); + Visit(context.GetRuleContext(0)); + Visit(context.GetRuleContext(0)); + Visit(context.GetRuleContext(0)); - if (inheritanceModifier != null && member is IClassNode classNode) - classNode.InheritanceModifier = inheritanceModifier.Value; + _contracts.Messages.Add(_currentMessage); } - - private void ProcessAttributes(AttributeSet? attributeSet, AttributesContext? context) + finally { - if (context == null || attributeSet == null) - return; - - var previousAttributeSet = _currentAttributeSet; + _currentMessage = null; + } + } - try - { - _currentAttributeSet = attributeSet; - _currentAttributeSet.ParseContext = context; + private void ProcessTypeModifiers(IMemberNode member, IEnumerable modifiers) + { + AccessModifier? accessModifier = null; + InheritanceModifier? inheritanceModifier = null; - Visit(context); - } - finally + foreach (var modifier in modifiers) + { + switch (modifier.Type) { - _currentAttributeSet = previousAttributeSet; + case MessageContractsLexer.KW_PUBLIC: + case MessageContractsLexer.KW_INTERNAL: + if (accessModifier == null) + { + accessModifier = modifier.Type switch + { + MessageContractsLexer.KW_PUBLIC => AccessModifier.Public, + MessageContractsLexer.KW_INTERNAL => AccessModifier.Internal, + _ => throw new InvalidOperationException($"Cannot map access modifier: {modifier.Text}") + }; + } + else + { + _contracts.AddError(modifier, "An access modifier has already been provided"); + } + + break; + + case MessageContractsLexer.KW_SEALED: + case MessageContractsLexer.KW_ABSTRACT: + if (inheritanceModifier == null) + { + if (member is not IClassNode) + _contracts.AddError(modifier, "Cannot apply inheritance modifier to a non-class type"); + + inheritanceModifier = modifier.Type switch + { + MessageContractsLexer.KW_SEALED => InheritanceModifier.Sealed, + MessageContractsLexer.KW_ABSTRACT => InheritanceModifier.Abstract, + _ => throw new InvalidOperationException($"Cannot map inheritance modifier: {modifier.Text}") + }; + } + else + { + _contracts.AddError(modifier, "An inheritance modifier has already been provided"); + } + + break; } } - private string GetId(IdContext? context) + member.AccessModifier = accessModifier ?? member.Options.GetAccessModifier(); + + if (inheritanceModifier != null && member is IClassNode classNode) + classNode.InheritanceModifier = inheritanceModifier.Value; + } + + private void ProcessAttributes(AttributeSet? attributeSet, AttributesContext? context) + { + if (context == null || attributeSet == null) + return; + + var previousAttributeSet = _currentAttributeSet; + + try { - return context?.GetValidatedId(_contracts) ?? string.Empty; + _currentAttributeSet = attributeSet; + _currentAttributeSet.ParseContext = context; + + Visit(context); + } + finally + { + _currentAttributeSet = previousAttributeSet; } } + + private string GetId(IdContext? context) + => context?.GetValidatedId(_contracts) ?? string.Empty; } diff --git a/src/Abc.Zebus.MessageDsl/Analysis/AstProcessor.cs b/src/Abc.Zebus.MessageDsl/Analysis/AstProcessor.cs index b0e0783..4211cdc 100644 --- a/src/Abc.Zebus.MessageDsl/Analysis/AstProcessor.cs +++ b/src/Abc.Zebus.MessageDsl/Analysis/AstProcessor.cs @@ -3,131 +3,130 @@ using System.Linq; using Abc.Zebus.MessageDsl.Ast; -namespace Abc.Zebus.MessageDsl.Analysis +namespace Abc.Zebus.MessageDsl.Analysis; + +internal class AstProcessor { - internal class AstProcessor + private readonly ParsedContracts _contracts; + + public AstProcessor(ParsedContracts contracts) { - private readonly ParsedContracts _contracts; + _contracts = contracts; + } - public AstProcessor(ParsedContracts contracts) - { - _contracts = contracts; - } + public void PreProcess() + { + _contracts.ImportedNamespaces.Add("System"); + _contracts.ImportedNamespaces.Add("ProtoBuf"); + _contracts.ImportedNamespaces.Add("Abc.Zebus"); + } - public void PreProcess() + public void PostProcess() + { + foreach (var message in _contracts.Messages) { - _contracts.ImportedNamespaces.Add("System"); - _contracts.ImportedNamespaces.Add("ProtoBuf"); - _contracts.ImportedNamespaces.Add("Abc.Zebus"); + ResolveTags(message); + AddInterfaces(message); + AddImplicitNamespaces(message); + SetInheritanceModifier(message); } - public void PostProcess() + foreach (var enumDef in _contracts.Enums) { - foreach (var message in _contracts.Messages) - { - ResolveTags(message); - AddInterfaces(message); - AddImplicitNamespaces(message); - SetInheritanceModifier(message); - } - - foreach (var enumDef in _contracts.Enums) - { - ResolveEnumValues(enumDef); - AddImplicitNamespaces(enumDef.Attributes); - - foreach (var memberDef in enumDef.Members) - AddImplicitNamespaces(memberDef.Attributes); - } - } + ResolveEnumValues(enumDef); + AddImplicitNamespaces(enumDef.Attributes); - private static void AddInterfaces(MessageDefinition message) - { - switch (message.Type) - { - case MessageType.Event: - message.BaseTypes.Add(KnownTypes.EventInterface); - break; - - case MessageType.Command: - message.BaseTypes.Add(KnownTypes.CommandInterface); - break; - - case MessageType.Custom: - message.BaseTypes.Add(KnownTypes.MessageInterface); - break; - } + foreach (var memberDef in enumDef.Members) + AddImplicitNamespaces(memberDef.Attributes); } + } - private static void ResolveTags(MessageDefinition message) + private static void AddInterfaces(MessageDefinition message) + { + switch (message.Type) { - var nextTag = AstValidator.ProtoMinTag; + case MessageType.Event: + message.BaseTypes.Add(KnownTypes.EventInterface); + break; - foreach (var param in message.Parameters) - { - if (param.Tag == 0) - param.Tag = nextTag; + case MessageType.Command: + message.BaseTypes.Add(KnownTypes.CommandInterface); + break; - nextTag = param.Tag + 1; - - if (nextTag is >= AstValidator.ProtoFirstReservedTag and <= AstValidator.ProtoLastReservedTag) - nextTag = AstValidator.ProtoLastReservedTag + 1; - } + case MessageType.Custom: + message.BaseTypes.Add(KnownTypes.MessageInterface); + break; } + } + + private static void ResolveTags(MessageDefinition message) + { + var nextTag = AstValidator.ProtoMinTag; - private static void ResolveEnumValues(EnumDefinition enumDef) + foreach (var param in message.Parameters) { - if (!enumDef.Options.Proto) - return; + if (param.Tag == 0) + param.Tag = nextTag; - if (enumDef.UnderlyingType.NetType != "int") - return; + nextTag = param.Tag + 1; - var nextValue = (int?)0; + if (nextTag is >= AstValidator.ProtoFirstReservedTag and <= AstValidator.ProtoLastReservedTag) + nextTag = AstValidator.ProtoLastReservedTag + 1; + } + } - foreach (var member in enumDef.Members) - { - member.ProtoValue = string.IsNullOrEmpty(member.Value) - ? nextValue - : enumDef.GetValidUnderlyingValue(member.Value) as int?; + private static void ResolveEnumValues(EnumDefinition enumDef) + { + if (!enumDef.Options.Proto) + return; - nextValue = member.ProtoValue + 1; - } - } + if (enumDef.UnderlyingType.NetType != "int") + return; - private void AddImplicitNamespaces(MessageDefinition message) + var nextValue = (int?)0; + + foreach (var member in enumDef.Members) { - AddImplicitNamespaces(message.Attributes); - - foreach (var paramDef in message.Parameters) - { - AddImplicitNamespaces(paramDef.Attributes); - - if (paramDef.Type.IsList) - _contracts.ImportedNamespaces.Add(typeof(List<>).Namespace!); - else if (paramDef.Type.IsDictionary) - _contracts.ImportedNamespaces.Add(typeof(Dictionary<,>).Namespace!); - else if (paramDef.Type.IsHashSet) - _contracts.ImportedNamespaces.Add(typeof(HashSet<>).Namespace!); - } + member.ProtoValue = string.IsNullOrEmpty(member.Value) + ? nextValue + : enumDef.GetValidUnderlyingValue(member.Value) as int?; + + nextValue = member.ProtoValue + 1; } + } + + private void AddImplicitNamespaces(MessageDefinition message) + { + AddImplicitNamespaces(message.Attributes); - private void AddImplicitNamespaces(AttributeSet attributes) + foreach (var paramDef in message.Parameters) { - if (attributes.HasAttribute(KnownTypes.DescriptionAttribute)) - _contracts.ImportedNamespaces.Add(typeof(DescriptionAttribute).Namespace!); + AddImplicitNamespaces(paramDef.Attributes); + + if (paramDef.Type.IsList) + _contracts.ImportedNamespaces.Add(typeof(List<>).Namespace!); + else if (paramDef.Type.IsDictionary) + _contracts.ImportedNamespaces.Add(typeof(Dictionary<,>).Namespace!); + else if (paramDef.Type.IsHashSet) + _contracts.ImportedNamespaces.Add(typeof(HashSet<>).Namespace!); } + } - private static void SetInheritanceModifier(MessageDefinition message) - { - if (message.InheritanceModifier != InheritanceModifier.Default) - return; + private void AddImplicitNamespaces(AttributeSet attributes) + { + if (attributes.HasAttribute(KnownTypes.DescriptionAttribute)) + _contracts.ImportedNamespaces.Add(typeof(DescriptionAttribute).Namespace!); + } + + private static void SetInheritanceModifier(MessageDefinition message) + { + if (message.InheritanceModifier != InheritanceModifier.Default) + return; - var hasInheritedMessages = message.Attributes.Any(attr => Equals(attr.TypeName, KnownTypes.ProtoIncludeAttribute)); + var hasInheritedMessages = message.Attributes.Any(attr => Equals(attr.TypeName, KnownTypes.ProtoIncludeAttribute)); - message.InheritanceModifier = hasInheritedMessages - ? InheritanceModifier.None - : InheritanceModifier.Sealed; - } + message.InheritanceModifier = hasInheritedMessages + ? InheritanceModifier.None + : InheritanceModifier.Sealed; } } diff --git a/src/Abc.Zebus.MessageDsl/Analysis/AstValidator.cs b/src/Abc.Zebus.MessageDsl/Analysis/AstValidator.cs index 96dc33a..ab5e844 100644 --- a/src/Abc.Zebus.MessageDsl/Analysis/AstValidator.cs +++ b/src/Abc.Zebus.MessageDsl/Analysis/AstValidator.cs @@ -3,230 +3,228 @@ using Abc.Zebus.MessageDsl.Ast; using Antlr4.Runtime; -namespace Abc.Zebus.MessageDsl.Analysis +namespace Abc.Zebus.MessageDsl.Analysis; + +internal class AstValidator { - internal class AstValidator + public const int ProtoMinTag = 1; + public const int ProtoMaxTag = 536870911; + public const int ProtoFirstReservedTag = 19000; + public const int ProtoLastReservedTag = 19999; + + private readonly ParsedContracts _contracts; + + public AstValidator(ParsedContracts contracts) { - public const int ProtoMinTag = 1; - public const int ProtoMaxTag = 536870911; - public const int ProtoFirstReservedTag = 19000; - public const int ProtoLastReservedTag = 19999; + _contracts = contracts; + } - private readonly ParsedContracts _contracts; + public void Validate() + { + foreach (var message in _contracts.Messages) + ValidateMessage(message); - public AstValidator(ParsedContracts contracts) - { - _contracts = contracts; - } + foreach (var enumDef in _contracts.Enums) + ValidateEnum(enumDef); - public void Validate() - { - foreach (var message in _contracts.Messages) - ValidateMessage(message); + DetectDuplicateTypes(); + } - foreach (var enumDef in _contracts.Enums) - ValidateEnum(enumDef); + private void ValidateMessage(MessageDefinition message) + { + var paramNames = new HashSet(); + var genericConstraints = new HashSet(); - DetectDuplicateTypes(); + if (message.Options.Proto) + { + if (message.GenericParameters.Count > 0) + _contracts.AddError(message.ParseContext, "Cannot generate .proto for generic message {0}", message.Name); } - private void ValidateMessage(MessageDefinition message) + ValidateAttributes(message.Attributes); + ValidateTags(message); + + foreach (var param in message.Parameters) { - var paramNames = new HashSet(); + var errorContext = param.ParseContext ?? message.ParseContext; - var genericConstraints = new HashSet(); + if (!paramNames.Add(param.Name)) + _contracts.AddError(errorContext, "Duplicate parameter name: {0}", param.Name); - if (message.Options.Proto) - { - if (message.GenericParameters.Count > 0) - _contracts.AddError(message.ParseContext, "Cannot generate .proto for generic message {0}", message.Name); - } + ValidateType(param.Type, param.ParseContext); + ValidateAttributes(param.Attributes); + } - ValidateAttributes(message.Attributes); - ValidateTags(message); + var requiredParameterSeen = false; - foreach (var param in message.Parameters) - { - var errorContext = param.ParseContext ?? message.ParseContext; + for (var i = message.Parameters.Count - 1; i >= 0; --i) + { + var param = message.Parameters[i]; + var errorContext = param.ParseContext ?? message.ParseContext; - if (!paramNames.Add(param.Name)) - _contracts.AddError(errorContext, "Duplicate parameter name: {0}", param.Name); + if (string.IsNullOrEmpty(param.DefaultValue)) + requiredParameterSeen = true; + else if (requiredParameterSeen) + _contracts.AddError(errorContext, "Optional parameter {0} cannot appear before a required parameter", param.Name); + } - ValidateType(param.Type, param.ParseContext); - ValidateAttributes(param.Attributes); - } + foreach (var constraint in message.GenericConstraints) + { + var errorContext = constraint.ParseContext ?? message.ParseContext; - var requiredParameterSeen = false; + if (!genericConstraints.Add(constraint.GenericParameterName)) + _contracts.AddError(errorContext, "Duplicate generic constraint: '{0}'", constraint.GenericParameterName); - for (var i = message.Parameters.Count - 1; i >= 0; --i) - { - var param = message.Parameters[i]; - var errorContext = param.ParseContext ?? message.ParseContext; + if (!message.GenericParameters.Contains(constraint.GenericParameterName)) + _contracts.AddError(errorContext, "Undefined generic parameter: '{0}'", constraint.GenericParameterName); - if (string.IsNullOrEmpty(param.DefaultValue)) - requiredParameterSeen = true; - else if (requiredParameterSeen) - _contracts.AddError(errorContext, "Optional parameter {0} cannot appear before a required parameter", param.Name); - } + if (constraint.IsClass && constraint.IsStruct) + _contracts.AddError(errorContext, "Constraint on '{0}' cannot require both class and struct", constraint.GenericParameterName); - foreach (var constraint in message.GenericConstraints) - { - var errorContext = constraint.ParseContext ?? message.ParseContext; + foreach (var constraintType in constraint.Types) + ValidateType(constraintType, message.ParseContext); + } - if (!genericConstraints.Add(constraint.GenericParameterName)) - _contracts.AddError(errorContext, "Duplicate generic constraint: '{0}'", constraint.GenericParameterName); + foreach (var baseType in message.BaseTypes) + ValidateType(baseType, message.ParseContext); - if (!message.GenericParameters.Contains(constraint.GenericParameterName)) - _contracts.AddError(errorContext, "Undefined generic parameter: '{0}'", constraint.GenericParameterName); + ValidateInheritance(message); + } - if (constraint.IsClass && constraint.IsStruct) - _contracts.AddError(errorContext, "Constraint on '{0}' cannot require both class and struct", constraint.GenericParameterName); + private void ValidateTags(MessageDefinition message) + { + var tags = new HashSet(); - foreach (var constraintType in constraint.Types) - ValidateType(constraintType, message.ParseContext); - } + foreach (var param in message.Parameters) + { + var errorContext = param.ParseContext ?? message.ParseContext; - foreach (var baseType in message.BaseTypes) - ValidateType(baseType, message.ParseContext); + if (!IsValidTag(param.Tag)) + _contracts.AddError(errorContext, "Tag for parameter '{0}' is not within the valid range ({1})", param.Name, param.Tag); - ValidateInheritance(message); + if (!tags.Add(param.Tag)) + _contracts.AddError(errorContext, "Duplicate tag {0} on parameter {1}", param.Tag, param.Name); } - private void ValidateTags(MessageDefinition message) + foreach (var attr in message.Attributes) { - var tags = new HashSet(); - - foreach (var param in message.Parameters) - { - var errorContext = param.ParseContext ?? message.ParseContext; + if (!Equals(attr.TypeName, KnownTypes.ProtoIncludeAttribute)) + continue; - if (!IsValidTag(param.Tag)) - _contracts.AddError(errorContext, "Tag for parameter '{0}' is not within the valid range ({1})", param.Name, param.Tag); + var errorContext = attr.ParseContext ?? message.ParseContext; - if (!tags.Add(param.Tag)) - _contracts.AddError(errorContext, "Duplicate tag {0} on parameter {1}", param.Tag, param.Name); + if (!AttributeInterpreter.TryParseProtoInclude(attr, out var tag, out _)) + { + _contracts.AddError(errorContext, "Invalid [{0}] parameters", KnownTypes.ProtoIncludeAttribute); + continue; } - foreach (var attr in message.Attributes) - { - if (!Equals(attr.TypeName, KnownTypes.ProtoIncludeAttribute)) - continue; + if (!IsValidTag(tag)) + _contracts.AddError(errorContext, "Tag for [{0}] is not within the valid range ({1})", KnownTypes.ProtoIncludeAttribute, tag); + + if (!tags.Add(tag)) + _contracts.AddError(errorContext, "Duplicate tag {0} on [{1}]", tag, KnownTypes.ProtoIncludeAttribute); + } + } - var errorContext = attr.ParseContext ?? message.ParseContext; + private void ValidateEnum(EnumDefinition enumDef) + { + if (!enumDef.IsValidUnderlyingType()) + _contracts.AddError(enumDef.ParseContext, "Invalid underlying type: {0}", enumDef.UnderlyingType); - if (!AttributeInterpreter.TryParseProtoInclude(attr, out var tag, out _)) - { - _contracts.AddError(errorContext, "Invalid [{0}] parameters", KnownTypes.ProtoIncludeAttribute); - continue; - } + if (enumDef.Options.Proto && enumDef.UnderlyingType.NetType != "int") + _contracts.AddError(enumDef.ParseContext, "An enum used in a proto file must have an underlying type of int"); - if (!IsValidTag(tag)) - _contracts.AddError(errorContext, "Tag for [{0}] is not within the valid range ({1})", KnownTypes.ProtoIncludeAttribute, tag); + ValidateAttributes(enumDef.Attributes); - if (!tags.Add(tag)) - _contracts.AddError(errorContext, "Duplicate tag {0} on [{1}]", tag, KnownTypes.ProtoIncludeAttribute); - } - } + var definedMembers = new HashSet(); - private void ValidateEnum(EnumDefinition enumDef) + foreach (var member in enumDef.Members) { - if (!enumDef.IsValidUnderlyingType()) - _contracts.AddError(enumDef.ParseContext, "Invalid underlying type: {0}", enumDef.UnderlyingType); + if (!definedMembers.Add(member.Name)) + _contracts.AddError(member.ParseContext, "Duplicate enum member: {0}", member.Name); - if (enumDef.Options.Proto && enumDef.UnderlyingType.NetType != "int") - _contracts.AddError(enumDef.ParseContext, "An enum used in a proto file must have an underlying type of int"); - - ValidateAttributes(enumDef.Attributes); + ValidateAttributes(member.Attributes); + } + } - var definedMembers = new HashSet(); + private void ValidateAttributes(AttributeSet attributes) + { + foreach (var attribute in attributes) + ValidateType(attribute.TypeName, attribute.ParseContext); + } - foreach (var member in enumDef.Members) - { - if (!definedMembers.Add(member.Name)) - _contracts.AddError(member.ParseContext, "Duplicate enum member: {0}", member.Name); + private void ValidateType(TypeName type, ParserRuleContext? context) + { + if (type.NetType.Contains("??")) + _contracts.AddError(context, "Invalid type: {0}", type.NetType); + } - ValidateAttributes(member.Attributes); - } - } + private void ValidateInheritance(MessageDefinition message) + { + if (message.BaseTypes.Count == 0) + return; - private void ValidateAttributes(AttributeSet attributes) + var seenTypes = new HashSet { - foreach (var attribute in attributes) - ValidateType(attribute.TypeName, attribute.ParseContext); - } + message.Name + }; - private void ValidateType(TypeName type, ParserRuleContext? context) - { - if (type.NetType.Contains("??")) - _contracts.AddError(context, "Invalid type: {0}", type.NetType); - } + var currentMessage = message; - private void ValidateInheritance(MessageDefinition message) + while (true) { - if (message.BaseTypes.Count == 0) - return; + if (currentMessage.BaseTypes.Count == 0) + break; - var seenTypes = new HashSet - { - message.Name - }; + currentMessage = _contracts.Messages.FirstOrDefault(m => m.Name == currentMessage.BaseTypes[0].NetType); + if (currentMessage is null) + break; - var currentMessage = message; - - while (true) + if (!seenTypes.Add(currentMessage.Name)) { - if (currentMessage.BaseTypes.Count == 0) - break; - - currentMessage = _contracts.Messages.FirstOrDefault(m => m.Name == currentMessage.BaseTypes[0].NetType); - if (currentMessage is null) - break; - - if (!seenTypes.Add(currentMessage.Name)) - { - _contracts.AddError(message.ParseContext, "There is a loop in the inheritance chain"); - break; - } + _contracts.AddError(message.ParseContext, "There is a loop in the inheritance chain"); + break; } } + } - private void DetectDuplicateTypes() - { - var seenTypes = new HashSet(); - var duplicates = new HashSet(); + private void DetectDuplicateTypes() + { + var seenTypes = new HashSet(); + var duplicates = new HashSet(); - var types = _contracts.Messages - .Cast() - .Concat(_contracts.Enums) - .ToList(); + var types = _contracts.Messages + .Cast() + .Concat(_contracts.Enums) + .ToList(); - foreach (var typeNode in types) - { - var nameWithGenericArity = GetNameWithGenericArity(typeNode); + foreach (var typeNode in types) + { + var nameWithGenericArity = GetNameWithGenericArity(typeNode); - if (!seenTypes.Add(nameWithGenericArity)) - duplicates.Add(nameWithGenericArity); - } + if (!seenTypes.Add(nameWithGenericArity)) + duplicates.Add(nameWithGenericArity); + } - foreach (var typeNode in types) - { - var nameWithGenericArity = GetNameWithGenericArity(typeNode); + foreach (var typeNode in types) + { + var nameWithGenericArity = GetNameWithGenericArity(typeNode); - if (duplicates.Contains(nameWithGenericArity)) - _contracts.AddError(typeNode.ParseContext, "Duplicate type name: {0}", nameWithGenericArity); - } + if (duplicates.Contains(nameWithGenericArity)) + _contracts.AddError(typeNode.ParseContext, "Duplicate type name: {0}", nameWithGenericArity); + } - static string GetNameWithGenericArity(AstNode node) - { - var name = ((INamedNode)node).Name; - if (node is MessageDefinition messageDef && messageDef.GenericParameters.Count > 0) - name = $"{name}`{messageDef.GenericParameters.Count}"; + static string GetNameWithGenericArity(AstNode node) + { + var name = ((INamedNode)node).Name; + if (node is MessageDefinition messageDef && messageDef.GenericParameters.Count > 0) + name = $"{name}`{messageDef.GenericParameters.Count}"; - return name; - } + return name; } - - public static bool IsValidTag(int tag) - => tag is >= ProtoMinTag and <= ProtoMaxTag and (< ProtoFirstReservedTag or > ProtoLastReservedTag); } + + public static bool IsValidTag(int tag) + => tag is >= ProtoMinTag and <= ProtoMaxTag and (< ProtoFirstReservedTag or > ProtoLastReservedTag); } diff --git a/src/Abc.Zebus.MessageDsl/Analysis/AttributeInterpreter.cs b/src/Abc.Zebus.MessageDsl/Analysis/AttributeInterpreter.cs index 24ae1b0..be55467 100644 --- a/src/Abc.Zebus.MessageDsl/Analysis/AttributeInterpreter.cs +++ b/src/Abc.Zebus.MessageDsl/Analysis/AttributeInterpreter.cs @@ -2,140 +2,139 @@ using System.Text.RegularExpressions; using Abc.Zebus.MessageDsl.Ast; -namespace Abc.Zebus.MessageDsl.Analysis +namespace Abc.Zebus.MessageDsl.Analysis; + +internal class AttributeInterpreter { - internal class AttributeInterpreter - { - private static readonly Regex _reProtoIncludeParams = new(@"^\s*(?[0-9]+)\s*,\s*typeof\s*\(\s*(?.+)\s*\)", RegexOptions.Compiled | RegexOptions.CultureInvariant); + private static readonly Regex _reProtoIncludeParams = new(@"^\s*(?[0-9]+)\s*,\s*typeof\s*\(\s*(?.+)\s*\)", RegexOptions.Compiled | RegexOptions.CultureInvariant); - private readonly ParsedContracts _contracts; + private readonly ParsedContracts _contracts; - public AttributeInterpreter(ParsedContracts contracts) - { - _contracts = contracts; - } + public AttributeInterpreter(ParsedContracts contracts) + { + _contracts = contracts; + } - public void InterpretAttributes() + public void InterpretAttributes() + { + foreach (var messageDefinition in _contracts.Messages) { - foreach (var messageDefinition in _contracts.Messages) - { - CheckIfTransient(messageDefinition); - CheckIfRoutable(messageDefinition); + CheckIfTransient(messageDefinition); + CheckIfRoutable(messageDefinition); - foreach (var parameterDefinition in messageDefinition.Parameters) - ProcessProtoMemberAttribute(parameterDefinition); - } + foreach (var parameterDefinition in messageDefinition.Parameters) + ProcessProtoMemberAttribute(parameterDefinition); } + } + + private void CheckIfRoutable(MessageDefinition message) + { + message.IsRoutable = !message.IsCustom && message.Attributes.HasAttribute(KnownTypes.RoutableAttribute); - private void CheckIfRoutable(MessageDefinition message) + if (message.IsRoutable) { - message.IsRoutable = !message.IsCustom && message.Attributes.HasAttribute(KnownTypes.RoutableAttribute); + _contracts.ImportedNamespaces.Add("Abc.Zebus.Routing"); - if (message.IsRoutable) + foreach (var param in message.Parameters) { - _contracts.ImportedNamespaces.Add("Abc.Zebus.Routing"); + var routingAttr = param.Attributes.Where(attr => Equals(attr.TypeName, KnownTypes.RoutingPositionAttribute)).ToList(); - foreach (var param in message.Parameters) + switch (routingAttr) { - var routingAttr = param.Attributes.Where(attr => Equals(attr.TypeName, KnownTypes.RoutingPositionAttribute)).ToList(); - - if (routingAttr.Count == 0) + case []: continue; - if (routingAttr.Count == 1) - { - var attr = routingAttr.Single(); - + case [var attr]: if (int.TryParse(attr.Parameters, out var routingPosition)) param.RoutingPosition = routingPosition; else _contracts.AddError(attr.ParseContext, "Invalid routing position: {0}", attr.Parameters); - } - else - { - _contracts.AddError(routingAttr.First().ParseContext, "Multiple routing positions are not allowed"); - } - } + break; - var routableParamCount = message.Parameters.Count(p => p.RoutingPosition != null); - var isValidSequence = message.Parameters - .Where(p => p.RoutingPosition != null) - .OrderBy(p => p.RoutingPosition) - .Select(p => p.RoutingPosition.GetValueOrDefault()) - .SequenceEqual(Enumerable.Range(1, routableParamCount)); - - if (routableParamCount == 0) - _contracts.AddError(message.ParseContext, "A routable message must have parameters with routing positions"); - else if (!isValidSequence) - _contracts.AddError(message.ParseContext, "Routing positions must form a continuous sequence starting with 1"); + case [var first, ..]: + _contracts.AddError(first.ParseContext, "Multiple routing positions are not allowed"); + break; + } } - else - { - var firstRoutingPositionAttr = message.Parameters - .SelectMany(p => p.Attributes) - .FirstOrDefault(attr => Equals(attr.TypeName, KnownTypes.RoutingPositionAttribute)); - if (firstRoutingPositionAttr != null) - _contracts.AddError(firstRoutingPositionAttr.ParseContext, "A non-routable message should not have RoutingPosition attributes"); - } + var routableParamCount = message.Parameters.Count(p => p.RoutingPosition != null); + var isValidSequence = message.Parameters + .Where(p => p.RoutingPosition != null) + .OrderBy(p => p.RoutingPosition) + .Select(p => p.RoutingPosition.GetValueOrDefault()) + .SequenceEqual(Enumerable.Range(1, routableParamCount)); + + if (routableParamCount == 0) + _contracts.AddError(message.ParseContext, "A routable message must have parameters with routing positions"); + else if (!isValidSequence) + _contracts.AddError(message.ParseContext, "Routing positions must form a continuous sequence starting with 1"); } - - private static void CheckIfTransient(MessageDefinition message) + else { - message.IsTransient = !message.IsCustom && message.Attributes.HasAttribute(KnownTypes.TransientAttribute); - } + var firstRoutingPositionAttr = message.Parameters + .SelectMany(p => p.Attributes) + .FirstOrDefault(attr => Equals(attr.TypeName, KnownTypes.RoutingPositionAttribute)); - private void ProcessProtoMemberAttribute(ParameterDefinition param) - { - var attr = param.Attributes.GetAttribute(KnownTypes.ProtoMemberAttribute); - if (attr == null) - return; + if (firstRoutingPositionAttr != null) + _contracts.AddError(firstRoutingPositionAttr.ParseContext, "A non-routable message should not have RoutingPosition attributes"); + } + } - if (string.IsNullOrWhiteSpace(attr.Parameters)) - { - _contracts.AddError(attr.ParseContext, "The [{0}] attribute must have parameters", KnownTypes.ProtoMemberAttribute); - return; - } + private static void CheckIfTransient(MessageDefinition message) + { + message.IsTransient = !message.IsCustom && message.Attributes.HasAttribute(KnownTypes.TransientAttribute); + } - var match = Regex.Match(attr.Parameters, @"^\s*(?[0-9]+)\s*(?:,|$)"); - if (!match.Success || !int.TryParse(match.Groups["nb"].Value, out var tagNb)) - { - _contracts.AddError(attr.ParseContext, "Invalid [{0}] parameters", KnownTypes.ProtoMemberAttribute); - return; - } + private void ProcessProtoMemberAttribute(ParameterDefinition param) + { + var attr = param.Attributes.GetAttribute(KnownTypes.ProtoMemberAttribute); + if (attr == null) + return; - if (param.Tag != 0) - { - _contracts.AddError(attr.ParseContext, "The parameter '{0}' already has an explicit tag ({1})", param.Name, param.Tag); - return; - } + if (string.IsNullOrWhiteSpace(attr.Parameters)) + { + _contracts.AddError(attr.ParseContext, "The [{0}] attribute must have parameters", KnownTypes.ProtoMemberAttribute); + return; + } - if (!AstValidator.IsValidTag(tagNb)) - { - _contracts.AddError(attr.ParseContext, "Tag for parameter '{0}' is not within the valid range ({1})", param.Name, tagNb); - return; - } + var match = Regex.Match(attr.Parameters, @"^\s*(?[0-9]+)\s*(?:,|$)"); + if (!match.Success || !int.TryParse(match.Groups["nb"].Value, out var tagNb)) + { + _contracts.AddError(attr.ParseContext, "Invalid [{0}] parameters", KnownTypes.ProtoMemberAttribute); + return; + } - param.Tag = tagNb; + if (param.Tag != 0) + { + _contracts.AddError(attr.ParseContext, "The parameter '{0}' already has an explicit tag ({1})", param.Name, param.Tag); + return; } - public static bool TryParseProtoInclude(AttributeDefinition? attribute, out int tag, out TypeName messageType) + if (!AstValidator.IsValidTag(tagNb)) { - tag = 0; - messageType = null!; + _contracts.AddError(attr.ParseContext, "Tag for parameter '{0}' is not within the valid range ({1})", param.Name, tagNb); + return; + } - if (!Equals(attribute?.TypeName, KnownTypes.ProtoIncludeAttribute)) - return false; + param.Tag = tagNb; + } - var match = _reProtoIncludeParams.Match(attribute.Parameters ?? string.Empty); - if (!match.Success) - return false; + public static bool TryParseProtoInclude(AttributeDefinition? attribute, out int tag, out TypeName messageType) + { + tag = 0; + messageType = null!; - if (!int.TryParse(match.Groups["tag"].Value, out tag)) - return false; + if (!Equals(attribute?.TypeName, KnownTypes.ProtoIncludeAttribute)) + return false; - messageType = match.Groups["typeName"].Value; - return true; - } + var match = _reProtoIncludeParams.Match(attribute.Parameters ?? string.Empty); + if (!match.Success) + return false; + + if (!int.TryParse(match.Groups["tag"].Value, out tag)) + return false; + + messageType = match.Groups["typeName"].Value; + return true; } } diff --git a/src/Abc.Zebus.MessageDsl/Analysis/ContractsEnhancer.cs b/src/Abc.Zebus.MessageDsl/Analysis/ContractsEnhancer.cs index c54a82a..5675a83 100644 --- a/src/Abc.Zebus.MessageDsl/Analysis/ContractsEnhancer.cs +++ b/src/Abc.Zebus.MessageDsl/Analysis/ContractsEnhancer.cs @@ -1,32 +1,31 @@ using Abc.Zebus.MessageDsl.Ast; -namespace Abc.Zebus.MessageDsl.Analysis +namespace Abc.Zebus.MessageDsl.Analysis; + +internal class ContractsEnhancer { - internal class ContractsEnhancer - { - private readonly ParsedContracts _contracts; + private readonly ParsedContracts _contracts; - public ContractsEnhancer(ParsedContracts contracts) - { - _contracts = contracts; - } + public ContractsEnhancer(ParsedContracts contracts) + { + _contracts = contracts; + } - public void Process() - { - foreach (var messageDefinition in _contracts.Messages) - ProcessMessage(messageDefinition); - } + public void Process() + { + foreach (var messageDefinition in _contracts.Messages) + ProcessMessage(messageDefinition); + } - private static void ProcessMessage(MessageDefinition message) - { - foreach (var parameterDefinition in message.Parameters) - ProcessParameter(parameterDefinition); - } + private static void ProcessMessage(MessageDefinition message) + { + foreach (var parameterDefinition in message.Parameters) + ProcessParameter(parameterDefinition); + } - private static void ProcessParameter(ParameterDefinition parameter) - { - if (parameter.Type.IsDictionary && !parameter.Attributes.HasAttribute(KnownTypes.ProtoMapAttribute)) - parameter.Attributes.Add(new AttributeDefinition(KnownTypes.ProtoMapAttribute, "DisableMap = true")); // https://github.com/mgravell/protobuf-net/issues/379 - } + private static void ProcessParameter(ParameterDefinition parameter) + { + if (parameter.Type.IsDictionary && !parameter.Attributes.HasAttribute(KnownTypes.ProtoMapAttribute)) + parameter.Attributes.Add(new AttributeDefinition(KnownTypes.ProtoMapAttribute, "DisableMap = true")); // https://github.com/mgravell/protobuf-net/issues/379 } } diff --git a/src/Abc.Zebus.MessageDsl/Analysis/KnownTypes.cs b/src/Abc.Zebus.MessageDsl/Analysis/KnownTypes.cs index 1658f16..2f11366 100644 --- a/src/Abc.Zebus.MessageDsl/Analysis/KnownTypes.cs +++ b/src/Abc.Zebus.MessageDsl/Analysis/KnownTypes.cs @@ -1,27 +1,26 @@ using Abc.Zebus.MessageDsl.Ast; -namespace Abc.Zebus.MessageDsl.Analysis +namespace Abc.Zebus.MessageDsl.Analysis; + +internal static class KnownTypes { - internal static class KnownTypes - { - public static TypeName EventInterface { get; } = new("IEvent"); - public static TypeName EventInterfaceFullName { get; } = new("Abc.Zebus.IEvent"); + public static TypeName EventInterface { get; } = new("IEvent"); + public static TypeName EventInterfaceFullName { get; } = new("Abc.Zebus.IEvent"); - public static TypeName CommandInterface { get; } = new("ICommand"); - public static TypeName CommandInterfaceFullName { get; } = new("Abc.Zebus.ICommand"); + public static TypeName CommandInterface { get; } = new("ICommand"); + public static TypeName CommandInterfaceFullName { get; } = new("Abc.Zebus.ICommand"); - public static TypeName MessageInterface { get; } = new("IMessage"); + public static TypeName MessageInterface { get; } = new("IMessage"); - public static TypeName RoutableAttribute { get; } = new("Routable"); - public static TypeName RoutingPositionAttribute { get; } = new("RoutingPosition"); - public static TypeName TransientAttribute { get; } = new("Transient"); + public static TypeName RoutableAttribute { get; } = new("Routable"); + public static TypeName RoutingPositionAttribute { get; } = new("RoutingPosition"); + public static TypeName TransientAttribute { get; } = new("Transient"); - public static TypeName ProtoContractAttribute { get; } = new("ProtoContract"); - public static TypeName ProtoMemberAttribute { get; } = new("ProtoMember"); - public static TypeName ProtoMapAttribute { get; } = new("ProtoMap"); - public static TypeName ProtoIncludeAttribute { get; } = new("ProtoInclude"); + public static TypeName ProtoContractAttribute { get; } = new("ProtoContract"); + public static TypeName ProtoMemberAttribute { get; } = new("ProtoMember"); + public static TypeName ProtoMapAttribute { get; } = new("ProtoMap"); + public static TypeName ProtoIncludeAttribute { get; } = new("ProtoInclude"); - public static TypeName ObsoleteAttribute { get; } = new("Obsolete"); - public static TypeName DescriptionAttribute { get; } = new("Description"); - } + public static TypeName ObsoleteAttribute { get; } = new("Obsolete"); + public static TypeName DescriptionAttribute { get; } = new("Description"); } diff --git a/src/Abc.Zebus.MessageDsl/Analysis/SyntaxDebugHelper.cs b/src/Abc.Zebus.MessageDsl/Analysis/SyntaxDebugHelper.cs index 024bfe3..7db3407 100644 --- a/src/Abc.Zebus.MessageDsl/Analysis/SyntaxDebugHelper.cs +++ b/src/Abc.Zebus.MessageDsl/Analysis/SyntaxDebugHelper.cs @@ -3,24 +3,23 @@ using Abc.Zebus.MessageDsl.Ast; using Antlr4.Runtime.Tree; -namespace Abc.Zebus.MessageDsl.Analysis +namespace Abc.Zebus.MessageDsl.Analysis; + +internal static class SyntaxDebugHelper { - internal static class SyntaxDebugHelper + public static string DumpParseTree(ParsedContracts contracts) { - public static string DumpParseTree(ParsedContracts contracts) - { - var writer = new IndentedTextWriter(new StringWriter()); - DumpParseTree(contracts.ParseTree, writer); - return writer.InnerWriter.ToString() ?? string.Empty; - } + var writer = new IndentedTextWriter(new StringWriter()); + DumpParseTree(contracts.ParseTree, writer); + return writer.InnerWriter.ToString() ?? string.Empty; + } - private static void DumpParseTree(IParseTree tree, IndentedTextWriter writer) - { - writer.WriteLine("{0}: {1}", tree.GetType().Name, tree.GetText()); - writer.Indent++; - for (var i = 0; i < tree.ChildCount; i++) - DumpParseTree(tree.GetChild(i), writer); - writer.Indent--; - } + private static void DumpParseTree(IParseTree tree, IndentedTextWriter writer) + { + writer.WriteLine("{0}: {1}", tree.GetType().Name, tree.GetText()); + writer.Indent++; + for (var i = 0; i < tree.ChildCount; i++) + DumpParseTree(tree.GetChild(i), writer); + writer.Indent--; } } diff --git a/src/Abc.Zebus.MessageDsl/Analysis/TextInterval.cs b/src/Abc.Zebus.MessageDsl/Analysis/TextInterval.cs index 15bb802..e561a30 100644 --- a/src/Abc.Zebus.MessageDsl/Analysis/TextInterval.cs +++ b/src/Abc.Zebus.MessageDsl/Analysis/TextInterval.cs @@ -1,66 +1,66 @@ using System; using JetBrains.Annotations; -namespace Abc.Zebus.MessageDsl.Analysis +namespace Abc.Zebus.MessageDsl.Analysis; + +public readonly struct TextInterval : IEquatable { - public readonly struct TextInterval : IEquatable - { - public static TextInterval Empty { get; } = new(); + public static TextInterval Empty { get; } = new(); - public int Start { get; } - public int End { get; } + public int Start { get; } + public int End { get; } - public int Length => End - Start; - public bool IsEmpty => Length == 0; + public int Length => End - Start; + public bool IsEmpty => Length == 0; - public TextInterval(int offset) - : this(offset, offset) - { - } + public TextInterval(int offset) + : this(offset, offset) + { + } - public TextInterval(int startOffset, int endOffset) - { - if (startOffset < 0) - throw new ArgumentOutOfRangeException(nameof(startOffset), "Start offset cannot be negative"); + public TextInterval(int startOffset, int endOffset) + { + if (startOffset < 0) + throw new ArgumentOutOfRangeException(nameof(startOffset), "Start offset cannot be negative"); - if (endOffset < startOffset) - throw new ArgumentOutOfRangeException(nameof(endOffset), "End offset cannot be less than start offset"); + if (endOffset < startOffset) + throw new ArgumentOutOfRangeException(nameof(endOffset), "End offset cannot be less than start offset"); - Start = startOffset; - End = endOffset; - } + Start = startOffset; + End = endOffset; + } - [Pure] - public TextInterval Intersect(TextInterval other) - => OverlapsOrIsAdjacent(other) - ? new TextInterval(Math.Max(Start, other.Start), Math.Min(End, other.End)) - : Empty; + [Pure] + public TextInterval Intersect(TextInterval other) + => OverlapsOrIsAdjacent(other) + ? new TextInterval(Math.Max(Start, other.Start), Math.Min(End, other.End)) + : Empty; - [Pure] - public bool Contains(TextInterval other) - => Start <= other.Start && other.End <= End; + [Pure] + public bool Contains(TextInterval other) + => Start <= other.Start && other.End <= End; - [Pure] - public bool Contains(int offset) - => Start <= offset && offset <= End; + [Pure] + public bool Contains(int offset) + => Start <= offset && offset <= End; - [Pure] - public bool Overlaps(TextInterval other) - => other.End > Start && other.Start < End; + [Pure] + public bool Overlaps(TextInterval other) + => other.End > Start && other.Start < End; - [Pure] - public bool OverlapsOrIsAdjacent(TextInterval other) - => other.End >= Start && other.Start <= End; + [Pure] + public bool OverlapsOrIsAdjacent(TextInterval other) + => other.End >= Start && other.Start <= End; - public bool Equals(TextInterval other) - => Start == other.Start && End == other.End; + public bool Equals(TextInterval other) + => Start == other.Start && End == other.End; - public override bool Equals(object? obj) => obj is TextInterval interval && Equals(interval); - public override int GetHashCode() => unchecked((Start * 397) ^ End); + public override bool Equals(object? obj) => obj is TextInterval interval && Equals(interval); + public override int GetHashCode() => unchecked((Start * 397) ^ End); - public static bool operator ==(TextInterval left, TextInterval right) => left.Equals(right); - public static bool operator !=(TextInterval left, TextInterval right) => !left.Equals(right); + public static bool operator ==(TextInterval left, TextInterval right) => left.Equals(right); + public static bool operator !=(TextInterval left, TextInterval right) => !left.Equals(right); - public override string ToString() => IsEmpty ? Start.ToString() : $"{Start}-{End}"; - } + public override string ToString() + => IsEmpty ? Start.ToString() : $"{Start}-{End}"; } diff --git a/src/Abc.Zebus.MessageDsl/Ast/AccessModifier.cs b/src/Abc.Zebus.MessageDsl/Ast/AccessModifier.cs index c6da079..1edeb20 100644 --- a/src/Abc.Zebus.MessageDsl/Ast/AccessModifier.cs +++ b/src/Abc.Zebus.MessageDsl/Ast/AccessModifier.cs @@ -1,8 +1,7 @@ -namespace Abc.Zebus.MessageDsl.Ast +namespace Abc.Zebus.MessageDsl.Ast; + +public enum AccessModifier { - public enum AccessModifier - { - Public, - Internal - } + Public, + Internal } diff --git a/src/Abc.Zebus.MessageDsl/Ast/AstNode.cs b/src/Abc.Zebus.MessageDsl/Ast/AstNode.cs index a7deb7c..cf922a1 100644 --- a/src/Abc.Zebus.MessageDsl/Ast/AstNode.cs +++ b/src/Abc.Zebus.MessageDsl/Ast/AstNode.cs @@ -2,27 +2,26 @@ using Antlr4.Runtime; using Antlr4.Runtime.Misc; -namespace Abc.Zebus.MessageDsl.Ast +namespace Abc.Zebus.MessageDsl.Ast; + +public class AstNode { - public class AstNode - { - internal ParserRuleContext? ParseContext { get; set; } + internal ParserRuleContext? ParseContext { get; set; } - public TextInterval GetSourceTextInterval() - { - var startIndex = ParseContext?.Start?.StartIndex ?? 0; - var endIndex = ParseContext?.Stop?.StopIndex + 1 ?? 0; + public TextInterval GetSourceTextInterval() + { + var startIndex = ParseContext?.Start?.StartIndex ?? 0; + var endIndex = ParseContext?.Stop?.StopIndex + 1 ?? 0; - return new TextInterval(startIndex, endIndex); - } + return new TextInterval(startIndex, endIndex); + } - public string GetSourceText() - { - var interval = GetSourceTextInterval(); - if (interval.IsEmpty) - return string.Empty; + public string GetSourceText() + { + var interval = GetSourceTextInterval(); + if (interval.IsEmpty) + return string.Empty; - return ParseContext?.Start?.InputStream?.GetText(Interval.Of(interval.Start, interval.End - 1)) ?? string.Empty; - } + return ParseContext?.Start?.InputStream?.GetText(Interval.Of(interval.Start, interval.End - 1)) ?? string.Empty; } } diff --git a/src/Abc.Zebus.MessageDsl/Ast/AttributeDefinition.cs b/src/Abc.Zebus.MessageDsl/Ast/AttributeDefinition.cs index a8f8a70..5d1e82f 100644 --- a/src/Abc.Zebus.MessageDsl/Ast/AttributeDefinition.cs +++ b/src/Abc.Zebus.MessageDsl/Ast/AttributeDefinition.cs @@ -1,46 +1,45 @@ -namespace Abc.Zebus.MessageDsl.Ast +namespace Abc.Zebus.MessageDsl.Ast; + +public class AttributeDefinition : AstNode { - public class AttributeDefinition : AstNode - { - private TypeName _typeName = default!; + private TypeName _typeName = default!; - public TypeName TypeName - { - get => _typeName; - private set => _typeName = NormalizeAttributeTypeName(value); - } + public TypeName TypeName + { + get => _typeName; + private set => _typeName = NormalizeAttributeTypeName(value); + } - public string? Parameters { get; set; } + public string? Parameters { get; set; } - public AttributeDefinition(TypeName attributeType, string? parameters = null) - { - TypeName = attributeType; - Parameters = parameters; - } + public AttributeDefinition(TypeName attributeType, string? parameters = null) + { + TypeName = attributeType; + Parameters = parameters; + } - public static TypeName NormalizeAttributeTypeName(TypeName typeName) - { - const string attributeSuffix = "Attribute"; + public static TypeName NormalizeAttributeTypeName(TypeName typeName) + { + const string attributeSuffix = "Attribute"; - if (typeName.NetType.EndsWith(attributeSuffix)) - return typeName.NetType.Substring(0, typeName.NetType.Length - attributeSuffix.Length); + if (typeName.NetType.EndsWith(attributeSuffix)) + return typeName.NetType.Substring(0, typeName.NetType.Length - attributeSuffix.Length); - return typeName; - } + return typeName; + } - public AttributeDefinition Clone() + public AttributeDefinition Clone() + { + return new AttributeDefinition(TypeName, Parameters) { - return new AttributeDefinition(TypeName, Parameters) - { - ParseContext = ParseContext - }; - } + ParseContext = ParseContext + }; + } - public override string ToString() - { - return string.IsNullOrEmpty(Parameters) - ? TypeName.ToString() - : $"{TypeName}({Parameters})"; - } + public override string ToString() + { + return string.IsNullOrEmpty(Parameters) + ? TypeName.ToString() + : $"{TypeName}({Parameters})"; } } diff --git a/src/Abc.Zebus.MessageDsl/Ast/AttributeSet.cs b/src/Abc.Zebus.MessageDsl/Ast/AttributeSet.cs index faa2803..c68551f 100644 --- a/src/Abc.Zebus.MessageDsl/Ast/AttributeSet.cs +++ b/src/Abc.Zebus.MessageDsl/Ast/AttributeSet.cs @@ -3,70 +3,69 @@ using System.Linq; using System.Text; -namespace Abc.Zebus.MessageDsl.Ast +namespace Abc.Zebus.MessageDsl.Ast; + +public class AttributeSet : AstNode, IList { - public class AttributeSet : AstNode, IList + public IList Attributes { get; } = new List(); + + public AttributeDefinition? GetAttribute(TypeName attributeType) { - public IList Attributes { get; } = new List(); + attributeType = AttributeDefinition.NormalizeAttributeTypeName(attributeType); + return Attributes.FirstOrDefault(attr => Equals(attr.TypeName, attributeType)); + } - public AttributeDefinition? GetAttribute(TypeName attributeType) - { - attributeType = AttributeDefinition.NormalizeAttributeTypeName(attributeType); - return Attributes.FirstOrDefault(attr => Equals(attr.TypeName, attributeType)); - } + public bool HasAttribute(TypeName attributeType) + => GetAttribute(attributeType) != null; - public bool HasAttribute(TypeName attributeType) - => GetAttribute(attributeType) != null; + public void AddFlagAttribute(TypeName attributeType) + { + if (!HasAttribute(attributeType)) + Attributes.Add(new AttributeDefinition(attributeType)); + } - public void AddFlagAttribute(TypeName attributeType) - { - if (!HasAttribute(attributeType)) - Attributes.Add(new AttributeDefinition(attributeType)); - } + public AttributeSet Clone() + { + var newSet = new AttributeSet(); - public AttributeSet Clone() - { - var newSet = new AttributeSet(); + foreach (var attribute in Attributes) + newSet.Attributes.Add(attribute.Clone()); - foreach (var attribute in Attributes) - newSet.Attributes.Add(attribute.Clone()); + return newSet; + } - return newSet; - } + public IEnumerator GetEnumerator() => Attributes.GetEnumerator(); + IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable)Attributes).GetEnumerator(); + public void Add(AttributeDefinition item) => Attributes.Add(item); + public void Clear() => Attributes.Clear(); + public bool Contains(AttributeDefinition item) => Attributes.Contains(item); + public void CopyTo(AttributeDefinition[] array, int arrayIndex) => Attributes.CopyTo(array, arrayIndex); + public bool Remove(AttributeDefinition item) => Attributes.Remove(item); + public int Count => Attributes.Count; + public bool IsReadOnly => Attributes.IsReadOnly; + public int IndexOf(AttributeDefinition item) => Attributes.IndexOf(item); + public void Insert(int index, AttributeDefinition item) => Attributes.Insert(index, item); + public void RemoveAt(int index) => Attributes.RemoveAt(index); - public IEnumerator GetEnumerator() => Attributes.GetEnumerator(); - IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable)Attributes).GetEnumerator(); - public void Add(AttributeDefinition item) => Attributes.Add(item); - public void Clear() => Attributes.Clear(); - public bool Contains(AttributeDefinition item) => Attributes.Contains(item); - public void CopyTo(AttributeDefinition[] array, int arrayIndex) => Attributes.CopyTo(array, arrayIndex); - public bool Remove(AttributeDefinition item) => Attributes.Remove(item); - public int Count => Attributes.Count; - public bool IsReadOnly => Attributes.IsReadOnly; - public int IndexOf(AttributeDefinition item) => Attributes.IndexOf(item); - public void Insert(int index, AttributeDefinition item) => Attributes.Insert(index, item); - public void RemoveAt(int index) => Attributes.RemoveAt(index); + public AttributeDefinition this[int index] + { + get => Attributes[index]; + set => Attributes[index] = value; + } - public AttributeDefinition this[int index] - { - get => Attributes[index]; - set => Attributes[index] = value; - } + public override string ToString() + { + var sb = new StringBuilder(); + sb.Append("["); - public override string ToString() + if (Attributes.Count > 0) { - var sb = new StringBuilder(); - sb.Append("["); - - if (Attributes.Count > 0) - { - foreach (var attr in Attributes) - sb.Append(attr).Append(", "); - sb.Length -= 2; - } - - sb.Append("]"); - return sb.ToString(); + foreach (var attr in Attributes) + sb.Append(attr).Append(", "); + sb.Length -= 2; } + + sb.Append("]"); + return sb.ToString(); } } diff --git a/src/Abc.Zebus.MessageDsl/Ast/ContractOptions.cs b/src/Abc.Zebus.MessageDsl/Ast/ContractOptions.cs index 64a7df8..aec9f65 100644 --- a/src/Abc.Zebus.MessageDsl/Ast/ContractOptions.cs +++ b/src/Abc.Zebus.MessageDsl/Ast/ContractOptions.cs @@ -1,6 +1,3 @@ -namespace Abc.Zebus.MessageDsl.Ast -{ - public class ContractOptions : OptionsBase - { - } -} +namespace Abc.Zebus.MessageDsl.Ast; + +public class ContractOptions : OptionsBase; diff --git a/src/Abc.Zebus.MessageDsl/Ast/EnumDefinition.cs b/src/Abc.Zebus.MessageDsl/Ast/EnumDefinition.cs index 39574c2..c5e7970 100644 --- a/src/Abc.Zebus.MessageDsl/Ast/EnumDefinition.cs +++ b/src/Abc.Zebus.MessageDsl/Ast/EnumDefinition.cs @@ -3,72 +3,72 @@ using System.Diagnostics.CodeAnalysis; using System.Globalization; -namespace Abc.Zebus.MessageDsl.Ast +namespace Abc.Zebus.MessageDsl.Ast; + +public class EnumDefinition : AstNode, IMemberNode { - public class EnumDefinition : AstNode, IMemberNode - { - private MemberOptions? _options; + private MemberOptions? _options; - public string Name { get; set; } = default!; - public TypeName UnderlyingType { get; set; } = "int"; - public AccessModifier AccessModifier { get; set; } - public AttributeSet Attributes { get; } = new(); - public IList Members { get; } = new List(); + public string Name { get; set; } = default!; + public TypeName UnderlyingType { get; set; } = "int"; + public AccessModifier AccessModifier { get; set; } + public AttributeSet Attributes { get; } = new(); + public IList Members { get; } = new List(); - public MemberOptions Options - { - get => _options ??= new MemberOptions(); - set => _options = value; - } + public MemberOptions Options + { + get => _options ??= new MemberOptions(); + set => _options = value; + } - public override string ToString() => Name; + public override string ToString() + => Name; - internal bool IsValidUnderlyingType() + internal bool IsValidUnderlyingType() + { + switch (UnderlyingType.NetType) { - switch (UnderlyingType.NetType) - { - case "byte": - case "sbyte": - case "short": - case "ushort": - case "int": - case "uint": - case "long": - case "ulong": - return true; + case "byte": + case "sbyte": + case "short": + case "ushort": + case "int": + case "uint": + case "long": + case "ulong": + return true; - default: - return false; - } + default: + return false; } + } - [SuppressMessage("ReSharper", "HeapView.BoxingAllocation")] - internal object? GetValidUnderlyingValue(string? value) - { - if (string.IsNullOrEmpty(value)) - return null; + [SuppressMessage("ReSharper", "HeapView.BoxingAllocation")] + internal object? GetValidUnderlyingValue(string? value) + { + if (string.IsNullOrEmpty(value)) + return null; - value = value!.Trim(); - var numberStyles = value.StartsWith("0x", StringComparison.OrdinalIgnoreCase) ? NumberStyles.HexNumber : NumberStyles.Integer; + value = value.Trim(); + var numberStyles = value.StartsWith("0x", StringComparison.OrdinalIgnoreCase) ? NumberStyles.HexNumber : NumberStyles.Integer; - if (numberStyles == NumberStyles.HexNumber) - value = value.Substring(2); + if (numberStyles == NumberStyles.HexNumber) + value = value.Substring(2); - if (value.Length == 0) - return null; + if (value.Length == 0) + return null; - return UnderlyingType.NetType switch - { - "byte" => byte.TryParse(value, numberStyles, CultureInfo.InvariantCulture, out var result) ? result : null, - "sbyte" => sbyte.TryParse(value, numberStyles, CultureInfo.InvariantCulture, out var result) ? result : null, - "short" => short.TryParse(value, numberStyles, CultureInfo.InvariantCulture, out var result) ? result : null, - "ushort" => ushort.TryParse(value, numberStyles, CultureInfo.InvariantCulture, out var result) ? result : null, - "int" => int.TryParse(value, numberStyles, CultureInfo.InvariantCulture, out var result) ? result : null, - "uint" => uint.TryParse(value, numberStyles, CultureInfo.InvariantCulture, out var result) ? result : null, - "long" => long.TryParse(value, numberStyles, CultureInfo.InvariantCulture, out var result) ? result : null, - "ulong" => ulong.TryParse(value, numberStyles, CultureInfo.InvariantCulture, out var result) ? result : null, - _ => null - }; - } + return UnderlyingType.NetType switch + { + "byte" => byte.TryParse(value, numberStyles, CultureInfo.InvariantCulture, out var result) ? result : null, + "sbyte" => sbyte.TryParse(value, numberStyles, CultureInfo.InvariantCulture, out var result) ? result : null, + "short" => short.TryParse(value, numberStyles, CultureInfo.InvariantCulture, out var result) ? result : null, + "ushort" => ushort.TryParse(value, numberStyles, CultureInfo.InvariantCulture, out var result) ? result : null, + "int" => int.TryParse(value, numberStyles, CultureInfo.InvariantCulture, out var result) ? result : null, + "uint" => uint.TryParse(value, numberStyles, CultureInfo.InvariantCulture, out var result) ? result : null, + "long" => long.TryParse(value, numberStyles, CultureInfo.InvariantCulture, out var result) ? result : null, + "ulong" => ulong.TryParse(value, numberStyles, CultureInfo.InvariantCulture, out var result) ? result : null, + _ => null + }; } } diff --git a/src/Abc.Zebus.MessageDsl/Ast/EnumMemberDefinition.cs b/src/Abc.Zebus.MessageDsl/Ast/EnumMemberDefinition.cs index 906a712..f76dc72 100644 --- a/src/Abc.Zebus.MessageDsl/Ast/EnumMemberDefinition.cs +++ b/src/Abc.Zebus.MessageDsl/Ast/EnumMemberDefinition.cs @@ -1,10 +1,9 @@ -namespace Abc.Zebus.MessageDsl.Ast +namespace Abc.Zebus.MessageDsl.Ast; + +public class EnumMemberDefinition : AstNode, INamedNode { - public class EnumMemberDefinition : AstNode, INamedNode - { - public string Name { get; set; } = default!; - public string? Value { get; set; } - public AttributeSet Attributes { get; } = new(); - internal int? ProtoValue { get; set; } - } + public string Name { get; set; } = default!; + public string? Value { get; set; } + public AttributeSet Attributes { get; } = new(); + internal int? ProtoValue { get; set; } } diff --git a/src/Abc.Zebus.MessageDsl/Ast/FieldRules.cs b/src/Abc.Zebus.MessageDsl/Ast/FieldRules.cs index d1dd121..decfc22 100644 --- a/src/Abc.Zebus.MessageDsl/Ast/FieldRules.cs +++ b/src/Abc.Zebus.MessageDsl/Ast/FieldRules.cs @@ -1,10 +1,9 @@ -namespace Abc.Zebus.MessageDsl.Ast +namespace Abc.Zebus.MessageDsl.Ast; + +public enum FieldRules { - public enum FieldRules - { - Unspecified, - Required, - Optional, - Repeated - } + Unspecified, + Required, + Optional, + Repeated } diff --git a/src/Abc.Zebus.MessageDsl/Ast/GenericConstraint.cs b/src/Abc.Zebus.MessageDsl/Ast/GenericConstraint.cs index 6e6418f..a38e880 100644 --- a/src/Abc.Zebus.MessageDsl/Ast/GenericConstraint.cs +++ b/src/Abc.Zebus.MessageDsl/Ast/GenericConstraint.cs @@ -1,17 +1,17 @@ using System.Collections.Generic; -namespace Abc.Zebus.MessageDsl.Ast +namespace Abc.Zebus.MessageDsl.Ast; + +public class GenericConstraint : AstNode { - public class GenericConstraint : AstNode - { - public string GenericParameterName { get; set; } = default!; + public string GenericParameterName { get; set; } = default!; - public bool IsClass { get; set; } - public bool IsStruct { get; set; } - public bool HasDefaultConstructor { get; set; } + public bool IsClass { get; set; } + public bool IsStruct { get; set; } + public bool HasDefaultConstructor { get; set; } - public ISet Types { get; } = new HashSet(); + public ISet Types { get; } = new HashSet(); - public override string ToString() => GenericParameterName; - } + public override string ToString() + => GenericParameterName; } diff --git a/src/Abc.Zebus.MessageDsl/Ast/IClassNode.cs b/src/Abc.Zebus.MessageDsl/Ast/IClassNode.cs index 12659f5..eb8803d 100644 --- a/src/Abc.Zebus.MessageDsl/Ast/IClassNode.cs +++ b/src/Abc.Zebus.MessageDsl/Ast/IClassNode.cs @@ -1,7 +1,6 @@ -namespace Abc.Zebus.MessageDsl.Ast +namespace Abc.Zebus.MessageDsl.Ast; + +internal interface IClassNode : IMemberNode { - internal interface IClassNode : IMemberNode - { - InheritanceModifier InheritanceModifier { get; set; } - } + InheritanceModifier InheritanceModifier { get; set; } } diff --git a/src/Abc.Zebus.MessageDsl/Ast/IMemberNode.cs b/src/Abc.Zebus.MessageDsl/Ast/IMemberNode.cs index fa2f715..f40e1e0 100644 --- a/src/Abc.Zebus.MessageDsl/Ast/IMemberNode.cs +++ b/src/Abc.Zebus.MessageDsl/Ast/IMemberNode.cs @@ -1,8 +1,7 @@ -namespace Abc.Zebus.MessageDsl.Ast +namespace Abc.Zebus.MessageDsl.Ast; + +internal interface IMemberNode : INamedNode { - internal interface IMemberNode : INamedNode - { - AccessModifier AccessModifier { get; set; } - MemberOptions Options { get; } - } + AccessModifier AccessModifier { get; set; } + MemberOptions Options { get; } } diff --git a/src/Abc.Zebus.MessageDsl/Ast/INamedNode.cs b/src/Abc.Zebus.MessageDsl/Ast/INamedNode.cs index ab943c5..420b5b3 100644 --- a/src/Abc.Zebus.MessageDsl/Ast/INamedNode.cs +++ b/src/Abc.Zebus.MessageDsl/Ast/INamedNode.cs @@ -1,7 +1,6 @@ -namespace Abc.Zebus.MessageDsl.Ast +namespace Abc.Zebus.MessageDsl.Ast; + +public interface INamedNode { - public interface INamedNode - { - string Name { get; } - } + string Name { get; } } diff --git a/src/Abc.Zebus.MessageDsl/Ast/InheritanceModifier.cs b/src/Abc.Zebus.MessageDsl/Ast/InheritanceModifier.cs index b95e9c1..a7ec13a 100644 --- a/src/Abc.Zebus.MessageDsl/Ast/InheritanceModifier.cs +++ b/src/Abc.Zebus.MessageDsl/Ast/InheritanceModifier.cs @@ -1,10 +1,9 @@ -namespace Abc.Zebus.MessageDsl.Ast +namespace Abc.Zebus.MessageDsl.Ast; + +public enum InheritanceModifier { - public enum InheritanceModifier - { - Default, - None, - Sealed, - Abstract - } + Default, + None, + Sealed, + Abstract } diff --git a/src/Abc.Zebus.MessageDsl/Ast/MemberOptions.cs b/src/Abc.Zebus.MessageDsl/Ast/MemberOptions.cs index 8226d9c..2ecc82a 100644 --- a/src/Abc.Zebus.MessageDsl/Ast/MemberOptions.cs +++ b/src/Abc.Zebus.MessageDsl/Ast/MemberOptions.cs @@ -1,25 +1,25 @@ using JetBrains.Annotations; -namespace Abc.Zebus.MessageDsl.Ast +namespace Abc.Zebus.MessageDsl.Ast; + +public class MemberOptions : OptionsBase { - public class MemberOptions : OptionsBase - { - public bool Proto { get; set; } - public bool Mutable { get; set; } + public bool Proto { get; set; } + public bool Mutable { get; set; } - public bool Internal { get; set; } + public bool Internal { get; set; } - public bool Public - { - get => !Internal; - [UsedImplicitly] set => Internal = !value; - } + public bool Public + { + get => !Internal; + [UsedImplicitly] set => Internal = !value; + } - public bool Nullable { get; set; } + public bool Nullable { get; set; } - public AccessModifier GetAccessModifier() - => Internal ? AccessModifier.Internal : AccessModifier.Public; + public AccessModifier GetAccessModifier() + => Internal ? AccessModifier.Internal : AccessModifier.Public; - public MemberOptions Clone() => (MemberOptions)MemberwiseClone(); - } + public MemberOptions Clone() + => (MemberOptions)MemberwiseClone(); } diff --git a/src/Abc.Zebus.MessageDsl/Ast/MessageDefinition.cs b/src/Abc.Zebus.MessageDsl/Ast/MessageDefinition.cs index 275a884..360543d 100644 --- a/src/Abc.Zebus.MessageDsl/Ast/MessageDefinition.cs +++ b/src/Abc.Zebus.MessageDsl/Ast/MessageDefinition.cs @@ -1,53 +1,53 @@ using System.Collections.Generic; using Abc.Zebus.MessageDsl.Analysis; -namespace Abc.Zebus.MessageDsl.Ast +namespace Abc.Zebus.MessageDsl.Ast; + +public class MessageDefinition : AstNode, IClassNode { - public class MessageDefinition : AstNode, IClassNode - { - private MemberOptions? _options; + private MemberOptions? _options; - public string Name { get; set; } = default!; - public AccessModifier AccessModifier { get; set; } - public InheritanceModifier InheritanceModifier { get; set; } - public IList GenericParameters { get; } = new List(); - public IList GenericConstraints { get; } = new List(); - public IList ContainingClasses { get; } = new List(); + public string Name { get; set; } = default!; + public AccessModifier AccessModifier { get; set; } + public InheritanceModifier InheritanceModifier { get; set; } + public IList GenericParameters { get; } = new List(); + public IList GenericConstraints { get; } = new List(); + public IList ContainingClasses { get; } = new List(); - public IList Parameters { get; } = new List(); - public IList BaseTypes { get; } = new List(); - public AttributeSet Attributes { get; } = new(); + public IList Parameters { get; } = new List(); + public IList BaseTypes { get; } = new List(); + public AttributeSet Attributes { get; } = new(); - public MemberOptions Options - { - get => _options ??= new MemberOptions(); - set => _options = value; - } + public MemberOptions Options + { + get => _options ??= new MemberOptions(); + set => _options = value; + } - public bool IsCustom { get; set; } - public bool IsTransient { get; set; } - public bool IsRoutable { get; set; } + public bool IsCustom { get; set; } + public bool IsTransient { get; set; } + public bool IsRoutable { get; set; } - public MessageType Type + public MessageType Type + { + get { - get - { - if (IsCustom) - return MessageType.Custom; + if (IsCustom) + return MessageType.Custom; - if (BaseTypes.Contains(KnownTypes.CommandInterface) || BaseTypes.Contains(KnownTypes.CommandInterfaceFullName)) - return MessageType.Command; + if (BaseTypes.Contains(KnownTypes.CommandInterface) || BaseTypes.Contains(KnownTypes.CommandInterfaceFullName)) + return MessageType.Command; - if (BaseTypes.Contains(KnownTypes.EventInterface) || BaseTypes.Contains(KnownTypes.EventInterfaceFullName)) - return MessageType.Event; + if (BaseTypes.Contains(KnownTypes.EventInterface) || BaseTypes.Contains(KnownTypes.EventInterfaceFullName)) + return MessageType.Event; - if (Name.EndsWith("Command")) - return MessageType.Command; + if (Name.EndsWith("Command")) + return MessageType.Command; - return MessageType.Event; - } + return MessageType.Event; } - - public override string ToString() => Name; } + + public override string ToString() + => Name; } diff --git a/src/Abc.Zebus.MessageDsl/Ast/MessageType.cs b/src/Abc.Zebus.MessageDsl/Ast/MessageType.cs index 2f08cb0..78185c0 100644 --- a/src/Abc.Zebus.MessageDsl/Ast/MessageType.cs +++ b/src/Abc.Zebus.MessageDsl/Ast/MessageType.cs @@ -1,10 +1,9 @@ -namespace Abc.Zebus.MessageDsl.Ast +namespace Abc.Zebus.MessageDsl.Ast; + +public enum MessageType { - public enum MessageType - { - None, - Custom, - Command, - Event - } + None, + Custom, + Command, + Event } diff --git a/src/Abc.Zebus.MessageDsl/Ast/OptionsBase.cs b/src/Abc.Zebus.MessageDsl/Ast/OptionsBase.cs index b04c95f..2208aad 100644 --- a/src/Abc.Zebus.MessageDsl/Ast/OptionsBase.cs +++ b/src/Abc.Zebus.MessageDsl/Ast/OptionsBase.cs @@ -3,52 +3,51 @@ using System.Reflection; using JetBrains.Annotations; -namespace Abc.Zebus.MessageDsl.Ast +namespace Abc.Zebus.MessageDsl.Ast; + +[UsedImplicitly(ImplicitUseTargetFlags.WithMembers)] +public abstract class OptionsBase { - [UsedImplicitly(ImplicitUseTargetFlags.WithMembers)] - public abstract class OptionsBase + public OptionDescriptor? GetOptionDescriptor(string? optionName) { - public OptionDescriptor? GetOptionDescriptor(string? optionName) - { - if (string.IsNullOrEmpty(optionName)) - return null; - - var property = GetType().GetProperties(BindingFlags.Public | BindingFlags.Instance) - .FirstOrDefault(prop => prop.Name.Equals(optionName, StringComparison.OrdinalIgnoreCase)); + if (string.IsNullOrEmpty(optionName)) + return null; - if (property == null) - return null; + var property = GetType().GetProperties(BindingFlags.Public | BindingFlags.Instance) + .FirstOrDefault(prop => prop.Name.Equals(optionName, StringComparison.OrdinalIgnoreCase)); - return new OptionDescriptor(this, property); - } + if (property == null) + return null; - public class OptionDescriptor - { - private readonly OptionsBase _options; - private readonly PropertyInfo _property; + return new OptionDescriptor(this, property); + } + + public class OptionDescriptor + { + private readonly OptionsBase _options; + private readonly PropertyInfo _property; + + public string Name => _property.Name; - public string Name => _property.Name; + public bool IsBoolean => _property.PropertyType == typeof(bool); - public bool IsBoolean => _property.PropertyType == typeof(bool); + internal OptionDescriptor(OptionsBase options, PropertyInfo property) + { + _options = options; + _property = property; + } - internal OptionDescriptor(OptionsBase options, PropertyInfo property) + public bool SetValue(string value) + { + try { - _options = options; - _property = property; + var typedValue = Convert.ChangeType(value, _property.PropertyType); + _property.SetValue(_options, typedValue); + return true; } - - public bool SetValue(string value) + catch { - try - { - var typedValue = Convert.ChangeType(value, _property.PropertyType); - _property.SetValue(_options, typedValue); - return true; - } - catch - { - return false; - } + return false; } } } diff --git a/src/Abc.Zebus.MessageDsl/Ast/ParameterDefinition.cs b/src/Abc.Zebus.MessageDsl/Ast/ParameterDefinition.cs index 44d6e40..a9ee56f 100644 --- a/src/Abc.Zebus.MessageDsl/Ast/ParameterDefinition.cs +++ b/src/Abc.Zebus.MessageDsl/Ast/ParameterDefinition.cs @@ -1,45 +1,46 @@ -namespace Abc.Zebus.MessageDsl.Ast +namespace Abc.Zebus.MessageDsl.Ast; + +public class ParameterDefinition : AstNode, INamedNode { - public class ParameterDefinition : AstNode, INamedNode + public int Tag { get; set; } + public string Name { get; set; } = default!; + public TypeName Type { get; set; } = default!; + public bool IsMarkedOptional { get; set; } + public string? DefaultValue { get; set; } + public bool IsWritableProperty { get; set; } + public AttributeSet Attributes { get; private set; } = new(); + + public FieldRules Rules { - public int Tag { get; set; } - public string Name { get; set; } = default!; - public TypeName Type { get; set; } = default!; - public bool IsMarkedOptional { get; set; } - public string? DefaultValue { get; set; } - public bool IsWritableProperty { get; set; } - public AttributeSet Attributes { get; private set; } = new(); - - public FieldRules Rules + get { - get - { - if (Type.IsRepeated) - return FieldRules.Repeated; + if (Type.IsRepeated) + return FieldRules.Repeated; - if (IsMarkedOptional || Type.IsNullable) - return FieldRules.Optional; + if (IsMarkedOptional || Type.IsNullable) + return FieldRules.Optional; - return FieldRules.Required; - } + return FieldRules.Required; } + } - public bool IsPacked => Rules == FieldRules.Repeated && Type.IsPackable; + public bool IsPacked => Rules == FieldRules.Repeated && Type.IsPackable; - public int? RoutingPosition { get; set; } + public int? RoutingPosition { get; set; } - public ParameterDefinition() - { - } + public ParameterDefinition() + { + } - internal ParameterDefinition(TypeName type, string name) - : this() - { - Type = type; - Name = name; - } + internal ParameterDefinition(TypeName type, string name) + : this() + { + Type = type; + Name = name; + } - public ParameterDefinition Clone() => new() + public ParameterDefinition Clone() + => new() { Tag = Tag, Name = Name, @@ -50,6 +51,6 @@ internal ParameterDefinition(TypeName type, string name) ParseContext = ParseContext }; - public override string ToString() => $"{Type} {Name}"; - } + public override string ToString() + => $"{Type} {Name}"; } diff --git a/src/Abc.Zebus.MessageDsl/Ast/ParsedContracts.cs b/src/Abc.Zebus.MessageDsl/Ast/ParsedContracts.cs index 76315ed..ac048eb 100644 --- a/src/Abc.Zebus.MessageDsl/Ast/ParsedContracts.cs +++ b/src/Abc.Zebus.MessageDsl/Ast/ParsedContracts.cs @@ -4,107 +4,106 @@ using Antlr4.Runtime; using JetBrains.Annotations; -namespace Abc.Zebus.MessageDsl.Ast +namespace Abc.Zebus.MessageDsl.Ast; + +public class ParsedContracts { - public class ParsedContracts - { - public IList Messages { get; } = new List(); - public IList Enums { get; } = new List(); - public ContractOptions Options { get; } = new(); - public ICollection Errors { get; } - public string Namespace { get; set; } = string.Empty; - public bool ExplicitNamespace { get; internal set; } - public ICollection ImportedNamespaces { get; } = new HashSet(); + public IList Messages { get; } = new List(); + public IList Enums { get; } = new List(); + public ContractOptions Options { get; } = new(); + public ICollection Errors { get; } + public string Namespace { get; set; } = string.Empty; + public bool ExplicitNamespace { get; internal set; } + public ICollection ImportedNamespaces { get; } = new HashSet(); - public CommonTokenStream TokenStream { get; } - public MessageContractsParser.CompileUnitContext ParseTree { get; } + public CommonTokenStream TokenStream { get; } + public MessageContractsParser.CompileUnitContext ParseTree { get; } - public bool IsValid => Errors.Count == 0; + public bool IsValid => Errors.Count == 0; - internal ParsedContracts() - { - // For unit tests + internal ParsedContracts() + { + // For unit tests - TokenStream = default!; - ParseTree = default!; - Errors = new List(); - } + TokenStream = default!; + ParseTree = default!; + Errors = new List(); + } - private ParsedContracts(CommonTokenStream tokenStream, MessageContractsParser.CompileUnitContext parseTree, ICollection errors) - { - TokenStream = tokenStream; - ParseTree = parseTree; - Errors = errors; - } + private ParsedContracts(CommonTokenStream tokenStream, MessageContractsParser.CompileUnitContext parseTree, ICollection errors) + { + TokenStream = tokenStream; + ParseTree = parseTree; + Errors = errors; + } - public static ParsedContracts CreateParseTree(string definitionText) - { - var errorListener = new CollectingErrorListener(); + public static ParsedContracts CreateParseTree(string definitionText) + { + var errorListener = new CollectingErrorListener(); - var input = new AntlrInputStream(definitionText); + var input = new AntlrInputStream(definitionText); - var lexer = new MessageContractsLexer(input); - lexer.RemoveErrorListeners(); - lexer.AddErrorListener(errorListener); + var lexer = new MessageContractsLexer(input); + lexer.RemoveErrorListeners(); + lexer.AddErrorListener(errorListener); - var tokenStream = new CommonTokenStream(lexer); + var tokenStream = new CommonTokenStream(lexer); - var parser = new MessageContractsParser(tokenStream); - parser.RemoveErrorListeners(); - parser.AddErrorListener(errorListener); + var parser = new MessageContractsParser(tokenStream); + parser.RemoveErrorListeners(); + parser.AddErrorListener(errorListener); - var parseTree = parser.compileUnit(); + var parseTree = parser.compileUnit(); - return new ParsedContracts(tokenStream, parseTree, errorListener.Errors); - } - - public static ParsedContracts Parse(string definitionText, string defaultNamespace) - { - var result = CreateParseTree(definitionText); + return new ParsedContracts(tokenStream, parseTree, errorListener.Errors); + } - if (!result.ExplicitNamespace) - result.Namespace = defaultNamespace; + public static ParsedContracts Parse(string definitionText, string defaultNamespace) + { + var result = CreateParseTree(definitionText); - if (result.Errors.Count == 0) - { - new AstCreationVisitor(result).VisitCompileUnit(result.ParseTree); - result.Process(); - } + if (!result.ExplicitNamespace) + result.Namespace = defaultNamespace; - return result; + if (result.Errors.Count == 0) + { + new AstCreationVisitor(result).VisitCompileUnit(result.ParseTree); + result.Process(); } - internal void Process() - { - var processor = new AstProcessor(this); + return result; + } - processor.PreProcess(); - new AttributeInterpreter(this).InterpretAttributes(); - new ContractsEnhancer(this).Process(); - processor.PostProcess(); + internal void Process() + { + var processor = new AstProcessor(this); - new AstValidator(this).Validate(); - } + processor.PreProcess(); + new AttributeInterpreter(this).InterpretAttributes(); + new ContractsEnhancer(this).Process(); + processor.PostProcess(); - public void AddError(string message) - => Errors.Add(new SyntaxError(message)); + new AstValidator(this).Validate(); + } - public void AddError(IToken? token, string message) - => Errors.Add(new SyntaxError(message, token)); + public void AddError(string message) + => Errors.Add(new SyntaxError(message)); - public void AddError(ParserRuleContext? context, string message) - => AddError(context?.Start, message); + public void AddError(IToken? token, string message) + => Errors.Add(new SyntaxError(message, token)); - [StringFormatMethod("format")] - public void AddError(ParserRuleContext? context, string format, params object?[] args) - => AddError(context, string.Format(format, args)); + public void AddError(ParserRuleContext? context, string message) + => AddError(context?.Start, message); - [StringFormatMethod("format")] - public void AddError(IToken? token, string format, params object?[] args) - => AddError(token, string.Format(format, args)); + [StringFormatMethod("format")] + public void AddError(ParserRuleContext? context, string format, params object?[] args) + => AddError(context, string.Format(format, args)); - [StringFormatMethod("format")] - public void AddError(string format, params object?[] args) - => AddError(string.Format(format, args)); - } + [StringFormatMethod("format")] + public void AddError(IToken? token, string format, params object?[] args) + => AddError(token, string.Format(format, args)); + + [StringFormatMethod("format")] + public void AddError(string format, params object?[] args) + => AddError(string.Format(format, args)); } diff --git a/src/Abc.Zebus.MessageDsl/Ast/TypeName.cs b/src/Abc.Zebus.MessageDsl/Ast/TypeName.cs index b6f9203..f149325 100644 --- a/src/Abc.Zebus.MessageDsl/Ast/TypeName.cs +++ b/src/Abc.Zebus.MessageDsl/Ast/TypeName.cs @@ -5,193 +5,192 @@ using Abc.Zebus.MessageDsl.Generator; using Abc.Zebus.MessageDsl.Support; -namespace Abc.Zebus.MessageDsl.Ast +namespace Abc.Zebus.MessageDsl.Ast; + +public sealed class TypeName : IEquatable { - public sealed class TypeName : IEquatable + private static readonly Regex _reSystemTypeName = new(@"\b(?:global::|(?\w+)(?!\.)\b", RegexOptions.Compiled | RegexOptions.CultureInvariant); + private static readonly Regex _reUnqualifiedName = new(@"\b(?]),", RegexOptions.Compiled | RegexOptions.CultureInvariant); + private static readonly Regex _reIdentifierPart = new(@"[^,\[\]<>\s]+", RegexOptions.Compiled | RegexOptions.CultureInvariant); + + private static readonly Dictionary _aliasTypeMap = new() + { + { "bool", "Boolean" }, + { "byte", "Byte" }, + { "sbyte", "SByte" }, + { "char", "Char" }, + { "decimal", "Decimal" }, + { "double", "Double" }, + { "float", "Single" }, + { "int", "Int32" }, + { "uint", "UInt32" }, + { "long", "Int64" }, + { "ulong", "UInt64" }, + { "object", "Object" }, + { "short", "Int16" }, + { "ushort", "UInt16" }, + { "string", "String" }, + }; + + private static readonly Dictionary _clrTypeToAlias; + + private static readonly HashSet _knownBclTypes = new() + { + "TimeSpan", + "DateTime", + "Guid", + "Decimal" + }; + + private static readonly Dictionary _protoTypeNameMap = new() + { + { "Double", "double" }, + { "Single", "float" }, + { "Int32", "int32" }, + { "Int64", "int64" }, + { "UInt32", "uint32" }, + { "UInt64", "uint64" }, + { "Boolean", "bool" }, + { "String", "string" }, + }; + + private static readonly HashSet _packableProtoBufTypes = new() + { + "double", "float", "int32", "int64", "uint32", "uint64", "sint32", "sint64", "fixed32", "fixed64", "sfixed32", "sfixed64", "bool" + }; + + private static readonly HashSet _knownValueTypes = new() + { + "bool", "byte", "sbyte", "char", "decimal", "double", "float", "int", "uint", "long", "ulong", "short", "ushort", + "TimeSpan", "DateTime", "Guid" + }; + + private static readonly HashSet _csharpNonTypeKeywords = CSharpSyntax.EnumerateCSharpKeywords().Except(_aliasTypeMap.Keys).ToHashSet(); + + static TypeName() + { + _clrTypeToAlias = new Dictionary(); + foreach (var pair in _aliasTypeMap) + _clrTypeToAlias.Add(pair.Value, pair.Key); + + _clrTypeToAlias.Add(typeof(List<>).Namespace + ".List", "List"); + } + + private string? _protoBufType; + + public string NetType { get; } + + public string ProtoBufType => _protoBufType ??= GetProtoBufType(); + + public bool IsArray => NetType.EndsWith("[]") || NetType.EndsWith("[]?"); + public bool IsList => NetType.StartsWith("List<") && (NetType.EndsWith(">") || NetType.EndsWith(">?")); + public bool IsDictionary => NetType.StartsWith("Dictionary<") && (NetType.EndsWith(">") || NetType.EndsWith(">?")); + public bool IsHashSet => NetType.StartsWith("HashSet<") && (NetType.EndsWith(">") || NetType.EndsWith(">?")); + public bool IsRepeated => IsArray || IsList || IsHashSet; + + public bool IsNullable => NetType.EndsWith("?"); + + public bool IsPackable => IsRepeated && _packableProtoBufTypes.Contains(ProtoBufType); + + public TypeName(string? netType) { - private static readonly Regex _reSystemTypeName = new(@"\b(?:global::|(?\w+)(?!\.)\b", RegexOptions.Compiled | RegexOptions.CultureInvariant); - private static readonly Regex _reUnqualifiedName = new(@"\b(?]),", RegexOptions.Compiled | RegexOptions.CultureInvariant); - private static readonly Regex _reIdentifierPart = new(@"[^,\[\]<>\s]+", RegexOptions.Compiled | RegexOptions.CultureInvariant); - - private static readonly Dictionary _aliasTypeMap = new() - { - { "bool", "Boolean" }, - { "byte", "Byte" }, - { "sbyte", "SByte" }, - { "char", "Char" }, - { "decimal", "Decimal" }, - { "double", "Double" }, - { "float", "Single" }, - { "int", "Int32" }, - { "uint", "UInt32" }, - { "long", "Int64" }, - { "ulong", "UInt64" }, - { "object", "Object" }, - { "short", "Int16" }, - { "ushort", "UInt16" }, - { "string", "String" }, - }; - - private static readonly Dictionary _clrTypeToAlias; - - private static readonly HashSet _knownBclTypes = new() - { - "TimeSpan", - "DateTime", - "Guid", - "Decimal" - }; - - private static readonly Dictionary _protoTypeNameMap = new() - { - { "Double", "double" }, - { "Single", "float" }, - { "Int32", "int32" }, - { "Int64", "int64" }, - { "UInt32", "uint32" }, - { "UInt64", "uint64" }, - { "Boolean", "bool" }, - { "String", "string" }, - }; - - private static readonly HashSet _packableProtoBufTypes = new() - { - "double", "float", "int32", "int64", "uint32", "uint64", "sint32", "sint64", "fixed32", "fixed64", "sfixed32", "sfixed64", "bool" - }; - - private static readonly HashSet _knownValueTypes = new() - { - "bool", "byte", "sbyte", "char", "decimal", "double", "float", "int", "uint", "long", "ulong", "short", "ushort", - "TimeSpan", "DateTime", "Guid" - }; - - private static readonly HashSet _csharpNonTypeKeywords = CSharpSyntax.EnumerateCSharpKeywords().Except(_aliasTypeMap.Keys).ToHashSet(); - - static TypeName() - { - _clrTypeToAlias = new Dictionary(); - foreach (var pair in _aliasTypeMap) - _clrTypeToAlias.Add(pair.Value, pair.Key); - - _clrTypeToAlias.Add(typeof(List<>).Namespace + ".List", "List"); - } - - private string? _protoBufType; - - public string NetType { get; } - - public string ProtoBufType => _protoBufType ??= GetProtoBufType(); - - public bool IsArray => NetType.EndsWith("[]") || NetType.EndsWith("[]?"); - public bool IsList => NetType.StartsWith("List<") && (NetType.EndsWith(">") || NetType.EndsWith(">?")); - public bool IsDictionary => NetType.StartsWith("Dictionary<") && (NetType.EndsWith(">") || NetType.EndsWith(">?")); - public bool IsHashSet => NetType.StartsWith("HashSet<") && (NetType.EndsWith(">") || NetType.EndsWith(">?")); - public bool IsRepeated => IsArray || IsList || IsHashSet; - - public bool IsNullable => NetType.EndsWith("?"); + NetType = NormalizeName(netType ?? string.Empty); + } - public bool IsPackable => IsRepeated && _packableProtoBufTypes.Contains(ProtoBufType); - - public TypeName(string? netType) - { - NetType = NormalizeName(netType ?? string.Empty); - } - - public TypeName? GetRepeatedItemType() - { - var nullableCharLength = IsNullable ? "?".Length : 0; - - if (IsArray) - return NetType.Substring(0, NetType.Length - "[]".Length - nullableCharLength); - - if (IsList) - return NetType.Substring("List<".Length, NetType.Length - "List<>".Length - nullableCharLength); - - if (IsHashSet) - return NetType.Substring("HashSet<".Length, NetType.Length - "HashSet<>".Length - nullableCharLength); + public TypeName? GetRepeatedItemType() + { + var nullableCharLength = IsNullable ? "?".Length : 0; - return null; - } + if (IsArray) + return NetType.Substring(0, NetType.Length - "[]".Length - nullableCharLength); + + if (IsList) + return NetType.Substring("List<".Length, NetType.Length - "List<>".Length - nullableCharLength); + + if (IsHashSet) + return NetType.Substring("HashSet<".Length, NetType.Length - "HashSet<>".Length - nullableCharLength); + + return null; + } - public TypeName GetNonNullableType() - { - if (IsNullable) - return NetType.Substring(0, NetType.Length - "?".Length); + public TypeName GetNonNullableType() + { + if (IsNullable) + return NetType.Substring(0, NetType.Length - "?".Length); - return this; - } + return this; + } - public bool IsKnownValueType() - => _knownValueTypes.Contains(GetNonNullableType().NetType); + public bool IsKnownValueType() + => _knownValueTypes.Contains(GetNonNullableType().NetType); - public static implicit operator TypeName(string? netType) => new(netType); + public static implicit operator TypeName(string? netType) => new(netType); - public override bool Equals(object? obj) => Equals(obj as TypeName); + public override bool Equals(object? obj) => Equals(obj as TypeName); - public bool Equals(TypeName? other) => other != null && other.NetType == NetType; + public bool Equals(TypeName? other) => other != null && other.NetType == NetType; - public override int GetHashCode() => NetType.GetHashCode(); + public override int GetHashCode() => NetType.GetHashCode(); - public override string ToString() => NetType; + public override string ToString() => NetType; - private static string NormalizeName(string name) - { - name = _reWhitespace.Replace(name, string.Empty); - name = _reComma.Replace(name, ", "); + private static string NormalizeName(string name) + { + name = _reWhitespace.Replace(name, string.Empty); + name = _reComma.Replace(name, ", "); - name = _reIdentifierPart.Replace(name, match => NormalizeNamePart(match.Value)); + name = _reIdentifierPart.Replace(name, match => NormalizeNamePart(match.Value)); - return name; - } + return name; + } - private static string NormalizeNamePart(string name) - { - name = name.TrimStart('@'); + private static string NormalizeNamePart(string name) + { + name = name.TrimStart('@'); - name = _reSystemTypeName.Replace( - name, - match => - { - var unqualifiedName = match.Groups["unqualifiedName"].Value; - return _clrTypeToAlias.GetValueOrDefault(unqualifiedName) ?? unqualifiedName; - } - ); + name = _reSystemTypeName.Replace( + name, + match => + { + var unqualifiedName = match.Groups["unqualifiedName"].Value; + return _clrTypeToAlias.GetValueOrDefault(unqualifiedName) ?? unqualifiedName; + } + ); - name = _clrTypeToAlias.GetValueOrDefault(name) ?? name; + name = _clrTypeToAlias.GetValueOrDefault(name) ?? name; - if (_csharpNonTypeKeywords.Contains(name)) - return "@" + name; + if (_csharpNonTypeKeywords.Contains(name)) + return "@" + name; - return name; - } + return name; + } - private string GetClrSystemTypeName() - { - var name = _reSystemTypeName.Replace(NetType, "${unqualifiedName}"); - return _reUnqualifiedName.Replace(name, match => _aliasTypeMap.GetValueOrDefault(match.Value) ?? match.Value); - } + private string GetClrSystemTypeName() + { + var name = _reSystemTypeName.Replace(NetType, "${unqualifiedName}"); + return _reUnqualifiedName.Replace(name, match => _aliasTypeMap.GetValueOrDefault(match.Value) ?? match.Value); + } - private string GetProtoBufType() - { - var type = GetNonNullableType(); + private string GetProtoBufType() + { + var type = GetNonNullableType(); - if (type.IsRepeated) - type = type.GetRepeatedItemType()!; + if (type.IsRepeated) + type = type.GetRepeatedItemType()!; - var clrName = type.GetClrSystemTypeName(); + var clrName = type.GetClrSystemTypeName(); - if (clrName.StartsWith("@")) - clrName = clrName.Substring(1); + if (clrName.StartsWith("@")) + clrName = clrName.Substring(1); - if (_knownBclTypes.Contains(clrName)) - return "bcl." + clrName; + if (_knownBclTypes.Contains(clrName)) + return "bcl." + clrName; - var name = _protoTypeNameMap.GetValueOrDefault(clrName) ?? type.NetType; - name = name.Replace("::", "."); + var name = _protoTypeNameMap.GetValueOrDefault(clrName) ?? type.NetType; + name = name.Replace("::", "."); - return name; - } + return name; } } diff --git a/src/Abc.Zebus.MessageDsl/Dsl/CollectingErrorListener.cs b/src/Abc.Zebus.MessageDsl/Dsl/CollectingErrorListener.cs index e2854da..d225105 100644 --- a/src/Abc.Zebus.MessageDsl/Dsl/CollectingErrorListener.cs +++ b/src/Abc.Zebus.MessageDsl/Dsl/CollectingErrorListener.cs @@ -3,76 +3,68 @@ using Antlr4.Runtime; using Antlr4.Runtime.Atn; -namespace Abc.Zebus.MessageDsl.Dsl -{ - internal class CollectingErrorListener : IAntlrErrorListener, IAntlrErrorListener - { - public ICollection Errors { get; } +namespace Abc.Zebus.MessageDsl.Dsl; - public CollectingErrorListener() - { - Errors = new List(); - } +internal class CollectingErrorListener : IAntlrErrorListener, IAntlrErrorListener +{ + public ICollection Errors { get; } = new List(); - public void SyntaxError(IRecognizer recognizer, int offendingSymbol, int line, int charPositionInLine, string msg, RecognitionException e) + public void SyntaxError(IRecognizer recognizer, int offendingSymbol, int line, int charPositionInLine, string msg, RecognitionException e) + { + var fakeToken = new CommonToken(0, string.Empty) { - var fakeToken = new CommonToken(0, string.Empty) - { - Line = line, - Column = charPositionInLine - }; + Line = line, + Column = charPositionInLine + }; - ReportError(msg, fakeToken, e); - } + ReportError(msg, fakeToken, e); + } - public void SyntaxError(IRecognizer recognizer, IToken offendingSymbol, int line, int charPositionInLine, string msg, RecognitionException e) - { - ReportError(msg, offendingSymbol, e); - } + public void SyntaxError(IRecognizer recognizer, IToken offendingSymbol, int line, int charPositionInLine, string msg, RecognitionException e) + => ReportError(msg, offendingSymbol, e); - private void ReportError(string msg, IToken offendingSymbol, RecognitionException exception) + private void ReportError(string msg, IToken offendingSymbol, RecognitionException exception) + { + switch (exception) { - switch (exception) + case NoViableAltException noViableAlt: { - case NoViableAltException noViableAlt: + var errorToken = noViableAlt.OffendingToken ?? noViableAlt.StartToken; + if (errorToken != null) { - var errorToken = noViableAlt.OffendingToken ?? noViableAlt.StartToken; - if (errorToken != null) - { - msg = errorToken.Type == Recognizer.Eof - ? "More input expected, the file is not terminated properly" - : $"Unexpected input at {GetTokenDisplay(errorToken)}"; - } - - break; + msg = errorToken.Type == Recognizer.Eof + ? "More input expected, the file is not terminated properly" + : $"Unexpected input at {GetTokenDisplay(errorToken)}"; } - case FailedPredicateException failedPredicate: - { - var tokenDisplay = GetTokenDisplay(failedPredicate.OffendingToken); - msg = $"Syntax error at {tokenDisplay}"; - - if (failedPredicate.Context is MessageContractsParser.EndOfLineContext) - msg = $"End of line expected at {tokenDisplay}"; - break; - } + break; } - msg = Regex.Replace(msg, @"expecting\s+\{(.+)\}", "expecting one of: $1"); - Errors.Add(new SyntaxError(msg, offendingSymbol)); + case FailedPredicateException failedPredicate: + { + var tokenDisplay = GetTokenDisplay(failedPredicate.OffendingToken); + msg = $"Syntax error at {tokenDisplay}"; + + if (failedPredicate.Context is MessageContractsParser.EndOfLineContext) + msg = $"End of line expected at {tokenDisplay}"; + break; + } } - private static string GetTokenDisplay(IToken? token) - { - if (token == null) - return "(unknown)"; + msg = Regex.Replace(msg, @"expecting\s+\{(.+)\}", "expecting one of: $1"); + Errors.Add(new SyntaxError(msg, offendingSymbol)); + } - if (token.Type == Recognizer.Eof) - return "(end of expression)"; + private static string GetTokenDisplay(IToken? token) + { + if (token == null) + return "(unknown)"; - var str = token.Text ?? string.Empty; - str = str.Replace('\n', ' ').Replace('\r', ' '); - return $"'{str}'"; - } + if (token.Type == Recognizer.Eof) + return "(end of expression)"; + + var str = token.Text ?? string.Empty; + str = str.Replace('\n', ' ').Replace('\r', ' '); + return $"'{str}'"; } } diff --git a/src/Abc.Zebus.MessageDsl/Dsl/MessageContracts.g4.lexer.cs b/src/Abc.Zebus.MessageDsl/Dsl/MessageContracts.g4.lexer.cs index aa5433e..4b6843e 100644 --- a/src/Abc.Zebus.MessageDsl/Dsl/MessageContracts.g4.lexer.cs +++ b/src/Abc.Zebus.MessageDsl/Dsl/MessageContracts.g4.lexer.cs @@ -1,6 +1,3 @@ -namespace Abc.Zebus.MessageDsl.Dsl -{ - partial class MessageContractsLexer - { - } -} +namespace Abc.Zebus.MessageDsl.Dsl; + +partial class MessageContractsLexer; diff --git a/src/Abc.Zebus.MessageDsl/Dsl/MessageContracts.g4.parser.cs b/src/Abc.Zebus.MessageDsl/Dsl/MessageContracts.g4.parser.cs index 754d0eb..927d8c3 100644 --- a/src/Abc.Zebus.MessageDsl/Dsl/MessageContracts.g4.parser.cs +++ b/src/Abc.Zebus.MessageDsl/Dsl/MessageContracts.g4.parser.cs @@ -2,88 +2,87 @@ using Abc.Zebus.MessageDsl.Generator; using Antlr4.Runtime; -namespace Abc.Zebus.MessageDsl.Dsl +namespace Abc.Zebus.MessageDsl.Dsl; + +partial class MessageContractsParser { - partial class MessageContractsParser + private bool IsAtStartOfPragma() { - private bool IsAtStartOfPragma() - { - var prevToken = _input.Lt(-1); - var sharpToken = _input.Lt(1); - var pragmaToken = _input.Lt(2); + var prevToken = _input.Lt(-1); + var sharpToken = _input.Lt(1); + var pragmaToken = _input.Lt(2); - if (sharpToken == null || pragmaToken == null) - return false; + if (sharpToken == null || pragmaToken == null) + return false; - if (sharpToken.Text != "#" || pragmaToken.Text != "pragma") - return false; + if (sharpToken.Text != "#" || pragmaToken.Text != "pragma") + return false; - if (prevToken != null && prevToken.Line == sharpToken.Line) - return false; + if (prevToken != null && prevToken.Line == sharpToken.Line) + return false; - if (sharpToken.Line != pragmaToken.Line) - return false; + if (sharpToken.Line != pragmaToken.Line) + return false; - return true; - } + return true; + } - private bool IsAtImplicitSeparator() - { - var prevToken = _input.Lt(-1); - var nextToken = _input.Lt(1); - - return prevToken == null - || nextToken == null - || nextToken.Type == TokenConstants.Eof - || prevToken.Type == MessageContractsLexer.SEP - || nextToken.Type == MessageContractsLexer.SEP - || prevToken.Line != nextToken.Line; - } + private bool IsAtImplicitSeparator() + { + var prevToken = _input.Lt(-1); + var nextToken = _input.Lt(1); + + return prevToken == null + || nextToken == null + || nextToken.Type == TokenConstants.Eof + || prevToken.Type == MessageContractsLexer.SEP + || nextToken.Type == MessageContractsLexer.SEP + || prevToken.Line != nextToken.Line; + } - private bool IsAtEndOfLine() - { - var prevToken = _input.Lt(-1); - var nextToken = _input.Lt(1); + private bool IsAtEndOfLine() + { + var prevToken = _input.Lt(-1); + var nextToken = _input.Lt(1); - return prevToken == null - || nextToken == null - || nextToken.Type == TokenConstants.Eof - || prevToken.Line != nextToken.Line; - } + return prevToken == null + || nextToken == null + || nextToken.Type == TokenConstants.Eof + || prevToken.Line != nextToken.Line; + } - private static bool IsValidIdEscape(IToken? escapeToken, IToken? nameToken) - { - if (escapeToken == null) - return true; + private static bool IsValidIdEscape(IToken? escapeToken, IToken? nameToken) + { + if (escapeToken == null) + return true; - return nameToken?.StartIndex == escapeToken.StopIndex + 1; - } + return nameToken?.StartIndex == escapeToken.StopIndex + 1; + } - private static bool IsValidIdEscape(IToken? escapeToken, ParserRuleContext? nameContext) - => IsValidIdEscape(escapeToken, nameContext?.Start); + private static bool IsValidIdEscape(IToken? escapeToken, ParserRuleContext? nameContext) + => IsValidIdEscape(escapeToken, nameContext?.Start); - private bool AreTwoNextTokensConsecutive() - { - var firstToken = _input.Lt(1); - var secondToken = _input.Lt(2); + private bool AreTwoNextTokensConsecutive() + { + var firstToken = _input.Lt(1); + var secondToken = _input.Lt(2); - return secondToken.StartIndex == firstToken.StopIndex + 1; - } + return secondToken.StartIndex == firstToken.StopIndex + 1; + } - partial class IdContext + partial class IdContext + { + public string GetValidatedId(ParsedContracts contracts) { - public string GetValidatedId(ParsedContracts contracts) - { - var id = nameId?.Text ?? nameCtxKw?.GetText() ?? nameKw?.GetText(); + var id = nameId?.Text ?? nameCtxKw?.GetText() ?? nameKw?.GetText(); - if (string.IsNullOrEmpty(id)) - return string.Empty; + if (string.IsNullOrEmpty(id)) + return string.Empty; - if (CSharpSyntax.IsCSharpKeyword(id!) && escape == null) - contracts.AddError(this, "'{0}' is a C# keyword and has to be escaped with '@'", id); + if (CSharpSyntax.IsCSharpKeyword(id) && escape == null) + contracts.AddError(this, "'{0}' is a C# keyword and has to be escaped with '@'", id); - return id!; - } + return id; } } } diff --git a/src/Abc.Zebus.MessageDsl/Dsl/SyntaxError.cs b/src/Abc.Zebus.MessageDsl/Dsl/SyntaxError.cs index af08296..a5b8f16 100644 --- a/src/Abc.Zebus.MessageDsl/Dsl/SyntaxError.cs +++ b/src/Abc.Zebus.MessageDsl/Dsl/SyntaxError.cs @@ -1,34 +1,33 @@ using Antlr4.Runtime; -namespace Abc.Zebus.MessageDsl.Dsl +namespace Abc.Zebus.MessageDsl.Dsl; + +public class SyntaxError { - public class SyntaxError - { - public int LineNumber { get; } - public int CharacterInLine { get; } - public string? Token { get; } + public int LineNumber { get; } + public int CharacterInLine { get; } + public string? Token { get; } - public string Message { get; } + public string Message { get; } - public SyntaxError(string message, IToken? startToken = null) - { - Message = message; - - if (startToken != null) - { - LineNumber = startToken.Line; - CharacterInLine = startToken.Column + 1; - Token = startToken.Text; - } - } + public SyntaxError(string message, IToken? startToken = null) + { + Message = message; - public override string ToString() + if (startToken != null) { - return LineNumber > 0 - ? CharacterInLine > 0 - ? $"[{LineNumber}:{CharacterInLine}] {Message}" - : $"[{LineNumber}] {Message}" - : Message; + LineNumber = startToken.Line; + CharacterInLine = startToken.Column + 1; + Token = startToken.Text; } } + + public override string ToString() + { + return LineNumber > 0 + ? CharacterInLine > 0 + ? $"[{LineNumber}:{CharacterInLine}] {Message}" + : $"[{LineNumber}] {Message}" + : Message; + } } diff --git a/src/Abc.Zebus.MessageDsl/Generator/CSharpGenerator.cs b/src/Abc.Zebus.MessageDsl/Generator/CSharpGenerator.cs index 58df062..24bb846 100644 --- a/src/Abc.Zebus.MessageDsl/Generator/CSharpGenerator.cs +++ b/src/Abc.Zebus.MessageDsl/Generator/CSharpGenerator.cs @@ -5,383 +5,431 @@ using Abc.Zebus.MessageDsl.Analysis; using Abc.Zebus.MessageDsl.Ast; -namespace Abc.Zebus.MessageDsl.Generator +namespace Abc.Zebus.MessageDsl.Generator; + +public sealed class CSharpGenerator : GeneratorBase { - public sealed class CSharpGenerator : GeneratorBase - { - private static readonly AttributeDefinition _attrProtoContract = new(KnownTypes.ProtoContractAttribute); - private static readonly AttributeDefinition _attrNonUserCode = new("System.Diagnostics.DebuggerNonUserCode"); - private static readonly AttributeDefinition _attrGeneratedCode = new("System.CodeDom.Compiler.GeneratedCode", $@"""{GeneratorName}"", ""{GeneratorVersion}"""); + private static readonly AttributeDefinition _attrProtoContract = new(KnownTypes.ProtoContractAttribute); + private static readonly AttributeDefinition _attrNonUserCode = new("System.Diagnostics.DebuggerNonUserCode"); + private static readonly AttributeDefinition _attrGeneratedCode = new("System.CodeDom.Compiler.GeneratedCode", $@"""{GeneratorName}"", ""{GeneratorVersion}"""); - private readonly Dictionary _messagesByName = new(); + private readonly Dictionary _messagesByName = new(); - private ParsedContracts Contracts { get; } + private ParsedContracts Contracts { get; } - private CSharpGenerator(ParsedContracts contracts) - { - Contracts = contracts; + private CSharpGenerator(ParsedContracts contracts) + { + Contracts = contracts; - foreach (var message in contracts.Messages) - _messagesByName[message.Name] = message; - } + foreach (var message in contracts.Messages) + _messagesByName[message.Name] = message; + } - public static string Generate(ParsedContracts contracts) - { - using var generator = new CSharpGenerator(contracts); - return generator.Generate(); - } + public static string Generate(ParsedContracts contracts) + { + using var generator = new CSharpGenerator(contracts); + return generator.Generate(); + } - private string Generate() - { - Reset(); + private string Generate() + { + Reset(); + + WriteHeader(); + WriteUsingDirectives(); + WritePragmas(); - WriteHeader(); - WriteUsingDirectives(); - WritePragmas(); + var hasNamespace = !string.IsNullOrEmpty(Contracts.Namespace); + if (hasNamespace) + Writer.WriteLine("namespace {0}", Identifier(Contracts.Namespace)); - var hasNamespace = !string.IsNullOrEmpty(Contracts.Namespace); - if (hasNamespace) - Writer.WriteLine("namespace {0}", Identifier(Contracts.Namespace)); + using (hasNamespace ? Block() : null) + { + var firstMember = true; - using (hasNamespace ? Block() : null) + foreach (var enumDef in Contracts.Enums) { - var firstMember = true; + if (!firstMember) + Writer.WriteLine(); - foreach (var enumDef in Contracts.Enums) - { - if (!firstMember) - Writer.WriteLine(); + WriteEnum(enumDef); + firstMember = false; + } - WriteEnum(enumDef); - firstMember = false; - } + var nullableRefTypes = false; - var nullableRefTypes = false; + foreach (var message in Contracts.Messages) + { + if (!firstMember) + Writer.WriteLine(); - foreach (var message in Contracts.Messages) + if (message.Options.Nullable != nullableRefTypes) { - if (!firstMember) - Writer.WriteLine(); - - if (message.Options.Nullable != nullableRefTypes) - { - WriteNullableDirective(message.Options.Nullable); - nullableRefTypes = message.Options.Nullable; - } - - WriteMessage(message); - firstMember = false; + WriteNullableDirective(message.Options.Nullable); + nullableRefTypes = message.Options.Nullable; } - } - return GeneratedOutput(); + WriteMessage(message); + firstMember = false; + } } - private void WriteHeader() - { - Writer.WriteLine("//------------------------------------------------------------------------------"); - Writer.WriteLine("// "); - Writer.WriteLine("// This code was generated by a tool."); - Writer.WriteLine("// "); - Writer.WriteLine("//------------------------------------------------------------------------------"); - Writer.WriteLine(); - } + return GeneratedOutput(); + } - private void WriteUsingDirectives() - { - var orderedNamespaces = Contracts.ImportedNamespaces - .OrderByDescending(ns => ns == "System" || ns.StartsWith("System.")) - .ThenBy(ns => ns, StringComparer.OrdinalIgnoreCase); + private void WriteHeader() + { + Writer.WriteLine("//------------------------------------------------------------------------------"); + Writer.WriteLine("// "); + Writer.WriteLine("// This code was generated by a tool."); + Writer.WriteLine("// "); + Writer.WriteLine("//------------------------------------------------------------------------------"); + Writer.WriteLine(); + } - foreach (var ns in orderedNamespaces) - Writer.WriteLine("using {0};", Identifier(ns)); + private void WriteUsingDirectives() + { + var orderedNamespaces = Contracts.ImportedNamespaces + .OrderByDescending(ns => ns == "System" || ns.StartsWith("System.")) + .ThenBy(ns => ns, StringComparer.OrdinalIgnoreCase); - Writer.WriteLine(); - } + foreach (var ns in orderedNamespaces) + Writer.WriteLine("using {0};", Identifier(ns)); - private void WritePragmas() - { - var hasObsolete = Contracts.Messages.Any(m => m.Attributes.HasAttribute(KnownTypes.ObsoleteAttribute)) - || Contracts.Messages.SelectMany(m => m.Parameters).Any(p => p.Attributes.HasAttribute(KnownTypes.ObsoleteAttribute)) - || Contracts.Enums.Any(m => m.Attributes.HasAttribute(KnownTypes.ObsoleteAttribute)) - || Contracts.Enums.SelectMany(m => m.Members).Any(m => m.Attributes.HasAttribute(KnownTypes.ObsoleteAttribute)); + Writer.WriteLine(); + } - if (hasObsolete) - { - Writer.WriteLine("#pragma warning disable 612"); - Writer.WriteLine(""); - } - } + private void WritePragmas() + { + var hasObsolete = Contracts.Messages.Any(m => m.Attributes.HasAttribute(KnownTypes.ObsoleteAttribute)) + || Contracts.Messages.SelectMany(m => m.Parameters).Any(p => p.Attributes.HasAttribute(KnownTypes.ObsoleteAttribute)) + || Contracts.Enums.Any(m => m.Attributes.HasAttribute(KnownTypes.ObsoleteAttribute)) + || Contracts.Enums.SelectMany(m => m.Members).Any(m => m.Attributes.HasAttribute(KnownTypes.ObsoleteAttribute)); - private void WriteEnum(EnumDefinition enumDef) + if (hasObsolete) { - if (!enumDef.Attributes.HasAttribute(_attrProtoContract.TypeName)) - WriteAttributeLine(_attrProtoContract); + Writer.WriteLine("#pragma warning disable 612"); + Writer.WriteLine(""); + } + } - WriteAttributeLine(_attrGeneratedCode); + private void WriteEnum(EnumDefinition enumDef) + { + if (!enumDef.Attributes.HasAttribute(_attrProtoContract.TypeName)) + WriteAttributeLine(_attrProtoContract); - foreach (var attribute in enumDef.Attributes) - WriteAttributeLine(attribute); + WriteAttributeLine(_attrGeneratedCode); - Writer.Write( - "{0} enum {1}", - AccessModifier(enumDef.AccessModifier), - Identifier(enumDef.Name) - ); + foreach (var attribute in enumDef.Attributes) + WriteAttributeLine(attribute); - if (enumDef.UnderlyingType.NetType != "int") - Writer.Write(" : {0}", enumDef.UnderlyingType.NetType); + Writer.Write( + "{0} enum {1}", + AccessModifier(enumDef.AccessModifier), + Identifier(enumDef.Name) + ); - Writer.WriteLine(); + if (enumDef.UnderlyingType.NetType != "int") + Writer.Write(" : {0}", enumDef.UnderlyingType.NetType); - using (Block()) - { - var hasAnyAttributeOnMembers = enumDef.Members.Any(m => m.Attributes.Count > 0); - var lastMember = enumDef.Members.LastOrDefault(); + Writer.WriteLine(); - foreach (var member in enumDef.Members) - { - foreach (var attribute in member.Attributes) - WriteAttributeLine(attribute); + using (Block()) + { + var hasAnyAttributeOnMembers = enumDef.Members.Any(m => m.Attributes.Count > 0); + var lastMember = enumDef.Members.LastOrDefault(); - Writer.Write(Identifier(member.Name)); + foreach (var member in enumDef.Members) + { + foreach (var attribute in member.Attributes) + WriteAttributeLine(attribute); - if (!string.IsNullOrEmpty(member.Value)) - Writer.Write(" = {0}", member.Value); + Writer.Write(Identifier(member.Name)); - if (member != lastMember) - { - Writer.Write(","); + if (!string.IsNullOrEmpty(member.Value)) + Writer.Write(" = {0}", member.Value); - if (hasAnyAttributeOnMembers) - Writer.WriteLine(); - } + if (member != lastMember) + { + Writer.Write(","); - Writer.WriteLine(); + if (hasAnyAttributeOnMembers) + Writer.WriteLine(); } + + Writer.WriteLine(); } } + } - private void WriteNullableDirective(bool enable) - { - Writer.WriteLine("#nullable {0}", enable ? "enable" : "disable"); - Writer.WriteLine(); - } + private void WriteNullableDirective(bool enable) + { + Writer.WriteLine("#nullable {0}", enable ? "enable" : "disable"); + Writer.WriteLine(); + } - private void WriteMessage(MessageDefinition message) + private void WriteMessage(MessageDefinition message) + { + var containingClassesStack = new Stack(); + foreach (var containingClass in message.ContainingClasses) { - var containingClassesStack = new Stack(); - foreach (var containingClass in message.ContainingClasses) - { - Writer.Write("partial class "); - Writer.WriteLine(containingClass.NetType); + Writer.Write("partial class "); + Writer.WriteLine(containingClass.NetType); - containingClassesStack.Push(Block()); - } + containingClassesStack.Push(Block()); + } + + if (!message.Attributes.HasAttribute(_attrProtoContract.TypeName)) + WriteAttributeLine(_attrProtoContract); - if (!message.Attributes.HasAttribute(_attrProtoContract.TypeName)) - WriteAttributeLine(_attrProtoContract); + WriteAttributeLine(_attrNonUserCode); + WriteAttributeLine(_attrGeneratedCode); - WriteAttributeLine(_attrNonUserCode); - WriteAttributeLine(_attrGeneratedCode); + foreach (var attribute in message.Attributes) + WriteAttributeLine(attribute); - foreach (var attribute in message.Attributes) - WriteAttributeLine(attribute); + Writer.Write(AccessModifier(message.AccessModifier)); + Writer.Write(" "); - Writer.Write(AccessModifier(message.AccessModifier)); + var inheritanceModifier = InheritanceModifier(message.InheritanceModifier); + if (!string.IsNullOrEmpty(inheritanceModifier)) + { + Writer.Write(inheritanceModifier); Writer.Write(" "); + } + + Writer.Write("partial class "); + Writer.Write(Identifier(message.Name)); + + if (message.GenericParameters.Count > 0) + { + Writer.Write("<"); + var templateParamList = List(); - var inheritanceModifier = InheritanceModifier(message.InheritanceModifier); - if (!string.IsNullOrEmpty(inheritanceModifier)) + foreach (var templateParameter in message.GenericParameters) { - Writer.Write(inheritanceModifier); - Writer.Write(" "); + templateParamList.NextItem(); + Writer.Write(Identifier(templateParameter)); } - Writer.Write("partial class "); - Writer.Write(Identifier(message.Name)); + Writer.Write(">"); + } - if (message.GenericParameters.Count > 0) + if (message.BaseTypes.Count > 0) + { + Writer.Write(" : "); + var baseTypeList = List(); + + foreach (var baseType in message.BaseTypes.Distinct()) { - Writer.Write("<"); - var templateParamList = List(); + baseTypeList.NextItem(); + Writer.Write(baseType.NetType); + } + } - foreach (var templateParameter in message.GenericParameters) - { - templateParamList.NextItem(); - Writer.Write(Identifier(templateParameter)); - } + Writer.WriteLine(); - Writer.Write(">"); - } + WriteGenericConstraints(message); - if (message.BaseTypes.Count > 0) - { - Writer.Write(" : "); - var baseTypeList = List(); + using (Block()) + { + foreach (var param in message.Parameters) + WriteParameterMember(message, param); - foreach (var baseType in message.BaseTypes.Distinct()) - { - baseTypeList.NextItem(); - Writer.Write(baseType.NetType); - } + var parameters = GetConstructorParameters(message); + if (parameters.Count != 0) + { + WriteDefaultConstructor(message); + WriteForwardingConstructor(message, parameters); + WriteMessageConstructor(message, parameters); } + } - Writer.WriteLine(); + while (containingClassesStack.Count != 0) + containingClassesStack.Pop().Dispose(); + } - WriteGenericConstraints(message); + private void WriteGenericConstraints(MessageDefinition message) + { + if (message.GenericConstraints.Count == 0) + return; - using (Block()) + using (Indent()) + { + foreach (var genericConstraint in message.GenericConstraints) { - foreach (var param in message.Parameters) - WriteParameterMember(message, param); + Writer.Write("where "); + Writer.Write(Identifier(genericConstraint.GenericParameterName)); + Writer.Write(" : "); + + var constraintList = List(); - var parameters = GetConstructorParameters(message); - if (parameters.Count != 0) + if (genericConstraint.IsClass) { - WriteDefaultConstructor(message); - WriteForwardingConstructor(message, parameters); - WriteMessageConstructor(message, parameters); + constraintList.NextItem(); + Writer.Write("class"); + } + else if (genericConstraint.IsStruct) + { + constraintList.NextItem(); + Writer.Write("struct"); } - } - - while (containingClassesStack.Count != 0) - containingClassesStack.Pop().Dispose(); - } - - private void WriteGenericConstraints(MessageDefinition message) - { - if (message.GenericConstraints.Count == 0) - return; - using (Indent()) - { - foreach (var genericConstraint in message.GenericConstraints) + foreach (var type in genericConstraint.Types) { - Writer.Write("where "); - Writer.Write(Identifier(genericConstraint.GenericParameterName)); - Writer.Write(" : "); - - var constraintList = List(); - - if (genericConstraint.IsClass) - { - constraintList.NextItem(); - Writer.Write("class"); - } - else if (genericConstraint.IsStruct) - { - constraintList.NextItem(); - Writer.Write("struct"); - } - - foreach (var type in genericConstraint.Types) - { - constraintList.NextItem(); - Writer.Write(type.NetType); - } - - if (genericConstraint.HasDefaultConstructor && !genericConstraint.IsStruct) - { - constraintList.NextItem(); - Writer.Write("new()"); - } + constraintList.NextItem(); + Writer.Write(type.NetType); + } - Writer.WriteLine(); + if (genericConstraint.HasDefaultConstructor && !genericConstraint.IsStruct) + { + constraintList.NextItem(); + Writer.Write("new()"); } + + Writer.WriteLine(); } } + } - private void WriteParameterMember(MessageDefinition message, ParameterDefinition param) + private void WriteParameterMember(MessageDefinition message, ParameterDefinition param) + { + if (!param.Attributes.HasAttribute(KnownTypes.ProtoMemberAttribute)) { - if (!param.Attributes.HasAttribute(KnownTypes.ProtoMemberAttribute)) - { - var protoMemberParams = new StringBuilder(); + var protoMemberParams = new StringBuilder(); - protoMemberParams.Append(param.Tag); - protoMemberParams.AppendFormat(", IsRequired = {0}", param.Rules == FieldRules.Required ? "true" : "false"); + protoMemberParams.Append(param.Tag); + protoMemberParams.AppendFormat(", IsRequired = {0}", param.Rules == FieldRules.Required ? "true" : "false"); - if (param.IsPacked) - protoMemberParams.Append(", IsPacked = true"); + if (param.IsPacked) + protoMemberParams.Append(", IsPacked = true"); - WriteAttributeLine(new AttributeDefinition(KnownTypes.ProtoMemberAttribute, protoMemberParams.ToString())); - } - - foreach (var attribute in param.Attributes) - WriteAttributeLine(attribute); - - var isWritable = param.IsWritableProperty || message.Options.Mutable; - - Writer.Write("public {0} {1}", param.Type.NetType, Identifier(MemberCase(param.Name))); - Writer.WriteLine(isWritable ? " { get; set; }" : " { get; private set; }"); - Writer.WriteLine(); + WriteAttributeLine(new AttributeDefinition(KnownTypes.ProtoMemberAttribute, protoMemberParams.ToString())); } - private void WriteDefaultConstructor(MessageDefinition message) - { - Writer.Write( - message.InheritanceModifier == Ast.InheritanceModifier.Abstract - ? "protected" - : message.Options.Mutable - ? "public" - : "private" - ); + foreach (var attribute in param.Attributes) + WriteAttributeLine(attribute); - Writer.Write(" "); - Writer.Write(Identifier(message.Name)); - Writer.WriteLine("()"); + var isWritable = param.IsWritableProperty || message.Options.Mutable; - WriteDefaultConstructorBody(message); - } + Writer.Write("public {0} {1}", param.Type.NetType, Identifier(MemberCase(param.Name))); + Writer.WriteLine(isWritable ? " { get; set; }" : " { get; private set; }"); + Writer.WriteLine(); + } + + private void WriteDefaultConstructor(MessageDefinition message) + { + Writer.Write( + message.InheritanceModifier == Ast.InheritanceModifier.Abstract + ? "protected" + : message.Options.Mutable + ? "public" + : "private" + ); + + Writer.Write(" "); + Writer.Write(Identifier(message.Name)); + Writer.WriteLine("()"); + + WriteDefaultConstructorBody(message); + } - private void WriteDefaultConstructorBody(MessageDefinition message) + private void WriteDefaultConstructorBody(MessageDefinition message) + { + using (Block()) { - using (Block()) + foreach (var param in message.Parameters) { - foreach (var param in message.Parameters) - { - if (param.Type.IsNullable) - continue; - - if (param.Type.IsArray) - Writer.WriteLine("{0} = Array.Empty<{1}>();", Identifier(MemberCase(param.Name)), param.Type.GetRepeatedItemType()!.NetType); - else if (param.Type.IsList || param.Type.IsDictionary || param.Type.IsHashSet) - Writer.WriteLine("{0} = new {1}();", Identifier(MemberCase(param.Name)), param.Type); - else if (message.Options.Nullable && !param.Type.IsKnownValueType()) - Writer.WriteLine("{0} = default!;", Identifier(MemberCase(param.Name))); - } + if (param.Type.IsNullable) + continue; + + if (param.Type.IsArray) + Writer.WriteLine("{0} = Array.Empty<{1}>();", Identifier(MemberCase(param.Name)), param.Type.GetRepeatedItemType()!.NetType); + else if (param.Type.IsList || param.Type.IsDictionary || param.Type.IsHashSet) + Writer.WriteLine("{0} = new {1}();", Identifier(MemberCase(param.Name)), param.Type); + else if (message.Options.Nullable && !param.Type.IsKnownValueType()) + Writer.WriteLine("{0} = default!;", Identifier(MemberCase(param.Name))); } } + } - private void WriteForwardingConstructor(MessageDefinition message, List parameters) + private void WriteForwardingConstructor(MessageDefinition message, List parameters) + { + if (message.InheritanceModifier == Ast.InheritanceModifier.Sealed + || !message.Options.Mutable + || parameters.Count == 0 + || !parameters[0].IsFromBase + || parameters.All(p => p.IsFromBase)) { - if (message.InheritanceModifier == Ast.InheritanceModifier.Sealed - || !message.Options.Mutable - || parameters.Count == 0 - || !parameters[0].IsFromBase - || parameters.All(p => p.IsFromBase)) - { - return; - } + return; + } + + Writer.WriteLine(); + + Writer.Write("protected "); + Writer.Write(Identifier(message.Name)); + Writer.Write("("); + + var paramList = List(); + foreach (var param in parameters) + { + if (!param.IsFromBase) + break; + + paramList.NextItem(); + Writer.Write("{0} {1}", param.Parameter.Type.NetType, Identifier(ParameterCase(param.Parameter.Name))); + } - Writer.WriteLine(); + Writer.WriteLine(")"); - Writer.Write("protected "); - Writer.Write(Identifier(message.Name)); - Writer.Write("("); + using (Indent()) + { + Writer.Write(": base("); - var paramList = List(); + paramList.Reset(); foreach (var param in parameters) { if (!param.IsFromBase) break; paramList.NextItem(); - Writer.Write("{0} {1}", param.Parameter.Type.NetType, Identifier(ParameterCase(param.Parameter.Name))); + Writer.Write(Identifier(ParameterCase(param.Parameter.Name))); } Writer.WriteLine(")"); + } + + WriteDefaultConstructorBody(message); + } + + private void WriteMessageConstructor(MessageDefinition message, List parameters) + { + Writer.WriteLine(); + + Writer.Write( + message.InheritanceModifier == Ast.InheritanceModifier.Abstract + ? "protected" + : "public" + ); + + Writer.Write(" "); + Writer.Write(Identifier(message.Name)); + Writer.Write("("); + + var paramList = List(); + foreach (var param in parameters) + { + paramList.NextItem(); + Writer.Write("{0} {1}", param.Parameter.Type.NetType, Identifier(ParameterCase(param.Parameter.Name))); + + if (!param.IsRequired && !string.IsNullOrEmpty(param.Parameter.DefaultValue)) + Writer.Write(" = {0}", param.Parameter.DefaultValue); + } + + Writer.WriteLine(")"); + if (parameters.Count != 0 && parameters[0].IsFromBase) + { using (Indent()) { Writer.Write(": base("); @@ -398,173 +446,119 @@ private void WriteForwardingConstructor(MessageDefinition message, List parameters) + using (Block()) { - Writer.WriteLine(); - - Writer.Write( - message.InheritanceModifier == Ast.InheritanceModifier.Abstract - ? "protected" - : "public" - ); - - Writer.Write(" "); - Writer.Write(Identifier(message.Name)); - Writer.Write("("); - - var paramList = List(); foreach (var param in parameters) { - paramList.NextItem(); - Writer.Write("{0} {1}", param.Parameter.Type.NetType, Identifier(ParameterCase(param.Parameter.Name))); - - if (!param.IsRequired && !string.IsNullOrEmpty(param.Parameter.DefaultValue)) - Writer.Write(" = {0}", param.Parameter.DefaultValue); - } - - Writer.WriteLine(")"); - - if (parameters.Count != 0 && parameters[0].IsFromBase) - { - using (Indent()) - { - Writer.Write(": base("); + if (param.IsFromBase) + continue; - paramList.Reset(); - foreach (var param in parameters) - { - if (!param.IsFromBase) - break; + Writer.Write("{0} = {1}", Identifier(MemberCase(param.Parameter.Name)), Identifier(ParameterCase(param.Parameter.Name))); - paramList.NextItem(); - Writer.Write(Identifier(ParameterCase(param.Parameter.Name))); - } + if (param.Parameter.Type.IsArray && !param.Parameter.Type.IsNullable) + Writer.Write(" ?? Array.Empty<{0}>()", param.Parameter.Type.GetRepeatedItemType()!.NetType); - Writer.WriteLine(")"); - } + Writer.WriteLine(";"); } + } + } - using (Block()) - { - foreach (var param in parameters) - { - if (param.IsFromBase) - continue; - - Writer.Write("{0} = {1}", Identifier(MemberCase(param.Parameter.Name)), Identifier(ParameterCase(param.Parameter.Name))); - - if (param.Parameter.Type.IsArray && !param.Parameter.Type.IsNullable) - Writer.Write(" ?? Array.Empty<{0}>()", param.Parameter.Type.GetRepeatedItemType()!.NetType); + private List GetConstructorParameters(MessageDefinition message) + { + var result = new List(); - Writer.WriteLine(";"); - } - } - } + var baseTypeName = message.BaseTypes.FirstOrDefault(); - private List GetConstructorParameters(MessageDefinition message) + while (baseTypeName != null) { - var result = new List(); + if (!_messagesByName.TryGetValue(baseTypeName.NetType, out var baseType)) + break; - var baseTypeName = message.BaseTypes.FirstOrDefault(); - - while (baseTypeName != null) + if (!baseType.Options.Mutable) { - if (!_messagesByName.TryGetValue(baseTypeName.NetType, out var baseType)) - break; + var index = 0; - if (!baseType.Options.Mutable) + foreach (var param in baseType.Parameters) { - var index = 0; - - foreach (var param in baseType.Parameters) - { - if (IsConstructorParameter(param)) - result.Insert(index++, new ParameterData(param, true)); - } + if (IsConstructorParameter(param)) + result.Insert(index++, new ParameterData(param, true)); } - - baseTypeName = baseType.BaseTypes.FirstOrDefault(); } - foreach (var param in message.Parameters) - { - if (IsConstructorParameter(param)) - result.Add(new ParameterData(param, false)); - } - - var requiredParameterSeen = false; + baseTypeName = baseType.BaseTypes.FirstOrDefault(); + } - for (var i = result.Count - 1; i >= 0; --i) - { - var param = result[i]; + foreach (var param in message.Parameters) + { + if (IsConstructorParameter(param)) + result.Add(new ParameterData(param, false)); + } - if (string.IsNullOrEmpty(param.Parameter.DefaultValue)) - requiredParameterSeen = true; + var requiredParameterSeen = false; - if (requiredParameterSeen) - param.IsRequired = true; - } + for (var i = result.Count - 1; i >= 0; --i) + { + var param = result[i]; - return result; + if (string.IsNullOrEmpty(param.Parameter.DefaultValue)) + requiredParameterSeen = true; - static bool IsConstructorParameter(ParameterDefinition parameter) - => !parameter.IsWritableProperty; + if (requiredParameterSeen) + param.IsRequired = true; } - private void WriteAttributeLine(AttributeDefinition attribute) - { - Writer.Write("["); - WriteAttribute(attribute); - Writer.WriteLine("]"); - } + return result; - private void WriteAttribute(AttributeDefinition attribute) - { - Writer.Write(Identifier(attribute.TypeName.NetType)); + static bool IsConstructorParameter(ParameterDefinition parameter) + => !parameter.IsWritableProperty; + } - if (!string.IsNullOrEmpty(attribute.Parameters)) - Writer.Write("({0})", attribute.Parameters); - } + private void WriteAttributeLine(AttributeDefinition attribute) + { + Writer.Write("["); + WriteAttribute(attribute); + Writer.WriteLine("]"); + } - private static string AccessModifier(AccessModifier accessModifier) - { - return accessModifier switch - { - Ast.AccessModifier.Public => "public", - Ast.AccessModifier.Internal => "internal", - _ => throw new ArgumentOutOfRangeException(nameof(accessModifier), accessModifier, null) - }; - } + private void WriteAttribute(AttributeDefinition attribute) + { + Writer.Write(Identifier(attribute.TypeName.NetType)); - private static string InheritanceModifier(InheritanceModifier inheritanceModifier) - { - return inheritanceModifier switch - { - Ast.InheritanceModifier.Default => string.Empty, - Ast.InheritanceModifier.None => string.Empty, - Ast.InheritanceModifier.Sealed => "sealed", - Ast.InheritanceModifier.Abstract => "abstract", - _ => throw new ArgumentOutOfRangeException(nameof(inheritanceModifier), inheritanceModifier, null) - }; - } + if (!string.IsNullOrEmpty(attribute.Parameters)) + Writer.Write("({0})", attribute.Parameters); + } - private static string Identifier(string id) => CSharpSyntax.Identifier(id); + private static string AccessModifier(AccessModifier accessModifier) + { + return accessModifier switch + { + Ast.AccessModifier.Public => "public", + Ast.AccessModifier.Internal => "internal", + _ => throw new ArgumentOutOfRangeException(nameof(accessModifier), accessModifier, null) + }; + } - private class ParameterData + private static string InheritanceModifier(InheritanceModifier inheritanceModifier) + { + return inheritanceModifier switch { - public ParameterDefinition Parameter { get; } - public bool IsFromBase { get; } - public bool IsRequired { get; set; } + Ast.InheritanceModifier.Default => string.Empty, + Ast.InheritanceModifier.None => string.Empty, + Ast.InheritanceModifier.Sealed => "sealed", + Ast.InheritanceModifier.Abstract => "abstract", + _ => throw new ArgumentOutOfRangeException(nameof(inheritanceModifier), inheritanceModifier, null) + }; + } - public ParameterData(ParameterDefinition parameter, bool isFromBase) - { - Parameter = parameter; - IsFromBase = isFromBase; - } - } + private static string Identifier(string id) + => CSharpSyntax.Identifier(id); + + private class ParameterData(ParameterDefinition parameter, bool isFromBase) + { + public ParameterDefinition Parameter { get; } = parameter; + public bool IsFromBase { get; } = isFromBase; + public bool IsRequired { get; set; } } } diff --git a/src/Abc.Zebus.MessageDsl/Generator/CSharpSyntax.cs b/src/Abc.Zebus.MessageDsl/Generator/CSharpSyntax.cs index 0a7f0b0..5676f98 100644 --- a/src/Abc.Zebus.MessageDsl/Generator/CSharpSyntax.cs +++ b/src/Abc.Zebus.MessageDsl/Generator/CSharpSyntax.cs @@ -3,78 +3,83 @@ using System.Linq; using System.Text.RegularExpressions; -namespace Abc.Zebus.MessageDsl.Generator +namespace Abc.Zebus.MessageDsl.Generator; + +public static class CSharpSyntax { - public static class CSharpSyntax + private static readonly HashSet _csharpKeywords = new() + { + "abstract", "as", "base", "bool", "break", + "byte", "case", "catch", "char", "checked", + "class", "const", "continue", "decimal", "default", + "delegate", "do", "double", "else", "enum", + "event", "explicit", "extern", "false", "finally", + "fixed", "float", "for", "foreach", "goto", + "if", "implicit", "in", "int", "interface", + "internal", "is", "lock", "long", "namespace", + "new", "null", "object", "operator", "out", + "override", "params", "private", "protected", "public", + "readonly", "ref", "return", "sbyte", "sealed", + "short", "sizeof", "stackalloc", "static", "string", + "struct", "switch", "this", "throw", "true", + "try", "typeof", "uint", "ulong", "unchecked", + "unsafe", "ushort", "using", "virtual", "void", + "volatile", "while" + }; + + public static IEnumerable EnumerateCSharpKeywords() + => _csharpKeywords.Select(i => i); + + public static bool IsCSharpKeyword(string id) + => _csharpKeywords.Contains(id); + + private static readonly Regex _tokenRe = new(@"@?\s*(?\w+)", RegexOptions.Compiled | RegexOptions.CultureInvariant); + + // https://msdn.microsoft.com/en-us/library/aa664670.aspx + private static readonly Regex _identifierRe = new(@"^@?[\p{Lu}\p{Ll}\p{Lt}\p{Lm}\p{Lo}\p{Nl}_][\p{Lu}\p{Ll}\p{Lt}\p{Lm}\p{Lo}\p{Nl}\p{Nd}\p{Pc}\p{Mn}\p{Mc}\p{Cf}]*$", RegexOptions.Compiled | RegexOptions.CultureInvariant); + + // https://msdn.microsoft.com/en-us/library/aa664669.aspx + private static readonly Regex _unicodeEscapeSequence = new(@"\\u(?[0-9a-fA-F]{4})|\\U(?[0-9a-fA-F]{8})", RegexOptions.Compiled | RegexOptions.CultureInvariant); + + public static string Identifier(string? id) { - private static readonly HashSet _csharpKeywords = new() - { - "abstract", "as", "base", "bool", "break", - "byte", "case", "catch", "char", "checked", - "class", "const", "continue", "decimal", "default", - "delegate", "do", "double", "else", "enum", - "event", "explicit", "extern", "false", "finally", - "fixed", "float", "for", "foreach", "goto", - "if", "implicit", "in", "int", "interface", - "internal", "is", "lock", "long", "namespace", - "new", "null", "object", "operator", "out", - "override", "params", "private", "protected", "public", - "readonly", "ref", "return", "sbyte", "sealed", - "short", "sizeof", "stackalloc", "static", "string", - "struct", "switch", "this", "throw", "true", - "try", "typeof", "uint", "ulong", "unchecked", - "unsafe", "ushort", "using", "virtual", "void", - "volatile", "while" - }; - - public static IEnumerable EnumerateCSharpKeywords() => _csharpKeywords.Select(i => i); - public static bool IsCSharpKeyword(string id) => _csharpKeywords.Contains(id); - - private static readonly Regex _tokenRe = new(@"@?\s*(?\w+)", RegexOptions.Compiled | RegexOptions.CultureInvariant); - - // https://msdn.microsoft.com/en-us/library/aa664670.aspx - private static readonly Regex _identifierRe = new(@"^@?[\p{Lu}\p{Ll}\p{Lt}\p{Lm}\p{Lo}\p{Nl}_][\p{Lu}\p{Ll}\p{Lt}\p{Lm}\p{Lo}\p{Nl}\p{Nd}\p{Pc}\p{Mn}\p{Mc}\p{Cf}]*$", RegexOptions.Compiled | RegexOptions.CultureInvariant); - - // https://msdn.microsoft.com/en-us/library/aa664669.aspx - private static readonly Regex _unicodeEscapeSequence = new(@"\\u(?[0-9a-fA-F]{4})|\\U(?[0-9a-fA-F]{8})", RegexOptions.Compiled | RegexOptions.CultureInvariant); - - public static string Identifier(string? id) - { - return _tokenRe.Replace(id ?? string.Empty, match => + return _tokenRe.Replace( + id ?? string.Empty, + match => { var token = match.Groups["id"].Value; return IsCSharpKeyword(token) ? "@" + token : token; - }); - } - - public static bool IsValidNamespace(string? ns) - { - if (string.IsNullOrEmpty(ns)) - return false; + } + ); + } - return ns!.Split(new[] { '.' }, StringSplitOptions.None) - .All(IsValidIdentifier); - } + public static bool IsValidNamespace(string? ns) + { + if (string.IsNullOrEmpty(ns)) + return false; - public static bool IsValidIdentifier(string? id) - { - if (string.IsNullOrEmpty(id)) - return false; + return ns.Split(['.'], StringSplitOptions.None) + .All(IsValidIdentifier); + } - id = ProcessUnicodeEscapeSequences(id!); + public static bool IsValidIdentifier(string? id) + { + if (string.IsNullOrEmpty(id)) + return false; - if (!_identifierRe.IsMatch(id)) - return false; + id = ProcessUnicodeEscapeSequences(id); - if (IsCSharpKeyword(id)) - return false; + if (!_identifierRe.IsMatch(id)) + return false; - return true; - } + if (IsCSharpKeyword(id)) + return false; - private static string ProcessUnicodeEscapeSequences(string value) - => _unicodeEscapeSequence.Replace(value, match => char.ConvertFromUtf32(Convert.ToInt32(match.Groups["hex"].Value, 16))); + return true; } + + private static string ProcessUnicodeEscapeSequences(string value) + => _unicodeEscapeSequence.Replace(value, match => char.ConvertFromUtf32(Convert.ToInt32(match.Groups["hex"].Value, 16))); } diff --git a/src/Abc.Zebus.MessageDsl/Generator/GeneratorBase.cs b/src/Abc.Zebus.MessageDsl/Generator/GeneratorBase.cs index eb6c67f..217adb7 100644 --- a/src/Abc.Zebus.MessageDsl/Generator/GeneratorBase.cs +++ b/src/Abc.Zebus.MessageDsl/Generator/GeneratorBase.cs @@ -4,82 +4,77 @@ using System.Text; using Abc.Zebus.MessageDsl.Support; -namespace Abc.Zebus.MessageDsl.Generator +namespace Abc.Zebus.MessageDsl.Generator; + +public abstract class GeneratorBase : IDisposable { - public abstract class GeneratorBase : IDisposable + protected static string GeneratorName { get; } = typeof(GeneratorBase).Assembly.GetName().Name!; + protected static Version GeneratorVersion { get; }= typeof(GeneratorBase).Assembly.GetName().Version!; + + private readonly StringBuilder _stringBuilder; + protected IndentedTextWriter Writer { get; } + + protected GeneratorBase() { - protected static string GeneratorName { get; } = typeof(GeneratorBase).Assembly.GetName().Name!; - protected static Version GeneratorVersion { get; }= typeof(GeneratorBase).Assembly.GetName().Version!; + _stringBuilder = new StringBuilder(); + Writer = new IndentedTextWriter(new StringWriter(_stringBuilder)); + } - private readonly StringBuilder _stringBuilder; - protected IndentedTextWriter Writer { get; } + protected void Reset() + { + _stringBuilder.Clear(); + Writer.Indent = 0; + } - protected GeneratorBase() - { - _stringBuilder = new StringBuilder(); - Writer = new IndentedTextWriter(new StringWriter(_stringBuilder)); - } + protected IDisposable Indent() + { + ++Writer.Indent; + return Disposable.Create(() => --Writer.Indent); + } - protected void Reset() - { - _stringBuilder.Clear(); - Writer.Indent = 0; - } + protected IDisposable Block() + { + Writer.WriteLine("{"); + ++Writer.Indent; - protected IDisposable Indent() + return Disposable.Create(() => { - ++Writer.Indent; - return Disposable.Create(() => --Writer.Indent); - } + --Writer.Indent; + Writer.WriteLine("}"); + }); + } - protected IDisposable Block() - { - Writer.WriteLine("{"); - ++Writer.Indent; - return Disposable.Create(() => - { - --Writer.Indent; - Writer.WriteLine("}"); - }); - } + protected ListHelper List(string separator = ", ") + => new(Writer, separator); - protected ListHelper List(string separator = ", ") - => new(Writer, separator); + protected string GeneratedOutput() + => _stringBuilder.ToString(); - protected string GeneratedOutput() => _stringBuilder.ToString(); + protected static string ParameterCase(string s) + => char.ToLowerInvariant(s[0]) + s.Substring(1); - protected static string ParameterCase(string s) => char.ToLowerInvariant(s[0]) + s.Substring(1); - protected static string MemberCase(string s) => char.ToUpperInvariant(s[0]) + s.Substring(1); + protected static string MemberCase(string s) + => char.ToUpperInvariant(s[0]) + s.Substring(1); - [System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Design", "CA1063:ImplementIDisposableCorrectly")] - public void Dispose() - { - Writer.Dispose(); - } + [System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Design", "CA1063:ImplementIDisposableCorrectly")] + public void Dispose() + { + Writer.Dispose(); + } - protected struct ListHelper + protected struct ListHelper(IndentedTextWriter writer, string separator) + { + private bool _firstItem = true; + + public void NextItem() { - private readonly IndentedTextWriter _writer; - private readonly string _separator; - private bool _firstItem; - - public ListHelper(IndentedTextWriter writer, string separator) - { - _writer = writer; - _separator = separator; - _firstItem = true; - } - - public void NextItem() - { - if (_firstItem) - _firstItem = false; - else - _writer.Write(_separator); - } - - public void Reset() - => _firstItem = true; + if (_firstItem) + _firstItem = false; + else + writer.Write(separator); } + + public void Reset() + => _firstItem = true; } } diff --git a/src/Abc.Zebus.MessageDsl/Generator/ProtoGenerator.cs b/src/Abc.Zebus.MessageDsl/Generator/ProtoGenerator.cs index 87caf21..77d14d1 100644 --- a/src/Abc.Zebus.MessageDsl/Generator/ProtoGenerator.cs +++ b/src/Abc.Zebus.MessageDsl/Generator/ProtoGenerator.cs @@ -2,180 +2,179 @@ using Abc.Zebus.MessageDsl.Analysis; using Abc.Zebus.MessageDsl.Ast; -namespace Abc.Zebus.MessageDsl.Generator -{ - public sealed class ProtoGenerator : GeneratorBase - { - private ParsedContracts Contracts { get; } +namespace Abc.Zebus.MessageDsl.Generator; - private ProtoGenerator(ParsedContracts contracts) - { - Contracts = contracts; - } +public sealed class ProtoGenerator : GeneratorBase +{ + private ParsedContracts Contracts { get; } - public static bool HasProtoOutput(ParsedContracts contracts) - { - return contracts.Messages.Any(i => i.Options.Proto) - || contracts.Enums.Any(i => i.Options.Proto); - } + private ProtoGenerator(ParsedContracts contracts) + { + Contracts = contracts; + } - public static string Generate(ParsedContracts contracts) - { - using var generator = new ProtoGenerator(contracts); - return generator.Generate(); - } + public static bool HasProtoOutput(ParsedContracts contracts) + { + return contracts.Messages.Any(i => i.Options.Proto) + || contracts.Enums.Any(i => i.Options.Proto); + } - private string Generate() - { - Reset(); + public static string Generate(ParsedContracts contracts) + { + using var generator = new ProtoGenerator(contracts); + return generator.Generate(); + } - WriteHeader(); + private string Generate() + { + Reset(); - foreach (var enumDef in Contracts.Enums.Where(msg => msg.Options.Proto)) - WriteEnum(enumDef); + WriteHeader(); - foreach (var message in Contracts.Messages.Where(msg => msg.Options.Proto)) - WriteMessage(message); + foreach (var enumDef in Contracts.Enums.Where(msg => msg.Options.Proto)) + WriteEnum(enumDef); - return GeneratedOutput(); - } + foreach (var message in Contracts.Messages.Where(msg => msg.Options.Proto)) + WriteMessage(message); - private void WriteHeader() - { - Writer.WriteLine(); - Writer.WriteLine("// Generated by {0} v{1}", GeneratorName, GeneratorVersion); - Writer.WriteLine(); + return GeneratedOutput(); + } - var requiresBclPackage = Contracts.Messages - .SelectMany(msg => msg.Parameters) - .Any(param => param.Type.ProtoBufType.StartsWith("bcl.")); + private void WriteHeader() + { + Writer.WriteLine(); + Writer.WriteLine("// Generated by {0} v{1}", GeneratorName, GeneratorVersion); + Writer.WriteLine(); - if (requiresBclPackage) - Writer.WriteLine("import \"bcl/bcl.proto\";"); + var requiresBclPackage = Contracts.Messages + .SelectMany(msg => msg.Parameters) + .Any(param => param.Type.ProtoBufType.StartsWith("bcl.")); - Writer.WriteLine("import \"servicebus.proto\";"); + if (requiresBclPackage) + Writer.WriteLine("import \"bcl/bcl.proto\";"); - if (!string.IsNullOrEmpty(Contracts.Namespace)) - { - Writer.WriteLine(); - Writer.WriteLine("package {0};", Contracts.Namespace); - } - } + Writer.WriteLine("import \"servicebus.proto\";"); - private void WriteEnum(EnumDefinition enumDef) + if (!string.IsNullOrEmpty(Contracts.Namespace)) { Writer.WriteLine(); - Writer.Write("enum {0} ", enumDef.Name); + Writer.WriteLine("package {0};", Contracts.Namespace); + } + } - using (Block()) - { - if (enumDef.Members.Where(i => i.ProtoValue != null).GroupBy(i => i.ProtoValue.GetValueOrDefault()).Any(g => g.Count() > 1)) - Writer.WriteLine("option allow_alias = true;"); + private void WriteEnum(EnumDefinition enumDef) + { + Writer.WriteLine(); + Writer.Write("enum {0} ", enumDef.Name); - foreach (var member in enumDef.Members) - Writer.WriteLine("{0} = {1};", member.Name, member.ProtoValue ?? (object)"TODO"); - } + using (Block()) + { + if (enumDef.Members.Where(i => i.ProtoValue != null).GroupBy(i => i.ProtoValue.GetValueOrDefault()).Any(g => g.Count() > 1)) + Writer.WriteLine("option allow_alias = true;"); + + foreach (var member in enumDef.Members) + Writer.WriteLine("{0} = {1};", member.Name, member.ProtoValue ?? (object)"TODO"); } + } - private void WriteMessage(MessageDefinition message) - { - Writer.WriteLine(); - Writer.Write("message {0} ", message.Name); + private void WriteMessage(MessageDefinition message) + { + Writer.WriteLine(); + Writer.Write("message {0} ", message.Name); - using (Block()) - { - WriteMessageOptions(message); + using (Block()) + { + WriteMessageOptions(message); - foreach (var param in message.Parameters) - WriteField(param); + foreach (var param in message.Parameters) + WriteField(param); - WriteIncludedMessages(message); - } + WriteIncludedMessages(message); } + } - private void WriteMessageOptions(MessageDefinition message) + private void WriteMessageOptions(MessageDefinition message) + { + if (message.Type != MessageType.Custom) { - if (message.Type != MessageType.Custom) - { - Writer.WriteLine("option (servicebus.message).type = {0};", message.Type == MessageType.Command ? "Command" : "Event"); + Writer.WriteLine("option (servicebus.message).type = {0};", message.Type == MessageType.Command ? "Command" : "Event"); - if (message.IsTransient) - Writer.WriteLine("option (servicebus.message).transient = true;"); + if (message.IsTransient) + Writer.WriteLine("option (servicebus.message).transient = true;"); - if (message.IsRoutable) - Writer.WriteLine("option (servicebus.message).routable = true;"); + if (message.IsRoutable) + Writer.WriteLine("option (servicebus.message).routable = true;"); - Writer.WriteLine(); - } + Writer.WriteLine(); } + } - private void WriteField(ParameterDefinition param) - { - Writer.Write( - "{0} {1} {2} = {3}", - param.Rules.ToString().ToLowerInvariant(), - param.Type.ProtoBufType, - MemberCase(param.Name), - param.Tag); + private void WriteField(ParameterDefinition param) + { + Writer.Write( + "{0} {1} {2} = {3}", + param.Rules.ToString().ToLowerInvariant(), + param.Type.ProtoBufType, + MemberCase(param.Name), + param.Tag); - WriteFieldOptions(param); + WriteFieldOptions(param); - Writer.WriteLine(";"); - } + Writer.WriteLine(";"); + } - private void WriteFieldOptions(ParameterDefinition param) - { - var first = true; + private void WriteFieldOptions(ParameterDefinition param) + { + var first = true; - if (param.IsPacked) - WriteFieldOption("packed", "true", ref first); + if (param.IsPacked) + WriteFieldOption("packed", "true", ref first); - if (param.Attributes.HasAttribute(KnownTypes.ObsoleteAttribute)) - WriteFieldOption("deprecated", "true", ref first); + if (param.Attributes.HasAttribute(KnownTypes.ObsoleteAttribute)) + WriteFieldOption("deprecated", "true", ref first); - if (param.RoutingPosition != null) - WriteFieldOption("(servicebus.routing_position)", param.RoutingPosition.ToString()!, ref first); + if (param.RoutingPosition != null) + WriteFieldOption("(servicebus.routing_position)", param.RoutingPosition.ToString()!, ref first); - if (!first) - Writer.Write("]"); - } + if (!first) + Writer.Write("]"); + } - private void WriteFieldOption(string key, string value, ref bool first) + private void WriteFieldOption(string key, string value, ref bool first) + { + if (first) { - if (first) - { - Writer.Write(" ["); - first = false; - } - else - { - Writer.Write(", "); - } - - Writer.Write("{0} = {1}", key, value); + Writer.Write(" ["); + first = false; } + else + { + Writer.Write(", "); + } + + Writer.Write("{0} = {1}", key, value); + } - private void WriteIncludedMessages(MessageDefinition message) + private void WriteIncludedMessages(MessageDefinition message) + { + foreach (var attr in message.Attributes) { - foreach (var attr in message.Attributes) + if (!Equals(attr.TypeName, KnownTypes.ProtoIncludeAttribute)) + continue; + + if (!AttributeInterpreter.TryParseProtoInclude(attr, out var tag, out var typeName)) { - if (!Equals(attr.TypeName, KnownTypes.ProtoIncludeAttribute)) - continue; - - if (!AttributeInterpreter.TryParseProtoInclude(attr, out var tag, out var typeName)) - { - Writer.WriteLine("// ERROR: bad sub type definition"); - continue; - } - - Writer.Write("optional "); - Writer.Write(typeName.ProtoBufType); - Writer.Write(" _subType"); - Writer.Write(typeName.ProtoBufType); - Writer.Write(" = "); - Writer.Write(tag); - Writer.WriteLine(";"); + Writer.WriteLine("// ERROR: bad sub type definition"); + continue; } + + Writer.Write("optional "); + Writer.Write(typeName.ProtoBufType); + Writer.Write(" _subType"); + Writer.Write(typeName.ProtoBufType); + Writer.Write(" = "); + Writer.Write(tag); + Writer.WriteLine(";"); } } } diff --git a/src/Abc.Zebus.MessageDsl/Support/CodeAnalysis.cs b/src/Abc.Zebus.MessageDsl/Support/CodeAnalysis.cs new file mode 100644 index 0000000..14c551f --- /dev/null +++ b/src/Abc.Zebus.MessageDsl/Support/CodeAnalysis.cs @@ -0,0 +1,18 @@ +#if !NETCOREAPP + +// ReSharper disable CheckNamespace +// ReSharper disable MemberCanBePrivate.Global +// ReSharper disable UnusedAutoPropertyAccessor.Global + +namespace System.Diagnostics.CodeAnalysis; + +[AttributeUsage(AttributeTargets.Parameter)] +internal sealed class NotNullWhenAttribute(bool returnValue) : Attribute +{ + public bool ReturnValue { get; } = returnValue; +} + +[AttributeUsage(AttributeTargets.Method, Inherited = false)] +internal sealed class DoesNotReturnAttribute : Attribute; + +#endif diff --git a/src/Abc.Zebus.MessageDsl/Support/Disposable.cs b/src/Abc.Zebus.MessageDsl/Support/Disposable.cs index 33734a1..26ea499 100644 --- a/src/Abc.Zebus.MessageDsl/Support/Disposable.cs +++ b/src/Abc.Zebus.MessageDsl/Support/Disposable.cs @@ -1,18 +1,18 @@ using System; using System.Threading; -namespace Abc.Zebus.MessageDsl.Support +namespace Abc.Zebus.MessageDsl.Support; + +internal class Disposable : IDisposable { - internal class Disposable : IDisposable - { - private Action? _onDispose; + private Action? _onDispose; + + private Disposable(Action onDispose) + => _onDispose = onDispose; - private Disposable(Action onDispose) - { - _onDispose = onDispose; - } + public static IDisposable Create(Action onDispose) + => new Disposable(onDispose); - public static IDisposable Create(Action onDispose) => new Disposable(onDispose); - public void Dispose() => Interlocked.Exchange(ref _onDispose, null)?.Invoke(); - } + public void Dispose() + => Interlocked.Exchange(ref _onDispose, null)?.Invoke(); } diff --git a/src/Abc.Zebus.MessageDsl/Support/Extensions.cs b/src/Abc.Zebus.MessageDsl/Support/Extensions.cs index df81381..862fe7a 100644 --- a/src/Abc.Zebus.MessageDsl/Support/Extensions.cs +++ b/src/Abc.Zebus.MessageDsl/Support/Extensions.cs @@ -2,39 +2,38 @@ using Antlr4.Runtime; using Antlr4.Runtime.Misc; -namespace Abc.Zebus.MessageDsl.Support +namespace Abc.Zebus.MessageDsl.Support; + +internal static class Extensions { - internal static class Extensions + public static TValue? GetValueOrDefault(this IDictionary dictionary, TKey key) + where TKey : notnull + { + return dictionary.TryGetValue(key, out var result) ? result : default; + } + + public static string GetFullText(this ParserRuleContext context) + { + if (context.Start == null || context.Stop == null || context.Start.StartIndex < 0 || context.Stop.StopIndex < 0) + return context.GetText(); + + return context.Start.InputStream.GetText(Interval.Of(context.Start.StartIndex, context.Stop.StopIndex)); + } + + public static string GetFullTextUntil(this IToken? startToken, IToken? endToken) + { + if (startToken == null || endToken == null || startToken.StartIndex < 0 || endToken.StopIndex < 0 || startToken.StartIndex > endToken.StartIndex || startToken.InputStream != endToken.InputStream) + return string.Empty; + + return startToken.InputStream.GetText(Interval.Of(startToken.StartIndex, endToken.StopIndex)); + } + + public static HashSet ToHashSet(this IEnumerable sequence) + => new(sequence); + + public static void AddRange(this ICollection collection, IEnumerable toAdd) { - public static TValue? GetValueOrDefault(this IDictionary dictionary, TKey key) - where TKey : notnull - { - return dictionary.TryGetValue(key, out var result) ? result : default!; - } - - public static string GetFullText(this ParserRuleContext context) - { - if (context.Start == null || context.Stop == null || context.Start.StartIndex < 0 || context.Stop.StopIndex < 0) - return context.GetText(); - - return context.Start.InputStream.GetText(Interval.Of(context.Start.StartIndex, context.Stop.StopIndex)); - } - - public static string GetFullTextUntil(this IToken? startToken, IToken? endToken) - { - if (startToken == null || endToken == null || startToken.StartIndex < 0 || endToken.StopIndex < 0 || startToken.StartIndex > endToken.StartIndex || startToken.InputStream != endToken.InputStream) - return string.Empty; - - return startToken.InputStream.GetText(Interval.Of(startToken.StartIndex, endToken.StopIndex)); - } - - public static HashSet ToHashSet(this IEnumerable sequence) - => new(sequence); - - public static void AddRange(this ICollection collection, IEnumerable toAdd) - { - foreach (var item in toAdd) - collection.Add(item); - } + foreach (var item in toAdd) + collection.Add(item); } } diff --git a/src/Abc.Zebus.MessageDsl/Support/Index.cs b/src/Abc.Zebus.MessageDsl/Support/Index.cs new file mode 100644 index 0000000..4d3b458 --- /dev/null +++ b/src/Abc.Zebus.MessageDsl/Support/Index.cs @@ -0,0 +1,84 @@ +#if !NETCOREAPP + +using System.Diagnostics.CodeAnalysis; +using System.Globalization; +using System.Runtime.CompilerServices; + +// ReSharper disable once CheckNamespace +namespace System; + +internal readonly struct Index : IEquatable +{ + private readonly int _value; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Index(int value, bool fromEnd = false) + { + if (value < 0) + ThrowValueArgumentOutOfRange_NeedNonNegNumException(); + + _value = fromEnd ? ~value : value; + } + + private Index(int value) + { + _value = value; + } + + public static Index Start => new(0); + public static Index End => new(~0); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Index FromStart(int value) + { + if (value < 0) + ThrowValueArgumentOutOfRange_NeedNonNegNumException(); + + return new Index(value); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Index FromEnd(int value) + { + if (value < 0) + ThrowValueArgumentOutOfRange_NeedNonNegNumException(); + + return new Index(~value); + } + + public int Value => _value < 0 ? ~_value : _value; + public bool IsFromEnd => _value < 0; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int GetOffset(int length) + { + var offset = _value; + if (IsFromEnd) + offset += length + 1; + return offset; + } + + public override bool Equals([NotNullWhen(true)] object? value) + => value is Index index && _value == index._value; + + public bool Equals(Index other) + => _value == other._value; + + public override int GetHashCode() + => _value; + + public static implicit operator Index(int value) + => FromStart(value); + + public override string ToString() + => IsFromEnd + ? '^' + Value.ToString(CultureInfo.InvariantCulture) + : ((uint)Value).ToString(CultureInfo.InvariantCulture); + + [DoesNotReturn] + [SuppressMessage("ReSharper", "NotResolvedInText")] + private static void ThrowValueArgumentOutOfRange_NeedNonNegNumException() + => throw new ArgumentOutOfRangeException("value", "Non-negative number required."); +} + +#endif