Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,12 @@ protected override async Task<Document> ConvertAssertionAsync(CodeFixContext con
Dictionary<SyntaxNode, SyntaxNode> setupReplacements =
FindAndBuildSetupReplacements(allInvocations, semanticModel, mockSymbol, cancellationToken);

Dictionary<InvocationExpressionSyntax, InvocationExpressionSyntax> verifyReplacements =
FindAndBuildVerifyReplacements(allInvocations, semanticModel, mockSymbol, cancellationToken);

List<SyntaxNode> nodesToReplace = [substituteCall,];
nodesToReplace.AddRange(setupReplacements.Keys);
nodesToReplace.AddRange(verifyReplacements.Keys);

compilationUnit = compilationUnit.ReplaceNodes(
nodesToReplace,
Expand All @@ -77,9 +81,15 @@ protected override async Task<Document> ConvertAssertionAsync(CodeFixContext con
return creationReplacement.WithTriviaFrom(substituteCall);
}

if (setupReplacements.TryGetValue(original, out SyntaxNode? replacement))
if (setupReplacements.TryGetValue(original, out SyntaxNode? setupReplacement))
{
return replacement;
return setupReplacement;
}

if (original is InvocationExpressionSyntax invocation &&
verifyReplacements.TryGetValue(invocation, out InvocationExpressionSyntax? verifyReplacement))
{
return verifyReplacement;
}

return original;
Expand All @@ -92,6 +102,16 @@ protected override async Task<Document> ConvertAssertionAsync(CodeFixContext con
compilationUnit = compilationUnit.AddUsings(usingDirective);
}

if (verifyReplacements.Count > 0)
{
bool hasVerifyUsing = compilationUnit.Usings.Any(u => u.Name?.ToString() == "Mockolate.Verify");
if (!hasVerifyUsing)
{
UsingDirectiveSyntax verifyUsingDirective = BuildUsingDirective(compilationUnit, "Mockolate.Verify");
compilationUnit = compilationUnit.AddUsings(verifyUsingDirective);
}
}

return document.WithSyntaxRoot(compilationUnit);
}

Expand Down Expand Up @@ -214,6 +234,111 @@ private static Dictionary<SyntaxNode, SyntaxNode> FindAndBuildSetupReplacements(
return result;
}

private static Dictionary<InvocationExpressionSyntax, InvocationExpressionSyntax> FindAndBuildVerifyReplacements(
IReadOnlyList<InvocationExpressionSyntax> allInvocations,
SemanticModel? semanticModel,
ISymbol? mockSymbol,
CancellationToken cancellationToken)
{
if (semanticModel is null || mockSymbol is null)
{
return [];
}

Dictionary<InvocationExpressionSyntax, InvocationExpressionSyntax> result = [];

foreach (InvocationExpressionSyntax outerInvocation in allInvocations)
{
if (outerInvocation.Expression is not MemberAccessExpressionSyntax outerAccess)
{
continue;
}

// outerInvocation is `something.MethodName(args)`; receiver `something` should be `sub.Received(...)`
// (or DidNotReceive). The receiver of `sub.Received()` is the tracked mock symbol.
if (outerAccess.Expression is not InvocationExpressionSyntax receiverCall ||
receiverCall.Expression is not MemberAccessExpressionSyntax receiverAccess)
{
continue;
}

string receivedMethod = receiverAccess.Name.Identifier.Text;
if (receivedMethod is not ("Received" or "DidNotReceive"))
{
continue;
}

if (!IsTrackedMockReceiver(receiverAccess.Expression, semanticModel, mockSymbol, cancellationToken))
{
continue;
}

ExpressionSyntax mockReceiver = receiverAccess.Expression;
ArgumentListSyntax transformedArgs =
TransformNSubstituteArgReferences(outerInvocation.ArgumentList, semanticModel, cancellationToken);
SimpleNameSyntax methodNameSyntax = outerAccess.Name;

MemberAccessExpressionSyntax verifyAccess = BuildVerifyAccess(mockReceiver, methodNameSyntax);
InvocationExpressionSyntax verifyInvocation = SyntaxFactory.InvocationExpression(verifyAccess, transformedArgs);

InvocationExpressionSyntax suffix = BuildVerifySuffix(verifyInvocation, receivedMethod, receiverCall.ArgumentList);

result[outerInvocation] = suffix.WithTriviaFrom(outerInvocation);
}

return result;
}

private static MemberAccessExpressionSyntax BuildVerifyAccess(ExpressionSyntax receiver, SimpleNameSyntax memberName)
{
MemberAccessExpressionSyntax mockAccess = SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
receiver,
SyntaxFactory.IdentifierName("Mock"));
MemberAccessExpressionSyntax verifyMember = SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
mockAccess,
SyntaxFactory.IdentifierName("Verify"));
return SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
verifyMember,
memberName);
}

private static InvocationExpressionSyntax BuildVerifySuffix(InvocationExpressionSyntax verifyInvocation,
string receivedMethod, ArgumentListSyntax receivedArgs)
{
// DidNotReceive() → .Never(); Received() → .AtLeastOnce(); Received(n) → .Exactly(n) or .Once() when n is 1.
if (receivedMethod == "DidNotReceive")
{
return AppendCountCall(verifyInvocation, "Never", SyntaxFactory.ArgumentList());
}

if (receivedArgs.Arguments.Count == 0)
{
return AppendCountCall(verifyInvocation, "AtLeastOnce", SyntaxFactory.ArgumentList());
}

// Received(n): when n is the literal integer 1 we collapse to Once(); otherwise pass through to Exactly(n).
if (receivedArgs.Arguments.Count == 1 &&
receivedArgs.Arguments[0].Expression is LiteralExpressionSyntax literal &&
literal.Token.Value is 1)
{
return AppendCountCall(verifyInvocation, "Once", SyntaxFactory.ArgumentList());
}

return AppendCountCall(verifyInvocation, "Exactly", receivedArgs);
}

private static InvocationExpressionSyntax AppendCountCall(InvocationExpressionSyntax verifyInvocation,
string methodName, ArgumentListSyntax argList) =>
SyntaxFactory.InvocationExpression(
SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
verifyInvocation,
SyntaxFactory.IdentifierName(methodName)),
argList);

/// <summary>
/// Returns <see langword="true" /> when <paramref name="expression" /> ultimately resolves to the tracked
/// mock symbol — either directly or via a chain of property/field accesses (auto-mocked nested members).
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
using Verifier = Mockolate.Migration.Tests.Verifiers.CSharpCodeFixVerifier<Mockolate.Migration.Analyzers.NSubstituteAnalyzer,
Mockolate.Migration.Analyzers.NSubstituteCodeFixProvider>;

namespace Mockolate.Migration.Tests;

public partial class NSubstituteCodeFixProviderTests
{
public sealed class VerifyTests
{
[Fact]
public async Task DidNotReceive_IsRewrittenToNever()
=> await Verifier.VerifyCodeFixAsync(
"""
using NSubstitute;

public interface IFoo { void Bar(int x); }

public class Tests
{
public void Test()
{
var sub = [|Substitute.For<IFoo>()|];
sub.DidNotReceive().Bar(1);
}
}
""",
"""
using NSubstitute;
using Mockolate;
using Mockolate.Verify;

public interface IFoo { void Bar(int x); }

public class Tests
{
public void Test()
{
var sub = IFoo.CreateMock();
sub.Mock.Verify.Bar(1).Never();
}
}
""");

[Fact]
public async Task Received_IsRewrittenToAtLeastOnce()
=> await Verifier.VerifyCodeFixAsync(
"""
using NSubstitute;

public interface IFoo { void Bar(int x); }

public class Tests
{
public void Test()
{
var sub = [|Substitute.For<IFoo>()|];
sub.Bar(1);
sub.Received().Bar(1);
}
}
""",
"""
using NSubstitute;
using Mockolate;
using Mockolate.Verify;

public interface IFoo { void Bar(int x); }

public class Tests
{
public void Test()
{
var sub = IFoo.CreateMock();
sub.Bar(1);
sub.Mock.Verify.Bar(1).AtLeastOnce();
}
}
""");

[Fact]
public async Task ReceivedExactCount_IsRewrittenToExactly()
=> await Verifier.VerifyCodeFixAsync(
"""
using NSubstitute;

public interface IFoo { void Bar(int x); }

public class Tests
{
public void Test()
{
var sub = [|Substitute.For<IFoo>()|];
sub.Received(3).Bar(1);
}
}
""",
"""
using NSubstitute;
using Mockolate;
using Mockolate.Verify;

public interface IFoo { void Bar(int x); }

public class Tests
{
public void Test()
{
var sub = IFoo.CreateMock();
sub.Mock.Verify.Bar(1).Exactly(3);
}
}
""");

[Fact]
public async Task ReceivedOne_IsRewrittenToOnce()
=> await Verifier.VerifyCodeFixAsync(
"""
using NSubstitute;

public interface IFoo { void Bar(int x); }

public class Tests
{
public void Test()
{
var sub = [|Substitute.For<IFoo>()|];
sub.Received(1).Bar(1);
}
}
""",
"""
using NSubstitute;
using Mockolate;
using Mockolate.Verify;

public interface IFoo { void Bar(int x); }

public class Tests
{
public void Test()
{
var sub = IFoo.CreateMock();
sub.Mock.Verify.Bar(1).Once();
}
}
""");

[Fact]
public async Task ReceivedWithArgMatcher_TransformsMatcher()
=> await Verifier.VerifyCodeFixAsync(
"""
using NSubstitute;

public interface IFoo { void Bar(int x); }

public class Tests
{
public void Test()
{
var sub = [|Substitute.For<IFoo>()|];
sub.Received().Bar(Arg.Any<int>());
}
}
""",
"""
using NSubstitute;
using Mockolate;
using Mockolate.Verify;

public interface IFoo { void Bar(int x); }

public class Tests
{
public void Test()
{
var sub = IFoo.CreateMock();
sub.Mock.Verify.Bar(It.IsAny<int>()).AtLeastOnce();
}
}
""");
}
}
Loading