diff --git a/src/StructId.Analyzer/RecordAnalyzer.cs b/src/StructId.Analyzer/RecordAnalyzer.cs index c6678f6..c62513b 100644 --- a/src/StructId.Analyzer/RecordAnalyzer.cs +++ b/src/StructId.Analyzer/RecordAnalyzer.cs @@ -22,31 +22,42 @@ public override void Initialize(AnalysisContext context) if (!Debugger.IsAttached) context.EnableConcurrentExecution(); - context.RegisterSyntaxNodeAction(Analyze, SyntaxKind.ClassDeclaration); - context.RegisterSyntaxNodeAction(Analyze, SyntaxKind.StructDeclaration); - context.RegisterSyntaxNodeAction(Analyze, SyntaxKind.RecordDeclaration); - context.RegisterSyntaxNodeAction(Analyze, SyntaxKind.RecordStructDeclaration); + context.RegisterCompilationStartAction(start => + { + var known = new KnownTypes(start.Compilation); + if (known.IStructId is null || known.IStructIdT is null) + return; + + start.RegisterSymbolAction(AnalyzeSymbol, SymbolKind.NamedType); + }); } - static void Analyze(SyntaxNodeAnalysisContext context) + static void AnalyzeSymbol(SymbolAnalysisContext context) { + if (context.Symbol is not INamedTypeSymbol symbol) + return; + var known = new KnownTypes(context.Compilation); - if (context.Node is not TypeDeclarationSyntax typeDeclaration || - known.IStructIdT is not { } structIdTypeOfT || - known.IStructId is not { } structIdType) + // We only care about IStructId and IStructId + if (!symbol.Is(known.IStructId) && !symbol.Is(known.IStructIdT)) return; - var symbol = context.SemanticModel.GetDeclaredSymbol(typeDeclaration); - if (symbol is null) + // We can only analyze if there's a declaration in source. + if (symbol.DeclaringSyntaxReferences.Length == 0 || + symbol.DeclaringSyntaxReferences + .Select(x => x.GetSyntax()) + .OfType() + .FirstOrDefault() is not { } typeDeclaration) return; - if (!symbol.Is(structIdType) && !symbol.Is(structIdTypeOfT)) - return; + // TODO: report or ignore if more than one declaration? // If there's only one declaration and it's not partial - var report = symbol.DeclaringSyntaxReferences.Length == 1 && !typeDeclaration.Modifiers.Any(SyntaxKind.PartialKeyword); - report |= !typeDeclaration.IsKind(SyntaxKind.RecordStructDeclaration) || !symbol.IsReadOnly; + var report = symbol.DeclaringSyntaxReferences.Length == 1 && + !typeDeclaration.Modifiers.Any(SyntaxKind.PartialKeyword); + + report |= !symbol.IsRecord || symbol.TypeKind != TypeKind.Struct || !symbol.IsReadOnly; if (report) { @@ -55,7 +66,7 @@ known.IStructIdT is not { } structIdTypeOfT || else if (typeDeclaration.BaseList?.Types.FirstOrDefault(t => t.Type is IdentifierNameSyntax { Identifier.Text: "IStructId" }) is { } implementation) context.ReportDiagnostic(Diagnostic.Create(MustBeRecordStruct, implementation.GetLocation(), symbol.Name)); else - context.ReportDiagnostic(Diagnostic.Create(MustBeRecordStruct, typeDeclaration.Identifier.GetLocation(), symbol.Name)); + context.ReportDiagnostic(Diagnostic.Create(MustBeRecordStruct, symbol.Locations.FirstOrDefault(), symbol.Name)); } if (typeDeclaration.ParameterList is null) @@ -72,6 +83,5 @@ known.IStructIdT is not { } structIdTypeOfT || var parameter = typeDeclaration.ParameterList.Parameters[0]; if (parameter.Identifier.Text != "Value") context.ReportDiagnostic(Diagnostic.Create(MustHaveValueConstructor, parameter.Identifier.GetLocation(), symbol.Name)); - } }