11// Test CUDA warp matrix functions (tensor core) compile correctly
22// build-pass
33
4- use cuda_std:: kernel;
5- use cuda_std:: warp:: matrix:: { TensorCore , dims, layout, MatrixElement } ;
64use cuda_std:: half:: f16;
5+ use cuda_std:: kernel;
6+ use cuda_std:: warp:: matrix:: { dims, layout, MatrixElement , TensorCore } ;
77
88#[ kernel]
99pub unsafe fn test_warp_matrix_type_safe ( ) {
1010 // Create a tensor core operation with compile-time validated dimensions
1111 // Type-driven API: specify element type directly
1212 let tc_f16 = TensorCore :: < f16 , dims:: Shape < 16 , 16 , 16 > > :: new ( ) ;
13-
13+
1414 // Create matrix fragments with type-safe API
1515 let mut a_fragment = tc_f16. matrix_a :: < layout:: Row > ( ) ;
1616 let mut b_fragment = tc_f16. matrix_b :: < layout:: Col > ( ) ;
17- let mut c_fragment = tc_f16. accumulator ( ) ; // Returns f32 accumulator by default
18-
17+ let mut c_fragment = tc_f16. accumulator ( ) ; // Returns f32 accumulator by default
18+
1919 // Initialize accumulator
2020 c_fragment. fill ( 0.0 ) ;
21-
21+
2222 // Load operations would use compile-time stride validation
2323 // a_fragment.load::<64>(&a_matrix); // STRIDE validated at compile time
2424 // b_fragment.load::<64>(&b_matrix);
25-
25+
2626 // Perform matrix multiply-accumulate
2727 // c_fragment.mma(&a_fragment, &b_fragment);
28-
28+
2929 // Store result
3030 // c_fragment.store::<layout::Row, 64>(&mut c_matrix);
3131}
@@ -35,11 +35,11 @@ pub unsafe fn test_different_tensor_shapes() {
3535 // 16x16x16 - Most common configuration
3636 let tc_16x16 = TensorCore :: < f16 , dims:: Shape < 16 , 16 , 16 > > :: new ( ) ;
3737 let _a = tc_16x16. matrix_a :: < layout:: Row > ( ) ;
38-
38+
3939 // 32x8x16 - Tall and skinny
4040 let tc_32x8 = TensorCore :: < f16 , dims:: Shape < 32 , 8 , 16 > > :: new ( ) ;
4141 let _a = tc_32x8. matrix_a :: < layout:: Row > ( ) ;
42-
42+
4343 // 8x32x16 - Short and wide
4444 let tc_8x32 = TensorCore :: < f16 , dims:: Shape < 8 , 32 , 16 > > :: new ( ) ;
4545 let _a = tc_8x32. matrix_a :: < layout:: Row > ( ) ;
@@ -52,38 +52,38 @@ pub unsafe fn test_different_element_types() {
5252 let tc = TensorCore :: < f16 , dims:: Shape < 16 , 16 , 16 > > :: new ( ) ;
5353 let _a = tc. matrix_a :: < layout:: Row > ( ) ;
5454 let _b = tc. matrix_b :: < layout:: Col > ( ) ;
55- let _c = tc. accumulator ( ) ; // f32 by default for f16
55+ let _c = tc. accumulator ( ) ; // f32 by default for f16
5656 }
57-
57+
5858 // f16 input, f16 accumulator
5959 {
6060 let tc = TensorCore :: < f16 , dims:: Shape < 16 , 16 , 16 > > :: new ( ) ;
6161 let _a = tc. matrix_a :: < layout:: Row > ( ) ;
6262 let _b = tc. matrix_b :: < layout:: Col > ( ) ;
63- let _c = tc. accumulator_f16 ( ) ; // explicitly use f16 accumulator
63+ let _c = tc. accumulator_f16 ( ) ; // explicitly use f16 accumulator
6464 }
65-
65+
6666 // i8 input, i32 accumulator
6767 {
6868 let tc = TensorCore :: < i8 , dims:: Shape < 16 , 16 , 16 > > :: new ( ) ;
6969 let _a = tc. matrix_a :: < layout:: Row > ( ) ;
7070 let _b = tc. matrix_b :: < layout:: Col > ( ) ;
71- let _c = tc. accumulator ( ) ; // i32 for i8
71+ let _c = tc. accumulator ( ) ; // i32 for i8
7272 }
7373}
7474
7575#[ kernel]
7676pub unsafe fn test_layout_combinations ( ) {
7777 let tc = TensorCore :: < f16 , dims:: Shape < 16 , 16 , 16 > > :: new ( ) ;
78-
78+
7979 // All valid layout combinations
8080 let _a_row = tc. matrix_a :: < layout:: Row > ( ) ;
8181 let _a_col = tc. matrix_a :: < layout:: Col > ( ) ;
8282 let _b_row = tc. matrix_b :: < layout:: Row > ( ) ;
8383 let _b_col = tc. matrix_b :: < layout:: Col > ( ) ;
84-
84+
8585 // Accumulator can have different layouts for storage
8686 let acc = tc. accumulator ( ) ;
8787 // acc.store::<layout::Row, 64>(&mut output);
8888 // acc.store::<layout::Col, 64>(&mut output);
89- }
89+ }
0 commit comments