Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 135 additions & 1 deletion Source/Mockolate.Migration.Analyzers.CodeFixers/MoqCodeFixProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ protected override async Task<Document> ConvertAssertionAsync(CodeFixContext con
Dictionary<InvocationExpressionSyntax, InvocationExpressionSyntax> verifyCallReplacements =
FindAndBuildVerifyCallReplacements(allInvocations, semanticModel, mockSymbol, cancellationToken);

Dictionary<InvocationExpressionSyntax, InvocationExpressionSyntax> verifyEventCallReplacements =
FindAndBuildVerifyEventCallReplacements(allInvocations, semanticModel, mockSymbol, cancellationToken);

Dictionary<InvocationExpressionSyntax, InvocationExpressionSyntax> raiseCallReplacements =
FindAndBuildRaiseCallReplacements(allInvocations, semanticModel, mockSymbol, cancellationToken);

Expand All @@ -105,6 +108,7 @@ protected override async Task<Document> ConvertAssertionAsync(CodeFixContext con
nodesToReplace.AddRange(setupPropertyCallReplacements.Keys);
nodesToReplace.AddRange(callbackReplacements.Keys);
nodesToReplace.AddRange(verifyCallReplacements.Keys);
nodesToReplace.AddRange(verifyEventCallReplacements.Keys);
nodesToReplace.AddRange(raiseCallReplacements.Keys);

compilationUnit = compilationUnit.ReplaceNodes(
Expand Down Expand Up @@ -148,6 +152,11 @@ protected override async Task<Document> ConvertAssertionAsync(CodeFixContext con
return verifyReplacement;
}

if (verifyEventCallReplacements.TryGetValue(invocation, out InvocationExpressionSyntax? verifyEventReplacement))
{
return verifyEventReplacement;
}

if (raiseCallReplacements.TryGetValue(invocation, out InvocationExpressionSyntax? raiseReplacement))
{
return raiseReplacement;
Expand All @@ -166,7 +175,7 @@ protected override async Task<Document> ConvertAssertionAsync(CodeFixContext con
compilationUnit = compilationUnit.AddUsings(usingDirective);
}

if (verifyCallReplacements.Count > 0)
if (verifyCallReplacements.Count > 0 || verifyEventCallReplacements.Count > 0)
{
bool hasVerifyUsing = compilationUnit.Usings.Any(u => u.Name?.ToString() == "Mockolate.Verify");
if (!hasVerifyUsing)
Expand Down Expand Up @@ -966,6 +975,131 @@ private static Dictionary<InvocationExpressionSyntax, InvocationExpressionSyntax
return result;
}

private static Dictionary<InvocationExpressionSyntax, InvocationExpressionSyntax> FindAndBuildVerifyEventCallReplacements(
IReadOnlyList<InvocationExpressionSyntax> allInvocations,
SemanticModel? semanticModel,
ISymbol? mockSymbol,
CancellationToken cancellationToken)
{
if (semanticModel is null || mockSymbol is null)
{
return [];
}

Dictionary<InvocationExpressionSyntax, InvocationExpressionSyntax> result = [];
foreach (InvocationExpressionSyntax invocation in allInvocations)
{
if (invocation.Expression is not MemberAccessExpressionSyntax memberAccess)
{
continue;
}

string methodName = memberAccess.Name.Identifier.Text;
string? eventSubscriptionMethod = methodName switch
{
"VerifyAdd" => "Subscribed",
"VerifyRemove" => "Unsubscribed",
_ => null,
};

if (eventSubscriptionMethod is null)
{
continue;
}

SymbolInfo symbolInfo = semanticModel.GetSymbolInfo(memberAccess.Expression, cancellationToken);
if (!SymbolEqualityComparer.Default.Equals(symbolInfo.Symbol, mockSymbol))
{
continue;
}

if (invocation.ArgumentList.Arguments.Count is 0 or > 2 ||
invocation.ArgumentList.Arguments[0].Expression is not LambdaExpressionSyntax lambda)
{
continue;
}

SyntaxKind expectedAssignment = methodName == "VerifyAdd"
? SyntaxKind.AddAssignmentExpression
: SyntaxKind.SubtractAssignmentExpression;

if (lambda.Body is not AssignmentExpressionSyntax assignment ||
!assignment.IsKind(expectedAssignment) ||
assignment.Left is not MemberAccessExpressionSyntax eventAccess)
{
continue;
}

Comment thread
vbreuss marked this conversation as resolved.
// LHS of += / -= must bind to an event, not a delegate field/property — otherwise
// the generated Verify.<name>.Subscribed() chain would reference a non-existent member.
if (semanticModel.GetSymbolInfo(eventAccess, cancellationToken).Symbol is not IEventSymbol)
{
continue;
}

string? lambdaParamName = GetSingleLambdaParamName(lambda);
if (lambdaParamName is null)
{
continue;
}

List<SimpleNameSyntax>? navigationChain = ExtractNavigationChain(eventAccess.Expression, lambdaParamName);
if (navigationChain is null)
{
continue;
}

SimpleNameSyntax eventNameSyntax = eventAccess.Name;

ExpressionSyntax receiver = memberAccess.Expression;
foreach (SimpleNameSyntax nav in navigationChain)
{
receiver = SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
receiver,
nav);
}

MemberAccessExpressionSyntax mockAccess = SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
receiver,
SyntaxFactory.IdentifierName("Mock"));
MemberAccessExpressionSyntax verifyAccess = SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
mockAccess,
SyntaxFactory.IdentifierName("Verify"));
MemberAccessExpressionSyntax eventMemberAccess = SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
verifyAccess,
eventNameSyntax);
InvocationExpressionSyntax baseInvocation = SyntaxFactory.InvocationExpression(
SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
eventMemberAccess,
SyntaxFactory.IdentifierName(eventSubscriptionMethod)),
SyntaxFactory.ArgumentList());

InvocationExpressionSyntax atLeastOnceFallback = SyntaxFactory.InvocationExpression(
SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
baseInvocation,
SyntaxFactory.IdentifierName("AtLeastOnce")),
SyntaxFactory.ArgumentList());

// Fall back to AtLeastOnce when the Times argument can't be translated — the
// Mock<T>() construction is unconditionally rewritten, so leaving the original
// VerifyAdd/VerifyRemove in place would produce non-compiling code.
InvocationExpressionSyntax replacement = invocation.ArgumentList.Arguments.Count == 2
? (ApplyTimesChain(baseInvocation, invocation.ArgumentList.Arguments[1].Expression) ?? atLeastOnceFallback)
.WithTriviaFrom(invocation)
: atLeastOnceFallback.WithTriviaFrom(invocation);

result[invocation] = replacement;
}

return result;
}

private static InvocationExpressionSyntax? ApplyTimesChain(
InvocationExpressionSyntax baseInvocation, ExpressionSyntax timesArg)
{
Expand Down
20 changes: 20 additions & 0 deletions Tests/Mockolate.Migration.Example.Tests/MoqMigrationExamples.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Text.RegularExpressions;
using Mockolate.Verify;
using Moq;
using Range = Moq.Range;

Expand Down Expand Up @@ -70,6 +71,15 @@ public async Task ExpectedMigrationResult()
// alternatively, provide a default value for the stubbed property
mock.Mock.Setup.Name.InitializeWith("foo");

/* ------ Events ------ */
// subscribing to and raising an event
mock.MyEvent += (_, _) => { };
mock.Mock.Raise.MyEvent(null, EventArgs.Empty);

// verifying event subscription / unsubscription
mock.Mock.Verify.MyEvent.Subscribed().Once();
mock.Mock.Verify.MyEvent.Unsubscribed().Never();

await That(true).IsTrue();
}

Expand Down Expand Up @@ -139,6 +149,15 @@ public async Task MoqCreation()
mock.SetupProperty(f => f.Name);
// alternatively, provide a default value for the stubbed property
mock.SetupProperty(f => f.Name, "foo");

/* ------ Events ------ */
// subscribing to and raising an event
mock.Object.MyEvent += (_, _) => { };
mock.Raise(foo => foo.MyEvent += null, EventArgs.Empty);

// verifying event subscription / unsubscription
mock.VerifyAdd(foo => foo.MyEvent += Moq.It.IsAny<EventHandler>(), Times.Once());
mock.VerifyRemove(foo => foo.MyEvent -= Moq.It.IsAny<EventHandler>(), Times.Never);
await That(true).IsTrue();
}

Expand All @@ -147,6 +166,7 @@ public interface IFoo
Bar Bar { get; set; }
string Name { get; set; }
int Value { get; set; }
event EventHandler MyEvent;
bool DoSomething(string value);
bool DoSomething(int number, string value);
Task<bool> DoSomethingAsync();
Expand Down
Loading
Loading