diff --git a/Source/Mockolate.Migration.Analyzers.CodeFixers/MoqCodeFixProvider.cs b/Source/Mockolate.Migration.Analyzers.CodeFixers/MoqCodeFixProvider.cs index fea4881..296f15c 100644 --- a/Source/Mockolate.Migration.Analyzers.CodeFixers/MoqCodeFixProvider.cs +++ b/Source/Mockolate.Migration.Analyzers.CodeFixers/MoqCodeFixProvider.cs @@ -71,18 +71,27 @@ protected override async Task ConvertAssertionAsync(CodeFixContext con bool replaceDeclarationType = declarationType is not null && declarationType is not IdentifierNameSyntax { IsVar: true, }; ISymbol? mockSymbol = GetDeclaredMockSymbol(semanticModel, expressionSyntax, cancellationToken); + IReadOnlyList allInvocations = + compilationUnit.DescendantNodes().OfType().ToList(); + List objectAccesses = FindObjectAccesses(compilationUnit, semanticModel, mockSymbol, cancellationToken); Dictionary setupCallReplacements = - FindAndBuildSetupCallReplacements(compilationUnit, semanticModel, mockSymbol, cancellationToken); + FindAndBuildSetupCallReplacements(allInvocations, semanticModel, mockSymbol, cancellationToken); + + Dictionary setupPropertyAccessReplacements = + FindAndBuildSetupPropertyAccessReplacements(allInvocations, semanticModel, mockSymbol, cancellationToken); + + Dictionary setupPropertyCallReplacements = + FindAndBuildSetupPropertyCallReplacements(allInvocations, semanticModel, mockSymbol, cancellationToken); Dictionary callbackReplacements = - FindAndBuildCallbackReplacements(compilationUnit, setupCallReplacements, out HashSet setupsWrappedByCallbacks); + FindAndBuildCallbackReplacements(allInvocations, setupCallReplacements, out HashSet setupsWrappedByCallbacks); Dictionary verifyCallReplacements = - FindAndBuildVerifyCallReplacements(compilationUnit, semanticModel, mockSymbol, cancellationToken); + FindAndBuildVerifyCallReplacements(allInvocations, semanticModel, mockSymbol, cancellationToken); Dictionary raiseCallReplacements = - FindAndBuildRaiseCallReplacements(compilationUnit, semanticModel, mockSymbol, cancellationToken); + FindAndBuildRaiseCallReplacements(allInvocations, semanticModel, mockSymbol, cancellationToken); List nodesToReplace = [expressionSyntax,]; if (replaceDeclarationType) @@ -92,6 +101,8 @@ protected override async Task ConvertAssertionAsync(CodeFixContext con nodesToReplace.AddRange(objectAccesses); nodesToReplace.AddRange(setupCallReplacements.Keys.Where(k => !setupsWrappedByCallbacks.Contains(k))); + nodesToReplace.AddRange(setupPropertyAccessReplacements.Keys); + nodesToReplace.AddRange(setupPropertyCallReplacements.Keys); nodesToReplace.AddRange(callbackReplacements.Keys); nodesToReplace.AddRange(verifyCallReplacements.Keys); nodesToReplace.AddRange(raiseCallReplacements.Keys); @@ -122,6 +133,16 @@ protected override async Task ConvertAssertionAsync(CodeFixContext con return setupReplacement; } + if (setupPropertyAccessReplacements.TryGetValue(invocation, out MemberAccessExpressionSyntax? propertyAccessReplacement)) + { + return propertyAccessReplacement; + } + + if (setupPropertyCallReplacements.TryGetValue(invocation, out InvocationExpressionSyntax? setupPropertyCallReplacement)) + { + return setupPropertyCallReplacement; + } + if (verifyCallReplacements.TryGetValue(invocation, out InvocationExpressionSyntax? verifyReplacement)) { return verifyReplacement; @@ -237,7 +258,7 @@ private static bool HasStrictMockBehavior(ExpressionSyntax expressionSyntax, Sem } private static Dictionary FindAndBuildRaiseCallReplacements( - CompilationUnitSyntax compilationUnit, + IReadOnlyList allInvocations, SemanticModel? semanticModel, ISymbol? mockSymbol, CancellationToken cancellationToken) @@ -248,7 +269,7 @@ private static Dictionary result = []; - foreach (InvocationExpressionSyntax invocation in compilationUnit.DescendantNodes().OfType()) + foreach (InvocationExpressionSyntax invocation in allInvocations) { if (invocation.Expression is not MemberAccessExpressionSyntax memberAccess) { @@ -279,13 +300,7 @@ private static Dictionary simple.Parameter.Identifier.Text, - ParenthesizedLambdaExpressionSyntax { ParameterList.Parameters: { Count: 1, } parms, } - => parms[0].Identifier.Text, - _ => null, - }; + string? lambdaParamName = GetSingleLambdaParamName(lambda); if (lambdaParamName is null) { @@ -385,7 +400,7 @@ private static List FindObjectAccesses( } private static Dictionary FindAndBuildSetupCallReplacements( - CompilationUnitSyntax compilationUnit, + IReadOnlyList allInvocations, SemanticModel? semanticModel, ISymbol? mockSymbol, CancellationToken cancellationToken) @@ -396,7 +411,7 @@ private static Dictionary result = []; - foreach (InvocationExpressionSyntax invocation in compilationUnit.DescendantNodes().OfType()) + foreach (InvocationExpressionSyntax invocation in allInvocations) { if (invocation.Expression is not MemberAccessExpressionSyntax memberAccess) { @@ -426,13 +441,7 @@ private static Dictionary simple.Parameter.Identifier.Text, - ParenthesizedLambdaExpressionSyntax { ParameterList.Parameters: { Count: 1, } parms, } - => parms[0].Identifier.Text, - _ => null, - }; + string? lambdaParamName = GetSingleLambdaParamName(lambda); if (lambdaParamName is null) { @@ -501,14 +510,248 @@ private static Dictionary FindAndBuildSetupPropertyAccessReplacements( + IReadOnlyList allInvocations, + SemanticModel? semanticModel, + ISymbol? mockSymbol, + CancellationToken cancellationToken) + { + if (semanticModel is null || mockSymbol is null) + { + return []; + } + + Dictionary result = []; + foreach (InvocationExpressionSyntax invocation in allInvocations) + { + if (invocation.Expression is not MemberAccessExpressionSyntax memberAccess) + { + continue; + } + + if (memberAccess.Name.Identifier.Text != "Setup") + { + continue; + } + + SymbolInfo symbolInfo = semanticModel.GetSymbolInfo(memberAccess.Expression, cancellationToken); + if (!SymbolEqualityComparer.Default.Equals(symbolInfo.Symbol, mockSymbol)) + { + continue; + } + + if (invocation.ArgumentList.Arguments.Count != 1 || + invocation.ArgumentList.Arguments[0].Expression is not LambdaExpressionSyntax lambda) + { + continue; + } + + // Only handle property access (not method calls — those are handled by FindAndBuildSetupCallReplacements) + if (lambda.Body is not MemberAccessExpressionSyntax lambdaMemberAccess) + { + continue; + } + + string? lambdaParamName = GetSingleLambdaParamName(lambda); + + if (lambdaParamName is null) + { + continue; + } + + List? navigationChain = ExtractNavigationChain(lambdaMemberAccess.Expression, lambdaParamName); + if (navigationChain is null) + { + continue; + } + + SimpleNameSyntax propertyNameSyntax = lambdaMemberAccess.Name; + + MemberAccessExpressionSyntax replacement; + if (navigationChain.Count == 0) + { + // Direct setup: mock.Mock.Setup.Property + MemberAccessExpressionSyntax mockAccess = SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + memberAccess.Expression, + SyntaxFactory.IdentifierName("Mock")); + MemberAccessExpressionSyntax setupAccess = SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + mockAccess, + SyntaxFactory.IdentifierName("Setup")); + replacement = SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + setupAccess, + propertyNameSyntax); + } + else + { + // Nested setup: mock.Nav1.Nav2.Mock.Setup.Property + ExpressionSyntax navChain = memberAccess.Expression; + foreach (SimpleNameSyntax nav in navigationChain) + { + navChain = SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + navChain, + nav); + } + + MemberAccessExpressionSyntax mockAccess = SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + navChain, + SyntaxFactory.IdentifierName("Mock")); + MemberAccessExpressionSyntax setupAccess = SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + mockAccess, + SyntaxFactory.IdentifierName("Setup")); + replacement = SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + setupAccess, + propertyNameSyntax); + } + + result[invocation] = replacement.WithTriviaFrom(invocation); + } + + return result; + } + + private static Dictionary FindAndBuildSetupPropertyCallReplacements( + IReadOnlyList allInvocations, + SemanticModel? semanticModel, + ISymbol? mockSymbol, + CancellationToken cancellationToken) + { + if (semanticModel is null || mockSymbol is null) + { + return []; + } + + Dictionary result = []; + foreach (InvocationExpressionSyntax invocation in allInvocations) + { + if (invocation.Expression is not MemberAccessExpressionSyntax memberAccess) + { + continue; + } + + if (memberAccess.Name.Identifier.Text != "SetupProperty") + { + continue; + } + + SymbolInfo symbolInfo = semanticModel.GetSymbolInfo(memberAccess.Expression, cancellationToken); + if (!SymbolEqualityComparer.Default.Equals(symbolInfo.Symbol, mockSymbol)) + { + continue; + } + + if (invocation.ArgumentList.Arguments.Count is 0 or > 2 || + invocation.ArgumentList.Arguments[0].Expression is not LambdaExpressionSyntax lambda) + { + continue; + } + + if (lambda.Body is not MemberAccessExpressionSyntax lambdaMemberAccess) + { + continue; + } + + string? lambdaParamName = GetSingleLambdaParamName(lambda); + + if (lambdaParamName is null) + { + continue; + } + + List? navigationChain = ExtractNavigationChain(lambdaMemberAccess.Expression, lambdaParamName); + if (navigationChain is null) + { + continue; + } + + SimpleNameSyntax propertyNameSyntax = lambdaMemberAccess.Name; + + ExpressionSyntax propertyAccess; + if (navigationChain.Count == 0) + { + MemberAccessExpressionSyntax mockAccess = SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + memberAccess.Expression, + SyntaxFactory.IdentifierName("Mock")); + MemberAccessExpressionSyntax setupAccess = SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + mockAccess, + SyntaxFactory.IdentifierName("Setup")); + propertyAccess = SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + setupAccess, + propertyNameSyntax); + } + else + { + ExpressionSyntax navChain = memberAccess.Expression; + foreach (SimpleNameSyntax nav in navigationChain) + { + navChain = SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + navChain, + nav); + } + + MemberAccessExpressionSyntax mockAccess = SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + navChain, + SyntaxFactory.IdentifierName("Mock")); + MemberAccessExpressionSyntax setupAccess = SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + mockAccess, + SyntaxFactory.IdentifierName("Setup")); + propertyAccess = SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + setupAccess, + propertyNameSyntax); + } + + InvocationExpressionSyntax replacement; + if (invocation.ArgumentList.Arguments.Count == 2) + { + // mock.SetupProperty(f => f.Name, "foo") → mock.Mock.Setup.Name.InitializeWith("foo") + ArgumentSyntax defaultValueArg = invocation.ArgumentList.Arguments[1]; + replacement = SyntaxFactory.InvocationExpression( + SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + propertyAccess, + SyntaxFactory.IdentifierName("InitializeWith")), + SyntaxFactory.ArgumentList(SyntaxFactory.SeparatedList([defaultValueArg,]))) + .WithTriviaFrom(invocation); + } + else + { + // mock.SetupProperty(f => f.Name) → mock.Mock.Setup.Name.Register() + replacement = SyntaxFactory.InvocationExpression( + SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + propertyAccess, + SyntaxFactory.IdentifierName("Register")), + SyntaxFactory.ArgumentList()) + .WithTriviaFrom(invocation); + } + + result[invocation] = replacement; + } + + return result; + } + private static Dictionary FindAndBuildCallbackReplacements( - CompilationUnitSyntax compilationUnit, + IReadOnlyList allInvocations, Dictionary setupCallReplacements, out HashSet setupsWrappedByCallbacks) { setupsWrappedByCallbacks = []; Dictionary result = []; - foreach (InvocationExpressionSyntax invocation in compilationUnit.DescendantNodes().OfType()) + foreach (InvocationExpressionSyntax invocation in allInvocations) { if (invocation.Expression is not MemberAccessExpressionSyntax memberAccess) { @@ -593,7 +836,7 @@ chainedInv.Expression is MemberAccessExpressionSyntax chainedMemberAccess && } private static Dictionary FindAndBuildVerifyCallReplacements( - CompilationUnitSyntax compilationUnit, + IReadOnlyList allInvocations, SemanticModel? semanticModel, ISymbol? mockSymbol, CancellationToken cancellationToken) @@ -604,7 +847,7 @@ private static Dictionary result = []; - foreach (InvocationExpressionSyntax invocation in compilationUnit.DescendantNodes().OfType()) + foreach (InvocationExpressionSyntax invocation in allInvocations) { if (invocation.Expression is not MemberAccessExpressionSyntax memberAccess) { @@ -634,13 +877,7 @@ private static Dictionary simple.Parameter.Identifier.Text, - ParenthesizedLambdaExpressionSyntax { ParameterList.Parameters: { Count: 1, } parms, } - => parms[0].Identifier.Text, - _ => null, - }; + string? lambdaParamName = GetSingleLambdaParamName(lambda); if (lambdaParamName is null) { @@ -841,6 +1078,15 @@ private static ExpressionSyntax AdjustIntBoundary(ExpressionSyntax expr, int del SyntaxFactory.LiteralExpression(SyntaxKind.NumericLiteralExpression, SyntaxFactory.Literal(-delta))); } + private static string? GetSingleLambdaParamName(LambdaExpressionSyntax lambda) => + lambda switch + { + SimpleLambdaExpressionSyntax simple => simple.Parameter.Identifier.Text, + ParenthesizedLambdaExpressionSyntax { ParameterList.Parameters: { Count: 1, } parms, } + => parms[0].Identifier.Text, + _ => null, + }; + private static List? ExtractNavigationChain(ExpressionSyntax expression, string lambdaParamName) { List chain = []; diff --git a/Tests/Mockolate.Migration.Example.Tests/MoqMigrationExamples.cs b/Tests/Mockolate.Migration.Example.Tests/MoqMigrationExamples.cs index 92a41e7..6cc9bb4 100644 --- a/Tests/Mockolate.Migration.Example.Tests/MoqMigrationExamples.cs +++ b/Tests/Mockolate.Migration.Example.Tests/MoqMigrationExamples.cs @@ -10,6 +10,8 @@ public class MoqMigrationExamples public async Task ExpectedMigrationResult() { IFoo mock = IFoo.CreateMock(); + mock.Mock.Setup.Bar.InitializeWith(Bar.CreateMock()); + mock.Bar.Mock.Setup.Baz.InitializeWith(Baz.CreateMock()); IFoo mock2 = IFoo.CreateMock(MockBehavior.Default.ThrowingWhenNotSetup()); mock2.Mock.Setup.DoSomething("ping").Returns(true); @@ -57,6 +59,17 @@ public async Task ExpectedMigrationResult() // matching regex mock.Mock.Setup.DoSomethingStringy(It.Matches("[a-d]+").AsRegex(RegexOptions.IgnoreCase)).Returns("foo"); + /* ------ Properties ------ */ + mock.Mock.Setup.Name.Returns("bar"); + + // auto-mocking hierarchies (a.k.a. recursive mocks) + mock.Bar.Baz.Mock.Setup.Name.Returns("baz"); + + // start "tracking" sets/gets to this property + mock.Mock.Setup.Name.Register(); + // alternatively, provide a default value for the stubbed property + mock.Mock.Setup.Name.InitializeWith("foo"); + await That(true).IsTrue(); } @@ -116,6 +129,16 @@ public async Task MoqCreation() // matching regex mock.Setup(x => x.DoSomethingStringy(Moq.It.IsRegex("[a-d]+", RegexOptions.IgnoreCase))).Returns("foo"); + /* ------ Properties ------ */ + mock.Setup(foo => foo.Name).Returns("bar"); + + // auto-mocking hierarchies (a.k.a. recursive mocks) + mock.Setup(foo => foo.Bar.Baz.Name).Returns("baz"); + + // start "tracking" sets/gets to this property + mock.SetupProperty(f => f.Name); + // alternatively, provide a default value for the stubbed property + mock.SetupProperty(f => f.Name, "foo"); await That(true).IsTrue(); } diff --git a/Tests/Mockolate.Migration.Tests/MoqCodeFixProviderTests.SetupTests.cs b/Tests/Mockolate.Migration.Tests/MoqCodeFixProviderTests.SetupTests.cs index b0167db..c6e2478 100644 --- a/Tests/Mockolate.Migration.Tests/MoqCodeFixProviderTests.SetupTests.cs +++ b/Tests/Mockolate.Migration.Tests/MoqCodeFixProviderTests.SetupTests.cs @@ -358,5 +358,143 @@ public void Test() } } """); + + [Fact] + public async Task Property_MigratesSetup() + => await Verifier.VerifyCodeFixAsync( + """ + using Moq; + + public interface IFoo { string Name { get; set; } } + + public class Tests + { + public void Test() + { + var mock = [|new Mock()|]; + mock.Setup(foo => foo.Name).Returns("bar"); + } + } + """, + """ + using Moq; + using Mockolate; + + public interface IFoo { string Name { get; set; } } + + public class Tests + { + public void Test() + { + var mock = IFoo.CreateMock(); + mock.Mock.Setup.Name.Returns("bar"); + } + } + """); + + [Fact] + public async Task Property_Nested_MigratesSetup() + => await Verifier.VerifyCodeFixAsync( + """ + using Moq; + + public class Baz { public virtual string Name { get; set; } = ""; } + public class Bar { public virtual Baz Baz { get; set; } } + + public interface IFoo { Bar Bar { get; set; } } + + public class Tests + { + public void Test() + { + var mock = [|new Mock()|]; + mock.Setup(foo => foo.Bar.Baz.Name).Returns("baz"); + } + } + """, + """ + using Moq; + using Mockolate; + + public class Baz { public virtual string Name { get; set; } = ""; } + public class Bar { public virtual Baz Baz { get; set; } } + + public interface IFoo { Bar Bar { get; set; } } + + public class Tests + { + public void Test() + { + var mock = IFoo.CreateMock(); + mock.Bar.Baz.Mock.Setup.Name.Returns("baz"); + } + } + """); + + [Fact] + public async Task SetupProperty_WithDefault_MigratesInitializeWith() + => await Verifier.VerifyCodeFixAsync( + """ + using Moq; + + public interface IFoo { string Name { get; set; } } + + public class Tests + { + public void Test() + { + var mock = [|new Mock()|]; + mock.SetupProperty(f => f.Name, "foo"); + } + } + """, + """ + using Moq; + using Mockolate; + + public interface IFoo { string Name { get; set; } } + + public class Tests + { + public void Test() + { + var mock = IFoo.CreateMock(); + mock.Mock.Setup.Name.InitializeWith("foo"); + } + } + """); + + [Fact] + public async Task SetupProperty_WithoutDefault_MigratesRegister() + => await Verifier.VerifyCodeFixAsync( + """ + using Moq; + + public interface IFoo { string Name { get; set; } } + + public class Tests + { + public void Test() + { + var mock = [|new Mock()|]; + mock.SetupProperty(f => f.Name); + } + } + """, + """ + using Moq; + using Mockolate; + + public interface IFoo { string Name { get; set; } } + + public class Tests + { + public void Test() + { + var mock = IFoo.CreateMock(); + mock.Mock.Setup.Name.Register(); + } + } + """); } }