Skip to content

Commit df78b69

Browse files
committed
Move intrinsics to own file
1 parent 8a5f992 commit df78b69

19 files changed

+1796
-1643
lines changed

crates/cuda_std/src/warp/matrix/intrinsics.rs

Lines changed: 1314 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
//! Macros to reduce repetition in matrix module
2+
3+
/// Macro for implementing MMA operations with common pattern
4+
macro_rules! impl_mma {
5+
// Standard MMA implementation pattern
6+
($elem:ty, $acc:ty, $shape:ty, $la:ty, $lb:ty, $intrinsic:ident,
7+
$a_size:literal, $b_size:literal, $c_size:literal) => {
8+
impl MmaWithShapeAndLayout<$elem, $elem, $acc, $shape, $la, $lb> for $acc {
9+
type Output = $acc;
10+
11+
#[gpu_only]
12+
fn mma(
13+
a: &MatrixA<$elem, $shape, $la>,
14+
b: &MatrixB<$elem, $shape, $lb>,
15+
c: &Accumulator<$acc, $shape>,
16+
) -> Accumulator<$acc, $shape> {
17+
let mut result = Accumulator::new();
18+
19+
// Extract A matrix values
20+
let a_vals = unsafe {
21+
core::mem::transmute::<[<$elem as MatrixElement>::Storage; $a_size], [i16; $a_size]>(
22+
*(&a.data[..$a_size] as *const [<$elem as MatrixElement>::Storage]
23+
as *const [<$elem as MatrixElement>::Storage; $a_size]),
24+
)
25+
};
26+
27+
// Extract B matrix values
28+
let b_vals = unsafe {
29+
core::mem::transmute::<[<$elem as MatrixElement>::Storage; $b_size], [i16; $b_size]>(
30+
*(&b.data[..$b_size] as *const [<$elem as MatrixElement>::Storage]
31+
as *const [<$elem as MatrixElement>::Storage; $b_size]),
32+
)
33+
};
34+
35+
// Extract C accumulator values
36+
let c_vals = unsafe {
37+
impl_mma!(@extract_c $acc, c, $c_size)
38+
};
39+
40+
// Call the intrinsic with unpacked values
41+
let result_vals = unsafe {
42+
impl_mma!(@call_intrinsic $intrinsic, a_vals, b_vals, c_vals,
43+
$a_size, $b_size, $c_size)
44+
};
45+
46+
// Store results
47+
impl_mma!(@store_result result, result_vals, $acc, $c_size);
48+
49+
result
50+
}
51+
}
52+
};
53+
54+
// Helper: Extract C values for f32 accumulator
55+
(@extract_c f32, $c:expr, $size:literal) => {
56+
core::mem::transmute::<[<f32 as AccumulatorElement>::Storage; $size], [f32; $size]>(
57+
*(&$c.data[..$size] as *const [<f32 as AccumulatorElement>::Storage]
58+
as *const [<f32 as AccumulatorElement>::Storage; $size]),
59+
)
60+
};
61+
62+
// Helper: Extract C values for f16 accumulator
63+
(@extract_c f16, $c:expr, $size:literal) => {
64+
core::mem::transmute::<[<f16 as AccumulatorElement>::Storage; $size], [i16; $size]>(
65+
*(&$c.data[..$size] as *const [<f16 as AccumulatorElement>::Storage]
66+
as *const [<f16 as AccumulatorElement>::Storage; $size]),
67+
)
68+
};
69+
70+
// Helper: Extract C values for i32 accumulator
71+
(@extract_c i32, $c:expr, $size:literal) => {
72+
core::mem::transmute::<[<i32 as AccumulatorElement>::Storage; $size], [i32; $size]>(
73+
*(&$c.data[..$size] as *const [<i32 as AccumulatorElement>::Storage]
74+
as *const [<i32 as AccumulatorElement>::Storage; $size]),
75+
)
76+
};
77+
78+
// Helper: Call intrinsic for 16x16x16 (16 A, 16 B, 8 C)
79+
(@call_intrinsic $intrinsic:ident, $a:expr, $b:expr, $c:expr, 16, 16, 8) => {
80+
$intrinsic(
81+
$a[0], $a[1], $a[2], $a[3], $a[4], $a[5], $a[6], $a[7],
82+
$a[8], $a[9], $a[10], $a[11], $a[12], $a[13], $a[14], $a[15],
83+
$b[0], $b[1], $b[2], $b[3], $b[4], $b[5], $b[6], $b[7],
84+
$b[8], $b[9], $b[10], $b[11], $b[12], $b[13], $b[14], $b[15],
85+
$c[0], $c[1], $c[2], $c[3], $c[4], $c[5], $c[6], $c[7],
86+
)
87+
};
88+
89+
// Helper: Call intrinsic for 16x8x16 (8 A, 8 B, 4 C)
90+
(@call_intrinsic $intrinsic:ident, $a:expr, $b:expr, $c:expr, 8, 8, 4) => {
91+
$intrinsic(
92+
$a[0], $a[1], $a[2], $a[3], $a[4], $a[5], $a[6], $a[7],
93+
$b[0], $b[1], $b[2], $b[3], $b[4], $b[5], $b[6], $b[7],
94+
$c[0], $c[1], $c[2], $c[3],
95+
)
96+
};
97+
98+
// Helper: Call intrinsic for 32x8x16 (16 A, 8 B, 8 C)
99+
(@call_intrinsic $intrinsic:ident, $a:expr, $b:expr, $c:expr, 16, 8, 8) => {
100+
$intrinsic(
101+
$a[0], $a[1], $a[2], $a[3], $a[4], $a[5], $a[6], $a[7],
102+
$a[8], $a[9], $a[10], $a[11], $a[12], $a[13], $a[14], $a[15],
103+
$b[0], $b[1], $b[2], $b[3], $b[4], $b[5], $b[6], $b[7],
104+
$c[0], $c[1], $c[2], $c[3], $c[4], $c[5], $c[6], $c[7],
105+
)
106+
};
107+
108+
// Helper: Call intrinsic for 8x32x16 (8 A, 16 B, 8 C)
109+
(@call_intrinsic $intrinsic:ident, $a:expr, $b:expr, $c:expr, 8, 16, 8) => {
110+
$intrinsic(
111+
$a[0], $a[1], $a[2], $a[3], $a[4], $a[5], $a[6], $a[7],
112+
$b[0], $b[1], $b[2], $b[3], $b[4], $b[5], $b[6], $b[7],
113+
$b[8], $b[9], $b[10], $b[11], $b[12], $b[13], $b[14], $b[15],
114+
$c[0], $c[1], $c[2], $c[3], $c[4], $c[5], $c[6], $c[7],
115+
)
116+
};
117+
118+
// Helper: Call intrinsic for 8x8x4 (2 A, 2 B, 2 C) - f64
119+
(@call_intrinsic $intrinsic:ident, $a:expr, $b:expr, $c:expr, 2, 2, 2) => {
120+
$intrinsic(
121+
$a[0], $a[1],
122+
$b[0], $b[1],
123+
$c[0], $c[1],
124+
)
125+
};
126+
127+
// Helper: Store result for f32
128+
(@store_result $result:expr, $vals:expr, f32, $size:literal) => {
129+
$result.data[..$size].copy_from_slice(&$vals);
130+
};
131+
132+
// Helper: Store result for f16
133+
(@store_result $result:expr, $vals:expr, f16, $size:literal) => {
134+
$result.data[..$size].copy_from_slice(&unsafe {
135+
core::mem::transmute::<[i16; $size], [<f16 as AccumulatorElement>::Storage; $size]>($vals)
136+
});
137+
};
138+
139+
// Helper: Store result for i32
140+
(@store_result $result:expr, $vals:expr, i32, $size:literal) => {
141+
$result.data[..$size].copy_from_slice(&$vals);
142+
};
143+
}
144+
145+
/// Macro for implementing all 4 layout combinations at once
146+
macro_rules! impl_mma_all_layouts {
147+
($elem:ty, $acc:ty, $shape:ty, $base_name:ident, $shape_suffix:ident, $a_size:literal, $b_size:literal, $c_size:literal) => {
148+
paste::paste! {
149+
// Row-Row
150+
impl_mma!($elem, $acc, $shape, layout::Row, layout::Row,
151+
[<$base_name _row_row_ $shape_suffix>], $a_size, $b_size, $c_size);
152+
153+
// Row-Col
154+
impl_mma!($elem, $acc, $shape, layout::Row, layout::Col,
155+
[<$base_name _row_col_ $shape_suffix>], $a_size, $b_size, $c_size);
156+
157+
// Col-Row
158+
impl_mma!($elem, $acc, $shape, layout::Col, layout::Row,
159+
[<$base_name _col_row_ $shape_suffix>], $a_size, $b_size, $c_size);
160+
161+
// Col-Col
162+
impl_mma!($elem, $acc, $shape, layout::Col, layout::Col,
163+
[<$base_name _col_col_ $shape_suffix>], $a_size, $b_size, $c_size);
164+
}
165+
};
166+
}

0 commit comments

Comments
 (0)