Skip to content

Commit 0ef145c

Browse files
CopilotQuafadas
andcommitted
Fix bounds checking in dgemm and sgemm, add offset test cases
Co-authored-by: Quafadas <[email protected]>
1 parent 2871712 commit 0ef145c

File tree

3 files changed

+48
-6
lines changed

3 files changed

+48
-6
lines changed

blas/src/main/java/dev/ludovic/netlib/blas/AbstractBLAS.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -291,9 +291,9 @@ public void dgemm(String transa, String transb, int m, int n, int k, double alph
291291
requireNonNull(a);
292292
requireNonNull(b);
293293
requireNonNull(c);
294-
checkIndex(offseta + (lsame("N", transa) ? k : m) * lda - 1, a.length);
295-
checkIndex(offsetb + (lsame("N", transb) ? n : k) * ldb - 1, b.length);
296-
checkIndex(offsetc + m * n - 1, c.length);
294+
checkIndex(offseta + (lsame("N", transa) ? (k - 1) * lda + (m - 1) : (m - 1) * lda + (k - 1)), a.length);
295+
checkIndex(offsetb + (lsame("N", transb) ? (n - 1) * ldb + (k - 1) : (k - 1) * ldb + (n - 1)), b.length);
296+
checkIndex(offsetc + (n - 1) * ldc + (m - 1), c.length);
297297
dgemmK(transa, transb, m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
298298
}
299299

@@ -321,9 +321,9 @@ public void sgemm(String transa, String transb, int m, int n, int k, float alpha
321321
requireNonNull(a);
322322
requireNonNull(b);
323323
requireNonNull(c);
324-
checkIndex(offseta + (lsame("N", transa) ? k : m) * lda - 1, a.length);
325-
checkIndex(offsetb + (lsame("N", transb) ? n : k) * ldb - 1, b.length);
326-
checkIndex(offsetc + m * n - 1, c.length);
324+
checkIndex(offseta + (lsame("N", transa) ? (k - 1) * lda + (m - 1) : (m - 1) * lda + (k - 1)), a.length);
325+
checkIndex(offsetb + (lsame("N", transb) ? (n - 1) * ldb + (k - 1) : (k - 1) * ldb + (n - 1)), b.length);
326+
checkIndex(offsetc + (n - 1) * ldc + (m - 1), c.length);
327327
sgemmK(transa, transb, m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
328328
}
329329

blas/src/test/java/dev/ludovic/netlib/blas/DgemmTest.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,4 +185,25 @@ void testSanity(BLAS blas) {
185185
blas.dgemm("T", "T", M, N/2, K, 0.0, dgeAT, K, dgeBT, N/2, 1.0, dgeCcopy = dgeC.clone(), M);
186186
assertArrayEquals(expected, dgeCcopy, depsilon);
187187
}
188+
189+
@ParameterizedTest
190+
@MethodSource("BLASImplementations")
191+
void testOffsetBoundsChecking(BLAS blas) {
192+
// Test case that reproduces the original bounds checking issue
193+
// Matrix A (2x3) with offset=1, stored in array of length 9
194+
double[] a = {1.0, 4.0, 7.0, 2.0, 5.0, 8.0, 3.0, 6.0, 9.0};
195+
double[] b = {1.0, 2.0, 3.0};
196+
double[] c = new double[2];
197+
double[] cExpected = new double[2];
198+
199+
// This should not throw IndexOutOfBoundsException
200+
assertDoesNotThrow(() -> {
201+
blas.dgemm("N", "N", 2, 1, 3, 1.0, a, 1, 3, b, 0, 3, 0.0, c, 0, 2);
202+
});
203+
204+
// Compare with F2J result
205+
f2j.dgemm("N", "N", 2, 1, 3, 1.0, a, 1, 3, b, 0, 3, 0.0, cExpected, 0, 2);
206+
blas.dgemm("N", "N", 2, 1, 3, 1.0, a, 1, 3, b, 0, 3, 0.0, c, 0, 2);
207+
assertArrayEquals(cExpected, c, depsilon);
208+
}
188209
}

blas/src/test/java/dev/ludovic/netlib/blas/SgemmTest.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,4 +181,25 @@ void testSanity(BLAS blas) {
181181
blas.sgemm("T", "T", M, N/2, K, 0.0f, sgeAT, K, sgeBT, N/2, 1.0f, sgeCcopy = sgeC.clone(), M);
182182
assertArrayEquals(expected, sgeCcopy, sepsilon);
183183
}
184+
185+
@ParameterizedTest
186+
@MethodSource("BLASImplementations")
187+
void testOffsetBoundsChecking(BLAS blas) {
188+
// Test case that reproduces the original bounds checking issue for sgemm
189+
// Matrix A (2x3) with offset=1, stored in array of length 9
190+
float[] a = {1.0f, 4.0f, 7.0f, 2.0f, 5.0f, 8.0f, 3.0f, 6.0f, 9.0f};
191+
float[] b = {1.0f, 2.0f, 3.0f};
192+
float[] c = new float[2];
193+
float[] cExpected = new float[2];
194+
195+
// This should not throw IndexOutOfBoundsException
196+
assertDoesNotThrow(() -> {
197+
blas.sgemm("N", "N", 2, 1, 3, 1.0f, a, 1, 3, b, 0, 3, 0.0f, c, 0, 2);
198+
});
199+
200+
// Compare with F2J result
201+
f2j.sgemm("N", "N", 2, 1, 3, 1.0f, a, 1, 3, b, 0, 3, 0.0f, cExpected, 0, 2);
202+
blas.sgemm("N", "N", 2, 1, 3, 1.0f, a, 1, 3, b, 0, 3, 0.0f, c, 0, 2);
203+
assertArrayEquals(cExpected, c, sepsilon);
204+
}
184205
}

0 commit comments

Comments
 (0)