diff --git a/src/main/java/com/williamfiset/algorithms/dp/EditDistanceIterative.java b/src/main/java/com/williamfiset/algorithms/dp/EditDistanceIterative.java index 25dfeea42..6d09fd1f9 100644 --- a/src/main/java/com/williamfiset/algorithms/dp/EditDistanceIterative.java +++ b/src/main/java/com/williamfiset/algorithms/dp/EditDistanceIterative.java @@ -1,74 +1,88 @@ +package com.williamfiset.algorithms.dp; + /** - * An implementation of the edit distance algorithm + * Edit Distance (Levenshtein Distance) — Iterative Bottom-Up DP + * + * Computes the minimum cost to transform string `a` into string `b` using + * three operations, each with a configurable cost: + * + * - Insert a character into `a` (cost: insertionCost) + * - Delete a character from `a` (cost: deletionCost) + * - Substitute a character in `a` (cost: substitutionCost, 0 if chars match) + * + * The DP table dp[i][j] represents the cost of converting the first i + * characters of `a` into the first j characters of `b`. Each cell is + * computed from three neighbors: diagonal (substitute/match), above (delete), + * and left (insert). + * + * See also: EditDistanceRecursive for a top-down memoized approach. * - *

Time Complexity: O(nm) + * Tested against: https://leetcode.com/problems/edit-distance + * + * Time: O(n*m) where n = a.length(), m = b.length() + * Space: O(n*m) * * @author Micah Stairs */ -package com.williamfiset.algorithms.dp; - public class EditDistanceIterative { - // Computes the cost to convert a string 'a' into a string 'b' using dynamic - // programming given the insertionCost, deletionCost and substitutionCost, O(nm) + /** + * Computes the minimum cost to convert string `a` into string `b`. + * + * @param a the source string + * @param b the target string + * @param insertionCost cost of inserting one character + * @param deletionCost cost of deleting one character + * @param substitutionCost cost of substituting one character (0 cost if chars already match) + * @return the minimum edit distance + * + * Time: O(n*m) + * Space: O(n*m) + */ public static int editDistance( String a, String b, int insertionCost, int deletionCost, int substitutionCost) { - - final int AL = a.length(), BL = b.length(); - int[][] dp = new int[AL + 1][BL + 1]; - - for (int i = 0; i <= AL; i++) { - for (int j = (i == 0 ? 1 : 0); j <= BL; j++) { - - int min = Integer.MAX_VALUE; - - // Substitution - if (i > 0 && j > 0) - min = dp[i - 1][j - 1] + (a.charAt(i - 1) == b.charAt(j - 1) ? 0 : substitutionCost); - - // Deletion - if (i > 0) min = Math.min(min, dp[i - 1][j] + deletionCost); - - // Insertion - if (j > 0) min = Math.min(min, dp[i][j - 1] + insertionCost); - - dp[i][j] = min; + if (a == null || b == null) throw new IllegalArgumentException("Input strings must not be null"); + + final int n = a.length(), m = b.length(); + int[][] dp = new int[n + 1][m + 1]; + + // Base cases: transforming a prefix of `a` into empty string (deletions only) + for (int i = 1; i <= n; i++) + dp[i][0] = i * deletionCost; + + // Base cases: transforming empty string into a prefix of `b` (insertions only) + for (int j = 1; j <= m; j++) + dp[0][j] = j * insertionCost; + + // Fill the DP table + for (int i = 1; i <= n; i++) { + for (int j = 1; j <= m; j++) { + // If characters match, no substitution cost; otherwise pay substitutionCost + int substitute = dp[i - 1][j - 1] + + (a.charAt(i - 1) == b.charAt(j - 1) ? 0 : substitutionCost); + int delete = dp[i - 1][j] + deletionCost; + int insert = dp[i][j - 1] + insertionCost; + dp[i][j] = Math.min(substitute, Math.min(delete, insert)); } } - return dp[AL][BL]; + return dp[n][m]; } public static void main(String[] args) { + // Identical strings — cost is 0 + System.out.println(editDistance("abcdefg", "abcdefg", 10, 10, 10)); // 0 - String a = "abcdefg"; - String b = "abcdefg"; - - // The strings are the same so the cost is zero - System.out.println(EditDistanceIterative.editDistance(a, b, 10, 10, 10)); - - a = "aaa"; - b = "aaabbb"; - - // 10*3 = 30 because of three insertions - System.out.println(EditDistanceIterative.editDistance(a, b, 10, 2, 3)); - - a = "1023"; - b = "10101010"; - - // Outputs 2*2 + 4*5 = 24 for 2 substitutions and 4 insertions - System.out.println(EditDistanceIterative.editDistance(a, b, 5, 7, 2)); - - a = "923456789"; - b = "12345"; + // 3 insertions at cost 10 each = 30 + System.out.println(editDistance("aaa", "aaabbb", 10, 2, 3)); // 30 - // Outputs 4*4 + 1 = 16 because we need to delete 4 - // characters and perform one substitution - System.out.println(EditDistanceIterative.editDistance(a, b, 2, 4, 1)); + // 2 substitutions (cost 2) + 4 insertions (cost 5) = 24 + System.out.println(editDistance("1023", "10101010", 5, 7, 2)); // 24 - a = "aaaaa"; - b = "aabaa"; + // 1 substitution (cost 1) + 4 deletions (cost 4) = 17 + System.out.println(editDistance("923456789", "12345", 2, 4, 1)); // 17 - System.out.println(EditDistanceIterative.editDistance(a, b, 2, 3, 10)); + // Insert 'b' then delete 'a' is cheaper than substituting 'a'->'b' + System.out.println(editDistance("aaaaa", "aabaa", 2, 3, 10)); // 5 } } diff --git a/src/main/java/com/williamfiset/algorithms/dp/EditDistanceRecursive.java b/src/main/java/com/williamfiset/algorithms/dp/EditDistanceRecursive.java index bb79e33b5..554f19d52 100644 --- a/src/main/java/com/williamfiset/algorithms/dp/EditDistanceRecursive.java +++ b/src/main/java/com/williamfiset/algorithms/dp/EditDistanceRecursive.java @@ -1,21 +1,47 @@ +package com.williamfiset.algorithms.dp; + /** - * A solution to the edit distance problem + * Edit Distance (Levenshtein Distance) — Top-Down Recursive with Memoization + * + * Computes the minimum cost to transform string `a` into string `b` using + * three operations, each with a configurable cost: + * + * - Insert a character into `a` (cost: insertionCost) + * - Delete a character from `a` (cost: deletionCost) + * - Substitute a character in `a` (cost: substitutionCost, 0 if chars match) + * + * The recursive function f(i, j) returns the cost of converting a[i..] into + * b[j..]. At each step it considers three choices — substitute/match, delete, + * insert — and memoizes results in a 2D table. + * + * Compared to EditDistanceIterative, the recursive approach only visits + * reachable states, which can be faster when many states are unreachable. + * + * Tested against: https://leetcode.com/problems/edit-distance * - *

Tested against: https://leetcode.com/problems/edit-distance + * Time: O(n*m) where n = a.length(), m = b.length() + * Space: O(n*m) * * @author William Fiset, william.alexandre.fiset@gmail.com */ -package com.williamfiset.algorithms.dp; - public class EditDistanceRecursive { - final char[] a, b; - final int insertionCost, deletionCost, substitutionCost; + private final char[] a, b; + private final int insertionCost, deletionCost, substitutionCost; + /** + * Creates an edit distance solver for the given strings and operation costs. + * + * @param a the source string + * @param b the target string + * @param insertionCost cost of inserting one character + * @param deletionCost cost of deleting one character + * @param substitutionCost cost of substituting one character (0 cost if chars already match) + */ public EditDistanceRecursive( String a, String b, int insertionCost, int deletionCost, int substitutionCost) { if (a == null || b == null) { - throw new IllegalArgumentException("Input string must not be null"); + throw new IllegalArgumentException("Input strings must not be null"); } this.a = a.toCharArray(); this.b = b.toCharArray(); @@ -24,70 +50,63 @@ public EditDistanceRecursive( this.substitutionCost = substitutionCost; } - private static int min(int... values) { - int m = Integer.MAX_VALUE; - for (int v : values) { - if (v < m) { - m = v; - } - } - return m; - } - - // Returns the Levenshtein distance to transform string `a` into string `b`. + /** + * Computes and returns the minimum edit distance from `a` to `b`. + * + * Time: O(n*m) + * Space: O(n*m) + */ public int editDistance() { Integer[][] dp = new Integer[a.length + 1][b.length + 1]; return f(dp, 0, 0); } + /** + * Recursive helper: returns the min cost to convert a[i..] into b[j..]. + */ private int f(Integer[][] dp, int i, int j) { - if (i == a.length && j == b.length) { - return 0; - } - if (i == a.length) { - return (b.length - j) * insertionCost; - } - if (j == b.length) { - return (a.length - i) * deletionCost; - } - if (dp[i][j] != null) { - return dp[i][j]; - } - int substituteOrSkip = f(dp, i + 1, j + 1) + (a[i] == b[j] ? 0 : substitutionCost); + // Both strings fully consumed — nothing left to do + if (i == a.length && j == b.length) return 0; + + // Remaining characters in `b` must be inserted + if (i == a.length) return (b.length - j) * insertionCost; + + // Remaining characters in `a` must be deleted + if (j == b.length) return (a.length - i) * deletionCost; + + if (dp[i][j] != null) return dp[i][j]; + + // Match (free) or substitute, then advance both pointers + int substitute = f(dp, i + 1, j + 1) + (a[i] == b[j] ? 0 : substitutionCost); + + // Delete a[i], advance i only int delete = f(dp, i + 1, j) + deletionCost; + + // Insert b[j] into a, advance j only int insert = f(dp, i, j + 1) + insertionCost; - return dp[i][j] = min(substituteOrSkip, delete, insert); + + return dp[i][j] = Math.min(substitute, Math.min(delete, insert)); } public static void main(String[] args) { - String a = "923456789"; - String b = "12345"; - EditDistanceRecursive solver = new EditDistanceRecursive(a, b, 100, 4, 2); - System.out.println(solver.editDistance()); - - a = "12345"; - b = "923456789"; - solver = new EditDistanceRecursive(a, b, 100, 4, 2); - System.out.println(solver.editDistance()); - - a = "aaa"; - b = "aaabbb"; - solver = new EditDistanceRecursive(a, b, 10, 2, 3); - System.out.println(solver.editDistance()); - - a = "1023"; - b = "10101010"; - solver = new EditDistanceRecursive(a, b, 5, 7, 2); - System.out.println(solver.editDistance()); - - a = "923456789"; - b = "12345"; - EditDistanceRecursive solver2 = new EditDistanceRecursive(a, b, 100, 4, 2); - System.out.println(solver2.editDistance()); - - a = "aaaaa"; - b = "aabaa"; - solver = new EditDistanceRecursive(a, b, 2, 3, 10); - System.out.println(solver.editDistance()); + // 1 substitution (cost 2) + 4 deletions (cost 4) = 18 + System.out.println( + new EditDistanceRecursive("923456789", "12345", 100, 4, 2).editDistance()); // 18 + + // Reverse direction: 1 substitution (cost 2) + 4 insertions (cost 100) = 402 + System.out.println( + new EditDistanceRecursive("12345", "923456789", 100, 4, 2).editDistance()); // 402 + + // 3 insertions at cost 10 each = 30 + System.out.println( + new EditDistanceRecursive("aaa", "aaabbb", 10, 2, 3).editDistance()); // 30 + + // 2 substitutions (cost 2) + 4 insertions (cost 5) = 24 + System.out.println( + new EditDistanceRecursive("1023", "10101010", 5, 7, 2).editDistance()); // 24 + + // Insert 'b' then delete 'a' is cheaper than substituting 'a'->'b' + System.out.println( + new EditDistanceRecursive("aaaaa", "aabaa", 2, 3, 10).editDistance()); // 5 } } diff --git a/src/test/java/com/williamfiset/algorithms/dp/BUILD b/src/test/java/com/williamfiset/algorithms/dp/BUILD index 91a1c716e..e19e69782 100644 --- a/src/test/java/com/williamfiset/algorithms/dp/BUILD +++ b/src/test/java/com/williamfiset/algorithms/dp/BUILD @@ -50,5 +50,16 @@ java_test( deps = TEST_DEPS, ) +# bazel test //src/test/java/com/williamfiset/algorithms/dp:EditDistanceIterativeTest +java_test( + name = "EditDistanceIterativeTest", + srcs = ["EditDistanceIterativeTest.java"], + main_class = "org.junit.platform.console.ConsoleLauncher", + use_testrunner = False, + args = ["--select-class=com.williamfiset.algorithms.dp.EditDistanceIterativeTest"], + runtime_deps = JUNIT5_RUNTIME_DEPS, + deps = TEST_DEPS, +) + # Run all tests # bazel test //src/test/java/com/williamfiset/algorithms/dp:all diff --git a/src/test/java/com/williamfiset/algorithms/dp/EditDistanceIterativeTest.java b/src/test/java/com/williamfiset/algorithms/dp/EditDistanceIterativeTest.java new file mode 100644 index 000000000..78b8b8269 --- /dev/null +++ b/src/test/java/com/williamfiset/algorithms/dp/EditDistanceIterativeTest.java @@ -0,0 +1,110 @@ +package com.williamfiset.algorithms.dp; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import org.junit.jupiter.api.Test; + +public class EditDistanceIterativeTest { + + @Test + public void testNullInputA() { + assertThrows( + IllegalArgumentException.class, + () -> EditDistanceIterative.editDistance(null, "abc", 1, 1, 1)); + } + + @Test + public void testNullInputB() { + assertThrows( + IllegalArgumentException.class, + () -> EditDistanceIterative.editDistance("abc", null, 1, 1, 1)); + } + + @Test + public void testIdenticalStrings() { + assertThat(EditDistanceIterative.editDistance("abcdefg", "abcdefg", 10, 10, 10)).isEqualTo(0); + } + + @Test + public void testBothEmpty() { + assertThat(EditDistanceIterative.editDistance("", "", 1, 1, 1)).isEqualTo(0); + } + + @Test + public void testEmptyToNonEmpty() { + // Converting "" to "abc" requires 3 insertions at cost 5 each + assertThat(EditDistanceIterative.editDistance("", "abc", 5, 1, 1)).isEqualTo(15); + } + + @Test + public void testNonEmptyToEmpty() { + // Converting "abc" to "" requires 3 deletions at cost 4 each + assertThat(EditDistanceIterative.editDistance("abc", "", 1, 4, 1)).isEqualTo(12); + } + + @Test + public void testInsertionsOnly() { + // "aaa" -> "aaabbb" requires 3 insertions at cost 10 + assertThat(EditDistanceIterative.editDistance("aaa", "aaabbb", 10, 2, 3)).isEqualTo(30); + } + + @Test + public void testSubstitutionsAndInsertions() { + // "1023" -> "10101010": 2 substitutions (cost 2) + 4 insertions (cost 5) = 24 + assertThat(EditDistanceIterative.editDistance("1023", "10101010", 5, 7, 2)).isEqualTo(24); + } + + @Test + public void testDeletionsAndSubstitution() { + // "923456789" -> "12345": 1 substitution (cost 1) + 4 deletions (cost 4) = 17 + assertThat(EditDistanceIterative.editDistance("923456789", "12345", 2, 4, 1)).isEqualTo(17); + } + + /** When substitution is expensive, insert+delete can be cheaper. */ + @Test + public void testInsertDeleteCheaperThanSubstitute() { + // "aaaaa" -> "aabaa": substituting costs 10, but insert 'b' (2) + delete 'a' (3) = 5 + assertThat(EditDistanceIterative.editDistance("aaaaa", "aabaa", 2, 3, 10)).isEqualTo(5); + } + + @Test + public void testSingleCharSubstitution() { + assertThat(EditDistanceIterative.editDistance("a", "b", 1, 1, 1)).isEqualTo(1); + } + + @Test + public void testSingleCharInsertion() { + assertThat(EditDistanceIterative.editDistance("a", "ab", 3, 1, 1)).isEqualTo(3); + } + + @Test + public void testSingleCharDeletion() { + assertThat(EditDistanceIterative.editDistance("ab", "a", 1, 7, 1)).isEqualTo(7); + } + + /** Verify iterative and recursive solvers agree on the same inputs. */ + @Test + public void testMatchesRecursiveSolver() { + String[][] pairs = { + {"abcdefg", "abcdefg"}, + {"aaa", "aaabbb"}, + {"1023", "10101010"}, + {"923456789", "12345"}, + {"aaaaa", "aabaa"}, + {"kitten", "sitting"}, + {"", "hello"}, + {"world", ""}, + }; + int[][] costs = {{10, 10, 10}, {10, 2, 3}, {5, 7, 2}, {2, 4, 1}, {2, 3, 10}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}; + + for (int k = 0; k < pairs.length; k++) { + String a = pairs[k][0], b = pairs[k][1]; + int ins = costs[k][0], del = costs[k][1], sub = costs[k][2]; + + int iterative = EditDistanceIterative.editDistance(a, b, ins, del, sub); + int recursive = new EditDistanceRecursive(a, b, ins, del, sub).editDistance(); + assertThat(iterative).isEqualTo(recursive); + } + } +}