Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System.Collections.Generic;
using System.Composition;
using System.Linq;
using System.Threading;
Expand All @@ -7,6 +8,8 @@
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;

#pragma warning disable S1192 // String literals should not be duplicated
#pragma warning disable S3776 // Cognitive Complexity of methods should not be too high
Comment thread
vbreuss marked this conversation as resolved.
namespace Mockolate.Migration.Analyzers;

/// <summary>
Expand All @@ -16,6 +19,19 @@ namespace Mockolate.Migration.Analyzers;
[Shared]
public class NSubstituteCodeFixProvider() : AssertionCodeFixProvider(Rules.NSubstituteRule)
{
private static readonly HashSet<string> SetupConfiguratorMethods =
[
"Returns",
"ReturnsForAnyArgs",
"ReturnsNull",
"ReturnsNullForAnyArgs",
"Throws",
"ThrowsForAnyArgs",
"ThrowsAsync",
"ThrowsAsyncForAnyArgs",
"AndDoes",
];

/// <inheritdoc />
protected override async Task<Document> ConvertAssertionAsync(CodeFixContext context,
ExpressionSyntax expressionSyntax, CancellationToken cancellationToken)
Expand All @@ -36,13 +52,38 @@ protected override async Task<Document> ConvertAssertionAsync(CodeFixContext con
return document;
}

ExpressionSyntax? replacement = BuildCreationReplacement(substituteCall);
if (replacement is null)
ExpressionSyntax? creationReplacement = BuildCreationReplacement(substituteCall);
if (creationReplacement is null)
{
return document;
}

compilationUnit = compilationUnit.ReplaceNode(substituteCall, replacement.WithTriviaFrom(substituteCall));
ISymbol? mockSymbol = GetDeclaredMockSymbol(semanticModel, substituteCall, cancellationToken);
IReadOnlyList<InvocationExpressionSyntax> allInvocations =
compilationUnit.DescendantNodes().OfType<InvocationExpressionSyntax>().ToList();

Dictionary<SyntaxNode, SyntaxNode> setupReplacements =
FindAndBuildSetupReplacements(allInvocations, semanticModel, mockSymbol, cancellationToken);

List<SyntaxNode> nodesToReplace = [substituteCall,];
nodesToReplace.AddRange(setupReplacements.Keys);

compilationUnit = compilationUnit.ReplaceNodes(
nodesToReplace,
(original, _) =>
{
if (original == substituteCall)
{
return creationReplacement.WithTriviaFrom(substituteCall);
}

if (setupReplacements.TryGetValue(original, out SyntaxNode? replacement))
{
return replacement;
}

return original;
});

bool hasUsing = compilationUnit.Usings.Any(u => u.Name?.ToString() == "Mockolate");
if (!hasUsing)
Expand Down Expand Up @@ -84,6 +125,230 @@ protected override async Task<Document> ConvertAssertionAsync(CodeFixContext con
return null;
}

private static ISymbol? GetDeclaredMockSymbol(SemanticModel? semanticModel,
InvocationExpressionSyntax substituteCall, CancellationToken cancellationToken)
{
if (semanticModel is null)
{
return null;
}

// The substitute call may be wrapped: var sub = Substitute.For<T>().Implementing<T2>(); we still want
// the variable declarator above. Walk up through expressions to find the EqualsValueClause.
SyntaxNode current = substituteCall;
while (current.Parent is ExpressionSyntax)
{
current = current.Parent;
}

return current.Parent switch
{
EqualsValueClauseSyntax { Parent: VariableDeclaratorSyntax declarator, }
=> semanticModel.GetDeclaredSymbol(declarator, cancellationToken),
EqualsValueClauseSyntax { Parent: PropertyDeclarationSyntax prop, }
=> semanticModel.GetDeclaredSymbol(prop, cancellationToken),
_ => null,
};
}

private static Dictionary<SyntaxNode, SyntaxNode> FindAndBuildSetupReplacements(
IReadOnlyList<InvocationExpressionSyntax> allInvocations,
SemanticModel? semanticModel,
ISymbol? mockSymbol,
CancellationToken cancellationToken)
{
if (semanticModel is null || mockSymbol is null)
{
return [];
}

Dictionary<SyntaxNode, SyntaxNode> result = [];

foreach (InvocationExpressionSyntax outerInvocation in allInvocations)
{
if (outerInvocation.Expression is not MemberAccessExpressionSyntax outerAccess ||
!SetupConfiguratorMethods.Contains(outerAccess.Name.Identifier.Text))
{
continue;
}

// The receiver of the configurator (e.g. .Returns) is either
// sub.Method(args) — pattern A
// sub.Property — pattern B
ExpressionSyntax receiver = outerAccess.Expression;

if (receiver is InvocationExpressionSyntax targetInvocation &&
targetInvocation.Expression is MemberAccessExpressionSyntax targetMemberAccess)
{
if (!IsTrackedMockReceiver(targetMemberAccess.Expression, semanticModel, mockSymbol, cancellationToken))
{
continue;
}

ArgumentListSyntax transformedArgs =
TransformNSubstituteArgReferences(targetInvocation.ArgumentList, semanticModel, cancellationToken);

MemberAccessExpressionSyntax setupAccess = BuildSetupAccess(
targetMemberAccess.Expression, targetMemberAccess.Name);
InvocationExpressionSyntax setupInvocation = SyntaxFactory.InvocationExpression(setupAccess, transformedArgs)
.WithTriviaFrom(targetInvocation);

result[targetInvocation] = setupInvocation;
continue;
}

if (receiver is MemberAccessExpressionSyntax targetPropertyAccess)
{
if (!IsTrackedMockReceiver(targetPropertyAccess.Expression, semanticModel, mockSymbol, cancellationToken))
{
continue;
}

MemberAccessExpressionSyntax setupAccess = BuildSetupAccess(
targetPropertyAccess.Expression, targetPropertyAccess.Name);

result[targetPropertyAccess] = setupAccess.WithTriviaFrom(targetPropertyAccess);
}
}

return result;
}

/// <summary>
/// Returns <see langword="true" /> when <paramref name="expression" /> ultimately resolves to the tracked
/// mock symbol — either directly or via a chain of property/field accesses (auto-mocked nested members).
/// </summary>
private static bool IsTrackedMockReceiver(ExpressionSyntax expression,
SemanticModel semanticModel, ISymbol mockSymbol, CancellationToken cancellationToken)
{
ExpressionSyntax current = expression;
while (true)
{
SymbolInfo info = semanticModel.GetSymbolInfo(current, cancellationToken);
if (SymbolEqualityComparer.Default.Equals(info.Symbol, mockSymbol))
{
return true;
}

if (current is MemberAccessExpressionSyntax memberAccess)
{
current = memberAccess.Expression;
continue;
}

return false;
}
}

private static MemberAccessExpressionSyntax BuildSetupAccess(ExpressionSyntax receiver, SimpleNameSyntax memberName)
{
MemberAccessExpressionSyntax mockAccess = SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
receiver,
SyntaxFactory.IdentifierName("Mock"));
MemberAccessExpressionSyntax setupAccess = SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
mockAccess,
SyntaxFactory.IdentifierName("Setup"));
return SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
setupAccess,
memberName);
}

/// <summary>
/// Translates NSubstitute argument matchers to their Mockolate equivalents anywhere inside the supplied
/// argument list. Currently handles <c>Arg.Any&lt;T&gt;</c>, <c>Arg.Is</c>, and the <c>Arg.Compat</c>
/// mirrors. Uses the semantic model so fully-qualified usages
/// (<c>NSubstitute.Arg.Any&lt;T&gt;()</c>, aliased imports, etc.) are recognised too.
/// </summary>
private static ArgumentListSyntax TransformNSubstituteArgReferences(ArgumentListSyntax args,
SemanticModel semanticModel, CancellationToken cancellationToken) =>
args.ReplaceNodes(
args.DescendantNodes().OfType<InvocationExpressionSyntax>()
.Where(invocation => IsNSubstituteArgCall(invocation, semanticModel, cancellationToken))
.ToArray(),
TransformNSubstituteArgInvocation);

private static bool IsNSubstituteArgCall(InvocationExpressionSyntax invocation,
SemanticModel semanticModel, CancellationToken cancellationToken)
{
if (semanticModel.GetSymbolInfo(invocation, cancellationToken).Symbol is not IMethodSymbol methodSymbol)
{
return false;
}

// Walk up nested types so both Arg.X(...) and Arg.Compat.X(...) resolve to NSubstitute.Arg.
for (INamedTypeSymbol? containingType = methodSymbol.ContainingType;
containingType is not null;
containingType = containingType.ContainingType)
{
if (containingType.Name == "Arg" &&
containingType.ContainingNamespace?.ToDisplayString() == "NSubstitute")
{
return true;
}
}

return false;
}

private static SyntaxNode TransformNSubstituteArgInvocation(InvocationExpressionSyntax original,
InvocationExpressionSyntax rewritten)
{
if (rewritten.Expression is not MemberAccessExpressionSyntax memberAccess)
{
return rewritten;
}

string methodName = memberAccess.Name.Identifier.Text;
TypeArgumentListSyntax? typeArgs = (memberAccess.Name as GenericNameSyntax)?.TypeArgumentList;

IdentifierNameSyntax itIdentifier = SyntaxFactory.IdentifierName("It");

switch (methodName)
{
case "Any":
// Arg.Any<T>() → It.IsAny<T>()
return BuildItInvocation(itIdentifier, "IsAny", typeArgs, SyntaxFactory.ArgumentList())
.WithTriviaFrom(original);

case "AnyType":
// Arg.AnyType — used as a type marker, not an invocation. Skip.
return rewritten;

case "Is":
// Arg.Is<T>(predicate) → It.Satisfies<T>(predicate)
// Arg.Is<T>(value) / Arg.Is(v) → It.Is<T>(value) (or just inline value)
if (rewritten.ArgumentList.Arguments.Count == 1 &&
rewritten.ArgumentList.Arguments[0].Expression is LambdaExpressionSyntax)
{
return BuildItInvocation(itIdentifier, "Satisfies", typeArgs, rewritten.ArgumentList)
.WithTriviaFrom(original);
}

return BuildItInvocation(itIdentifier, "Is", typeArgs, rewritten.ArgumentList)
.WithTriviaFrom(original);

default:
return rewritten;
}
}

private static InvocationExpressionSyntax BuildItInvocation(IdentifierNameSyntax itIdentifier, string methodName,
TypeArgumentListSyntax? typeArgs, ArgumentListSyntax argList)
{
SimpleNameSyntax method = typeArgs is null
? SyntaxFactory.IdentifierName(methodName)
: SyntaxFactory.GenericName(SyntaxFactory.Identifier(methodName)).WithTypeArgumentList(typeArgs);
return SyntaxFactory.InvocationExpression(
SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
itIdentifier,
method),
argList);
}

/// <summary>
/// Translates the NSubstitute creation call to a Mockolate creation chain. Returns <see langword="null" />
/// when the call cannot be migrated.
Expand Down Expand Up @@ -196,3 +461,6 @@ private static string DetectLineEnding(SyntaxNode root)
return "\n";
}
}

#pragma warning restore S3776 // Cognitive Complexity of methods should not be too high
#pragma warning restore S1192
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

namespace Mockolate.Migration.Tests;

public class NSubstituteCodeFixProviderTests
public partial class NSubstituteCodeFixProviderTests
{
public sealed class CreationTests
{
Expand Down
Loading
Loading