Skip to content

Commit 783f6ac

Browse files
committed
MMA vs WMMA distinction
1 parent 7b6238e commit 783f6ac

29 files changed

+1102
-347
lines changed

crates/cuda_std/src/warp/matrix/intrinsics.rs

Lines changed: 6 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -244,92 +244,12 @@ mod m16n16k16 {
244244
}
245245

246246
// ============= 16x8x16 shape =============
247+
// NOTE: m16n8k16 only supports MMA operations, not WMMA load/store
247248
mod m16n8k16 {
248-
// Load operations
249-
pub(crate) mod load {
250-
extern "C" {
251-
// f16 loads
252-
#[link_name = "llvm.nvvm.wmma.m16n8k16.load.a.sync.row.stride.f16"]
253-
pub(crate) fn wmma_load_a_f16_row_m16n8k16(ptr: *const u8, stride: i32) -> [i16; 8];
254-
#[link_name = "llvm.nvvm.wmma.m16n8k16.load.a.sync.col.stride.f16"]
255-
pub(crate) fn wmma_load_a_f16_col_m16n8k16(ptr: *const u8, stride: i32) -> [i16; 8];
256-
#[link_name = "llvm.nvvm.wmma.m16n8k16.load.b.sync.row.stride.f16"]
257-
pub(crate) fn wmma_load_b_f16_row_m16n8k16(ptr: *const u8, stride: i32) -> [i16; 8];
258-
#[link_name = "llvm.nvvm.wmma.m16n8k16.load.b.sync.col.stride.f16"]
259-
pub(crate) fn wmma_load_b_f16_col_m16n8k16(ptr: *const u8, stride: i32) -> [i16; 8];
260-
261-
// bf16 loads
262-
#[link_name = "llvm.nvvm.wmma.m16n8k16.load.a.sync.row.stride.bf16"]
263-
pub(crate) fn wmma_load_a_bf16_row_m16n8k16(ptr: *const u8, stride: i32) -> [i16; 8];
264-
#[link_name = "llvm.nvvm.wmma.m16n8k16.load.a.sync.col.stride.bf16"]
265-
pub(crate) fn wmma_load_a_bf16_col_m16n8k16(ptr: *const u8, stride: i32) -> [i16; 8];
266-
#[link_name = "llvm.nvvm.wmma.m16n8k16.load.b.sync.row.stride.bf16"]
267-
pub(crate) fn wmma_load_b_bf16_row_m16n8k16(ptr: *const u8, stride: i32) -> [i16; 8];
268-
#[link_name = "llvm.nvvm.wmma.m16n8k16.load.b.sync.col.stride.bf16"]
269-
pub(crate) fn wmma_load_b_bf16_col_m16n8k16(ptr: *const u8, stride: i32) -> [i16; 8];
270-
271-
// i8/u8 loads
272-
#[link_name = "llvm.nvvm.wmma.m16n8k16.load.a.sync.row.stride.s8"]
273-
pub(crate) fn wmma_load_a_s8_row_m16n8k16(ptr: *const u8, stride: i32) -> [i32; 2];
274-
#[link_name = "llvm.nvvm.wmma.m16n8k16.load.a.sync.col.stride.s8"]
275-
pub(crate) fn wmma_load_a_s8_col_m16n8k16(ptr: *const u8, stride: i32) -> [i32; 2];
276-
#[link_name = "llvm.nvvm.wmma.m16n8k16.load.a.sync.row.stride.u8"]
277-
pub(crate) fn wmma_load_a_u8_row_m16n8k16(ptr: *const u8, stride: i32) -> [i32; 2];
278-
#[link_name = "llvm.nvvm.wmma.m16n8k16.load.a.sync.col.stride.u8"]
279-
pub(crate) fn wmma_load_a_u8_col_m16n8k16(ptr: *const u8, stride: i32) -> [i32; 2];
280-
#[link_name = "llvm.nvvm.wmma.m16n8k16.load.b.sync.row.stride.s8"]
281-
pub(crate) fn wmma_load_b_s8_row_m16n8k16(ptr: *const u8, stride: i32) -> [i32; 2];
282-
#[link_name = "llvm.nvvm.wmma.m16n8k16.load.b.sync.col.stride.s8"]
283-
pub(crate) fn wmma_load_b_s8_col_m16n8k16(ptr: *const u8, stride: i32) -> [i32; 2];
284-
#[link_name = "llvm.nvvm.wmma.m16n8k16.load.b.sync.row.stride.u8"]
285-
pub(crate) fn wmma_load_b_u8_row_m16n8k16(ptr: *const u8, stride: i32) -> [i32; 2];
286-
#[link_name = "llvm.nvvm.wmma.m16n8k16.load.b.sync.col.stride.u8"]
287-
pub(crate) fn wmma_load_b_u8_col_m16n8k16(ptr: *const u8, stride: i32) -> [i32; 2];
288-
289-
// Accumulator loads
290-
#[link_name = "llvm.nvvm.wmma.m16n8k16.load.c.sync.row.stride.f32"]
291-
pub(crate) fn wmma_load_c_f32_row_m16n8k16(ptr: *const u8, stride: i32) -> [f32; 4];
292-
#[link_name = "llvm.nvvm.wmma.m16n8k16.load.c.sync.col.stride.f32"]
293-
pub(crate) fn wmma_load_c_f32_col_m16n8k16(ptr: *const u8, stride: i32) -> [f32; 4];
294-
#[link_name = "llvm.nvvm.wmma.m16n8k16.load.c.sync.row.stride.s32"]
295-
pub(crate) fn wmma_load_c_s32_row_m16n8k16(ptr: *const u8, stride: i32) -> [i32; 4];
296-
#[link_name = "llvm.nvvm.wmma.m16n8k16.load.c.sync.col.stride.s32"]
297-
pub(crate) fn wmma_load_c_s32_col_m16n8k16(ptr: *const u8, stride: i32) -> [i32; 4];
298-
}
299-
}
300-
301-
// Store operations
302-
pub(crate) mod store {
303-
extern "C" {
304-
// f32 stores
305-
#[link_name = "llvm.nvvm.wmma.m16n8k16.store.d.sync.row.stride.f32"]
306-
pub(crate) fn wmma_store_d_f32_row_m16n8k16(
307-
ptr: *mut u8,
308-
d0: f32, d1: f32, d2: f32, d3: f32,
309-
stride: i32,
310-
);
311-
#[link_name = "llvm.nvvm.wmma.m16n8k16.store.d.sync.col.stride.f32"]
312-
pub(crate) fn wmma_store_d_f32_col_m16n8k16(
313-
ptr: *mut u8,
314-
d0: f32, d1: f32, d2: f32, d3: f32,
315-
stride: i32,
316-
);
317-
318-
// i32 stores
319-
#[link_name = "llvm.nvvm.wmma.m16n8k16.store.d.sync.row.stride.s32"]
320-
pub(crate) fn wmma_store_d_s32_row_m16n8k16(
321-
ptr: *mut u8,
322-
d0: i32, d1: i32, d2: i32, d3: i32,
323-
stride: i32,
324-
);
325-
#[link_name = "llvm.nvvm.wmma.m16n8k16.store.d.sync.col.stride.s32"]
326-
pub(crate) fn wmma_store_d_s32_col_m16n8k16(
327-
ptr: *mut u8,
328-
d0: i32, d1: i32, d2: i32, d3: i32,
329-
stride: i32,
330-
);
331-
}
332-
}
249+
// No WMMA load operations - this shape only supports MMA
250+
// The LLVM spec only defines MMA intrinsics for m16n8k16
251+
252+
// No WMMA store operations - this shape only supports MMA
333253

334254
// MMA operations
335255
pub(crate) mod mma {
@@ -784,8 +704,7 @@ pub(crate) use m16n16k16::load::*;
784704
pub(crate) use m16n16k16::store::*;
785705
pub(crate) use m16n16k16::mma::*;
786706

787-
pub(crate) use m16n8k16::load::*;
788-
pub(crate) use m16n8k16::store::*;
707+
// m16n8k16 has no load/store operations (MMA-only)
789708
pub(crate) use m16n8k16::mma::*;
790709

791710
pub(crate) use m32n8k16::load::*;

crates/cuda_std/src/warp/matrix/mod.rs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,24 @@ pub trait TensorCoreShape: sealed::Sealed {
4545
const K: usize;
4646
}
4747

48+
/// Shapes that support MMA compute operations
49+
#[diagnostic::on_unimplemented(
50+
message = "`{Self}` does not support MMA compute operations",
51+
label = "MMA compute not available for this shape",
52+
note = "This shape is not a valid tensor core shape for MMA operations"
53+
)]
54+
pub trait MmaShape: TensorCoreShape {}
55+
56+
/// Shapes that support full WMMA operations (load, store, in addition to compute)
57+
/// This trait extends MmaShape since WMMA shapes can do everything MMA shapes can do
58+
#[diagnostic::on_unimplemented(
59+
message = "`{Self}` does not support WMMA load/store operations",
60+
label = "WMMA load/store not available for this shape",
61+
note = "This shape only supports MMA compute operations, not WMMA load/store",
62+
note = "See LLVM source: https://github.com/llvm/llvm-project/blob/main/llvm/include/llvm/IR/IntrinsicsNVVM.td#L419-L1067"
63+
)]
64+
pub trait WmmaShape: MmaShape {}
65+
4866
// Only these exact combinations are valid for tensor cores
4967
impl TensorCoreShape for dims::Shape<16, 16, 16> {
5068
const M: usize = 16;
@@ -94,6 +112,28 @@ impl TensorCoreShape for dims::Shape<8, 8, 4> {
94112
const K: usize = 4;
95113
}
96114

115+
// All tensor core shapes support MMA compute operations
116+
impl MmaShape for dims::Shape<16, 16, 16> {}
117+
impl MmaShape for dims::Shape<32, 8, 16> {}
118+
impl MmaShape for dims::Shape<8, 32, 16> {}
119+
impl MmaShape for dims::Shape<16, 8, 16> {} // MMA-only (no WMMA load/store)
120+
impl MmaShape for dims::Shape<16, 16, 8> {}
121+
impl MmaShape for dims::Shape<8, 8, 32> {}
122+
impl MmaShape for dims::Shape<8, 8, 128> {}
123+
impl MmaShape for dims::Shape<8, 8, 4> {}
124+
125+
// Shapes that ALSO support WMMA load/store (in addition to MMA compute)
126+
impl WmmaShape for dims::Shape<16, 16, 16> {}
127+
impl WmmaShape for dims::Shape<32, 8, 16> {}
128+
impl WmmaShape for dims::Shape<8, 32, 16> {}
129+
// Note: 16x8x16 does NOT implement WmmaShape - it's MMA-only
130+
131+
// Other shapes support WMMA for their respective data types
132+
impl WmmaShape for dims::Shape<16, 16, 8> {} // TF32
133+
impl WmmaShape for dims::Shape<8, 8, 32> {} // i8/u8
134+
impl WmmaShape for dims::Shape<8, 8, 128> {} // i4/u4
135+
impl WmmaShape for dims::Shape<8, 8, 4> {} // f64
136+
97137
// ============================================================================
98138
// Layout Types
99139
// ============================================================================
@@ -306,6 +346,7 @@ where
306346
#[gpu_only]
307347
pub unsafe fn load<const STRIDE: usize>(&mut self, ptr: *const T)
308348
where
349+
Shape: WmmaShape, // Require WmmaShape for load operations
309350
StrideValidator<T, STRIDE>: ValidStride,
310351
T: ops::LoadMatrixA<Shape, L>,
311352
{
@@ -343,6 +384,7 @@ where
343384
#[gpu_only]
344385
pub unsafe fn load<const STRIDE: usize>(&mut self, ptr: *const T)
345386
where
387+
Shape: WmmaShape, // Require WmmaShape for load operations
346388
StrideValidator<T, STRIDE>: ValidStride,
347389
T: ops::LoadMatrixB<Shape, L>,
348390
{
@@ -379,6 +421,7 @@ where
379421
#[gpu_only]
380422
pub unsafe fn load<L, const STRIDE: usize>(&mut self, ptr: *const T)
381423
where
424+
Shape: WmmaShape, // Require WmmaShape for load operations
382425
L: Layout,
383426
StrideValidator<T, STRIDE>: ValidStride,
384427
T: ops::LoadMatrixC<Shape, L>,
@@ -398,6 +441,7 @@ where
398441
#[gpu_only]
399442
pub unsafe fn store<L, const STRIDE: usize>(&self, ptr: *mut T)
400443
where
444+
Shape: WmmaShape, // Require WmmaShape for store operations
401445
L: Layout,
402446
StrideValidator<T, STRIDE>: ValidStride,
403447
T: ops::StoreMatrixD<Shape, L>,

crates/cuda_std/src/warp/matrix/multi.rs

Lines changed: 1 addition & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ impl<T, Shape, L, const N: usize, const STRIDE: usize> BatchLoad<STRIDE>
247247
for FragmentArray<MatrixA<T, Shape, L>, N>
248248
where
249249
T: MatrixElement,
250-
Shape: TensorCoreShape,
250+
Shape: TensorCoreShape + crate::warp::matrix::WmmaShape, // Require WmmaShape for load
251251
L: Layout,
252252
crate::warp::matrix::StrideValidator<T, STRIDE>: crate::warp::matrix::ValidStride,
253253
T: crate::warp::matrix::ops::LoadMatrixA<Shape, L>,
@@ -360,63 +360,3 @@ pub type FragmentRow<T, const N: usize> = FragmentArray<T, N>;
360360

361361
/// Single column of fragments
362362
pub type FragmentCol<T, const N: usize> = FragmentArray<T, N>;
363-
364-
// ============================================================================
365-
// Example Usage
366-
// ============================================================================
367-
368-
#[cfg(test)]
369-
mod examples {
370-
use super::*;
371-
use crate::warp::matrix::dims;
372-
use half::f16;
373-
374-
/// Example: Flash Attention with idiomatic Rust patterns
375-
fn flash_attention_example() {
376-
// Use const generics for compile-time guarantees
377-
const TILE_M: usize = 2;
378-
const TILE_N: usize = 8;
379-
380-
type Shape = dims::Shape<16, 8, 16>;
381-
382-
// Create tiled fragments with type inference
383-
let mut q_tiles: FragmentGrid<MatrixA<f16, Shape, Row>, TILE_M, TILE_N> =
384-
Default::default();
385-
let mut k_tiles: FragmentGrid<MatrixB<f16, Shape, Row>, TILE_M, TILE_N> =
386-
Default::default();
387-
let mut acc_tiles: FragmentGrid<Accumulator<f32, Shape>, TILE_M, TILE_N> =
388-
Default::default();
389-
390-
// Load with iterator pattern
391-
for (i, tile) in q_tiles.iter_mut().enumerate() {
392-
// tile.load(ptr.offset(i * stride), stride);
393-
}
394-
395-
// Use indexing for specific tiles
396-
let tile_0_0 = &q_tiles[(0, 0)];
397-
398-
// Map operations are zero-cost
399-
let scaled_tiles = acc_tiles.map(|tile| {
400-
// Scale each tile
401-
tile
402-
});
403-
}
404-
405-
/// Example: Using the builder pattern
406-
fn builder_example() {
407-
type Shape = dims::Shape<16, 16, 16>;
408-
409-
// Type inference makes this clean
410-
let a_tiles = FragmentBuilder::<f16, Shape>::array::<4>();
411-
let b_tiles = FragmentBuilder::<f16, Shape>::grid::<2, 2>();
412-
}
413-
414-
/// Example: Pattern matching for optimization
415-
fn pattern_matching_example() {
416-
// Compiler can optimize based on const values
417-
const PATTERN: TilePattern<2, 4, 2> = TilePattern;
418-
419-
type Shape = dims::Shape<16, 16, 16>;
420-
let (a, b, c) = PATTERN.create_fragments::<f16, f32, Shape>();
421-
}
422-
}

0 commit comments

Comments
 (0)