From a4a7ac038209da1c75106421ee045c0641083e64 Mon Sep 17 00:00:00 2001 From: Greg Oledzki Date: Fri, 25 Jul 2025 16:09:41 +0200 Subject: [PATCH 1/3] Revert "Remove UseForEachLoop for now (#662)" This reverts commit 2127a4e1ebfa654265909b15c79adc366a437563. --- .../staticanalysis/UseForEachLoop.java | 416 ++++++++++++++++++ .../staticanalysis/UseForEachLoopTest.java | 352 +++++++++++++++ 2 files changed, 768 insertions(+) create mode 100644 src/main/java/org/openrewrite/staticanalysis/UseForEachLoop.java create mode 100644 src/test/java/org/openrewrite/staticanalysis/UseForEachLoopTest.java diff --git a/src/main/java/org/openrewrite/staticanalysis/UseForEachLoop.java b/src/main/java/org/openrewrite/staticanalysis/UseForEachLoop.java new file mode 100644 index 000000000..356ac66a0 --- /dev/null +++ b/src/main/java/org/openrewrite/staticanalysis/UseForEachLoop.java @@ -0,0 +1,416 @@ +/* + * Copyright 2025 the original author or authors. + *

+ * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://docs.moderne.io/licensing/moderne-source-available-license + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.staticanalysis; + +import org.jspecify.annotations.Nullable; +import org.openrewrite.ExecutionContext; +import org.openrewrite.Recipe; +import org.openrewrite.Tree; +import org.openrewrite.TreeVisitor; +import org.openrewrite.java.JavaTemplate; +import org.openrewrite.java.JavaVisitor; +import org.openrewrite.java.VariableNameUtils; +import org.openrewrite.java.search.SemanticallyEqual; +import org.openrewrite.java.tree.*; +import org.openrewrite.marker.Markers; + +import java.time.Duration; +import java.util.Collections; + +import static org.openrewrite.java.VariableNameUtils.GenerationStrategy.INCREMENT_NUMBER; + +public class UseForEachLoop extends Recipe { + + @Override + public String getDisplayName() { + return "Use for-each loops instead of manual indexing"; + } + + @Override + public String getDescription() { + return "Replace traditional for loops that iterate over collections or arrays with enhanced for-each loops for improved readability."; + } + + @Override + public Duration getEstimatedEffortPerOccurrence() { + return Duration.ofMinutes(2); + } + + @Override + public TreeVisitor getVisitor() { + return new JavaVisitor() { + @Override + public J visitForLoop(J.ForLoop forLoop, ExecutionContext ctx) { + J.ForLoop.Control control = forLoop.getControl(); + + if (control.getInit().size() != 1 || control.getCondition() == null || control.getUpdate().size() != 1) { + return super.visitForLoop(forLoop, ctx); + } + + Statement init = control.getInit().get(0); + if (!(init instanceof J.VariableDeclarations)) { + return super.visitForLoop(forLoop, ctx); + } + + J.VariableDeclarations initVars = (J.VariableDeclarations) init; + if (initVars.getVariables().size() != 1) { + return super.visitForLoop(forLoop, ctx); + } + + J.VariableDeclarations.NamedVariable indexVar = initVars.getVariables().get(0); + if (indexVar.getInitializer() == null || !(indexVar.getInitializer() instanceof J.Literal)) { + return super.visitForLoop(forLoop, ctx); + } + + J.Literal initValue = (J.Literal) indexVar.getInitializer(); + if (!Integer.valueOf(0).equals(initValue.getValue())) { + return super.visitForLoop(forLoop, ctx); + } + + String indexVarName = indexVar.getSimpleName(); + + if (!(control.getCondition() instanceof J.Binary)) { + return super.visitForLoop(forLoop, ctx); + } + + J.Binary condition = (J.Binary) control.getCondition(); + if (condition.getOperator() != J.Binary.Type.LessThan) { + return super.visitForLoop(forLoop, ctx); + } + + if (!(condition.getLeft() instanceof J.Identifier) || + !((J.Identifier) condition.getLeft()).getSimpleName().equals(indexVarName)) { + return super.visitForLoop(forLoop, ctx); + } + + J collection; + if (condition.getRight() instanceof J.MethodInvocation) { + J.MethodInvocation sizeCall = (J.MethodInvocation) condition.getRight(); + if (!"size".equals(sizeCall.getSimpleName()) || !((sizeCall.getArguments().isEmpty()) || (sizeCall.getArguments().size() == 1 && sizeCall.getArguments().get(0) instanceof J.Empty))) { + return super.visitForLoop(forLoop, ctx); + } + collection = sizeCall.getSelect(); + } else if (condition.getRight() instanceof J.FieldAccess) { + J.FieldAccess lengthAccess = (J.FieldAccess) condition.getRight(); + if (!"length".equals(lengthAccess.getSimpleName())) { + return super.visitForLoop(forLoop, ctx); + } + collection = lengthAccess.getTarget(); + } else { + return super.visitForLoop(forLoop, ctx); + } + + Statement update = control.getUpdate().get(0); + if (!(update instanceof J.Unary)) { + return super.visitForLoop(forLoop, ctx); + } + + J.Unary unaryUpdate = (J.Unary) update; + if (unaryUpdate.getOperator() != J.Unary.Type.PostIncrement && unaryUpdate.getOperator() != J.Unary.Type.PreIncrement) { + return super.visitForLoop(forLoop, ctx); + } + + if (!(unaryUpdate.getExpression() instanceof J.Identifier) || + !((J.Identifier) unaryUpdate.getExpression()).getSimpleName().equals(indexVarName)) { + return super.visitForLoop(forLoop, ctx); + } + + if (!isValidForTransformation(forLoop.getBody(), indexVarName, collection)) { + return super.visitForLoop(forLoop, ctx); + } + + if (!isIterableOrArray(collection)) { + return super.visitForLoop(forLoop, ctx); + } + + String forEachVarName = determineForEachVariableName(forLoop.getBody(), indexVarName, collection); + + JavaTemplate template = JavaTemplate.builder("for (String " + forEachVarName + " : #{any()}) #{any()}") + .build(); + + Statement transformedBody = (Statement) new BodyTransformer(indexVarName, collection, forEachVarName).visit(forLoop.getBody(), getCursor()); + + J.ForEachLoop forEachLoop = template.apply(getCursor(), forLoop.getCoordinates().replace(), collection, transformedBody); + + J.ForEachLoop.Control foreachControl = forEachLoop.getControl(); + J iterable = foreachControl.getIterable(); + return forEachLoop.withControl( + foreachControl.withIterable(iterable.withPrefix(Space.format(" "))) + ); + } + + private String determineForEachVariableName(Statement body, String indexVarName, J collection) { + VariableNameDetector detector = new VariableNameDetector(indexVarName, collection); + detector.visit(body, null); + + String detectedName = detector.getDetectedVariableName(); + if (detectedName != null) { + return detectedName; + } + + String derivedName = deriveVariableNameFromCollection(collection); + + return VariableNameUtils.generateVariableName(derivedName, getCursor(), INCREMENT_NUMBER); + } + + private String deriveVariableNameFromCollection(J collection) { + String collectionName = getCollectionName(collection); + if (collectionName == null) { + return "item"; + } + + if (collectionName.endsWith("s") && collectionName.length() > 1) { + return collectionName.substring(0, collectionName.length() - 1); + } + + return collectionName + "Item"; + } + + private String getCollectionName(J collection) { + if (collection instanceof J.Identifier) { + return ((J.Identifier) collection).getSimpleName(); + } + if (collection instanceof J.FieldAccess) { + return ((J.FieldAccess) collection).getSimpleName(); + } + return null; + } + + private boolean isValidForTransformation(Statement body, String indexVarName, J collection) { + ValidationVisitor validator = new ValidationVisitor(indexVarName, collection); + validator.visit(body, null); + return validator.isValid(); + } + + private boolean isIterableOrArray(J collection) { + if (collection == null || !(collection instanceof TypedTree)) { + return false; + } + + JavaType type = ((TypedTree) collection).getType(); + if (type == null) { + return false; + } + + return type instanceof JavaType.Array || + TypeUtils.isAssignableTo("java.lang.Iterable", type); + } + + private boolean isCollectionOrArrayAccess(J initializer, String indexVarName, J collection) { + if (initializer instanceof J.MethodInvocation) { + J.MethodInvocation method = (J.MethodInvocation) initializer; + return "get".equals(method.getSimpleName()) && + method.getArguments().size() == 1 && + method.getArguments().get(0) instanceof J.Identifier && + indexVarName.equals(((J.Identifier) method.getArguments().get(0)).getSimpleName()) && + SemanticallyEqual.areEqual(method.getSelect(), collection); + } + if (initializer instanceof J.ArrayAccess) { + J.ArrayAccess arrayAccess = (J.ArrayAccess) initializer; + return arrayAccess.getDimension().getIndex() instanceof J.Identifier && + indexVarName.equals(((J.Identifier) arrayAccess.getDimension().getIndex()).getSimpleName()) && + SemanticallyEqual.areEqual(arrayAccess.getIndexed(), collection); + } + return false; + } + + private class VariableNameDetector extends JavaVisitor { + private final String indexVarName; + private final J collection; + private String detectedVariableName; + + public VariableNameDetector(String indexVarName, J collection) { + this.indexVarName = indexVarName; + this.collection = collection; + } + + public String getDetectedVariableName() { + return detectedVariableName; + } + + @Override + public J visitVariableDeclarations(J.VariableDeclarations variableDeclarations, Object o) { + if (variableDeclarations.getVariables().size() == 1) { + J.VariableDeclarations.NamedVariable variable = variableDeclarations.getVariables().get(0); + if (variable.getInitializer() != null && isCollectionOrArrayAccess(variable.getInitializer(), indexVarName, collection)) { + detectedVariableName = variable.getSimpleName(); + } + } + return super.visitVariableDeclarations(variableDeclarations, o); + } + } + + private class ValidationVisitor extends JavaVisitor { + private final String indexVarName; + private final J collection; + private boolean valid = true; + private boolean insideValidAccess; + private boolean insideInvalidAccess; + + public ValidationVisitor(String indexVarName, J collection) { + this.indexVarName = indexVarName; + this.collection = collection; + } + + public boolean isValid() { + return valid; + } + + @Override + public J visitIdentifier(J.Identifier identifier, Object o) { + if (indexVarName.equals(identifier.getSimpleName()) && !insideValidAccess) { + valid = false; + } + return super.visitIdentifier(identifier, o); + } + + @Override + public J visitMethodInvocation(J.MethodInvocation method, Object o) { + if ("get".equals(method.getSimpleName()) && + method.getArguments().size() == 1 && + method.getArguments().get(0) instanceof J.Identifier && + indexVarName.equals(((J.Identifier) method.getArguments().get(0)).getSimpleName())) { + + boolean wasInsideValidAccess = insideValidAccess; + if (SemanticallyEqual.areEqual(method.getSelect(), collection)) { + insideValidAccess = true; + } else { + valid = false; + } + + J result = super.visitMethodInvocation(method, o); + insideValidAccess = wasInsideValidAccess; + return result; + } + return super.visitMethodInvocation(method, o); + } + + @Override + public J visitArrayAccess(J.ArrayAccess arrayAccess, Object o) { + if (arrayAccess.getDimension().getIndex() instanceof J.Identifier && + indexVarName.equals(((J.Identifier) arrayAccess.getDimension().getIndex()).getSimpleName())) { + + boolean wasInsideValidAccess = insideValidAccess; + if (SemanticallyEqual.areEqual(arrayAccess.getIndexed(), collection)) { + insideValidAccess = true; + } else { + valid = false; + } + if (insideInvalidAccess) { + valid = false; + } + + J result = super.visitArrayAccess(arrayAccess, o); + insideValidAccess = wasInsideValidAccess; + return result; + } + return super.visitArrayAccess(arrayAccess, o); + } + + @Override + public J visitAssignment(J.Assignment assignment, Object o) { + this.insideInvalidAccess = true; + this.visit(assignment.getVariable(), o); + this.insideInvalidAccess = false; + return super.visitAssignment(assignment, o); + } + } + + private class BodyTransformer extends JavaVisitor { + private final String indexVarName; + private final J collection; + private final String newVariableName; + private String variableToReplace; + + public BodyTransformer(String indexVarName, J collection, String newVariableName) { + this.indexVarName = indexVarName; + this.collection = collection; + this.newVariableName = newVariableName; + } + + @Override + public @Nullable J visitVariableDeclarations(J.VariableDeclarations variableDeclarations, Object o) { + if (variableDeclarations.getVariables().size() == 1) { + J.VariableDeclarations.NamedVariable variable = variableDeclarations.getVariables().get(0); + if (variable.getInitializer() != null && isCollectionOrArrayAccess(variable.getInitializer(), indexVarName, collection)) { + variableToReplace = variable.getSimpleName(); + return null; + } + } + return super.visitVariableDeclarations(variableDeclarations, o); + } + + @Override + public J visitIdentifier(J.Identifier identifier, Object o) { + if (variableToReplace != null && variableToReplace.equals(identifier.getSimpleName())) { + return new J.Identifier( + Tree.randomId(), + identifier.getPrefix(), + Markers.EMPTY, + Collections.emptyList(), + newVariableName, + identifier.getType(), + null + ); + } + return super.visitIdentifier(identifier, o); + } + + @Override + public J visitMethodInvocation(J.MethodInvocation method, Object o) { + if ("get".equals(method.getSimpleName()) && + method.getArguments().size() == 1 && + method.getArguments().get(0) instanceof J.Identifier && + indexVarName.equals(((J.Identifier) method.getArguments().get(0)).getSimpleName()) && + SemanticallyEqual.areEqual(method.getSelect(), collection)) { + + return new J.Identifier( + Tree.randomId(), + method.getPrefix(), + Markers.EMPTY, + Collections.emptyList(), + newVariableName, + method.getType(), + null + ); + } + return super.visitMethodInvocation(method, o); + } + + + @Override + public J visitArrayAccess(J.ArrayAccess arrayAccess, Object o) { + if (arrayAccess.getDimension().getIndex() instanceof J.Identifier && + indexVarName.equals(((J.Identifier) arrayAccess.getDimension().getIndex()).getSimpleName()) && + SemanticallyEqual.areEqual(arrayAccess.getIndexed(), collection)) { + + return new J.Identifier( + Tree.randomId(), + arrayAccess.getPrefix(), + Markers.EMPTY, + Collections.emptyList(), + newVariableName, + arrayAccess.getType(), + null + ); + } + return super.visitArrayAccess(arrayAccess, o); + } + + } + }; + } +} diff --git a/src/test/java/org/openrewrite/staticanalysis/UseForEachLoopTest.java b/src/test/java/org/openrewrite/staticanalysis/UseForEachLoopTest.java new file mode 100644 index 000000000..505bdb2d7 --- /dev/null +++ b/src/test/java/org/openrewrite/staticanalysis/UseForEachLoopTest.java @@ -0,0 +1,352 @@ +/* + * Copyright 2025 the original author or authors. + *

+ * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://docs.moderne.io/licensing/moderne-source-available-license + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.staticanalysis; + +import org.junit.jupiter.api.Test; +import org.openrewrite.DocumentExample; +import org.openrewrite.test.RecipeSpec; +import org.openrewrite.test.RewriteTest; + +import static org.openrewrite.java.Assertions.java; + +@SuppressWarnings({"SimplifiableForEach", "ForLoopReplaceableByForEach"}) +class UseForEachLoopTest implements RewriteTest { + @Override + public void defaults(RecipeSpec spec) { + spec.recipe(new UseForEachLoop()); + } + + @DocumentExample + @Test + void transformListIteration() { + rewriteRun( + //language=java + java( + """ + import java.util.List; + + class Test { + void test(List names) { + for (int i = 0; i < names.size(); i++) { + System.out.println(names.get(i)); + } + } + } + """, + """ + import java.util.List; + + class Test { + void test(List names) { + for (String name : names) { + System.out.println(name); + } + } + } + """ + ) + ); + } + + @Test + void removeVariable() { + rewriteRun( + //language=java + java( + """ + import java.util.List; + + class Test { + void test(List names) { + for (int i = 0; i < names.size(); i++) { + String name = names.get(i); + System.out.println(name); + } + } + } + """, + """ + import java.util.List; + + class Test { + void test(List names) { + for (String name : names) { + System.out.println(name); + } + } + } + """ + ) + ); + } + + @Test + void transformArrayIteration() { + rewriteRun( + //language=java + java( + """ + class Test { + void test(String[] names) { + for (int i = 0; i < names.length; i++) { + System.out.println(names[i]); + } + } + } + """, + """ + class Test { + void test(String[] names) { + for (String name : names) { + System.out.println(name); + } + } + } + """ + ) + ); + } + + @Test + void preIncrementLoop() { + rewriteRun( + //language=java + java( + """ + import java.util.List; + + class Test { + void test(List numbers) { + for (int i = 0; i < numbers.size(); ++i) { + System.out.println(numbers.get(i)); + } + } + } + """, + """ + import java.util.List; + + class Test { + void test(List numbers) { + for (String number : numbers) { + System.out.println(number); + } + } + } + """ + ) + ); + } + + @Test + void notClashVariableNames() { + rewriteRun( + //language=java + java( + """ + import java.util.List; + + class Test { + void test(List numbers, java.util.Date number) { + for (int i = 0; i < numbers.size(); ++i) { + System.out.println(numbers.get(i)); + } + } + } + """, + """ + import java.util.List; + + class Test { + void test(List numbers, java.util.Date number) { + for (String number1 : numbers) { + System.out.println(number1); + } + } + } + """ + ) + ); + } + + @Test + void noChangeWhenIndexUsedForOtherPurposes() { + rewriteRun( + //language=java + java( + """ + import java.util.List; + + class Test { + void test(List names) { + for (int i = 0; i < names.size(); i++) { + System.out.println(i + ": " + names.get(i)); + } + } + } + """ + ) + ); + } + + @Test + void noChangeWhenNotStartingFromZero() { + rewriteRun( + //language=java + java( + """ + import java.util.List; + + class Test { + void test(List names) { + for (int i = 1; i < names.size(); i++) { + System.out.println(names.get(i)); + } + } + } + """ + ) + ); + } + + @Test + void noChangeWhenNotSimpleIncrement() { + rewriteRun( + //language=java + java( + """ + import java.util.List; + + class Test { + void test(List names) { + for (int i = 0; i < names.size(); i += 2) { + System.out.println(names.get(i)); + } + } + } + """ + ) + ); + } + + @Test + void noChangeWhenAccessingDifferentCollection() { + rewriteRun( + //language=java + java( + """ + import java.util.List; + + class Test { + void test(List names, List other) { + for (int i = 0; i < names.size(); i++) { + System.out.println(other.get(i)); + } + } + } + """ + ) + ); + } + + @Test + void noChangeWhenComplexCondition() { + rewriteRun( + //language=java + java( + """ + import java.util.List; + + class Test { + void test(List names) { + for (int i = 0; i < names.size() && i < 10; i++) { + System.out.println(names.get(i)); + } + } + } + """ + ) + ); + } + + @Test + void noChangeWhenNoCollectionAccess() { + rewriteRun( + //language=java + java( + """ + import java.util.List; + + class Test { + void test(List names) { + for (int i = 0; i < names.size(); i++) { + System.out.println("Processing item " + i); + } + } + } + """ + ) + ); + } + + @Test + void noChangeWhenCustomCollectionLikeClass() { + rewriteRun( + //language=java + java( + """ + class Test { + static class CustomContainer { + private String[] items = {"a", "b", "c"}; + + public String get(int index) { + return items[index]; + } + + public int size() { + return items.length; + } + } + + void test(CustomContainer container) { + for (int i = 0; i < container.size(); i++) { + System.out.println(container.get(i)); + } + } + } + """ + ) + ); + } + + @Test + void noChangeWhenArrayAccessOnLeftSideOfAssignment() { + rewriteRun( + //language=java + java( + """ + class Test { + void test(String[] names) { + for (int i = 0; i < names.length; i++) { + names[i] = "modified"; + } + } + } + """ + ) + ); + } +} From faf26b5b7eb8506ae6b0b0a569e7f2ca5ae5ad3c Mon Sep 17 00:00:00 2001 From: Greg Oledzki Date: Fri, 25 Jul 2025 15:21:49 +0200 Subject: [PATCH 2/3] Unit test --- .../staticanalysis/UseForEachLoopTest.java | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/src/test/java/org/openrewrite/staticanalysis/UseForEachLoopTest.java b/src/test/java/org/openrewrite/staticanalysis/UseForEachLoopTest.java index 505bdb2d7..e70d2a4e1 100644 --- a/src/test/java/org/openrewrite/staticanalysis/UseForEachLoopTest.java +++ b/src/test/java/org/openrewrite/staticanalysis/UseForEachLoopTest.java @@ -349,4 +349,56 @@ void test(String[] names) { ) ); } + + @Test + void iteratingOverCustomType() { + rewriteRun( + //language=java + java( + """ + import java.util.List; + + class Test { + static class Person { + String name; + int age; + + Person(String name, int age) { + this.name = name; + this.age = age; + } + } + + void test(List people) { + for (int i = 0; i < people.size(); i++) { + Person person = people.get(i); + System.out.println(person.name + " is " + person.age + " years old"); + } + } + } + """, + """ + import java.util.List; + + class Test { + static class Person { + String name; + int age; + + Person(String name, int age) { + this.name = name; + this.age = age; + } + } + + void test(List people) { + for (Person person : people) { + System.out.println(person.name + " is " + person.age + " years old"); + } + } + } + """ + ) + ); + } } From e4943268ced693ac11ee73fac66ba47f39f746ac Mon Sep 17 00:00:00 2001 From: Greg Oledzki Date: Fri, 25 Jul 2025 15:31:50 +0200 Subject: [PATCH 3/3] getElementType --- .../staticanalysis/UseForEachLoop.java | 59 +++- .../staticanalysis/UseForEachLoopTest.java | 286 +++++++++--------- 2 files changed, 201 insertions(+), 144 deletions(-) diff --git a/src/main/java/org/openrewrite/staticanalysis/UseForEachLoop.java b/src/main/java/org/openrewrite/staticanalysis/UseForEachLoop.java index 356ac66a0..9ea5b607f 100644 --- a/src/main/java/org/openrewrite/staticanalysis/UseForEachLoop.java +++ b/src/main/java/org/openrewrite/staticanalysis/UseForEachLoop.java @@ -20,6 +20,7 @@ import org.openrewrite.Recipe; import org.openrewrite.Tree; import org.openrewrite.TreeVisitor; +import org.openrewrite.java.JavaParser; import org.openrewrite.java.JavaTemplate; import org.openrewrite.java.JavaVisitor; import org.openrewrite.java.VariableNameUtils; @@ -137,8 +138,16 @@ public J visitForLoop(J.ForLoop forLoop, ExecutionContext ctx) { } String forEachVarName = determineForEachVariableName(forLoop.getBody(), indexVarName, collection); + String elementTypeName = getElementTypeName(collection); - JavaTemplate template = JavaTemplate.builder("for (String " + forEachVarName + " : #{any()}) #{any()}") + if (elementTypeName == null) { + // If we can't determine the element type, don't make any changes + return super.visitForLoop(forLoop, ctx); + } + + JavaTemplate template = JavaTemplate.builder("for (" + elementTypeName + " " + forEachVarName + " : #{any()}) #{any()}") + .contextSensitive() + .javaParser(JavaParser.fromJavaVersion()) .build(); Statement transformedBody = (Statement) new BodyTransformer(indexVarName, collection, forEachVarName).visit(forLoop.getBody(), getCursor()); @@ -209,6 +218,54 @@ private boolean isIterableOrArray(J collection) { TypeUtils.isAssignableTo("java.lang.Iterable", type); } + private JavaType getElementType(J collection) { + if (collection == null || !(collection instanceof TypedTree)) { + return null; + } + + JavaType type = ((TypedTree) collection).getType(); + if (type == null) { + return null; + } + + if (type instanceof JavaType.Array) { + JavaType.Array arrayType = (JavaType.Array) type; + return arrayType.getElemType(); + } else if (type instanceof JavaType.Parameterized) { + JavaType.Parameterized parameterized = (JavaType.Parameterized) type; + if (!parameterized.getTypeParameters().isEmpty()) { + return parameterized.getTypeParameters().get(0); + } + } + + return null; + } + + private String getElementTypeName(J collection) { + JavaType elementType = getElementType(collection); + if (elementType == null) { + return null; + } + return getSimpleTypeName(elementType); + } + + private String getSimpleTypeName(JavaType type) { + if (type instanceof JavaType.FullyQualified) { + String className = ((JavaType.FullyQualified) type).getClassName(); + // Handle nested classes - extract just the simple name + int lastDot = className.lastIndexOf('.'); + if (lastDot > 0) { + return className.substring(lastDot + 1); + } + return className; + } else if (type instanceof JavaType.Primitive) { + return type.toString(); + } else if (type instanceof JavaType.GenericTypeVariable) { + return ((JavaType.GenericTypeVariable) type).getName(); + } + return type.toString(); + } + private boolean isCollectionOrArrayAccess(J initializer, String indexVarName, J collection) { if (initializer instanceof J.MethodInvocation) { J.MethodInvocation method = (J.MethodInvocation) initializer; diff --git a/src/test/java/org/openrewrite/staticanalysis/UseForEachLoopTest.java b/src/test/java/org/openrewrite/staticanalysis/UseForEachLoopTest.java index e70d2a4e1..91018f018 100644 --- a/src/test/java/org/openrewrite/staticanalysis/UseForEachLoopTest.java +++ b/src/test/java/org/openrewrite/staticanalysis/UseForEachLoopTest.java @@ -35,28 +35,28 @@ void transformListIteration() { rewriteRun( //language=java java( - """ - import java.util.List; - - class Test { - void test(List names) { - for (int i = 0; i < names.size(); i++) { - System.out.println(names.get(i)); - } - } - } - """, """ - import java.util.List; - - class Test { - void test(List names) { - for (String name : names) { - System.out.println(name); - } - } - } + import java.util.List; + + class Test { + void test(List names) { + for (int i = 0; i < names.size(); i++) { + System.out.println(names.get(i)); + } + } + } + """, """ + import java.util.List; + + class Test { + void test(List names) { + for (String name : names) { + System.out.println(name); + } + } + } + """ ) ); } @@ -67,28 +67,28 @@ void removeVariable() { //language=java java( """ - import java.util.List; + import java.util.List; - class Test { - void test(List names) { - for (int i = 0; i < names.size(); i++) { - String name = names.get(i); - System.out.println(name); + class Test { + void test(List names) { + for (int i = 0; i < names.size(); i++) { + String name = names.get(i); + System.out.println(name); + } } } - } - """, - """ - import java.util.List; - - class Test { - void test(List names) { - for (String name : names) { - System.out.println(name); - } - } - } + """, """ + import java.util.List; + + class Test { + void test(List names) { + for (String name : names) { + System.out.println(name); + } + } + } + """ ) ); } @@ -99,23 +99,23 @@ void transformArrayIteration() { //language=java java( """ - class Test { - void test(String[] names) { - for (int i = 0; i < names.length; i++) { - System.out.println(names[i]); - } - } - } - """, - """ - class Test { - void test(String[] names) { - for (String name : names) { - System.out.println(name); - } - } - } + class Test { + void test(String[] names) { + for (int i = 0; i < names.length; i++) { + System.out.println(names[i]); + } + } + } + """, """ + class Test { + void test(String[] names) { + for (String name : names) { + System.out.println(name); + } + } + } + """ ) ); } @@ -126,27 +126,27 @@ void preIncrementLoop() { //language=java java( """ - import java.util.List; - - class Test { - void test(List numbers) { - for (int i = 0; i < numbers.size(); ++i) { - System.out.println(numbers.get(i)); - } - } - } - """, - """ - import java.util.List; - - class Test { - void test(List numbers) { - for (String number : numbers) { - System.out.println(number); - } - } - } + import java.util.List; + + class Test { + void test(List numbers) { + for (int i = 0; i < numbers.size(); ++i) { + System.out.println(numbers.get(i)); + } + } + } + """, """ + import java.util.List; + + class Test { + void test(List numbers) { + for (Integer number : numbers) { + System.out.println(number); + } + } + } + """ ) ); } @@ -157,27 +157,27 @@ void notClashVariableNames() { //language=java java( """ - import java.util.List; - - class Test { - void test(List numbers, java.util.Date number) { - for (int i = 0; i < numbers.size(); ++i) { - System.out.println(numbers.get(i)); - } - } - } - """, - """ - import java.util.List; - - class Test { - void test(List numbers, java.util.Date number) { - for (String number1 : numbers) { - System.out.println(number1); - } - } - } + import java.util.List; + + class Test { + void test(List numbers, java.util.Date number) { + for (int i = 0; i < numbers.size(); ++i) { + System.out.println(numbers.get(i)); + } + } + } + """, """ + import java.util.List; + + class Test { + void test(List numbers, java.util.Date number) { + for (Integer number1 : numbers) { + System.out.println(number1); + } + } + } + """ ) ); } @@ -338,14 +338,14 @@ void noChangeWhenArrayAccessOnLeftSideOfAssignment() { //language=java java( """ - class Test { - void test(String[] names) { - for (int i = 0; i < names.length; i++) { - names[i] = "modified"; - } - } - } - """ + class Test { + void test(String[] names) { + for (int i = 0; i < names.length; i++) { + names[i] = "modified"; + } + } + } + """ ) ); } @@ -356,48 +356,48 @@ void iteratingOverCustomType() { //language=java java( """ - import java.util.List; - - class Test { - static class Person { - String name; - int age; - - Person(String name, int age) { - this.name = name; - this.age = age; - } - } - - void test(List people) { - for (int i = 0; i < people.size(); i++) { - Person person = people.get(i); - System.out.println(person.name + " is " + person.age + " years old"); - } - } - } - """, - """ - import java.util.List; - - class Test { - static class Person { - String name; - int age; - - Person(String name, int age) { - this.name = name; - this.age = age; - } - } - - void test(List people) { - for (Person person : people) { - System.out.println(person.name + " is " + person.age + " years old"); - } - } - } + import java.util.List; + + class Test { + static class Person { + String name; + int age; + + Person(String name, int age) { + this.name = name; + this.age = age; + } + } + + void test(List people) { + for (int i = 0; i < people.size(); i++) { + Person person = people.get(i); + System.out.println(person.name + " is " + person.age + " years old"); + } + } + } + """, """ + import java.util.List; + + class Test { + static class Person { + String name; + int age; + + Person(String name, int age) { + this.name = name; + this.age = age; + } + } + + void test(List people) { + for (Person person : people) { + System.out.println(person.name + " is " + person.age + " years old"); + } + } + } + """ ) ); }