diff --git a/blas/src/main/java/dev/ludovic/netlib/blas/AbstractBLAS.java b/blas/src/main/java/dev/ludovic/netlib/blas/AbstractBLAS.java index ff96558..93f5cc8 100644 --- a/blas/src/main/java/dev/ludovic/netlib/blas/AbstractBLAS.java +++ b/blas/src/main/java/dev/ludovic/netlib/blas/AbstractBLAS.java @@ -291,9 +291,9 @@ public void dgemm(String transa, String transb, int m, int n, int k, double alph requireNonNull(a); requireNonNull(b); requireNonNull(c); - checkIndex(offseta + (lsame("N", transa) ? k : m) * lda - 1, a.length); - checkIndex(offsetb + (lsame("N", transb) ? n : k) * ldb - 1, b.length); - checkIndex(offsetc + m * n - 1, c.length); + checkIndex(offseta + (lsame("N", transa) ? (k - 1) * lda + (m - 1) : (m - 1) * lda + (k - 1)), a.length); + checkIndex(offsetb + (lsame("N", transb) ? (n - 1) * ldb + (k - 1) : (k - 1) * ldb + (n - 1)), b.length); + checkIndex(offsetc + (n - 1) * ldc + (m - 1), c.length); dgemmK(transa, transb, m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); } @@ -321,9 +321,9 @@ public void sgemm(String transa, String transb, int m, int n, int k, float alpha requireNonNull(a); requireNonNull(b); requireNonNull(c); - checkIndex(offseta + (lsame("N", transa) ? k : m) * lda - 1, a.length); - checkIndex(offsetb + (lsame("N", transb) ? n : k) * ldb - 1, b.length); - checkIndex(offsetc + m * n - 1, c.length); + checkIndex(offseta + (lsame("N", transa) ? (k - 1) * lda + (m - 1) : (m - 1) * lda + (k - 1)), a.length); + checkIndex(offsetb + (lsame("N", transb) ? (n - 1) * ldb + (k - 1) : (k - 1) * ldb + (n - 1)), b.length); + checkIndex(offsetc + (n - 1) * ldc + (m - 1), c.length); sgemmK(transa, transb, m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc); } @@ -349,7 +349,7 @@ public void dgemv(String trans, int m, int n, double alpha, double[] a, int offs requireNonNull(a); requireNonNull(x); requireNonNull(y); - checkIndex(offseta + n * lda - 1, a.length); + checkIndex(offseta + (n - 1) * lda + (m - 1), a.length); checkIndex(offsetx + ((lsame("N", trans) ? n : m) - 1) * Math.abs(incx), x.length); checkIndex(offsety + ((lsame("N", trans) ? m : n) - 1) * Math.abs(incy), y.length); dgemvK(trans, m, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy); @@ -377,7 +377,7 @@ public void sgemv(String trans, int m, int n, float alpha, float[] a, int offset requireNonNull(a); requireNonNull(x); requireNonNull(y); - checkIndex(offseta + n * lda - 1, a.length); + checkIndex(offseta + (n - 1) * lda + (m - 1), a.length); checkIndex(offsetx + ((lsame("N", trans) ? n : m) - 1) * Math.abs(incx), x.length); checkIndex(offsety + ((lsame("N", trans) ? m : n) - 1) * Math.abs(incy), y.length); sgemvK(trans, m, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy); diff --git a/blas/src/test/java/dev/ludovic/netlib/blas/DgemmTest.java b/blas/src/test/java/dev/ludovic/netlib/blas/DgemmTest.java index ae43c30..e6358da 100644 --- a/blas/src/test/java/dev/ludovic/netlib/blas/DgemmTest.java +++ b/blas/src/test/java/dev/ludovic/netlib/blas/DgemmTest.java @@ -185,4 +185,46 @@ void testSanity(BLAS blas) { blas.dgemm("T", "T", M, N/2, K, 0.0, dgeAT, K, dgeBT, N/2, 1.0, dgeCcopy = dgeC.clone(), M); assertArrayEquals(expected, dgeCcopy, depsilon); } + + @ParameterizedTest + @MethodSource("BLASImplementations") + void testOffset1BoundsChecking(BLAS blas) { + // Test case that reproduces the original bounds checking issue + // Matrix A (2x3) with offset=1, stored in array of length 9 + double[] a = {1.0, 4.0, 7.0, 2.0, 5.0, 8.0, 3.0, 6.0, 9.0}; + double[] b = {1.0, 2.0, 3.0}; + double[] c = new double[2]; + double[] cExpected = new double[2]; + + // This should not throw IndexOutOfBoundsException + assertDoesNotThrow(() -> { + blas.dgemm("N", "N", 2, 1, 3, 1.0, a, 1, 3, b, 0, 3, 0.0, c, 0, 2); + }); + + // Compare with F2J result + f2j.dgemm("N", "N", 2, 1, 3, 1.0, a, 1, 3, b, 0, 3, 0.0, cExpected, 0, 2); + blas.dgemm("N", "N", 2, 1, 3, 1.0, a, 1, 3, b, 0, 3, 0.0, c, 0, 2); + assertArrayEquals(cExpected, c, depsilon); + } + + @ParameterizedTest + @MethodSource("BLASImplementations") + void testOffset2BoundsChecking(BLAS blas) { + // Test case that reproduces the original bounds checking issue + // Matrix A (2x3) with offset=1, stored in array of length 10 + double[] a = {1.0, 4.0, 7.0, 2.0, 5.0, 8.0, 3.0, 6.0, 9.0, 10.0}; + double[] b = {1.0, 2.0, 3.0}; + double[] c = new double[2]; + double[] cExpected = new double[2]; + + // This should not throw IndexOutOfBoundsException + assertDoesNotThrow(() -> { + blas.dgemm("N", "N", 2, 1, 3, 1.0, a, 2, 3, b, 0, 3, 0.0, c, 0, 2); + }); + + // Compare with F2J result + f2j.dgemm("N", "N", 2, 1, 3, 1.0, a, 2, 3, b, 0, 3, 0.0, cExpected, 0, 2); + blas.dgemm("N", "N", 2, 1, 3, 1.0, a, 2, 3, b, 0, 3, 0.0, c, 0, 2); + assertArrayEquals(cExpected, c, depsilon); + } } diff --git a/blas/src/test/java/dev/ludovic/netlib/blas/DgemvTest.java b/blas/src/test/java/dev/ludovic/netlib/blas/DgemvTest.java index 3801911..0a6c1d5 100644 --- a/blas/src/test/java/dev/ludovic/netlib/blas/DgemvTest.java +++ b/blas/src/test/java/dev/ludovic/netlib/blas/DgemvTest.java @@ -128,4 +128,46 @@ void testSanity(BLAS blas) { // } assertArrayEquals(expected, dYcopy, depsilon); } + + @ParameterizedTest + @MethodSource("BLASImplementations") + void testOffset1BoundsChecking(BLAS blas) { + // Test case that reproduces the original bounds checking issue for dgemv + // Matrix A (2x3) with offset=1, stored in array of length 9 + double[] a = {1.0, 4.0, 7.0, 2.0, 5.0, 8.0, 3.0, 6.0, 9.0}; + double[] x = {1.0, 2.0, 3.0}; + double[] y = new double[2]; + double[] yExpected = new double[2]; + + // This should not throw IndexOutOfBoundsException + assertDoesNotThrow(() -> { + blas.dgemv("N", 2, 3, 1.0, a, 1, 3, x, 0, 1, 0.0, y, 0, 1); + }); + + // Compare with F2J result + f2j.dgemv("N", 2, 3, 1.0, a, 1, 3, x, 0, 1, 0.0, yExpected, 0, 1); + blas.dgemv("N", 2, 3, 1.0, a, 1, 3, x, 0, 1, 0.0, y, 0, 1); + assertArrayEquals(yExpected, y, depsilon); + } + + @ParameterizedTest + @MethodSource("BLASImplementations") + void testOffset2BoundsChecking(BLAS blas) { + // Test case that reproduces the original bounds checking issue for dgemv + // Matrix A (2x3) with offset=1, stored in array of length 9 + double[] a = {1.0, 4.0, 7.0, 2.0, 5.0, 8.0, 3.0, 6.0, 9.0, 10.0}; + double[] x = {1.0, 2.0, 3.0}; + double[] y = new double[2]; + double[] yExpected = new double[2]; + + // This should not throw IndexOutOfBoundsException + assertDoesNotThrow(() -> { + blas.dgemv("N", 2, 3, 1.0, a, 2, 3, x, 0, 1, 0.0, y, 0, 1); + }); + + // Compare with F2J result + f2j.dgemv("N", 2, 3, 1.0, a, 2, 3, x, 0, 1, 0.0, yExpected, 0, 1); + blas.dgemv("N", 2, 3, 1.0, a, 2, 3, x, 0, 1, 0.0, y, 0, 1); + assertArrayEquals(yExpected, y, depsilon); + } } diff --git a/blas/src/test/java/dev/ludovic/netlib/blas/SgemmTest.java b/blas/src/test/java/dev/ludovic/netlib/blas/SgemmTest.java index a5fee1d..5f98469 100644 --- a/blas/src/test/java/dev/ludovic/netlib/blas/SgemmTest.java +++ b/blas/src/test/java/dev/ludovic/netlib/blas/SgemmTest.java @@ -181,4 +181,46 @@ void testSanity(BLAS blas) { blas.sgemm("T", "T", M, N/2, K, 0.0f, sgeAT, K, sgeBT, N/2, 1.0f, sgeCcopy = sgeC.clone(), M); assertArrayEquals(expected, sgeCcopy, sepsilon); } + + @ParameterizedTest + @MethodSource("BLASImplementations") + void testOffset1BoundsChecking(BLAS blas) { + // Test case that reproduces the original bounds checking issue for sgemm + // Matrix A (2x3) with offset=1, stored in array of length 9 + float[] a = {1.0f, 4.0f, 7.0f, 2.0f, 5.0f, 8.0f, 3.0f, 6.0f, 9.0f}; + float[] b = {1.0f, 2.0f, 3.0f}; + float[] c = new float[2]; + float[] cExpected = new float[2]; + + // This should not throw IndexOutOfBoundsException + assertDoesNotThrow(() -> { + blas.sgemm("N", "N", 2, 1, 3, 1.0f, a, 1, 3, b, 0, 3, 0.0f, c, 0, 2); + }); + + // Compare with F2J result + f2j.sgemm("N", "N", 2, 1, 3, 1.0f, a, 1, 3, b, 0, 3, 0.0f, cExpected, 0, 2); + blas.sgemm("N", "N", 2, 1, 3, 1.0f, a, 1, 3, b, 0, 3, 0.0f, c, 0, 2); + assertArrayEquals(cExpected, c, sepsilon); + } + + @ParameterizedTest + @MethodSource("BLASImplementations") + void testOffset2BoundsChecking(BLAS blas) { + // Test case that reproduces the original bounds checking issue for sgemm + // Matrix A (2x3) with offset=1, stored in array of length 9 + float[] a = {1.0f, 4.0f, 7.0f, 2.0f, 5.0f, 8.0f, 3.0f, 6.0f, 9.0f, 10.0f}; + float[] b = {1.0f, 2.0f, 3.0f}; + float[] c = new float[2]; + float[] cExpected = new float[2]; + + // This should not throw IndexOutOfBoundsException + assertDoesNotThrow(() -> { + blas.sgemm("N", "N", 2, 1, 3, 1.0f, a, 2, 3, b, 0, 3, 0.0f, c, 0, 2); + }); + + // Compare with F2J result + f2j.sgemm("N", "N", 2, 1, 3, 1.0f, a, 2, 3, b, 0, 3, 0.0f, cExpected, 0, 2); + blas.sgemm("N", "N", 2, 1, 3, 1.0f, a, 2, 3, b, 0, 3, 0.0f, c, 0, 2); + assertArrayEquals(cExpected, c, sepsilon); + } } diff --git a/blas/src/test/java/dev/ludovic/netlib/blas/SgemvTest.java b/blas/src/test/java/dev/ludovic/netlib/blas/SgemvTest.java index 824f7c4..fce483f 100644 --- a/blas/src/test/java/dev/ludovic/netlib/blas/SgemvTest.java +++ b/blas/src/test/java/dev/ludovic/netlib/blas/SgemvTest.java @@ -93,4 +93,46 @@ void testSanity(BLAS blas) { blas.sgemv("T", M, N, 1.0f, sgeA, M, sX, 1, 0.0f, sYcopy = sY.clone(), 1); assertArrayEquals(expected, sYcopy, sepsilon); } + + @ParameterizedTest + @MethodSource("BLASImplementations") + void testOffset1BoundsChecking(BLAS blas) { + // Test case that reproduces the original bounds checking issue for sgemv + // Matrix A (2x3) with offset=1, stored in array of length 9 + float[] a = {1.0f, 4.0f, 7.0f, 2.0f, 5.0f, 8.0f, 3.0f, 6.0f, 9.0f}; + float[] x = {1.0f, 2.0f, 3.0f}; + float[] y = new float[2]; + float[] yExpected = new float[2]; + + // This should not throw IndexOutOfBoundsException + assertDoesNotThrow(() -> { + blas.sgemv("N", 2, 3, 1.0f, a, 1, 3, x, 0, 1, 0.0f, y, 0, 1); + }); + + // Compare with F2J result + f2j.sgemv("N", 2, 3, 1.0f, a, 1, 3, x, 0, 1, 0.0f, yExpected, 0, 1); + blas.sgemv("N", 2, 3, 1.0f, a, 1, 3, x, 0, 1, 0.0f, y, 0, 1); + assertArrayEquals(yExpected, y, sepsilon); + } + + @ParameterizedTest + @MethodSource("BLASImplementations") + void testOffset2BoundsChecking(BLAS blas) { + // Test case that reproduces the original bounds checking issue for sgemv + // Matrix A (2x3) with offset=1, stored in array of length 9 + float[] a = {1.0f, 4.0f, 7.0f, 2.0f, 5.0f, 8.0f, 3.0f, 6.0f, 9.0f, 10.0f}; + float[] x = {1.0f, 2.0f, 3.0f}; + float[] y = new float[2]; + float[] yExpected = new float[2]; + + // This should not throw IndexOutOfBoundsException + assertDoesNotThrow(() -> { + blas.sgemv("N", 2, 3, 1.0f, a, 2, 3, x, 0, 1, 0.0f, y, 0, 1); + }); + + // Compare with F2J result + f2j.sgemv("N", 2, 3, 1.0f, a, 2, 3, x, 0, 1, 0.0f, yExpected, 0, 1); + blas.sgemv("N", 2, 3, 1.0f, a, 2, 3, x, 0, 1, 0.0f, y, 0, 1); + assertArrayEquals(yExpected, y, sepsilon); + } } diff --git a/pom.xml b/pom.xml index 56e9a83..d25096d 100644 --- a/pom.xml +++ b/pom.xml @@ -216,7 +216,10 @@ information or have any questions. maven-surefire-plugin - 3.0.0-M5 + 3.5.3 + + plain + org.apache.maven.plugins