diff --git a/Source/Mockolate.Migration.Analyzers.CodeFixers/NSubstituteCodeFixProvider.cs b/Source/Mockolate.Migration.Analyzers.CodeFixers/NSubstituteCodeFixProvider.cs index ac9cd13..bfdb987 100644 --- a/Source/Mockolate.Migration.Analyzers.CodeFixers/NSubstituteCodeFixProvider.cs +++ b/Source/Mockolate.Migration.Analyzers.CodeFixers/NSubstituteCodeFixProvider.cs @@ -1,3 +1,4 @@ +using System.Collections.Generic; using System.Composition; using System.Linq; using System.Threading; @@ -7,6 +8,8 @@ using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; +#pragma warning disable S1192 // String literals should not be duplicated +#pragma warning disable S3776 // Cognitive Complexity of methods should not be too high namespace Mockolate.Migration.Analyzers; /// @@ -16,6 +19,19 @@ namespace Mockolate.Migration.Analyzers; [Shared] public class NSubstituteCodeFixProvider() : AssertionCodeFixProvider(Rules.NSubstituteRule) { + private static readonly HashSet SetupConfiguratorMethods = + [ + "Returns", + "ReturnsForAnyArgs", + "ReturnsNull", + "ReturnsNullForAnyArgs", + "Throws", + "ThrowsForAnyArgs", + "ThrowsAsync", + "ThrowsAsyncForAnyArgs", + "AndDoes", + ]; + /// protected override async Task ConvertAssertionAsync(CodeFixContext context, ExpressionSyntax expressionSyntax, CancellationToken cancellationToken) @@ -36,13 +52,38 @@ protected override async Task ConvertAssertionAsync(CodeFixContext con return document; } - ExpressionSyntax? replacement = BuildCreationReplacement(substituteCall); - if (replacement is null) + ExpressionSyntax? creationReplacement = BuildCreationReplacement(substituteCall); + if (creationReplacement is null) { return document; } - compilationUnit = compilationUnit.ReplaceNode(substituteCall, replacement.WithTriviaFrom(substituteCall)); + ISymbol? mockSymbol = GetDeclaredMockSymbol(semanticModel, substituteCall, cancellationToken); + IReadOnlyList allInvocations = + compilationUnit.DescendantNodes().OfType().ToList(); + + Dictionary setupReplacements = + FindAndBuildSetupReplacements(allInvocations, semanticModel, mockSymbol, cancellationToken); + + List nodesToReplace = [substituteCall,]; + nodesToReplace.AddRange(setupReplacements.Keys); + + compilationUnit = compilationUnit.ReplaceNodes( + nodesToReplace, + (original, _) => + { + if (original == substituteCall) + { + return creationReplacement.WithTriviaFrom(substituteCall); + } + + if (setupReplacements.TryGetValue(original, out SyntaxNode? replacement)) + { + return replacement; + } + + return original; + }); bool hasUsing = compilationUnit.Usings.Any(u => u.Name?.ToString() == "Mockolate"); if (!hasUsing) @@ -84,6 +125,230 @@ protected override async Task ConvertAssertionAsync(CodeFixContext con return null; } + private static ISymbol? GetDeclaredMockSymbol(SemanticModel? semanticModel, + InvocationExpressionSyntax substituteCall, CancellationToken cancellationToken) + { + if (semanticModel is null) + { + return null; + } + + // The substitute call may be wrapped: var sub = Substitute.For().Implementing(); we still want + // the variable declarator above. Walk up through expressions to find the EqualsValueClause. + SyntaxNode current = substituteCall; + while (current.Parent is ExpressionSyntax) + { + current = current.Parent; + } + + return current.Parent switch + { + EqualsValueClauseSyntax { Parent: VariableDeclaratorSyntax declarator, } + => semanticModel.GetDeclaredSymbol(declarator, cancellationToken), + EqualsValueClauseSyntax { Parent: PropertyDeclarationSyntax prop, } + => semanticModel.GetDeclaredSymbol(prop, cancellationToken), + _ => null, + }; + } + + private static Dictionary FindAndBuildSetupReplacements( + IReadOnlyList allInvocations, + SemanticModel? semanticModel, + ISymbol? mockSymbol, + CancellationToken cancellationToken) + { + if (semanticModel is null || mockSymbol is null) + { + return []; + } + + Dictionary result = []; + + foreach (InvocationExpressionSyntax outerInvocation in allInvocations) + { + if (outerInvocation.Expression is not MemberAccessExpressionSyntax outerAccess || + !SetupConfiguratorMethods.Contains(outerAccess.Name.Identifier.Text)) + { + continue; + } + + // The receiver of the configurator (e.g. .Returns) is either + // sub.Method(args) — pattern A + // sub.Property — pattern B + ExpressionSyntax receiver = outerAccess.Expression; + + if (receiver is InvocationExpressionSyntax targetInvocation && + targetInvocation.Expression is MemberAccessExpressionSyntax targetMemberAccess) + { + if (!IsTrackedMockReceiver(targetMemberAccess.Expression, semanticModel, mockSymbol, cancellationToken)) + { + continue; + } + + ArgumentListSyntax transformedArgs = + TransformNSubstituteArgReferences(targetInvocation.ArgumentList, semanticModel, cancellationToken); + + MemberAccessExpressionSyntax setupAccess = BuildSetupAccess( + targetMemberAccess.Expression, targetMemberAccess.Name); + InvocationExpressionSyntax setupInvocation = SyntaxFactory.InvocationExpression(setupAccess, transformedArgs) + .WithTriviaFrom(targetInvocation); + + result[targetInvocation] = setupInvocation; + continue; + } + + if (receiver is MemberAccessExpressionSyntax targetPropertyAccess) + { + if (!IsTrackedMockReceiver(targetPropertyAccess.Expression, semanticModel, mockSymbol, cancellationToken)) + { + continue; + } + + MemberAccessExpressionSyntax setupAccess = BuildSetupAccess( + targetPropertyAccess.Expression, targetPropertyAccess.Name); + + result[targetPropertyAccess] = setupAccess.WithTriviaFrom(targetPropertyAccess); + } + } + + return result; + } + + /// + /// Returns when ultimately resolves to the tracked + /// mock symbol — either directly or via a chain of property/field accesses (auto-mocked nested members). + /// + private static bool IsTrackedMockReceiver(ExpressionSyntax expression, + SemanticModel semanticModel, ISymbol mockSymbol, CancellationToken cancellationToken) + { + ExpressionSyntax current = expression; + while (true) + { + SymbolInfo info = semanticModel.GetSymbolInfo(current, cancellationToken); + if (SymbolEqualityComparer.Default.Equals(info.Symbol, mockSymbol)) + { + return true; + } + + if (current is MemberAccessExpressionSyntax memberAccess) + { + current = memberAccess.Expression; + continue; + } + + return false; + } + } + + private static MemberAccessExpressionSyntax BuildSetupAccess(ExpressionSyntax receiver, SimpleNameSyntax memberName) + { + MemberAccessExpressionSyntax mockAccess = SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + receiver, + SyntaxFactory.IdentifierName("Mock")); + MemberAccessExpressionSyntax setupAccess = SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + mockAccess, + SyntaxFactory.IdentifierName("Setup")); + return SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + setupAccess, + memberName); + } + + /// + /// Translates NSubstitute argument matchers to their Mockolate equivalents anywhere inside the supplied + /// argument list. Currently handles Arg.Any<T>, Arg.Is, and the Arg.Compat + /// mirrors. Uses the semantic model so fully-qualified usages + /// (NSubstitute.Arg.Any<T>(), aliased imports, etc.) are recognised too. + /// + private static ArgumentListSyntax TransformNSubstituteArgReferences(ArgumentListSyntax args, + SemanticModel semanticModel, CancellationToken cancellationToken) => + args.ReplaceNodes( + args.DescendantNodes().OfType() + .Where(invocation => IsNSubstituteArgCall(invocation, semanticModel, cancellationToken)) + .ToArray(), + TransformNSubstituteArgInvocation); + + private static bool IsNSubstituteArgCall(InvocationExpressionSyntax invocation, + SemanticModel semanticModel, CancellationToken cancellationToken) + { + if (semanticModel.GetSymbolInfo(invocation, cancellationToken).Symbol is not IMethodSymbol methodSymbol) + { + return false; + } + + // Walk up nested types so both Arg.X(...) and Arg.Compat.X(...) resolve to NSubstitute.Arg. + for (INamedTypeSymbol? containingType = methodSymbol.ContainingType; + containingType is not null; + containingType = containingType.ContainingType) + { + if (containingType.Name == "Arg" && + containingType.ContainingNamespace?.ToDisplayString() == "NSubstitute") + { + return true; + } + } + + return false; + } + + private static SyntaxNode TransformNSubstituteArgInvocation(InvocationExpressionSyntax original, + InvocationExpressionSyntax rewritten) + { + if (rewritten.Expression is not MemberAccessExpressionSyntax memberAccess) + { + return rewritten; + } + + string methodName = memberAccess.Name.Identifier.Text; + TypeArgumentListSyntax? typeArgs = (memberAccess.Name as GenericNameSyntax)?.TypeArgumentList; + + IdentifierNameSyntax itIdentifier = SyntaxFactory.IdentifierName("It"); + + switch (methodName) + { + case "Any": + // Arg.Any() → It.IsAny() + return BuildItInvocation(itIdentifier, "IsAny", typeArgs, SyntaxFactory.ArgumentList()) + .WithTriviaFrom(original); + + case "AnyType": + // Arg.AnyType — used as a type marker, not an invocation. Skip. + return rewritten; + + case "Is": + // Arg.Is(predicate) → It.Satisfies(predicate) + // Arg.Is(value) / Arg.Is(v) → It.Is(value) (or just inline value) + if (rewritten.ArgumentList.Arguments.Count == 1 && + rewritten.ArgumentList.Arguments[0].Expression is LambdaExpressionSyntax) + { + return BuildItInvocation(itIdentifier, "Satisfies", typeArgs, rewritten.ArgumentList) + .WithTriviaFrom(original); + } + + return BuildItInvocation(itIdentifier, "Is", typeArgs, rewritten.ArgumentList) + .WithTriviaFrom(original); + + default: + return rewritten; + } + } + + private static InvocationExpressionSyntax BuildItInvocation(IdentifierNameSyntax itIdentifier, string methodName, + TypeArgumentListSyntax? typeArgs, ArgumentListSyntax argList) + { + SimpleNameSyntax method = typeArgs is null + ? SyntaxFactory.IdentifierName(methodName) + : SyntaxFactory.GenericName(SyntaxFactory.Identifier(methodName)).WithTypeArgumentList(typeArgs); + return SyntaxFactory.InvocationExpression( + SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + itIdentifier, + method), + argList); + } + /// /// Translates the NSubstitute creation call to a Mockolate creation chain. Returns /// when the call cannot be migrated. @@ -196,3 +461,6 @@ private static string DetectLineEnding(SyntaxNode root) return "\n"; } } + +#pragma warning restore S3776 // Cognitive Complexity of methods should not be too high +#pragma warning restore S1192 diff --git a/Tests/Mockolate.Migration.Tests/NSubstituteCodeFixProviderTests.CreationTests.cs b/Tests/Mockolate.Migration.Tests/NSubstituteCodeFixProviderTests.CreationTests.cs index cd81ea5..4ba392f 100644 --- a/Tests/Mockolate.Migration.Tests/NSubstituteCodeFixProviderTests.CreationTests.cs +++ b/Tests/Mockolate.Migration.Tests/NSubstituteCodeFixProviderTests.CreationTests.cs @@ -3,7 +3,7 @@ namespace Mockolate.Migration.Tests; -public class NSubstituteCodeFixProviderTests +public partial class NSubstituteCodeFixProviderTests { public sealed class CreationTests { diff --git a/Tests/Mockolate.Migration.Tests/NSubstituteCodeFixProviderTests.SetupTests.cs b/Tests/Mockolate.Migration.Tests/NSubstituteCodeFixProviderTests.SetupTests.cs new file mode 100644 index 0000000..0ff2830 --- /dev/null +++ b/Tests/Mockolate.Migration.Tests/NSubstituteCodeFixProviderTests.SetupTests.cs @@ -0,0 +1,278 @@ +using Verifier = Mockolate.Migration.Tests.Verifiers.CSharpCodeFixVerifier; + +namespace Mockolate.Migration.Tests; + +public partial class NSubstituteCodeFixProviderTests +{ + public sealed class SetupTests + { + [Fact] + public async Task ArgAny_IsRewrittenToItIsAny() + => await Verifier.VerifyCodeFixAsync( + """ + using NSubstitute; + + public interface IFoo { int Bar(int x); } + + public class Tests + { + public void Test() + { + var sub = [|Substitute.For()|]; + sub.Bar(Arg.Any()).Returns(42); + } + } + """, + """ + using NSubstitute; + using Mockolate; + + public interface IFoo { int Bar(int x); } + + public class Tests + { + public void Test() + { + var sub = IFoo.CreateMock(); + sub.Mock.Setup.Bar(It.IsAny()).Returns(42); + } + } + """); + + [Fact] + public async Task ArgCompat_IsRewrittenToItMatchers() + => await Verifier.VerifyCodeFixAsync( + """ + using NSubstitute; + + public interface IFoo { int Sum(int x, int y); } + + public class Tests + { + public void Test() + { + var sub = [|Substitute.For()|]; + sub.Sum(Arg.Compat.Any(), Arg.Compat.Is(y => y > 0)).Returns(42); + } + } + """, + """ + using NSubstitute; + using Mockolate; + + public interface IFoo { int Sum(int x, int y); } + + public class Tests + { + public void Test() + { + var sub = IFoo.CreateMock(); + sub.Mock.Setup.Sum(It.IsAny(), It.Satisfies(y => y > 0)).Returns(42); + } + } + """); + + [Fact] + public async Task ArgIsPredicate_IsRewrittenToItSatisfies() + => await Verifier.VerifyCodeFixAsync( + """ + using NSubstitute; + + public interface IFoo { int Bar(int x); } + + public class Tests + { + public void Test() + { + var sub = [|Substitute.For()|]; + sub.Bar(Arg.Is(x => x > 0)).Returns(42); + } + } + """, + """ + using NSubstitute; + using Mockolate; + + public interface IFoo { int Bar(int x); } + + public class Tests + { + public void Test() + { + var sub = IFoo.CreateMock(); + sub.Mock.Setup.Bar(It.Satisfies(x => x > 0)).Returns(42); + } + } + """); + + [Fact] + public async Task ArgIsValue_IsRewrittenToItIs() + => await Verifier.VerifyCodeFixAsync( + """ + using NSubstitute; + + public interface IFoo { int Bar(int x); } + + public class Tests + { + public void Test() + { + var sub = [|Substitute.For()|]; + sub.Bar(Arg.Is(5)).Returns(42); + } + } + """, + """ + using NSubstitute; + using Mockolate; + + public interface IFoo { int Bar(int x); } + + public class Tests + { + public void Test() + { + var sub = IFoo.CreateMock(); + sub.Mock.Setup.Bar(It.Is(5)).Returns(42); + } + } + """); + + [Fact] + public async Task MethodReturns_IsRewrittenToMockSetup() + => await Verifier.VerifyCodeFixAsync( + """ + using NSubstitute; + + public interface IFoo { int Bar(int x); } + + public class Tests + { + public void Test() + { + var sub = [|Substitute.For()|]; + sub.Bar(1).Returns(42); + } + } + """, + """ + using NSubstitute; + using Mockolate; + + public interface IFoo { int Bar(int x); } + + public class Tests + { + public void Test() + { + var sub = IFoo.CreateMock(); + sub.Mock.Setup.Bar(1).Returns(42); + } + } + """); + + [Fact] + public async Task MethodThrowsGeneric_IsRewrittenToMockSetup() + => await Verifier.VerifyCodeFixAsync( + """ + using System; + using NSubstitute; + using NSubstitute.ExceptionExtensions; + + public interface IFoo { int Bar(int x); } + + public class Tests + { + public void Test() + { + var sub = [|Substitute.For()|]; + sub.Bar(1).Throws(); + } + } + """, + """ + using System; + using NSubstitute; + using NSubstitute.ExceptionExtensions; + using Mockolate; + + public interface IFoo { int Bar(int x); } + + public class Tests + { + public void Test() + { + var sub = IFoo.CreateMock(); + sub.Mock.Setup.Bar(1).Throws(); + } + } + """); + + [Fact] + public async Task MultipleMatchers_AreAllRewritten() + => await Verifier.VerifyCodeFixAsync( + """ + using NSubstitute; + + public interface IFoo { int Sum(int x, int y); } + + public class Tests + { + public void Test() + { + var sub = [|Substitute.For()|]; + sub.Sum(Arg.Any(), Arg.Is(y => y > 0)).Returns(42); + } + } + """, + """ + using NSubstitute; + using Mockolate; + + public interface IFoo { int Sum(int x, int y); } + + public class Tests + { + public void Test() + { + var sub = IFoo.CreateMock(); + sub.Mock.Setup.Sum(It.IsAny(), It.Satisfies(y => y > 0)).Returns(42); + } + } + """); + + [Fact] + public async Task PropertyReturns_IsRewrittenToMockSetup() + => await Verifier.VerifyCodeFixAsync( + """ + using NSubstitute; + + public interface IFoo { string Name { get; } } + + public class Tests + { + public void Test() + { + var sub = [|Substitute.For()|]; + sub.Name.Returns("bar"); + } + } + """, + """ + using NSubstitute; + using Mockolate; + + public interface IFoo { string Name { get; } } + + public class Tests + { + public void Test() + { + var sub = IFoo.CreateMock(); + sub.Mock.Setup.Name.Returns("bar"); + } + } + """); + } +}