From d479ecc36b42099bba9238ceefb495af7176ff70 Mon Sep 17 00:00:00 2001 From: Arnab Nandy Date: Fri, 3 Jul 2026 23:29:43 +0530 Subject: [PATCH] feat(migrate): support return statements in IfElseIfConstructToSwitch recipe Signed-off-by: Arnab Nandy --- .../lang/IfElseIfConstructToSwitch.java | 46 ++- .../lang/IfElseIfConstructToSwitchTest.java | 377 ++++++++++++++---- 2 files changed, 315 insertions(+), 108 deletions(-) diff --git a/src/main/java/org/openrewrite/java/migrate/lang/IfElseIfConstructToSwitch.java b/src/main/java/org/openrewrite/java/migrate/lang/IfElseIfConstructToSwitch.java index 70d10748e1..39c6e8e466 100644 --- a/src/main/java/org/openrewrite/java/migrate/lang/IfElseIfConstructToSwitch.java +++ b/src/main/java/org/openrewrite/java/migrate/lang/IfElseIfConstructToSwitch.java @@ -37,7 +37,6 @@ import org.openrewrite.staticanalysis.kotlin.KotlinFileChecker; import java.util.*; -import java.util.concurrent.atomic.AtomicBoolean; import static java.util.Collections.emptyList; import static java.util.Collections.singletonList; @@ -167,11 +166,6 @@ private boolean validatePotentialCandidate() { if (patternMatchers.keySet().stream().anyMatch(instanceOf -> instanceOf.getPattern() == null)) { return false; } - // The blocks cannot do a return as that would lead to all blocks having to do a return, - // the block/expression difference in return for switch statements / expressions being different... - if (returns(nullCheckedStatement) || patternMatchers.values().stream().anyMatch(this::returns) || returns(else_)) { - return false; - } // Do no harm -> If we do not know how to replace(yet), do not replace if (patternMatchers.keySet().stream().anyMatch(instanceOf -> { J clazz = instanceOf.getClazz(); @@ -188,15 +182,6 @@ private boolean validatePotentialCandidate() { (hasLastElseBlock ? 1 : 0); } - private boolean returns(@Nullable Statement statement) { - return statement != null && new JavaIsoVisitor() { - @Override - public J.Return visitReturn(J.Return return_, AtomicBoolean atomicBoolean) { - atomicBoolean.set(true); - return return_; - } - }.reduce(statement, new AtomicBoolean(false)).get(); - } public J.@Nullable Switch buildSwitchTemplate() { Optional switchOn = switchOn(); @@ -208,22 +193,39 @@ public J.Return visitReturn(J.Return return_, AtomicBoolean atomicBoolean) { StringBuilder switchBody = new StringBuilder("switch (#{any()}) {\n"); int i = 1; if (nullCheckedParameter != null) { - switchBody.append("case null -> #{any()};\n"); - arguments[i++] = getStatement(Objects.requireNonNull(nullCheckedStatement)); + Statement nullStmt = getStatement(Objects.requireNonNull(nullCheckedStatement)); + if (nullStmt instanceof J.Return) { + switchBody.append("case null -> {#{any()};}\n"); + } else { + switchBody.append("case null -> #{any()};\n"); + } + arguments[i++] = nullStmt; } for (Map.Entry entry : patternMatchers.entrySet()) { J.InstanceOf instanceOf = entry.getKey(); - switchBody.append("case #{}#{} -> #{any()};\n"); + Statement patternStmt = getStatement(entry.getValue()); + if (patternStmt instanceof J.Return) { + switchBody.append("case #{}#{} -> {#{any()};}\n"); + } else { + switchBody.append("case #{}#{} -> #{any()};\n"); + } arguments[i++] = getClassName(instanceOf); arguments[i++] = getPattern(instanceOf); - arguments[i++] = getStatement(entry.getValue()); + arguments[i++] = patternStmt; } - switchBody.append(nullCheckedParameter != null ? "default" : "case null, default").append(" -> #{any()};\n"); + Statement defaultStmt; if (else_ != null) { - arguments[i] = getStatement(else_); + defaultStmt = getStatement(else_); + } else { + defaultStmt = createEmptyBlock(); + } + String defaultPrefix = nullCheckedParameter != null ? "default" : "case null, default"; + if (defaultStmt instanceof J.Return) { + switchBody.append(defaultPrefix).append(" -> {#{any()};}\n"); } else { - arguments[i] = createEmptyBlock(); + switchBody.append(defaultPrefix).append(" -> #{any()};\n"); } + arguments[i] = defaultStmt; switchBody.append("}\n"); J.Switch result = JavaTemplate.apply(switchBody.toString(), cursor, if_.getCoordinates().replace(), arguments).withPrefix(if_.getPrefix()); diff --git a/src/test/java/org/openrewrite/java/migrate/lang/IfElseIfConstructToSwitchTest.java b/src/test/java/org/openrewrite/java/migrate/lang/IfElseIfConstructToSwitchTest.java index b1bae30106..f0e092a5ca 100644 --- a/src/test/java/org/openrewrite/java/migrate/lang/IfElseIfConstructToSwitchTest.java +++ b/src/test/java/org/openrewrite/java/migrate/lang/IfElseIfConstructToSwitchTest.java @@ -846,63 +846,9 @@ public Object getObj() { ); } - @Test - void noSwitchBlockWhenNullBlockReturns() { - rewriteRun( - //language=java - java( - """ - class Test { - static String formatter(Object obj) { - String formatted = "initialValue"; - if (obj == null) { - return "null"; - } else if (obj instanceof Integer i) - formatted = String.format("int %d", i); - else if (obj instanceof Long l) { - formatted = String.format("long %d", l); - } else if (obj instanceof Double d) { - formatted = String.format("double %f", d); - } else if (obj instanceof String s) { - String str = "String"; - formatted = String.format("%s %s", str, s); - } - return formatted; - } - } - """ - ) - ); - } - @Test - void noSwitchBlockWhenInstanceOfBlockReturns() { - rewriteRun( - //language=java - java( - """ - class Test { - static String formatter(Object obj) { - String formatted = "initialValue"; - if (obj == null) { - formatted = "null"; - } else if (obj instanceof Integer i) - return String.format("int %d", i); - else if (obj instanceof Long l) { - formatted = String.format("long %d", l); - } else if (obj instanceof Double d) { - formatted = String.format("double %f", d); - } else if (obj instanceof String s) { - String str = "String"; - formatted = String.format("%s %s", str, s); - } - return formatted; - } - } - """ - ) - ); - } + + @Test void noSwitchBlockWithTrailingNonEqualsBinaryCheck() { @@ -930,36 +876,7 @@ void test(Object o) { ); } - @Test - void noSwitchBlockWhenElseBlockReturns() { - rewriteRun( - //language=java - java( - """ - class Test { - static String formatter(Object obj) { - String formatted = "initialValue"; - if (obj == null) { - formatted = "null"; - } else if (obj instanceof Integer i) - formatted = String.format("int %d", i); - else if (obj instanceof Long l) { - formatted = String.format("long %d", l); - } else if (obj instanceof Double d) { - formatted = String.format("double %f", d); - } else if (obj instanceof String s) { - String str = "String"; - formatted = String.format("%s %s", str, s); - } else { - return "Unknown test result"; - } - return formatted; - } - } - """ - ) - ); - } + } @Test @@ -1142,4 +1059,292 @@ static String formatter(Object obj) { ); } } + + @Test + void switchBlockWhenNullBlockReturns() { + rewriteRun( + //language=java + java( + """ + class Test { + static String formatter(Object obj) { + String formatted = "initialValue"; + if (obj == null) { + return "null"; + } else if (obj instanceof Integer i) + formatted = String.format("int %d", i); + else if (obj instanceof Long l) { + formatted = String.format("long %d", l); + } else if (obj instanceof Double d) { + formatted = String.format("double %f", d); + } else if (obj instanceof String s) { + String str = "String"; + formatted = String.format("%s %s", str, s); + } + return formatted; + } + } + """, + """ + class Test { + static String formatter(Object obj) { + String formatted = "initialValue"; + switch (obj) { + case null -> { + return "null"; + } + case Integer i -> formatted = String.format("int %d", i); + case Long l -> formatted = String.format("long %d", l); + case Double d -> formatted = String.format("double %f", d); + case String s -> { + String str = "String"; + formatted = String.format("%s %s", str, s); + } + default -> {} + } + return formatted; + } + } + """ + ) + ); + } + + @Test + void switchBlockWhenInstanceOfBlockReturns() { + rewriteRun( + //language=java + java( + """ + class Test { + static String formatter(Object obj) { + String formatted = "initialValue"; + if (obj == null) { + formatted = "null"; + } else if (obj instanceof Integer i) + return String.format("int %d", i); + else if (obj instanceof Long l) { + formatted = String.format("long %d", l); + } else if (obj instanceof Double d) { + formatted = String.format("double %f", d); + } else if (obj instanceof String s) { + String str = "String"; + formatted = String.format("%s %s", str, s); + } + return formatted; + } + } + """, + """ + class Test { + static String formatter(Object obj) { + String formatted = "initialValue"; + switch (obj) { + case null -> formatted = "null"; + case Integer i -> { + return String.format("int %d", i); + } + case Long l -> formatted = String.format("long %d", l); + case Double d -> formatted = String.format("double %f", d); + case String s -> { + String str = "String"; + formatted = String.format("%s %s", str, s); + } + default -> {} + } + return formatted; + } + } + """ + ) + ); + } + + @Test + void switchBlockWhenElseBlockReturns() { + rewriteRun( + //language=java + java( + """ + class Test { + static String formatter(Object obj) { + String formatted = "initialValue"; + if (obj == null) { + formatted = "null"; + } else if (obj instanceof Integer i) + formatted = String.format("int %d", i); + else if (obj instanceof Long l) { + formatted = String.format("long %d", l); + } else if (obj instanceof Double d) { + formatted = String.format("double %f", d); + } else if (obj instanceof String s) { + String str = "String"; + formatted = String.format("%s %s", str, s); + } else { + return "Unknown test result"; + } + return formatted; + } + } + """, + """ + class Test { + static String formatter(Object obj) { + String formatted = "initialValue"; + switch (obj) { + case null -> formatted = "null"; + case Integer i -> formatted = String.format("int %d", i); + case Long l -> formatted = String.format("long %d", l); + case Double d -> formatted = String.format("double %f", d); + case String s -> { + String str = "String"; + formatted = String.format("%s %s", str, s); + } + default -> { + return "Unknown test result"; + } + } + return formatted; + } + } + """ + ) + ); + } + + @Test + void switchBlockWithBracelessReturns() { + rewriteRun( + //language=java + java( + """ + class Test { + static String formatter(Object obj) { + if (obj == null) + return "null"; + else if (obj instanceof Integer i) + return String.format("int %d", i); + else if (obj instanceof Long l) + return String.format("long %d", l); + else + return "unknown"; + } + } + """, + """ + class Test { + static String formatter(Object obj) { + switch (obj) { + case null -> { + return "null"; + } + case Integer i -> { + return String.format("int %d", i); + } + case Long l -> { + return String.format("long %d", l); + } + default -> { + return "unknown"; + } + } + } + } + """ + ) + ); + } + + @Test + void switchBlockWithMixedReturnBranches() { + rewriteRun( + //language=java + java( + """ + class Test { + static String formatter(Object obj) { + String formatted = "initialValue"; + if (obj == null) + return "null"; + else if (obj instanceof Integer i) + formatted = String.format("int %d", i); + else if (obj instanceof Long l) + return String.format("long %d", l); + else + formatted = "unknown"; + return formatted; + } + } + """, + """ + class Test { + static String formatter(Object obj) { + String formatted = "initialValue"; + switch (obj) { + case null -> { + return "null"; + } + case Integer i -> formatted = String.format("int %d", i); + case Long l -> { + return String.format("long %d", l); + } + default -> formatted = "unknown"; + } + return formatted; + } + } + """ + ) + ); + } + + @Test + void switchBlockWithNestedReturnsInLoopOrLambda() { + rewriteRun( + //language=java + java( + """ + import java.util.List; + class Test { + static void process(Object obj, List list) { + if (obj == null) { + list.forEach(s -> { + if (s.isEmpty()) return; + System.out.println(s); + }); + } else if (obj instanceof Integer i) { + for (int j = 0; j < i; j++) { + if (j == 5) return; + System.out.println(j); + } + } else if (obj instanceof Long l) { + System.out.println(l); + } + } + } + """, + """ + import java.util.List; + class Test { + static void process(Object obj, List list) { + switch (obj) { + case null -> list.forEach(s -> { + if (s.isEmpty()) return; + System.out.println(s); + }); + case Integer i -> { + for (int j = 0; j < i; j++) { + if (j == 5) return; + System.out.println(j); + } + } + case Long l -> System.out.println(l); + default -> {} + } + } + } + """ + ) + ); + } }