Skip to content

Commit 2f5341a

Browse files
CopilotQuafadas
andcommitted
Fix bounds checking in dgemv and sgemv, add corresponding tests
Co-authored-by: Quafadas <[email protected]>
1 parent 0ef145c commit 2f5341a

File tree

3 files changed

+44
-2
lines changed

3 files changed

+44
-2
lines changed

blas/src/main/java/dev/ludovic/netlib/blas/AbstractBLAS.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ public void dgemv(String trans, int m, int n, double alpha, double[] a, int offs
349349
requireNonNull(a);
350350
requireNonNull(x);
351351
requireNonNull(y);
352-
checkIndex(offseta + n * lda - 1, a.length);
352+
checkIndex(offseta + (n - 1) * lda + (m - 1), a.length);
353353
checkIndex(offsetx + ((lsame("N", trans) ? n : m) - 1) * Math.abs(incx), x.length);
354354
checkIndex(offsety + ((lsame("N", trans) ? m : n) - 1) * Math.abs(incy), y.length);
355355
dgemvK(trans, m, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
@@ -377,7 +377,7 @@ public void sgemv(String trans, int m, int n, float alpha, float[] a, int offset
377377
requireNonNull(a);
378378
requireNonNull(x);
379379
requireNonNull(y);
380-
checkIndex(offseta + n * lda - 1, a.length);
380+
checkIndex(offseta + (n - 1) * lda + (m - 1), a.length);
381381
checkIndex(offsetx + ((lsame("N", trans) ? n : m) - 1) * Math.abs(incx), x.length);
382382
checkIndex(offsety + ((lsame("N", trans) ? m : n) - 1) * Math.abs(incy), y.length);
383383
sgemvK(trans, m, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);

blas/src/test/java/dev/ludovic/netlib/blas/DgemvTest.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,4 +128,25 @@ void testSanity(BLAS blas) {
128128
// }
129129
assertArrayEquals(expected, dYcopy, depsilon);
130130
}
131+
132+
@ParameterizedTest
133+
@MethodSource("BLASImplementations")
134+
void testOffsetBoundsChecking(BLAS blas) {
135+
// Test case that reproduces the original bounds checking issue for dgemv
136+
// Matrix A (2x3) with offset=1, stored in array of length 9
137+
double[] a = {1.0, 4.0, 7.0, 2.0, 5.0, 8.0, 3.0, 6.0, 9.0};
138+
double[] x = {1.0, 2.0, 3.0};
139+
double[] y = new double[2];
140+
double[] yExpected = new double[2];
141+
142+
// This should not throw IndexOutOfBoundsException
143+
assertDoesNotThrow(() -> {
144+
blas.dgemv("N", 2, 3, 1.0, a, 1, 3, x, 0, 1, 0.0, y, 0, 1);
145+
});
146+
147+
// Compare with F2J result
148+
f2j.dgemv("N", 2, 3, 1.0, a, 1, 3, x, 0, 1, 0.0, yExpected, 0, 1);
149+
blas.dgemv("N", 2, 3, 1.0, a, 1, 3, x, 0, 1, 0.0, y, 0, 1);
150+
assertArrayEquals(yExpected, y, depsilon);
151+
}
131152
}

blas/src/test/java/dev/ludovic/netlib/blas/SgemvTest.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,4 +93,25 @@ void testSanity(BLAS blas) {
9393
blas.sgemv("T", M, N, 1.0f, sgeA, M, sX, 1, 0.0f, sYcopy = sY.clone(), 1);
9494
assertArrayEquals(expected, sYcopy, sepsilon);
9595
}
96+
97+
@ParameterizedTest
98+
@MethodSource("BLASImplementations")
99+
void testOffsetBoundsChecking(BLAS blas) {
100+
// Test case that reproduces the original bounds checking issue for sgemv
101+
// Matrix A (2x3) with offset=1, stored in array of length 9
102+
float[] a = {1.0f, 4.0f, 7.0f, 2.0f, 5.0f, 8.0f, 3.0f, 6.0f, 9.0f};
103+
float[] x = {1.0f, 2.0f, 3.0f};
104+
float[] y = new float[2];
105+
float[] yExpected = new float[2];
106+
107+
// This should not throw IndexOutOfBoundsException
108+
assertDoesNotThrow(() -> {
109+
blas.sgemv("N", 2, 3, 1.0f, a, 1, 3, x, 0, 1, 0.0f, y, 0, 1);
110+
});
111+
112+
// Compare with F2J result
113+
f2j.sgemv("N", 2, 3, 1.0f, a, 1, 3, x, 0, 1, 0.0f, yExpected, 0, 1);
114+
blas.sgemv("N", 2, 3, 1.0f, a, 1, 3, x, 0, 1, 0.0f, y, 0, 1);
115+
assertArrayEquals(yExpected, y, sepsilon);
116+
}
96117
}

0 commit comments

Comments
 (0)