Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,14 @@ protected override async Task<Document> ConvertAssertionAsync(CodeFixContext con
Dictionary<InvocationExpressionSyntax, InvocationExpressionSyntax> clearReplacements =
FindAndBuildClearReceivedCallsReplacements(allInvocations, semanticModel, mockSymbol, cancellationToken);

Dictionary<AssignmentExpressionSyntax, InvocationExpressionSyntax> raiseReplacements =
FindAndBuildRaiseReplacements(compilationUnit, semanticModel, mockSymbol, cancellationToken);

List<SyntaxNode> nodesToReplace = [substituteCall,];
nodesToReplace.AddRange(setupReplacements.Keys);
nodesToReplace.AddRange(verifyReplacements.Keys);
nodesToReplace.AddRange(clearReplacements.Keys);
nodesToReplace.AddRange(raiseReplacements.Keys);

compilationUnit = compilationUnit.ReplaceNodes(
nodesToReplace,
Expand Down Expand Up @@ -103,6 +107,12 @@ protected override async Task<Document> ConvertAssertionAsync(CodeFixContext con
}
}

if (original is AssignmentExpressionSyntax assignment &&
raiseReplacements.TryGetValue(assignment, out InvocationExpressionSyntax? raiseReplacement))
{
return raiseReplacement;
}

return original;
});

Expand Down Expand Up @@ -245,6 +255,131 @@ private static Dictionary<SyntaxNode, SyntaxNode> FindAndBuildSetupReplacements(
return result;
}

private static Dictionary<AssignmentExpressionSyntax, InvocationExpressionSyntax> FindAndBuildRaiseReplacements(
CompilationUnitSyntax compilationUnit,
SemanticModel? semanticModel,
ISymbol? mockSymbol,
CancellationToken cancellationToken)
{
if (semanticModel is null || mockSymbol is null)
{
return [];
}

Dictionary<AssignmentExpressionSyntax, InvocationExpressionSyntax> result = [];

foreach (AssignmentExpressionSyntax assignment in compilationUnit.DescendantNodes().OfType<AssignmentExpressionSyntax>())
{
if (!assignment.IsKind(SyntaxKind.AddAssignmentExpression))
{
continue;
}

if (assignment.Left is not MemberAccessExpressionSyntax eventAccess ||
assignment.Right is not InvocationExpressionSyntax raiseInvocation)
{
continue;
}

Comment thread
vbreuss marked this conversation as resolved.
if (!IsTrackedMockReceiver(eventAccess.Expression, semanticModel, mockSymbol, cancellationToken))
{
continue;
}

if (semanticModel.GetSymbolInfo(eventAccess, cancellationToken).Symbol is not IEventSymbol)
{
continue;
}

if (raiseInvocation.Expression is not MemberAccessExpressionSyntax raiseAccess)
{
continue;
}

if (semanticModel.GetSymbolInfo(raiseInvocation, cancellationToken).Symbol is not IMethodSymbol raiseMethodSymbol ||
raiseMethodSymbol.ContainingType?.Name != "Raise" ||
raiseMethodSymbol.ContainingType.ContainingNamespace?.ToDisplayString() != "NSubstitute")
{
continue;
}

string raiseMethod = raiseAccess.Name.Identifier.Text;
ArgumentListSyntax raiseArgs = BuildRaiseArguments(raiseInvocation.ArgumentList, raiseMethod);

MemberAccessExpressionSyntax mockAccess = SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
eventAccess.Expression,
SyntaxFactory.IdentifierName("Mock"));
MemberAccessExpressionSyntax raiseMember = SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
mockAccess,
SyntaxFactory.IdentifierName("Raise"));
MemberAccessExpressionSyntax raiseEventName = SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
raiseMember,
eventAccess.Name.WithoutTrivia());

result[assignment] = SyntaxFactory.InvocationExpression(raiseEventName, raiseArgs)
.WithTriviaFrom(assignment);
}

return result;
}

/// <summary>
/// Translates the argument list of an NSubstitute <c>Raise.X(...)</c> call into the corresponding
/// <c>Mock.Raise.EventName(...)</c> argument list.
/// </summary>
private static ArgumentListSyntax BuildRaiseArguments(ArgumentListSyntax raiseArgs, string raiseMethod)
{
// Raise.Event<TDelegate>(args...) — non-EventHandler delegates, just forward the args.
// Raise.EventWith(args) — single arg means EventArgs only (sender omitted, defaults to null).
// Raise.EventWith(sender, ea) — two args, pass through.
// Raise.Event() / Raise.EventWith() — empty, default to (null, EventArgs.Empty) for EventHandler.
if (raiseMethod is "Event")
{
if (raiseArgs.Arguments.Count == 0)
{
return SyntaxFactory.ArgumentList(SyntaxFactory.SeparatedList(
[
SyntaxFactory.Argument(SyntaxFactory.LiteralExpression(SyntaxKind.NullLiteralExpression)),
EventArgsEmptyArgument(),
]));
}

return raiseArgs;
}

if (raiseMethod is "EventWith")
{
if (raiseArgs.Arguments.Count == 0)
{
return SyntaxFactory.ArgumentList(SyntaxFactory.SeparatedList(
[
SyntaxFactory.Argument(SyntaxFactory.LiteralExpression(SyntaxKind.NullLiteralExpression)),
EventArgsEmptyArgument(),
]));
Comment thread
vbreuss marked this conversation as resolved.
}

if (raiseArgs.Arguments.Count == 1)
{
return SyntaxFactory.ArgumentList(SyntaxFactory.SeparatedList(
[
SyntaxFactory.Argument(SyntaxFactory.LiteralExpression(SyntaxKind.NullLiteralExpression)),
raiseArgs.Arguments[0],
]));
}

return raiseArgs;
}

return raiseArgs;
}

// Fully qualified so the rewrite compiles even when the source file does not have `using System;`.
private static ArgumentSyntax EventArgsEmptyArgument() =>
SyntaxFactory.Argument(SyntaxFactory.ParseExpression("global::System.EventArgs.Empty"));

private static Dictionary<InvocationExpressionSyntax, InvocationExpressionSyntax> FindAndBuildClearReceivedCallsReplacements(
IReadOnlyList<InvocationExpressionSyntax> allInvocations,
SemanticModel? semanticModel,
Expand Down
Loading
Loading