1+ //! Macros to reduce repetition in matrix module
2+
3+ /// Macro for implementing MMA operations with common pattern
4+ macro_rules! impl_mma {
5+ // Standard MMA implementation pattern
6+ ( $elem: ty, $acc: ty, $shape: ty, $la: ty, $lb: ty, $intrinsic: ident,
7+ $a_size: literal, $b_size: literal, $c_size: literal) => {
8+ impl MmaWithShapeAndLayout <$elem, $elem, $acc, $shape, $la, $lb> for $acc {
9+ type Output = $acc;
10+
11+ #[ gpu_only]
12+ fn mma(
13+ a: & MatrixA <$elem, $shape, $la>,
14+ b: & MatrixB <$elem, $shape, $lb>,
15+ c: & Accumulator <$acc, $shape>,
16+ ) -> Accumulator <$acc, $shape> {
17+ let mut result = Accumulator :: new( ) ;
18+
19+ // Extract A matrix values
20+ let a_vals = unsafe {
21+ core:: mem:: transmute:: <[ <$elem as MatrixElement >:: Storage ; $a_size] , [ i16 ; $a_size] >(
22+ * ( & a. data[ ..$a_size] as * const [ <$elem as MatrixElement >:: Storage ]
23+ as * const [ <$elem as MatrixElement >:: Storage ; $a_size] ) ,
24+ )
25+ } ;
26+
27+ // Extract B matrix values
28+ let b_vals = unsafe {
29+ core:: mem:: transmute:: <[ <$elem as MatrixElement >:: Storage ; $b_size] , [ i16 ; $b_size] >(
30+ * ( & b. data[ ..$b_size] as * const [ <$elem as MatrixElement >:: Storage ]
31+ as * const [ <$elem as MatrixElement >:: Storage ; $b_size] ) ,
32+ )
33+ } ;
34+
35+ // Extract C accumulator values
36+ let c_vals = unsafe {
37+ impl_mma!( @extract_c $acc, c, $c_size)
38+ } ;
39+
40+ // Call the intrinsic with unpacked values
41+ let result_vals = unsafe {
42+ impl_mma!( @call_intrinsic $intrinsic, a_vals, b_vals, c_vals,
43+ $a_size, $b_size, $c_size)
44+ } ;
45+
46+ // Store results
47+ impl_mma!( @store_result result, result_vals, $acc, $c_size) ;
48+
49+ result
50+ }
51+ }
52+ } ;
53+
54+ // Helper: Extract C values for f32 accumulator
55+ ( @extract_c f32 , $c: expr, $size: literal) => {
56+ core:: mem:: transmute:: <[ <f32 as AccumulatorElement >:: Storage ; $size] , [ f32 ; $size] >(
57+ * ( & $c. data[ ..$size] as * const [ <f32 as AccumulatorElement >:: Storage ]
58+ as * const [ <f32 as AccumulatorElement >:: Storage ; $size] ) ,
59+ )
60+ } ;
61+
62+ // Helper: Extract C values for f16 accumulator
63+ ( @extract_c f16, $c: expr, $size: literal) => {
64+ core:: mem:: transmute:: <[ <f16 as AccumulatorElement >:: Storage ; $size] , [ i16 ; $size] >(
65+ * ( & $c. data[ ..$size] as * const [ <f16 as AccumulatorElement >:: Storage ]
66+ as * const [ <f16 as AccumulatorElement >:: Storage ; $size] ) ,
67+ )
68+ } ;
69+
70+ // Helper: Extract C values for i32 accumulator
71+ ( @extract_c i32 , $c: expr, $size: literal) => {
72+ core:: mem:: transmute:: <[ <i32 as AccumulatorElement >:: Storage ; $size] , [ i32 ; $size] >(
73+ * ( & $c. data[ ..$size] as * const [ <i32 as AccumulatorElement >:: Storage ]
74+ as * const [ <i32 as AccumulatorElement >:: Storage ; $size] ) ,
75+ )
76+ } ;
77+
78+ // Helper: Call intrinsic for 16x16x16 (16 A, 16 B, 8 C)
79+ ( @call_intrinsic $intrinsic: ident, $a: expr, $b: expr, $c: expr, 16 , 16 , 8 ) => {
80+ $intrinsic(
81+ $a[ 0 ] , $a[ 1 ] , $a[ 2 ] , $a[ 3 ] , $a[ 4 ] , $a[ 5 ] , $a[ 6 ] , $a[ 7 ] ,
82+ $a[ 8 ] , $a[ 9 ] , $a[ 10 ] , $a[ 11 ] , $a[ 12 ] , $a[ 13 ] , $a[ 14 ] , $a[ 15 ] ,
83+ $b[ 0 ] , $b[ 1 ] , $b[ 2 ] , $b[ 3 ] , $b[ 4 ] , $b[ 5 ] , $b[ 6 ] , $b[ 7 ] ,
84+ $b[ 8 ] , $b[ 9 ] , $b[ 10 ] , $b[ 11 ] , $b[ 12 ] , $b[ 13 ] , $b[ 14 ] , $b[ 15 ] ,
85+ $c[ 0 ] , $c[ 1 ] , $c[ 2 ] , $c[ 3 ] , $c[ 4 ] , $c[ 5 ] , $c[ 6 ] , $c[ 7 ] ,
86+ )
87+ } ;
88+
89+ // Helper: Call intrinsic for 16x8x16 (8 A, 8 B, 4 C)
90+ ( @call_intrinsic $intrinsic: ident, $a: expr, $b: expr, $c: expr, 8 , 8 , 4 ) => {
91+ $intrinsic(
92+ $a[ 0 ] , $a[ 1 ] , $a[ 2 ] , $a[ 3 ] , $a[ 4 ] , $a[ 5 ] , $a[ 6 ] , $a[ 7 ] ,
93+ $b[ 0 ] , $b[ 1 ] , $b[ 2 ] , $b[ 3 ] , $b[ 4 ] , $b[ 5 ] , $b[ 6 ] , $b[ 7 ] ,
94+ $c[ 0 ] , $c[ 1 ] , $c[ 2 ] , $c[ 3 ] ,
95+ )
96+ } ;
97+
98+ // Helper: Call intrinsic for 32x8x16 (16 A, 8 B, 8 C)
99+ ( @call_intrinsic $intrinsic: ident, $a: expr, $b: expr, $c: expr, 16 , 8 , 8 ) => {
100+ $intrinsic(
101+ $a[ 0 ] , $a[ 1 ] , $a[ 2 ] , $a[ 3 ] , $a[ 4 ] , $a[ 5 ] , $a[ 6 ] , $a[ 7 ] ,
102+ $a[ 8 ] , $a[ 9 ] , $a[ 10 ] , $a[ 11 ] , $a[ 12 ] , $a[ 13 ] , $a[ 14 ] , $a[ 15 ] ,
103+ $b[ 0 ] , $b[ 1 ] , $b[ 2 ] , $b[ 3 ] , $b[ 4 ] , $b[ 5 ] , $b[ 6 ] , $b[ 7 ] ,
104+ $c[ 0 ] , $c[ 1 ] , $c[ 2 ] , $c[ 3 ] , $c[ 4 ] , $c[ 5 ] , $c[ 6 ] , $c[ 7 ] ,
105+ )
106+ } ;
107+
108+ // Helper: Call intrinsic for 8x32x16 (8 A, 16 B, 8 C)
109+ ( @call_intrinsic $intrinsic: ident, $a: expr, $b: expr, $c: expr, 8 , 16 , 8 ) => {
110+ $intrinsic(
111+ $a[ 0 ] , $a[ 1 ] , $a[ 2 ] , $a[ 3 ] , $a[ 4 ] , $a[ 5 ] , $a[ 6 ] , $a[ 7 ] ,
112+ $b[ 0 ] , $b[ 1 ] , $b[ 2 ] , $b[ 3 ] , $b[ 4 ] , $b[ 5 ] , $b[ 6 ] , $b[ 7 ] ,
113+ $b[ 8 ] , $b[ 9 ] , $b[ 10 ] , $b[ 11 ] , $b[ 12 ] , $b[ 13 ] , $b[ 14 ] , $b[ 15 ] ,
114+ $c[ 0 ] , $c[ 1 ] , $c[ 2 ] , $c[ 3 ] , $c[ 4 ] , $c[ 5 ] , $c[ 6 ] , $c[ 7 ] ,
115+ )
116+ } ;
117+
118+ // Helper: Call intrinsic for 8x8x4 (2 A, 2 B, 2 C) - f64
119+ ( @call_intrinsic $intrinsic: ident, $a: expr, $b: expr, $c: expr, 2 , 2 , 2 ) => {
120+ $intrinsic(
121+ $a[ 0 ] , $a[ 1 ] ,
122+ $b[ 0 ] , $b[ 1 ] ,
123+ $c[ 0 ] , $c[ 1 ] ,
124+ )
125+ } ;
126+
127+ // Helper: Store result for f32
128+ ( @store_result $result: expr, $vals: expr, f32 , $size: literal) => {
129+ $result. data[ ..$size] . copy_from_slice( & $vals) ;
130+ } ;
131+
132+ // Helper: Store result for f16
133+ ( @store_result $result: expr, $vals: expr, f16, $size: literal) => {
134+ $result. data[ ..$size] . copy_from_slice( & unsafe {
135+ core:: mem:: transmute:: <[ i16 ; $size] , [ <f16 as AccumulatorElement >:: Storage ; $size] >( $vals)
136+ } ) ;
137+ } ;
138+
139+ // Helper: Store result for i32
140+ ( @store_result $result: expr, $vals: expr, i32 , $size: literal) => {
141+ $result. data[ ..$size] . copy_from_slice( & $vals) ;
142+ } ;
143+ }
144+
145+ /// Macro for implementing all 4 layout combinations at once
146+ macro_rules! impl_mma_all_layouts {
147+ ( $elem: ty, $acc: ty, $shape: ty, $base_name: ident, $shape_suffix: ident, $a_size: literal, $b_size: literal, $c_size: literal) => {
148+ paste:: paste! {
149+ // Row-Row
150+ impl_mma!( $elem, $acc, $shape, layout:: Row , layout:: Row ,
151+ [ <$base_name _row_row_ $shape_suffix>] , $a_size, $b_size, $c_size) ;
152+
153+ // Row-Col
154+ impl_mma!( $elem, $acc, $shape, layout:: Row , layout:: Col ,
155+ [ <$base_name _row_col_ $shape_suffix>] , $a_size, $b_size, $c_size) ;
156+
157+ // Col-Row
158+ impl_mma!( $elem, $acc, $shape, layout:: Col , layout:: Row ,
159+ [ <$base_name _col_row_ $shape_suffix>] , $a_size, $b_size, $c_size) ;
160+
161+ // Col-Col
162+ impl_mma!( $elem, $acc, $shape, layout:: Col , layout:: Col ,
163+ [ <$base_name _col_col_ $shape_suffix>] , $a_size, $b_size, $c_size) ;
164+ }
165+ } ;
166+ }
0 commit comments