From 209ba5d4649a6df7f4e4c71c37b4146b9f2a8731 Mon Sep 17 00:00:00 2001 From: Daniel Cazzulino Date: Sun, 22 Dec 2024 18:11:03 -0300 Subject: [PATCH] Improve detection of IStructId namespace from compilation The existing approach of using analyzer config options breaks if there are transitive project references involved, since the options will contain potentially the wrong namespace (the one for the project being built rather than the one where IStructId actually exists). This changes the approach to looking up the types by their metadata name plus a codegen attribute which we assume users won't be using for their own code (even if they do happen to use `IStructId` for some other purpose). --- src/StructId.Analyzer/AnalysisExtensions.cs | 17 +++++--- src/StructId.Analyzer/BaseGenerator.cs | 5 +-- src/StructId.Analyzer/CodeTemplate.cs | 40 +++++++++++++++---- src/StructId.Analyzer/KnownTypes.cs | 22 +++++----- .../NewtonsoftJsonGenerator.cs | 14 ++++--- src/StructId.Analyzer/RecordAnalyzer.cs | 6 +-- src/StructId.Analyzer/TemplateAnalyzer.cs | 2 - src/StructId.Analyzer/TemplatedGenerator.cs | 30 +++++++------- .../TemplatizedTValueExtensions.cs | 5 +-- src/StructId.Package/StructId.props | 4 -- src/StructId/IStructId.cs | 3 ++ src/StructId/IStructIdT.cs | 3 ++ 12 files changed, 89 insertions(+), 62 deletions(-) diff --git a/src/StructId.Analyzer/AnalysisExtensions.cs b/src/StructId.Analyzer/AnalysisExtensions.cs index cc8aee1..3eea604 100644 --- a/src/StructId.Analyzer/AnalysisExtensions.cs +++ b/src/StructId.Analyzer/AnalysisExtensions.cs @@ -29,6 +29,11 @@ public static CSharpParseOptions GetParseOptions(this Compilation compilation) => (CSharpParseOptions?)compilation.SyntaxTrees.FirstOrDefault()?.Options ?? CSharpParseOptions.Default.WithLanguageVersion(LanguageVersion.Latest); + public static bool IsGeneratedByStructId(this ISymbol symbol) + => symbol.GetAttributes().Any(a + => a.AttributeClass?.Name == "GeneratedCodeAttribute" && + a.ConstructorArguments.Select(c => c.Value).OfType().Any(v => v == nameof(StructId))); + /// /// Checks whether the type inherits or implements the /// type, even if it's a generic type. @@ -62,12 +67,6 @@ @this is INamedTypeSymbol namedActual && return Is(@this.BaseType, baseTypeOrInterface, looseGenerics); } - public static string GetStructIdNamespace(this AnalyzerConfigOptions options) - => options.TryGetValue("build_property.StructIdNamespace", out var ns) && !string.IsNullOrEmpty(ns) ? ns : "StructId"; - - public static IncrementalValueProvider GetStructIdNamespace(this IncrementalValueProvider options) - => options.Select((x, _) => x.GlobalOptions.TryGetValue("build_property.StructIdNamespace", out var ns) ? ns : "StructId"); - public static bool ImplementsExplicitly(this INamedTypeSymbol namedTypeSymbol, INamedTypeSymbol interfaceTypeSymbol) { if (interfaceTypeSymbol.IsUnboundGenericType && interfaceTypeSymbol.TypeParameters.Length == 1) @@ -156,6 +155,9 @@ public static string ToFileName(this ITypeSymbol type) public static bool IsStructId(this ITypeSymbol type) => type.AllInterfaces.Any(x => x.Name == "IStructId"); + public static bool IsValueTemplate(this INamedTypeSymbol symbol) + => symbol.GetAttributes().Any(IsValueTemplate); + public static bool IsValueTemplate(this AttributeData attribute) => attribute.AttributeClass?.Name == "TValue" || attribute.AttributeClass?.Name == "TValueAttribute"; @@ -163,6 +165,9 @@ public static bool IsValueTemplate(this AttributeData attribute) public static bool IsValueTemplate(this AttributeSyntax attribute) => attribute.Name.ToString() == "TValue" || attribute.Name.ToString() == "TValueAttribute"; + public static bool IsStructIdTemplate(this INamedTypeSymbol symbol) + => symbol.GetAttributes().Any(IsStructIdTemplate); + public static bool IsStructIdTemplate(this AttributeData attribute) => attribute.AttributeClass?.Name == "TStructId" || attribute.AttributeClass?.Name == "TStructIdAttribute"; diff --git a/src/StructId.Analyzer/BaseGenerator.cs b/src/StructId.Analyzer/BaseGenerator.cs index fdaf3f0..8be9f39 100644 --- a/src/StructId.Analyzer/BaseGenerator.cs +++ b/src/StructId.Analyzer/BaseGenerator.cs @@ -26,11 +26,8 @@ protected record struct TemplateArgs(INamedTypeSymbol TSelf, INamedTypeSymbol TV public virtual void Initialize(IncrementalGeneratorInitializationContext context) { - var structIdNamespace = context.AnalyzerConfigOptionsProvider.GetStructIdNamespace(); - var known = context.CompilationProvider - .Combine(structIdNamespace) - .Select((x, _) => new KnownTypes(x.Left, x.Right)); + .Select((x, _) => new KnownTypes(x)); // Locate the required type var types = context.CompilationProvider diff --git a/src/StructId.Analyzer/CodeTemplate.cs b/src/StructId.Analyzer/CodeTemplate.cs index 466c8d1..1a3ee38 100644 --- a/src/StructId.Analyzer/CodeTemplate.cs +++ b/src/StructId.Analyzer/CodeTemplate.cs @@ -31,14 +31,14 @@ public static string Apply(string template, string structIdType, string valueTyp public static string Apply(string template, string valueType, bool normalizeWhitespace = false) { - var applied = ApplyImpl(Parse(template), valueType); + var applied = ApplyValueImpl(Parse(template), valueType); return normalizeWhitespace ? applied.NormalizeWhitespace().ToFullString().Trim() : applied.ToFullString().Trim(); } - public static SyntaxNode ApplyValue(this SyntaxNode node, INamedTypeSymbol valueType) => ApplyImpl(node, valueType.ToFullName()); + public static SyntaxNode ApplyValue(this SyntaxNode node, INamedTypeSymbol valueType) => ApplyValueImpl(node, valueType.ToFullName()); public static SyntaxNode Apply(this SyntaxNode node, INamedTypeSymbol structId) { @@ -59,7 +59,7 @@ public static SyntaxNode Apply(this SyntaxNode node, INamedTypeSymbol structId) return ApplyImpl(root, structId.Name, tid, targetNamespace, corens); } - static SyntaxNode ApplyImpl(this SyntaxNode node, string valueType) + static SyntaxNode ApplyValueImpl(this SyntaxNode node, string valueType) { var root = node.SyntaxTree.GetCompilationUnitRoot(); if (root == null) @@ -194,7 +194,7 @@ bool IsFileLocal(TypeDeclarationSyntax node) => !node.AttributeLists.Any(list => list.Attributes.Any(a => a.IsValueTemplate())); } - class TemplateRewriter(string tself, string tid) : CSharpSyntaxRewriter + class TemplateRewriter(string tself, string tvalue) : CSharpSyntaxRewriter { public override SyntaxNode? VisitRecordDeclaration(RecordDeclarationSyntax node) { @@ -282,8 +282,20 @@ class TemplateRewriter(string tself, string tid) : CSharpSyntaxRewriter return IdentifierName(tself) .WithLeadingTrivia(node.Identifier.LeadingTrivia) .WithTrailingTrivia(node.Identifier.TrailingTrivia); - else if (node.Identifier.Text == "TId" || node.Identifier.Text == "TValue") - return IdentifierName(tid) + + if (node.Identifier.Text.StartsWith("TSelf_")) + return IdentifierName(node.Identifier.Text.Replace("TSelf_", tvalue.Replace('.', '_') + "_")) + .WithLeadingTrivia(node.Identifier.LeadingTrivia) + .WithTrailingTrivia(node.Identifier.TrailingTrivia); + + // TODO: remove TId as it's legacy + if (node.Identifier.Text == "TId" || node.Identifier.Text == "TValue") + return IdentifierName(tvalue) + .WithLeadingTrivia(node.Identifier.LeadingTrivia) + .WithTrailingTrivia(node.Identifier.TrailingTrivia); + + if (node.Identifier.Text.StartsWith("TValue_")) + return IdentifierName(node.Identifier.Text.Replace("TValue_", tvalue.Replace('.', '_') + "_")) .WithLeadingTrivia(node.Identifier.LeadingTrivia) .WithTrailingTrivia(node.Identifier.TrailingTrivia); @@ -297,8 +309,20 @@ public override SyntaxToken VisitToken(SyntaxToken token) return Identifier(tself) .WithLeadingTrivia(token.LeadingTrivia) .WithTrailingTrivia(token.TrailingTrivia); - else if (token.IsKind(SyntaxKind.IdentifierToken) && (token.Text == "TId" || token.Text == "TValue")) - return Identifier(tid) + + if (token.IsKind(SyntaxKind.IdentifierToken) && token.Text.StartsWith("TSelf_")) + return Identifier(token.Text.Replace("TSelf_", tvalue.Replace('.', '_') + "_")) + .WithLeadingTrivia(token.LeadingTrivia) + .WithTrailingTrivia(token.TrailingTrivia); + + // TODO: remove TId as it's legacy + if (token.IsKind(SyntaxKind.IdentifierToken) && (token.Text == "TId" || token.Text == "TValue")) + return Identifier(tvalue) + .WithLeadingTrivia(token.LeadingTrivia) + .WithTrailingTrivia(token.TrailingTrivia); + + if (token.IsKind(SyntaxKind.IdentifierToken) && token.Text.StartsWith("TValue_")) + return Identifier(token.Text.Replace("TValue_", tvalue.Replace('.', '_') + "_")) .WithLeadingTrivia(token.LeadingTrivia) .WithTrailingTrivia(token.TrailingTrivia); diff --git a/src/StructId.Analyzer/KnownTypes.cs b/src/StructId.Analyzer/KnownTypes.cs index 6cc419b..2b92880 100644 --- a/src/StructId.Analyzer/KnownTypes.cs +++ b/src/StructId.Analyzer/KnownTypes.cs @@ -1,4 +1,5 @@ -using Microsoft.CodeAnalysis; +using System.Linq; +using Microsoft.CodeAnalysis; namespace StructId; @@ -6,23 +7,26 @@ namespace StructId; /// Provides access to some common types and properties used in the compilation. /// /// The compilation used to resolve the known types. -/// The namespace for StructId types. -public record KnownTypes(Compilation Compilation, string StructIdNamespace) +public record KnownTypes(Compilation Compilation) { + public string StructIdNamespace => IStructId?.ContainingNamespace.ToFullName() ?? "StructId"; + /// /// System.String /// public INamedTypeSymbol String { get; } = Compilation.GetTypeByMetadataName("System.String")!; + /// /// StructId.IStructId /// - public INamedTypeSymbol? IStructId { get; } = Compilation.GetTypeByMetadataName($"{StructIdNamespace}.IStructId"); + public INamedTypeSymbol? IStructId { get; } = Compilation + .GetAllTypes(true) + .FirstOrDefault(x => x.MetadataName == "IStructId" && x.IsGeneratedByStructId()); + /// /// StructId.IStructId{T} /// - public INamedTypeSymbol? IStructIdT { get; } = Compilation.GetTypeByMetadataName($"{StructIdNamespace}.IStructId`1"); - /// - /// StructId.TStructIdAttribute - /// - public INamedTypeSymbol? TStructId { get; } = Compilation.GetTypeByMetadataName($"{StructIdNamespace}.TStructIdAttribute"); + public INamedTypeSymbol? IStructIdT { get; } = Compilation + .GetAllTypes(true) + .FirstOrDefault(x => x.MetadataName == "IStructId`1" && x.IsGeneratedByStructId()); } \ No newline at end of file diff --git a/src/StructId.Analyzer/NewtonsoftJsonGenerator.cs b/src/StructId.Analyzer/NewtonsoftJsonGenerator.cs index a3273c8..d5125ec 100644 --- a/src/StructId.Analyzer/NewtonsoftJsonGenerator.cs +++ b/src/StructId.Analyzer/NewtonsoftJsonGenerator.cs @@ -15,19 +15,21 @@ public override void Initialize(IncrementalGeneratorInitializationContext contex { base.Initialize(context); + var source = context.CompilationProvider + .Select((x, _) => (new KnownTypes(x), x.GetTypeByMetadataName("Newtonsoft.Json.JsonConverter`1"))); + context.RegisterSourceOutput( - context.CompilationProvider - .Select((x, _) => x.GetTypeByMetadataName("Newtonsoft.Json.JsonConverter`1")) - .Combine(context.AnalyzerConfigOptionsProvider.GetStructIdNamespace()), + source, (context, source) => { - if (source.Left == null) + (var known, var converter) = source; + if (converter == null) return; context.AddSource("NewtonsoftJsonConverter.cs", SourceText.From( ThisAssembly.Resources.Templates.NewtonsoftJsonConverter_1.Text - .Replace("namespace StructId;", $"namespace {source.Right};") - .Replace("using StructId;", $"using {source.Right};"), + .Replace("namespace StructId;", $"namespace {known.StructIdNamespace};") + .Replace("using StructId;", $"using {known.StructIdNamespace};"), Encoding.UTF8)); }); } diff --git a/src/StructId.Analyzer/RecordAnalyzer.cs b/src/StructId.Analyzer/RecordAnalyzer.cs index e6f12da..c6678f6 100644 --- a/src/StructId.Analyzer/RecordAnalyzer.cs +++ b/src/StructId.Analyzer/RecordAnalyzer.cs @@ -30,11 +30,11 @@ public override void Initialize(AnalysisContext context) static void Analyze(SyntaxNodeAnalysisContext context) { - var ns = context.Options.AnalyzerConfigOptionsProvider.GlobalOptions.GetStructIdNamespace(); + var known = new KnownTypes(context.Compilation); if (context.Node is not TypeDeclarationSyntax typeDeclaration || - context.Compilation.GetTypeByMetadataName($"{ns}.IStructId`1") is not { } structIdTypeOfT || - context.Compilation.GetTypeByMetadataName($"{ns}.IStructId") is not { } structIdType) + known.IStructIdT is not { } structIdTypeOfT || + known.IStructId is not { } structIdType) return; var symbol = context.SemanticModel.GetDeclaredSymbol(typeDeclaration); diff --git a/src/StructId.Analyzer/TemplateAnalyzer.cs b/src/StructId.Analyzer/TemplateAnalyzer.cs index 306fadd..1883b8a 100644 --- a/src/StructId.Analyzer/TemplateAnalyzer.cs +++ b/src/StructId.Analyzer/TemplateAnalyzer.cs @@ -30,8 +30,6 @@ public override void Initialize(AnalysisContext context) static void Analyze(SyntaxNodeAnalysisContext context) { - var ns = context.Options.AnalyzerConfigOptionsProvider.GlobalOptions.GetStructIdNamespace(); - if (context.Node is not TypeDeclarationSyntax typeDeclaration || !typeDeclaration.AttributeLists.Any(list => list.Attributes.Any(attr => attr.IsStructIdTemplate()))) return; diff --git a/src/StructId.Analyzer/TemplatedGenerator.cs b/src/StructId.Analyzer/TemplatedGenerator.cs index f5890f4..19b4af5 100644 --- a/src/StructId.Analyzer/TemplatedGenerator.cs +++ b/src/StructId.Analyzer/TemplatedGenerator.cs @@ -80,11 +80,8 @@ public bool AppliesTo(INamedTypeSymbol valueType) public void Initialize(IncrementalGeneratorInitializationContext context) { - var structIdNamespace = context.AnalyzerConfigOptionsProvider.GetStructIdNamespace(); - var known = context.CompilationProvider - .Combine(structIdNamespace) - .Select((x, _) => new KnownTypes(x.Left, x.Right)); + .Select((x, _) => new KnownTypes(x)); var templates = context.CompilationProvider .SelectMany((x, _) => x.GetAllTypes(includeReferenced: true).OfType()) @@ -99,38 +96,39 @@ public void Initialize(IncrementalGeneratorInitializationContext context) .Combine(known) .Select((x, cancellation) => { - var (structId, known) = x; + var (tself, known) = x; // We infer the idType from the required primary constructor Value parameter type - var idType = (INamedTypeSymbol)structId.GetMembers().OfType().First(p => p.Name == "Value").Type; - var attribute = structId.GetAttributes().First(a => a.AttributeClass != null && a.AttributeClass.Is(known.TStructId)); + var tvalue = (INamedTypeSymbol)tself.GetMembers().OfType().First(p => p.Name == "Value").Type; + var attribute = tself.GetAttributes().First(a => a.IsStructIdTemplate()); // The id type isn't declared in the same file, so we don't do anything fancy with it. - if (idType.DeclaringSyntaxReferences.Length == 0) - return new Template(structId, idType, attribute, known); + if (tvalue.DeclaringSyntaxReferences.Length == 0) + return new Template(tself, tvalue, attribute, known); // Otherwise, the idType is a file-local type with a single interface - var type = idType.DeclaringSyntaxReferences[0].GetSyntax(cancellation) as TypeDeclarationSyntax; + var type = tvalue.DeclaringSyntaxReferences[0].GetSyntax(cancellation) as TypeDeclarationSyntax; var iface = type?.BaseList?.Types.FirstOrDefault()?.Type; if (type == null || iface == null) - return new Template(structId, idType, attribute, known) { OriginalTValue = idType }; + return new Template(tself, tvalue, attribute, known) { OriginalTValue = tvalue }; if (x.Right.Compilation.GetSemanticModel(type.SyntaxTree).GetSymbolInfo(iface).Symbol is not INamedTypeSymbol ifaceType) - return new Template(structId, idType, attribute, known); + return new Template(tself, tvalue, attribute, known); // if the interface is a generic type with a single type argument that is the same as the idType // make it an unbound generic type. We'll bind it to the actual idType later at template render time. - if (ifaceType.IsGenericType && ifaceType.TypeArguments.Length == 1 && ifaceType.TypeArguments[0].Equals(idType, SymbolEqualityComparer.Default)) + if (ifaceType.IsGenericType && ifaceType.TypeArguments.Length == 1 && ifaceType.TypeArguments[0].Equals(tvalue, SymbolEqualityComparer.Default)) ifaceType = ifaceType.ConstructUnboundGenericType(); - return new Template(structId, ifaceType, attribute, known) + return new Template(tself, ifaceType, attribute, known) { - OriginalTValue = idType + OriginalTValue = tvalue }; }) .Collect(); var ids = context.CompilationProvider - .SelectMany((x, _) => x.Assembly.GetAllTypes().OfType()) + .SelectMany((x, _) => x.Assembly.GetAllTypes().OfType() + .Where(t => !t.IsValueTemplate() && !t.IsStructIdTemplate())) .Where(x => x.IsRecord && x.IsValueType && x.IsPartial()) .Combine(known) .Where(x => x.Left.Is(x.Right.IStructId) || x.Left.Is(x.Right.IStructIdT)) diff --git a/src/StructId.Analyzer/TemplatizedTValueExtensions.cs b/src/StructId.Analyzer/TemplatizedTValueExtensions.cs index a8a85af..6e1c9fb 100644 --- a/src/StructId.Analyzer/TemplatizedTValueExtensions.cs +++ b/src/StructId.Analyzer/TemplatizedTValueExtensions.cs @@ -89,11 +89,8 @@ static class TemplatizedTValueExtensions /// public static IncrementalValuesProvider SelectTemplatizedValues(this IncrementalGeneratorInitializationContext context) { - var structIdNamespace = context.AnalyzerConfigOptionsProvider.GetStructIdNamespace(); - var known = context.CompilationProvider - .Combine(structIdNamespace) - .Select((x, _) => new KnownTypes(x.Left, x.Right)); + .Select((x, _) => new KnownTypes(x)); var templates = context.CompilationProvider .SelectMany((x, _) => x.GetAllTypes(includeReferenced: true).OfType()) diff --git a/src/StructId.Package/StructId.props b/src/StructId.Package/StructId.props index ba6cd04..4de98b5 100644 --- a/src/StructId.Package/StructId.props +++ b/src/StructId.Package/StructId.props @@ -1,7 +1,3 @@ - - - - \ No newline at end of file diff --git a/src/StructId/IStructId.cs b/src/StructId/IStructId.cs index d0296cf..9b41dbb 100644 --- a/src/StructId/IStructId.cs +++ b/src/StructId/IStructId.cs @@ -1,10 +1,13 @@ // +using System.CodeDom.Compiler; + namespace StructId; /// /// Interface for string-based identifiers. /// +[GeneratedCode("StructId", default)] public partial interface IStructId { /// diff --git a/src/StructId/IStructIdT.cs b/src/StructId/IStructIdT.cs index d03e9bd..9e2557c 100644 --- a/src/StructId/IStructIdT.cs +++ b/src/StructId/IStructIdT.cs @@ -1,11 +1,14 @@ // +using System.CodeDom.Compiler; + namespace StructId; /// /// Interface for struct-based identifiers. /// /// The struct type for the inner of the identifier. +[GeneratedCode("StructId", default)] public partial interface IStructId where TValue : struct { ///