diff --git a/src/main/java/org/openrewrite/java/testing/mockito/AddMockitoExtensionIfAnnotationsUsed.java b/src/main/java/org/openrewrite/java/testing/mockito/AddMockitoExtensionIfAnnotationsUsed.java
new file mode 100644
index 000000000..6a969f755
--- /dev/null
+++ b/src/main/java/org/openrewrite/java/testing/mockito/AddMockitoExtensionIfAnnotationsUsed.java
@@ -0,0 +1,79 @@
+/*
+ * Copyright 2025 the original author or authors.
+ *
+ * Licensed under the Moderne Source Available License (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://docs.moderne.io/licensing/moderne-source-available-license
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.openrewrite.java.testing.mockito;
+
+import org.openrewrite.ExecutionContext;
+import org.openrewrite.Recipe;
+import org.openrewrite.TreeVisitor;
+import org.openrewrite.java.JavaIsoVisitor;
+import org.openrewrite.java.JavaParser;
+import org.openrewrite.java.JavaTemplate;
+import org.openrewrite.java.search.FindAnnotations;
+import org.openrewrite.java.search.FindTypes;
+import org.openrewrite.java.search.IsLikelyTest;
+import org.openrewrite.java.tree.J;
+
+import static java.util.Comparator.comparing;
+import static org.openrewrite.Preconditions.*;
+
+public class AddMockitoExtensionIfAnnotationsUsed extends Recipe {
+ @Override
+ public String getDisplayName() {
+ return "Adds Mockito extensions to Mockito tests";
+ }
+
+ @Override
+ public String getDescription() {
+ return "Adds `@ExtendWith(MockitoExtension.class)` to tests using `@Mock` or `@Captor`.";
+ }
+
+ @Override
+ public TreeVisitor, ExecutionContext> getVisitor() {
+
+ TreeVisitor, ExecutionContext> hasExtendedWithAnnotation = new FindAnnotations("org.junit.jupiter.api.extension.ExtendWith(org.mockito.junit.jupiter.MockitoExtension.class)", false).getVisitor();
+ @SuppressWarnings("unchecked")
+ TreeVisitor, ExecutionContext>[] hasAnyMockitoAnnotation = new TreeVisitor[]{
+ // see https://www.baeldung.com/mockito-annotations for examples
+ new FindAnnotations("org.mockito.Captor", false).getVisitor(),
+ new FindAnnotations("org.mockito.Mock", false).getVisitor(),
+ new FindAnnotations("org.mockito.Spy", false).getVisitor(),
+ new FindAnnotations("org.mockito.InjectMocks", false).getVisitor(),
+ };
+
+ return check(and(new IsLikelyTest().getVisitor(),
+ // check to only migrate JUnit 5 tests
+ new FindTypes("org.junit.jupiter..*", false).getVisitor(),
+ // prevent addition if present
+ not(hasExtendedWithAnnotation),
+ or(hasAnyMockitoAnnotation)),
+ new JavaIsoVisitor() {
+ @Override
+ public J.ClassDeclaration visitClassDeclaration(J.ClassDeclaration classDecl, ExecutionContext ctx) {
+
+ maybeAddImport("org.mockito.junit.jupiter.MockitoExtension");
+ maybeAddImport("org.junit.jupiter.api.extension.ExtendWith");
+
+ return JavaTemplate.builder("@ExtendWith(MockitoExtension.class)")
+ .imports("org.mockito.junit.jupiter.MockitoExtension")
+ .imports("org.junit.jupiter.api.extension.ExtendWith")
+ .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "junit-jupiter-api", "mockito-junit-jupiter"))
+ .build()
+ .apply(getCursor(), classDecl.getCoordinates().addAnnotation(comparing(J.Annotation::getSimpleName)));
+ }
+ });
+ }
+}
diff --git a/src/main/resources/META-INF/rewrite/mockito.yml b/src/main/resources/META-INF/rewrite/mockito.yml
index bfcd46bea..5b05d3b41 100644
--- a/src/main/resources/META-INF/rewrite/mockito.yml
+++ b/src/main/resources/META-INF/rewrite/mockito.yml
@@ -168,6 +168,7 @@ recipeList:
- org.openrewrite.java.testing.mockito.CleanupMockitoImports
- org.openrewrite.java.testing.mockito.MockUtilsToStatic
- org.openrewrite.java.testing.junit5.MockitoJUnitToMockitoExtension
+ - org.openrewrite.java.testing.mockito.AddMockitoExtensionIfAnnotationsUsed
- org.openrewrite.java.testing.mockito.RemoveInitMocksIfRunnersSpecified
- org.openrewrite.java.testing.mockito.ReplacePowerMockito
- org.openrewrite.java.dependencies.ChangeDependency:
diff --git a/src/test/java/org/openrewrite/java/testing/jmockit/JMockitExpectationsToMockitoTest.java b/src/test/java/org/openrewrite/java/testing/jmockit/JMockitExpectationsToMockitoTest.java
index ad76bbae9..7fbbf9ec1 100644
--- a/src/test/java/org/openrewrite/java/testing/jmockit/JMockitExpectationsToMockitoTest.java
+++ b/src/test/java/org/openrewrite/java/testing/jmockit/JMockitExpectationsToMockitoTest.java
@@ -43,14 +43,14 @@ void whenTimesAndResult() {
import mockit.Mocked;
import mockit.integration.junit5.JMockitExtension;
import org.junit.jupiter.api.extension.ExtendWith;
-
+
import static org.junit.jupiter.api.Assertions.assertEquals;
-
+
@ExtendWith(JMockitExtension.class)
class MyTest {
@Mocked
Object myObject;
-
+
void test() {
new Expectations() {{
myObject.toString();
@@ -66,15 +66,15 @@ void test() {
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
-
+
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.*;
-
+
@ExtendWith(MockitoExtension.class)
class MyTest {
@Mock
Object myObject;
-
+
void test() {
when(myObject.toString()).thenReturn("foo");
assertEquals("foo", myObject.toString());
@@ -97,12 +97,12 @@ void whenNoResultNoTimes() {
import mockit.Mocked;
import mockit.integration.junit5.JMockitExtension;
import org.junit.jupiter.api.extension.ExtendWith;
-
+
@ExtendWith(JMockitExtension.class)
class MyTest {
@Mocked
Object myObject;
-
+
void test() {
new Expectations() {{
myObject.wait(anyLong, anyInt);
@@ -115,14 +115,14 @@ void test() {
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
-
+
import static org.mockito.Mockito.*;
-
+
@ExtendWith(MockitoExtension.class)
class MyTest {
@Mock
Object myObject;
-
+
void test() {
myObject.wait(10L, 10);
verify(myObject).wait(anyLong(), anyInt());
@@ -143,12 +143,12 @@ void whenNoResultNoTimesNoArgs() {
import mockit.Mocked;
import mockit.integration.junit5.JMockitExtension;
import org.junit.jupiter.api.extension.ExtendWith;
-
+
@ExtendWith(JMockitExtension.class)
class MyTest {
@Mocked
Object myObject;
-
+
void test() {
new Expectations() {{
myObject.wait();
@@ -161,14 +161,14 @@ void test() {
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
-
+
import static org.mockito.Mockito.verify;
-
+
@ExtendWith(MockitoExtension.class)
class MyTest {
@Mock
Object myObject;
-
+
void test() {
myObject.wait(10L, 10);
verify(myObject).wait();
@@ -189,14 +189,14 @@ void whenHasResultNoTimes() {
import mockit.Mocked;
import mockit.integration.junit5.JMockitExtension;
import org.junit.jupiter.api.extension.ExtendWith;
-
+
import static org.junit.jupiter.api.Assertions.assertEquals;
-
+
@ExtendWith(JMockitExtension.class)
class MyTest {
@Mocked
Object myObject;
-
+
void test() {
new Expectations() {{
myObject.toString();
@@ -210,15 +210,15 @@ void test() {
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
-
+
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.when;
-
+
@ExtendWith(MockitoExtension.class)
class MyTest {
@Mock
Object myObject;
-
+
void test() {
when(myObject.toString()).thenReturn("foo");
assertEquals("foo", myObject.toString());
@@ -248,14 +248,14 @@ public String getSomeField() {
import mockit.Mocked;
import mockit.integration.junit5.JMockitExtension;
import org.junit.jupiter.api.extension.ExtendWith;
-
+
import static org.junit.jupiter.api.Assertions.assertNull;
-
+
@ExtendWith(JMockitExtension.class)
class MyTest {
@Mocked
MyObject myObject;
-
+
void test() {
new Expectations() {{
myObject.getSomeField();
@@ -269,15 +269,15 @@ void test() {
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
-
+
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.mockito.Mockito.when;
-
+
@ExtendWith(MockitoExtension.class)
class MyTest {
@Mock
MyObject myObject;
-
+
void test() {
when(myObject.getSomeField()).thenReturn(null);
assertNull(myObject.getSomeField());
@@ -307,14 +307,14 @@ public int getSomeField() {
import mockit.Mocked;
import mockit.integration.junit5.JMockitExtension;
import org.junit.jupiter.api.extension.ExtendWith;
-
+
import static org.junit.jupiter.api.Assertions.assertEquals;
-
+
@ExtendWith(JMockitExtension.class)
class MyTest {
@Mocked
MyObject myObject;
-
+
void test() {
new Expectations() {{
myObject.getSomeField();
@@ -333,15 +333,15 @@ void test() {
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
-
+
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.when;
-
+
@ExtendWith(MockitoExtension.class)
class MyTest {
@Mock
MyObject myObject;
-
+
void test() {
when(myObject.getSomeField()).thenReturn(10);
assertEquals(10, myObject.getSomeField());
@@ -373,14 +373,14 @@ public String getSomeField(String s) {
import mockit.Mocked;
import mockit.integration.junit5.JMockitExtension;
import org.junit.jupiter.api.extension.ExtendWith;
-
+
import static org.junit.jupiter.api.Assertions.assertEquals;
-
+
@ExtendWith(JMockitExtension.class)
class MyTest {
@Mocked
MyObject myObject;
-
+
void test() {
new Expectations() {{
myObject.getSomeField(anyString);
@@ -394,16 +394,16 @@ void test() {
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
-
+
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.when;
-
+
@ExtendWith(MockitoExtension.class)
class MyTest {
@Mock
MyObject myObject;
-
+
void test() {
when(myObject.getSomeField(anyString())).thenReturn("foo");
assertEquals("foo", myObject.getSomeField("bar"));
@@ -433,16 +433,16 @@ public String getSomeField() {
import mockit.Mocked;
import mockit.integration.junit5.JMockitExtension;
import org.junit.jupiter.api.extension.ExtendWith;
-
+
import static org.junit.jupiter.api.Assertions.assertEquals;
-
+
@ExtendWith(JMockitExtension.class)
class MyTest {
@Mocked
MyObject myObject;
-
+
String expected = "expected";
-
+
void test() {
new Expectations() {{
myObject.getSomeField();
@@ -456,17 +456,17 @@ void test() {
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
-
+
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.when;
-
+
@ExtendWith(MockitoExtension.class)
class MyTest {
@Mock
MyObject myObject;
-
+
String expected = "expected";
-
+
void test() {
when(myObject.getSomeField()).thenReturn(expected);
assertEquals(expected, myObject.getSomeField());
@@ -496,14 +496,14 @@ public Object getSomeField() {
import mockit.Mocked;
import mockit.integration.junit5.JMockitExtension;
import org.junit.jupiter.api.extension.ExtendWith;
-
+
import static org.junit.jupiter.api.Assertions.assertNotNull;
-
+
@ExtendWith(JMockitExtension.class)
class MyTest {
@Mocked
MyObject myObject;
-
+
void test() {
new Expectations() {{
myObject.getSomeField();
@@ -517,15 +517,15 @@ void test() {
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
-
+
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.mockito.Mockito.when;
-
+
@ExtendWith(MockitoExtension.class)
class MyTest {
@Mock
MyObject myObject;
-
+
void test() {
when(myObject.getSomeField()).thenReturn(new Object());
assertNotNull(myObject.getSomeField());
@@ -555,12 +555,12 @@ public String getSomeField() {
import mockit.Mocked;
import mockit.integration.junit5.JMockitExtension;
import org.junit.jupiter.api.extension.ExtendWith;
-
+
@ExtendWith(JMockitExtension.class)
class MyTest {
@Mocked
MyObject myObject;
-
+
void test() throws RuntimeException {
new Expectations() {{
myObject.getSomeField();
@@ -574,14 +574,14 @@ void test() throws RuntimeException {
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
-
+
import static org.mockito.Mockito.when;
-
+
@ExtendWith(MockitoExtension.class)
class MyTest {
@Mock
MyObject myObject;
-
+
void test() throws RuntimeException {
when(myObject.getSomeField()).thenThrow(new RuntimeException());
myObject.getSomeField();
@@ -611,14 +611,14 @@ public String getSomeField() {
import mockit.Mocked;
import mockit.integration.junit5.JMockitExtension;
import org.junit.jupiter.api.extension.ExtendWith;
-
+
import static org.junit.jupiter.api.Assertions.assertEquals;
-
+
@ExtendWith(JMockitExtension.class)
class MyTest {
@Mocked
MyObject myObject;
-
+
void test() throws RuntimeException {
new Expectations() {{
myObject.getSomeField();
@@ -633,15 +633,15 @@ void test() throws RuntimeException {
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
-
+
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.when;
-
+
@ExtendWith(MockitoExtension.class)
class MyTest {
@Mock
MyObject myObject;
-
+
void test() throws RuntimeException {
when(myObject.getSomeField()).thenReturn("foo", "bar");
assertEquals("foo", myObject.getSomeField());
@@ -660,7 +660,7 @@ void whenClassArgumentMatcher() {
java(
"""
import java.util.List;
-
+
class MyObject {
public String getSomeField(List input) {
return "X";
@@ -678,19 +678,19 @@ public String getSomeArrayField(Object input) {
"""
import java.util.ArrayList;
import java.util.List;
-
+
import mockit.Expectations;
import mockit.Mocked;
import mockit.integration.junit5.JMockitExtension;
import org.junit.jupiter.api.extension.ExtendWith;
-
+
import static org.junit.jupiter.api.Assertions.assertNull;
-
+
@ExtendWith(JMockitExtension.class)
class MyTest {
@Mocked
MyObject myObject;
-
+
void test() {
new Expectations() {{
myObject.getSomeField((List) any);
@@ -709,19 +709,19 @@ void test() {
"""
import java.util.ArrayList;
import java.util.List;
-
+
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
-
+
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.mockito.Mockito.*;
-
+
@ExtendWith(MockitoExtension.class)
class MyTest {
@Mock
MyObject myObject;
-
+
void test() {
when(myObject.getSomeField(anyList())).thenReturn(null);
when(myObject.getSomeOtherField(any(Object.class))).thenReturn(null);
@@ -743,7 +743,7 @@ void whenNoArguments() {
java(
"""
import java.util.List;
-
+
class MyObject {
public String getSomeField() {
return "X";
@@ -755,19 +755,19 @@ public String getSomeField() {
"""
import java.util.ArrayList;
import java.util.List;
-
+
import mockit.Expectations;
import mockit.Mocked;
import mockit.integration.junit5.JMockitExtension;
import org.junit.jupiter.api.extension.ExtendWith;
-
+
import static org.junit.jupiter.api.Assertions.assertNull;
-
+
@ExtendWith(JMockitExtension.class)
class MyTest {
@Mocked
MyObject myObject;
-
+
void test() {
new Expectations() {{
myObject.getSomeField();
@@ -780,19 +780,19 @@ void test() {
"""
import java.util.ArrayList;
import java.util.List;
-
+
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
-
+
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.mockito.Mockito.when;
-
+
@ExtendWith(MockitoExtension.class)
class MyTest {
@Mock
MyObject myObject;
-
+
void test() {
when(myObject.getSomeField()).thenReturn(null);
assertNull(myObject.getSomeField());
@@ -810,7 +810,7 @@ void whenMixedArgumentMatcher() {
java(
"""
import java.util.List;
-
+
class MyObject {
public String getSomeField(String s, String s2, String s3, long l1) {
return "X";
@@ -822,19 +822,19 @@ public String getSomeField(String s, String s2, String s3, long l1) {
"""
import java.util.ArrayList;
import java.util.List;
-
+
import mockit.Expectations;
import mockit.Mocked;
import mockit.integration.junit5.JMockitExtension;
import org.junit.jupiter.api.extension.ExtendWith;
-
+
import static org.junit.jupiter.api.Assertions.assertNull;
-
+
@ExtendWith(JMockitExtension.class)
class MyTest {
@Mocked
MyObject myObject;
-
+
void test() {
String bazz = "bazz";
new Expectations() {{
@@ -848,19 +848,19 @@ void test() {
"""
import java.util.ArrayList;
import java.util.List;
-
+
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
-
+
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.mockito.Mockito.*;
-
+
@ExtendWith(MockitoExtension.class)
class MyTest {
@Mock
MyObject myObject;
-
+
void test() {
String bazz = "bazz";
when(myObject.getSomeField(eq("foo"), anyString(), eq(bazz), eq(10L))).thenReturn(null);
@@ -879,7 +879,7 @@ void whenSetupStatements() {
java(
"""
class MyObject {
-
+
public String getSomeField(String s) {
return "X";
}
@@ -895,26 +895,26 @@ public String getString() {
import mockit.Mocked;
import mockit.integration.junit5.JMockitExtension;
import org.junit.jupiter.api.extension.ExtendWith;
-
+
import static org.junit.jupiter.api.Assertions.assertEquals;
-
+
@ExtendWith(JMockitExtension.class)
class MyTest {
@Mocked
MyObject myObject;
-
+
void test() {
String a = "a";
String s = "s";
-
+
new Expectations() {{
myObject.getSomeField(anyString);
result = s;
-
+
myObject.getString();
result = a;
}};
-
+
assertEquals("s", myObject.getSomeField("foo"));
assertEquals("a", myObject.getString());
}
@@ -924,22 +924,22 @@ void test() {
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
-
+
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.when;
-
+
@ExtendWith(MockitoExtension.class)
class MyTest {
@Mock
MyObject myObject;
-
+
void test() {
String a = "a";
String s = "s";
when(myObject.getSomeField(anyString())).thenReturn(s);
when(myObject.getString()).thenReturn(a);
-
+
assertEquals("s", myObject.getSomeField("foo"));
assertEquals("a", myObject.getString());
}
@@ -968,14 +968,14 @@ public String getSomeField(String s) {
import mockit.Mocked;
import mockit.integration.junit5.JMockitExtension;
import org.junit.jupiter.api.extension.ExtendWith;
-
+
import static org.junit.jupiter.api.Assertions.assertEquals;
-
+
@ExtendWith(JMockitExtension.class)
class MyTest {
@Mocked
MyObject myObject;
-
+
void test() {
String a = "a";
new Expectations() {{
@@ -984,7 +984,7 @@ void test() {
String b = "b";
result = s;
}};
-
+
assertEquals("s", myObject.getSomeField("foo"));
}
}
@@ -993,22 +993,22 @@ void test() {
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
-
+
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.when;
-
+
@ExtendWith(MockitoExtension.class)
class MyTest {
@Mock
MyObject myObject;
-
+
void test() {
String a = "a";
String s = "s";
String b = "b";
when(myObject.getSomeField(anyString())).thenReturn(s);
-
+
assertEquals("s", myObject.getSomeField("foo"));
}
}
@@ -1027,12 +1027,12 @@ void whenTimes() {
import mockit.Mocked;
import mockit.integration.junit5.JMockitExtension;
import org.junit.jupiter.api.extension.ExtendWith;
-
+
@ExtendWith(JMockitExtension.class)
class MyTest {
@Mocked
Object myObject;
-
+
void test() {
new Expectations() {{
myObject.wait(anyLong, anyInt);
@@ -1048,14 +1048,14 @@ void test() {
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
-
+
import static org.mockito.Mockito.*;
-
+
@ExtendWith(MockitoExtension.class)
class MyTest {
@Mock
Object myObject;
-
+
void test() {
myObject.wait(10L, 10);
myObject.wait(10L, 10);
@@ -1078,12 +1078,12 @@ void whenMinTimes() {
import mockit.Mocked;
import mockit.integration.junit5.JMockitExtension;
import org.junit.jupiter.api.extension.ExtendWith;
-
+
@ExtendWith(JMockitExtension.class)
class MyTest {
@Mocked
Object myObject;
-
+
void test() {
new Expectations() {{
myObject.wait(anyLong, anyInt);
@@ -1097,14 +1097,14 @@ void test() {
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
-
+
import static org.mockito.Mockito.*;
-
+
@ExtendWith(MockitoExtension.class)
class MyTest {
@Mock
Object myObject;
-
+
void test() {
myObject.wait(10L, 10);
verify(myObject, atLeast(2)).wait(anyLong(), anyInt());
@@ -1125,12 +1125,12 @@ void whenMaxTimes() {
import mockit.Mocked;
import mockit.integration.junit5.JMockitExtension;
import org.junit.jupiter.api.extension.ExtendWith;
-
+
@ExtendWith(JMockitExtension.class)
class MyTest {
@Mocked
Object myObject;
-
+
void test() {
new Expectations() {{
myObject.wait(anyLong, anyInt);
@@ -1144,14 +1144,14 @@ void test() {
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
-
+
import static org.mockito.Mockito.*;
-
+
@ExtendWith(MockitoExtension.class)
class MyTest {
@Mock
Object myObject;
-
+
void test() {
myObject.wait(10L, 10);
verify(myObject, atMost(5)).wait(anyLong(), anyInt());
@@ -1172,12 +1172,12 @@ void whenMinTimesMaxTimes() {
import mockit.Mocked;
import mockit.integration.junit5.JMockitExtension;
import org.junit.jupiter.api.extension.ExtendWith;
-
+
@ExtendWith(JMockitExtension.class)
class MyTest {
@Mocked
Object myObject;
-
+
void test() {
new Expectations() {{
myObject.wait(anyLong, anyInt);
@@ -1192,14 +1192,14 @@ void test() {
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
-
+
import static org.mockito.Mockito.*;
-
+
@ExtendWith(MockitoExtension.class)
class MyTest {
@Mock
Object myObject;
-
+
void test() {
myObject.wait(10L, 10);
verify(myObject, atLeast(1)).wait(anyLong(), anyInt());
@@ -1230,14 +1230,14 @@ public String getSomeField() {
import mockit.Tested;
import mockit.integration.junit5.JMockitExtension;
import org.junit.jupiter.api.extension.ExtendWith;
-
+
import static org.junit.jupiter.api.Assertions.assertEquals;
-
+
@ExtendWith(JMockitExtension.class)
class MyTest {
@Tested
MyObject myObject;
-
+
void test() {
new Expectations(myObject) {{
myObject.getSomeField();
@@ -1251,15 +1251,15 @@ void test() {
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.junit.jupiter.MockitoExtension;
-
+
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.when;
-
+
@ExtendWith(MockitoExtension.class)
class MyTest {
@InjectMocks
MyObject myObject;
-
+
void test() {
when(myObject.getSomeField()).thenReturn("foo");
assertEquals("foo", myObject.getSomeField());
@@ -1296,18 +1296,18 @@ public void doSomething() {}
import mockit.Mocked;
import mockit.integration.junit5.JMockitExtension;
import org.junit.jupiter.api.extension.ExtendWith;
-
+
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
-
+
@ExtendWith(JMockitExtension.class)
class MyTest {
@Mocked
Object myObject;
-
+
@Mocked
MyObject myOtherObject;
-
+
void test() {
new Expectations() {{
myObject.hashCode();
@@ -1329,19 +1329,19 @@ void test() {
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
-
+
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.mockito.Mockito.*;
-
+
@ExtendWith(MockitoExtension.class)
class MyTest {
@Mock
Object myObject;
-
+
@Mock
MyObject myOtherObject;
-
+
void test() {
when(myObject.hashCode()).thenReturn(10);
when(myOtherObject.getSomeObjectField()).thenReturn(null);
@@ -1377,15 +1377,15 @@ public String getSomeStringField() {
import mockit.Mocked;
import mockit.integration.junit5.JMockitExtension;
import org.junit.jupiter.api.extension.ExtendWith;
-
+
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
-
+
@ExtendWith(JMockitExtension.class)
class MyTest {
@Mocked
MyObject myObject;
-
+
void test() {
new Expectations() {{
myObject.getSomeStringField();
@@ -1404,16 +1404,16 @@ void test() {
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
-
+
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.mockito.Mockito.when;
-
+
@ExtendWith(MockitoExtension.class)
class MyTest {
@Mock
MyObject myObject;
-
+
void test() {
when(myObject.getSomeStringField()).thenReturn("a");
assertEquals("a", myObject.getSomeStringField());
@@ -1448,15 +1448,15 @@ public String getY() {
import mockit.Mocked;
import mockit.integration.junit5.JMockitExtension;
import org.junit.jupiter.api.extension.ExtendWith;
-
+
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
-
+
@ExtendWith(JMockitExtension.class)
class MyTest {
@Mocked
MyObject myObject;
-
+
void test() {
new Expectations() {
{
@@ -1477,16 +1477,16 @@ void test() {
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
-
+
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.mockito.Mockito.when;
-
+
@ExtendWith(MockitoExtension.class)
class MyTest {
@Mock
MyObject myObject;
-
+
void test() {
when(myObject.getX()).thenReturn("x1");
when(myObject.getY()).thenReturn("y1");
@@ -1509,15 +1509,15 @@ void whenMultipleExpectationsNoResults() {
import mockit.Mocked;
import mockit.integration.junit5.JMockitExtension;
import org.junit.jupiter.api.extension.ExtendWith;
-
+
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
-
+
@ExtendWith(JMockitExtension.class)
class MyTest {
@Mocked
Object myObject;
-
+
void test() {
new Expectations() {{
myObject.wait(anyLong);
@@ -1534,17 +1534,17 @@ void test() {
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
-
+
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.mockito.Mockito.anyLong;
import static org.mockito.Mockito.verify;
-
+
@ExtendWith(MockitoExtension.class)
class MyTest {
@Mock
Object myObject;
-
+
void test() {
myObject.wait(1L);
myObject.wait();
@@ -1567,15 +1567,15 @@ void whenWithRedundantThisModifier() {
import mockit.Mocked;
import mockit.integration.junit5.JMockitExtension;
import org.junit.jupiter.api.extension.ExtendWith;
-
+
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
-
+
@ExtendWith(JMockitExtension.class)
class MyTest {
@Mocked
Object myObject;
-
+
void test() {
new Expectations() {{
myObject.wait(this.anyLong, anyInt);
@@ -1588,16 +1588,16 @@ void test() {
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
-
+
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.mockito.Mockito.*;
-
+
@ExtendWith(MockitoExtension.class)
class MyTest {
@Mock
Object myObject;
-
+
void test() {
myObject.wait();
verify(myObject).wait(anyLong(), anyInt());
@@ -1628,15 +1628,15 @@ public String getSomeStringField() {
import mockit.Mocked;
import mockit.integration.junit5.JMockitExtension;
import org.junit.jupiter.api.extension.ExtendWith;
-
+
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
-
+
@ExtendWith(JMockitExtension.class)
class MyTest {
@Mocked
MyObject myObject;
-
+
void test() {
new Expectations() {{
// comments for this line below
@@ -1656,16 +1656,16 @@ void test() {
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
-
+
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.mockito.Mockito.when;
-
+
@ExtendWith(MockitoExtension.class)
class MyTest {
@Mock
MyObject myObject;
-
+
void test() {
// comments for this line below
when(myObject.getSomeStringField()).thenReturn("a");
@@ -1689,12 +1689,12 @@ void whenEmptyBlock() {
import mockit.Mocked;
import mockit.integration.junit5.JMockitExtension;
import org.junit.jupiter.api.extension.ExtendWith;
-
+
@ExtendWith(JMockitExtension.class)
class MyTest {
@Mocked
Object myObject;
-
+
void test() {
new Expectations() {{
}};
@@ -1706,12 +1706,12 @@ void test() {
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
-
+
@ExtendWith(MockitoExtension.class)
class MyTest {
@Mock
Object myObject;
-
+
void test() {
myObject.wait(1L);
}
@@ -1799,8 +1799,11 @@ interface MyInterface {
""",
"""
import org.junit.jupiter.api.Test;
+ import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
+ import org.mockito.junit.jupiter.MockitoExtension;
+ @ExtendWith(MockitoExtension.class)
class MyTest {
@Mock
private MyInterface myMock;
diff --git a/src/test/java/org/openrewrite/java/testing/junit6/MigrateMethodOrdererAlphanumericTest.java b/src/test/java/org/openrewrite/java/testing/junit6/MigrateMethodOrdererAlphanumericTest.java
index 6f2ce5449..735c004e8 100644
--- a/src/test/java/org/openrewrite/java/testing/junit6/MigrateMethodOrdererAlphanumericTest.java
+++ b/src/test/java/org/openrewrite/java/testing/junit6/MigrateMethodOrdererAlphanumericTest.java
@@ -30,7 +30,7 @@ class MigrateMethodOrdererAlphanumericTest implements RewriteTest {
public void defaults(RecipeSpec spec) {
spec
.parser(JavaParser.fromJavaVersion()
- .classpathFromResources(new InMemoryExecutionContext(), "junit-jupiter-api"))
+ .classpathFromResources(new InMemoryExecutionContext(), "junit-jupiter-api-5"))
.recipe(new MigrateMethodOrdererAlphanumeric());
}
diff --git a/src/test/java/org/openrewrite/java/testing/mockito/AddMockitoExtensionIfAnnotationsUsedTest.java b/src/test/java/org/openrewrite/java/testing/mockito/AddMockitoExtensionIfAnnotationsUsedTest.java
new file mode 100644
index 000000000..34abfe13b
--- /dev/null
+++ b/src/test/java/org/openrewrite/java/testing/mockito/AddMockitoExtensionIfAnnotationsUsedTest.java
@@ -0,0 +1,189 @@
+/*
+ * Copyright 2025 the original author or authors.
+ *
+ * Licensed under the Moderne Source Available License (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://docs.moderne.io/licensing/moderne-source-available-license
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.openrewrite.java.testing.mockito;
+
+import org.junit.jupiter.api.Test;
+import org.openrewrite.DocumentExample;
+import org.openrewrite.InMemoryExecutionContext;
+import org.openrewrite.java.JavaParser;
+import org.openrewrite.test.RecipeSpec;
+import org.openrewrite.test.RewriteTest;
+
+import static org.openrewrite.java.Assertions.java;
+
+class AddMockitoExtensionIfAnnotationsUsedTest implements RewriteTest {
+
+ @Override
+ public void defaults(RecipeSpec spec) {
+ spec.recipe(new AddMockitoExtensionIfAnnotationsUsed())
+ .parser(JavaParser.fromJavaVersion()
+ .classpathFromResources(new InMemoryExecutionContext(), "junit-jupiter-api", "mockito-junit-jupiter", "mockito-core")
+ .dependsOn("public class Service {}"));
+ }
+
+ @DocumentExample
+ @Test
+ void addForMock() {
+ rewriteRun(
+ //language=java
+ java(
+ """
+ import org.junit.jupiter.api.Test;
+ import org.mockito.Mock;
+
+ class Test {
+ @Mock
+ Service service;
+ @Test
+ void test() {}
+ }
+ """,
+ """
+ import org.junit.jupiter.api.Test;
+ import org.junit.jupiter.api.extension.ExtendWith;
+ import org.mockito.Mock;
+ import org.mockito.junit.jupiter.MockitoExtension;
+
+ @ExtendWith(MockitoExtension.class)
+ class Test {
+ @Mock
+ Service service;
+ @Test
+ void test() {}
+ }
+ """
+ )
+ );
+ }
+
+ @Test
+ void addForCaptor() {
+ rewriteRun(
+ //language=java
+ java(
+ """
+ import org.junit.jupiter.api.Test;
+ import org.mockito.Captor;
+
+ class Test {
+ @Captor
+ Service service;
+ @Test
+ void test() {}
+ }
+ """,
+ """
+ import org.junit.jupiter.api.Test;
+ import org.junit.jupiter.api.extension.ExtendWith;
+ import org.mockito.Captor;
+ import org.mockito.junit.jupiter.MockitoExtension;
+
+ @ExtendWith(MockitoExtension.class)
+ class Test {
+ @Captor
+ Service service;
+ @Test
+ void test() {}
+ }
+ """
+ )
+ );
+ }
+
+ @Test
+ void dontAddIfPresent() {
+ rewriteRun(
+ //language=java
+ java(
+ """
+ import org.junit.jupiter.api.extension.ExtendWith;
+ import org.mockito.Captor;
+ import org.mockito.Mock;
+ import org.mockito.junit.jupiter.MockitoExtension;
+
+ @ExtendWith(MockitoExtension.class)
+ class Test {
+ @Captor
+ Service service;
+ @Mock
+ Service service;
+ }
+ """
+ )
+ );
+ }
+
+ @Test
+ void dontAddIfJunit4() {
+ rewriteRun(
+ //language=java
+ java(
+ """
+ import org.junit.Test;
+ import org.junit.jupiter.api.extension.ExtendWith;
+ import org.mockito.Mock;
+ import org.mockito.junit.jupiter.MockitoExtension;
+
+ class Test {
+ @Mock
+ Service service;
+ @Test
+ void test() {}
+ }
+ """
+ )
+ );
+ }
+
+ @Test
+ void notInferWithExistingAnnotations() {
+ rewriteRun(
+ //language=java
+ java(
+ """
+ import org.junit.jupiter.api.Disabled;
+ import org.mockito.Captor;
+ import org.mockito.Mock;
+
+ @Disabled
+ class Test {
+ @Mock
+ Service service;
+ @Test
+ void test() {}
+ }
+ """,
+ """
+ import org.junit.jupiter.api.Disabled;
+ import org.junit.jupiter.api.extension.ExtendWith;
+ import org.mockito.Captor;
+ import org.mockito.Mock;
+ import org.mockito.junit.jupiter.MockitoExtension;
+
+ @Disabled
+ @ExtendWith(MockitoExtension.class)
+ class Test {
+ @Mock
+ Service service;
+ @Test
+ void test() {}
+ }
+ """
+ )
+ );
+ }
+}
diff --git a/src/test/java/org/openrewrite/java/testing/mockito/JunitMockitoUpgradeIntegrationTest.java b/src/test/java/org/openrewrite/java/testing/mockito/JunitMockitoUpgradeIntegrationTest.java
index e3269d016..69fd31ecc 100755
--- a/src/test/java/org/openrewrite/java/testing/mockito/JunitMockitoUpgradeIntegrationTest.java
+++ b/src/test/java/org/openrewrite/java/testing/mockito/JunitMockitoUpgradeIntegrationTest.java
@@ -89,26 +89,20 @@ public void usingAnnotationBasedMock() {
"""
package org.openrewrite.java.testing.junit5;
- import org.junit.jupiter.api.AfterEach;
- import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
+ import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
- import org.mockito.MockitoAnnotations;
+ import org.mockito.junit.jupiter.MockitoExtension;
import java.util.List;
import static org.mockito.Mockito.verify;
+ @ExtendWith(MockitoExtension.class)
public class MockitoTests {
- private AutoCloseable mocks;
@Mock
List mockedList;
- @BeforeEach
- public void initMocks() {
- mocks = MockitoAnnotations.openMocks(this);
- }
-
@Test
public void usingAnnotationBasedMock() {
@@ -118,11 +112,6 @@ public void usingAnnotationBasedMock() {
verify(mockedList).add("one");
verify(mockedList).clear();
}
-
- @AfterEach
- void tearDown() throws Exception {
- mocks.close();
- }
}
"""
)
diff --git a/src/test/java/org/openrewrite/java/testing/mockito/Mockito1to3MigrationTest.java b/src/test/java/org/openrewrite/java/testing/mockito/Mockito1to3MigrationTest.java
index 092584573..460c8b550 100644
--- a/src/test/java/org/openrewrite/java/testing/mockito/Mockito1to3MigrationTest.java
+++ b/src/test/java/org/openrewrite/java/testing/mockito/Mockito1to3MigrationTest.java
@@ -173,13 +173,16 @@ void someTest() {
"""
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
+ import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
+ import org.mockito.junit.jupiter.MockitoExtension;
import static org.mockito.ArgumentMatchers.anyList;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.when;
+ @ExtendWith(MockitoExtension.class)
class MyTest {
@Mock
Object objectMock;