Skip to content

Commit 658f34f

Browse files
committed
Checkpoint after trying to do loading
1 parent 783f6ac commit 658f34f

File tree

62 files changed

+2714
-398
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+2714
-398
lines changed

crates/cuda_std/src/warp/matrix/intrinsics.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,3 +723,46 @@ pub(crate) use m16n16k8::convert::*;
723723
pub(crate) use m16n16k8::load::*;
724724
pub(crate) use m16n16k8::store::*;
725725
pub(crate) use m16n16k8::mma::*;
726+
727+
// ============= ldmatrix intrinsics =============
728+
// These load matrix fragments from shared memory for MMA operations
729+
pub(crate) mod ldmatrix {
730+
#[allow(dead_code)]
731+
extern "C" {
732+
// 8x8 matrix with 16-bit elements (bf16/f16)
733+
#[link_name = "llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16"]
734+
pub(crate) fn ldmatrix_m8n8_x1_b16(ptr: *const u8) -> i32;
735+
736+
#[link_name = "llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.b16"]
737+
pub(crate) fn ldmatrix_m8n8_x2_b16(ptr: *const u8) -> [i32; 2];
738+
739+
#[link_name = "llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.b16"]
740+
pub(crate) fn ldmatrix_m8n8_x4_b16(ptr: *const u8) -> [i32; 4];
741+
742+
// With transpose
743+
#[link_name = "llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16"]
744+
pub(crate) fn ldmatrix_m8n8_x1_trans_b16(ptr: *const u8) -> i32;
745+
746+
#[link_name = "llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.trans.b16"]
747+
pub(crate) fn ldmatrix_m8n8_x2_trans_b16(ptr: *const u8) -> [i32; 2];
748+
749+
#[link_name = "llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.trans.b16"]
750+
pub(crate) fn ldmatrix_m8n8_x4_trans_b16(ptr: *const u8) -> [i32; 4];
751+
752+
// 16x16 matrix with 8-bit elements
753+
#[link_name = "llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.b8"]
754+
pub(crate) fn ldmatrix_m16n16_x1_b8(ptr: *const u8) -> [i32; 2];
755+
756+
#[link_name = "llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.b8"]
757+
pub(crate) fn ldmatrix_m16n16_x2_b8(ptr: *const u8) -> [i32; 4];
758+
759+
// 16x16 with transpose (mandatory for 16x16)
760+
#[link_name = "llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8"]
761+
pub(crate) fn ldmatrix_m16n16_x1_trans_b8(ptr: *const u8) -> [i32; 2];
762+
763+
#[link_name = "llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8"]
764+
pub(crate) fn ldmatrix_m16n16_x2_trans_b8(ptr: *const u8) -> [i32; 4];
765+
}
766+
}
767+
768+
pub(crate) use ldmatrix::*;
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# LdMatrix Integration Design
2+
3+
## Overview
4+
`ldmatrix` provides a way to load matrix data from shared memory into registers formatted for MMA operations. This enables memory loading for MMA-only shapes like Shape<16, 8, 16>.
5+
6+
## Trait Hierarchy
7+
8+
```rust
9+
// Base trait - all shapes support MMA compute
10+
pub trait MmaShape: TensorCoreShape {
11+
// MMA compute operations
12+
}
13+
14+
// Shapes that support loading from shared memory
15+
pub trait LdMatrixShape: MmaShape {
16+
const LDMATRIX_SHAPE: LdMatrixShapeType;
17+
const LDMATRIX_NUM: LdMatrixNum;
18+
}
19+
20+
// Shapes that support WMMA load/store from global memory
21+
pub trait WmmaShape: MmaShape {
22+
// WMMA global memory operations
23+
}
24+
25+
// Some shapes implement both!
26+
impl LdMatrixShape for Shape<16, 16, 16> { ... }
27+
impl WmmaShape for Shape<16, 16, 16> { ... }
28+
29+
// Shape<16, 8, 16> only supports ldmatrix (not WMMA)
30+
impl LdMatrixShape for Shape<16, 8, 16> {
31+
const LDMATRIX_SHAPE: LdMatrixShapeType = LdMatrixShapeType::M8N8;
32+
const LDMATRIX_NUM: LdMatrixNum = LdMatrixNum::X2; // Need 2 matrices for 16x8
33+
}
34+
```
35+
36+
## API Methods
37+
38+
```rust
39+
impl<T, Shape, L> MatrixA<T, Shape, L>
40+
where
41+
T: MatrixElement,
42+
Shape: MmaShape + FragmentSize<T>,
43+
L: Layout,
44+
{
45+
/// Available for all MMA shapes - register initialization
46+
pub fn from_array(values: [T; Shape::A_REGISTERS]) -> Self { ... }
47+
pub fn splat(value: T) -> Self { ... }
48+
}
49+
50+
impl<T, Shape, L> MatrixA<T, Shape, L>
51+
where
52+
T: MatrixElement,
53+
Shape: LdMatrixShape + FragmentSize<T>,
54+
L: Layout,
55+
{
56+
/// Load from shared memory using ldmatrix
57+
pub unsafe fn load_shared(&mut self, ptr: *const T, stride: usize) {
58+
// Use ldmatrix intrinsic
59+
}
60+
}
61+
62+
impl<T, Shape, L> MatrixA<T, Shape, L>
63+
where
64+
T: MatrixElement,
65+
Shape: WmmaShape + FragmentSize<T>,
66+
L: Layout,
67+
{
68+
/// Load from global memory using WMMA
69+
pub unsafe fn load<const STRIDE: usize>(&mut self, ptr: *const T) {
70+
// Use WMMA load intrinsic
71+
}
72+
}
73+
```
74+
75+
## Usage Examples
76+
77+
```rust
78+
// Shape<16, 8, 16> - MMA + ldmatrix (no WMMA)
79+
type Shape = dims::Shape<16, 8, 16>;
80+
81+
// Can use register API
82+
let a = MatrixA::<bf16, Shape, Row>::from_array([bf16::ZERO; 8]);
83+
84+
// Can use shared memory load
85+
extern "C" __shared__ static mut SMEM: [bf16; 1024];
86+
let mut b = MatrixA::<bf16, Shape, Row>::new();
87+
b.load_shared(SMEM.as_ptr(), 16); // OK - has LdMatrixShape
88+
89+
// Cannot use global memory load
90+
// b.load::<16>(global_ptr); // ERROR: Shape doesn't implement WmmaShape
91+
92+
// Shape<16, 16, 16> - has all three APIs
93+
type WmmaShape = dims::Shape<16, 16, 16>;
94+
95+
let mut c = MatrixA::<bf16, WmmaShape, Row>::new();
96+
c.load_shared(SMEM.as_ptr(), 16); // OK - has LdMatrixShape
97+
c.load::<16>(global_ptr); // OK - has WmmaShape
98+
```
99+
100+
## Implementation Strategy
101+
102+
1. Define ldmatrix intrinsics for each shape/type combination
103+
2. Implement LdMatrixShape trait for applicable shapes
104+
3. Add load_shared methods conditionally based on LdMatrixShape
105+
4. Update documentation to explain memory hierarchy
106+
107+
## Benefits
108+
109+
1. **Type-safe**: Compile-time enforcement of what each shape supports
110+
2. **Ergonomic**: Same API style, just different method names for different memory
111+
3. **Zero-cost**: All resolved at compile time
112+
4. **Clear semantics**: Method names indicate memory source
113+
5. **Flexible**: Shapes can support multiple loading strategies
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# MMA-Only Shape API Design
2+
3+
## Problem
4+
Shape<16, 8, 16> supports MMA compute but not WMMA load/store. We need an ergonomic way to:
5+
1. Get data into fragments for MMA-only shapes
6+
2. Keep the API consistent with WMMA shapes where possible
7+
3. Make it compile-time safe with zero runtime cost
8+
9+
## Proposed Solution
10+
11+
### 1. Register-based construction (works for ALL shapes)
12+
```rust
13+
impl<T, Shape, L> MatrixA<T, Shape, L>
14+
where
15+
T: MatrixElement,
16+
Shape: MmaShape, // Note: only requires MmaShape, not WmmaShape
17+
L: Layout,
18+
{
19+
/// Create from register values directly
20+
/// This works for both WMMA and MMA-only shapes
21+
pub fn from_registers(values: &[T]) -> Self {
22+
// Implementation details
23+
}
24+
25+
/// Fill with a single value (broadcast)
26+
pub fn splat(value: T) -> Self {
27+
// Implementation details
28+
}
29+
}
30+
```
31+
32+
### 2. Conditional load method (only available for WMMA shapes)
33+
```rust
34+
impl<T, Shape, L> MatrixA<T, Shape, L>
35+
where
36+
Shape: WmmaShape, // Only for WMMA-capable shapes
37+
{
38+
pub unsafe fn load<const STRIDE: usize>(&mut self, ptr: *const T) { ... }
39+
}
40+
```
41+
42+
### 3. Builder pattern for complex initialization
43+
```rust
44+
impl<T, Shape, L> MatrixA<T, Shape, L>
45+
where
46+
Shape: MmaShape,
47+
{
48+
pub fn builder() -> MatrixBuilder<T, Shape, L> { ... }
49+
}
50+
51+
pub struct MatrixBuilder<T, Shape, L> { ... }
52+
53+
impl<T, Shape, L> MatrixBuilder<T, Shape, L> {
54+
pub fn set_lane(self, lane: usize, value: T) -> Self { ... }
55+
pub fn set_row(self, row: usize, values: &[T]) -> Self { ... }
56+
pub fn build(self) -> MatrixA<T, Shape, L> { ... }
57+
}
58+
```
59+
60+
### 4. Associated types for fragment size
61+
```rust
62+
pub trait MmaShape: TensorCoreShape {
63+
type FragmentA<T: MatrixElement>: AsRef<[T::Storage]>;
64+
type FragmentB<T: MatrixElement>: AsRef<[T::Storage]>;
65+
type FragmentC<T: AccumulatorElement>: AsRef<[T::Storage]>;
66+
}
67+
```
68+
69+
## Usage Examples
70+
71+
### For WMMA shapes (unchanged):
72+
```rust
73+
// Traditional WMMA load still works
74+
let mut a_frag = tc.matrix_a();
75+
a_frag.load::<16>(data_ptr);
76+
```
77+
78+
### For MMA-only shapes:
79+
```rust
80+
// Shape<16, 8, 16> - MMA only
81+
type Shape = dims::Shape<16, 8, 16>;
82+
let tc = TensorCore::<bf16, Shape>::new();
83+
84+
// Option 1: From registers
85+
let values = [bf16::from_f32(1.0); 8];
86+
let a_frag = MatrixA::from_registers(&values);
87+
88+
// Option 2: Splat
89+
let b_frag = MatrixB::splat(bf16::from_f32(2.0));
90+
91+
// Option 3: Builder
92+
let c_frag = Accumulator::builder()
93+
.set_lane(0, 1.0)
94+
.set_lane(1, 2.0)
95+
.build();
96+
97+
// MMA operations work the same way
98+
let result = c_frag.mma(&a_frag, &b_frag);
99+
```
100+
101+
## Benefits
102+
1. **Ergonomic**: Similar API for both WMMA and MMA shapes
103+
2. **Type-safe**: Compile-time errors if you try to load on MMA-only shapes
104+
3. **Zero-cost**: All resolved at compile time
105+
4. **Intuitive**: Methods clearly indicate data source (registers vs memory)
106+
5. **Flexible**: Multiple ways to construct fragments based on needs

0 commit comments

Comments
 (0)