Skip to content

Commit b3a82d5

Browse files
committed
Remove commented out / stubebd out logic in matrix compiletests
1 parent 49b7e2c commit b3a82d5

File tree

2 files changed

+22
-16
lines changed

2 files changed

+22
-16
lines changed

tests/compiletests/ui/warp/matrix/basic_matrix.rs

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,20 @@ pub unsafe fn test_warp_matrix_type_safe() {
1919
// Initialize accumulator
2020
c_fragment.fill(0.0);
2121

22+
// Create mock pointers for demonstration
23+
let a_matrix: *const f16 = core::ptr::null();
24+
let b_matrix: *const f16 = core::ptr::null();
25+
let c_matrix: *mut f32 = core::ptr::null_mut();
26+
2227
// Load operations would use compile-time stride validation
23-
// a_fragment.load::<64>(&a_matrix); // STRIDE validated at compile time
24-
// b_fragment.load::<64>(&b_matrix);
28+
a_fragment.load::<16>(a_matrix); // STRIDE validated at compile time
29+
b_fragment.load::<16>(b_matrix);
2530

2631
// Perform matrix multiply-accumulate
27-
// c_fragment.mma(&a_fragment, &b_fragment);
32+
c_fragment.mma_inplace(&a_fragment, &b_fragment);
2833

29-
// Store result
30-
// c_fragment.store::<layout::Row, 64>(&mut c_matrix);
34+
// Store result (f32 needs stride multiple of 4)
35+
c_fragment.store::<layout::Row, 16>(c_matrix);
3136
}
3237

3338
#[kernel]
@@ -84,6 +89,7 @@ pub unsafe fn test_layout_combinations() {
8489

8590
// Accumulator can have different layouts for storage
8691
let acc = tc.accumulator();
87-
// acc.store::<layout::Row, 64>(&mut output);
88-
// acc.store::<layout::Col, 64>(&mut output);
92+
let output: *mut f32 = core::ptr::null_mut();
93+
acc.store::<layout::Row, 16>(output);
94+
acc.store::<layout::Col, 16>(output);
8995
}

tests/compiletests/ui/warp/matrix/matrix_types.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -132,27 +132,27 @@ pub unsafe fn test_layout_combinations() {
132132
let tc = TensorCore::<f16, Shape>::new();
133133

134134
// Row-major matrix A
135-
let _a_row: MatrixA<f16, Shape, layout::Row> = tc.matrix_a();
135+
let a_row: MatrixA<f16, Shape, layout::Row> = tc.matrix_a();
136136

137137
// Column-major matrix A
138-
let _a_col: MatrixA<f16, Shape, layout::Col> = tc.matrix_a();
138+
let a_col: MatrixA<f16, Shape, layout::Col> = tc.matrix_a();
139139

140140
// Row-major matrix B
141-
let _b_row: MatrixB<f16, Shape, layout::Row> = tc.matrix_b();
141+
let b_row: MatrixB<f16, Shape, layout::Row> = tc.matrix_b();
142142

143143
// Column-major matrix B
144-
let _b_col: MatrixB<f16, Shape, layout::Col> = tc.matrix_b();
144+
let b_col: MatrixB<f16, Shape, layout::Col> = tc.matrix_b();
145145

146146
// All combinations are valid for MMA
147-
let acc = tc.accumulator();
147+
let mut acc = tc.accumulator();
148148
// Row-Row combination
149-
// acc.mma(&a_row, &b_row);
149+
let _result1 = acc.mma(&a_row, &b_row);
150150
// Row-Col combination
151-
// acc.mma(&a_row, &b_col);
151+
let _result2 = acc.mma(&a_row, &b_col);
152152
// Col-Row combination
153-
// acc.mma(&a_col, &b_row);
153+
let _result3 = acc.mma(&a_col, &b_row);
154154
// Col-Col combination
155-
// acc.mma(&a_col, &b_col);
155+
let _result4 = acc.mma(&a_col, &b_col);
156156
}
157157

158158
// Helper generic function (not a kernel)

0 commit comments

Comments
 (0)