Skip to content

Commit a394880

Browse files
committed
working?!?
1 parent 8075a6a commit a394880

File tree

3 files changed

+189
-1
lines changed

3 files changed

+189
-1
lines changed

crates/cuda_std/src/warp/matrix.rs

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,43 @@ extern "C" {
108108
stride: i32
109109
);
110110

111+
// MMA intrinsics for f16 -> f16
112+
#[link_name = "llvm.nvvm.wmma.m16n16k16.mma.sync.row.row.f16.f16"]
113+
pub(crate) fn wmma_mma_f16_f16_row_row_m16n16k16(
114+
a0: i16, a1: i16, a2: i16, a3: i16, a4: i16, a5: i16, a6: i16, a7: i16,
115+
a8: i16, a9: i16, a10: i16, a11: i16, a12: i16, a13: i16, a14: i16, a15: i16,
116+
b0: i16, b1: i16, b2: i16, b3: i16, b4: i16, b5: i16, b6: i16, b7: i16,
117+
b8: i16, b9: i16, b10: i16, b11: i16, b12: i16, b13: i16, b14: i16, b15: i16,
118+
c0: i16, c1: i16, c2: i16, c3: i16, c4: i16, c5: i16, c6: i16, c7: i16
119+
) -> [i16; 8];
120+
121+
#[link_name = "llvm.nvvm.wmma.m16n16k16.mma.sync.row.col.f16.f16"]
122+
pub(crate) fn wmma_mma_f16_f16_row_col_m16n16k16(
123+
a0: i16, a1: i16, a2: i16, a3: i16, a4: i16, a5: i16, a6: i16, a7: i16,
124+
a8: i16, a9: i16, a10: i16, a11: i16, a12: i16, a13: i16, a14: i16, a15: i16,
125+
b0: i16, b1: i16, b2: i16, b3: i16, b4: i16, b5: i16, b6: i16, b7: i16,
126+
b8: i16, b9: i16, b10: i16, b11: i16, b12: i16, b13: i16, b14: i16, b15: i16,
127+
c0: i16, c1: i16, c2: i16, c3: i16, c4: i16, c5: i16, c6: i16, c7: i16
128+
) -> [i16; 8];
129+
130+
#[link_name = "llvm.nvvm.wmma.m16n16k16.mma.sync.col.row.f16.f16"]
131+
pub(crate) fn wmma_mma_f16_f16_col_row_m16n16k16(
132+
a0: i16, a1: i16, a2: i16, a3: i16, a4: i16, a5: i16, a6: i16, a7: i16,
133+
a8: i16, a9: i16, a10: i16, a11: i16, a12: i16, a13: i16, a14: i16, a15: i16,
134+
b0: i16, b1: i16, b2: i16, b3: i16, b4: i16, b5: i16, b6: i16, b7: i16,
135+
b8: i16, b9: i16, b10: i16, b11: i16, b12: i16, b13: i16, b14: i16, b15: i16,
136+
c0: i16, c1: i16, c2: i16, c3: i16, c4: i16, c5: i16, c6: i16, c7: i16
137+
) -> [i16; 8];
138+
139+
#[link_name = "llvm.nvvm.wmma.m16n16k16.mma.sync.col.col.f16.f16"]
140+
pub(crate) fn wmma_mma_f16_f16_col_col_m16n16k16(
141+
a0: i16, a1: i16, a2: i16, a3: i16, a4: i16, a5: i16, a6: i16, a7: i16,
142+
a8: i16, a9: i16, a10: i16, a11: i16, a12: i16, a13: i16, a14: i16, a15: i16,
143+
b0: i16, b1: i16, b2: i16, b3: i16, b4: i16, b5: i16, b6: i16, b7: i16,
144+
b8: i16, b9: i16, b10: i16, b11: i16, b12: i16, b13: i16, b14: i16, b15: i16,
145+
c0: i16, c1: i16, c2: i16, c3: i16, c4: i16, c5: i16, c6: i16, c7: i16
146+
) -> [i16; 8];
147+
111148
// MMA intrinsics for f16 -> f32
112149
#[link_name = "llvm.nvvm.wmma.m16n16k16.mma.sync.row.row.f16.f32"]
113150
pub(crate) fn wmma_mma_f16_f32_row_row_m16n16k16(
@@ -898,6 +935,151 @@ where
898935
) -> Accumulator<Self::Output, Shape>;
899936
}
900937

938+
// f16 × f16 + f16 → f16 with 16x16x16 (all layout combinations)
939+
impl MmaWithShapeAndLayout<f16, f16, f16, dims::Shape<16, 16, 16>, layout::Row, layout::Row> for f16 {
940+
type Output = f16;
941+
942+
#[gpu_only]
943+
fn mma(
944+
a: &MatrixA<f16, dims::Shape<16, 16, 16>, layout::Row>,
945+
b: &MatrixB<f16, dims::Shape<16, 16, 16>, layout::Row>,
946+
c: &Accumulator<f16, dims::Shape<16, 16, 16>>,
947+
) -> Accumulator<f16, dims::Shape<16, 16, 16>> {
948+
let mut result = Accumulator::new();
949+
950+
let a_vals = unsafe { core::mem::transmute::<[<f16 as MatrixElement>::Storage; 16], [i16; 16]>(
951+
*(&a.data[..16] as *const [<f16 as MatrixElement>::Storage] as *const [<f16 as MatrixElement>::Storage; 16])
952+
)};
953+
let b_vals = unsafe { core::mem::transmute::<[<f16 as MatrixElement>::Storage; 16], [i16; 16]>(
954+
*(&b.data[..16] as *const [<f16 as MatrixElement>::Storage] as *const [<f16 as MatrixElement>::Storage; 16])
955+
)};
956+
let c_vals = unsafe { core::mem::transmute::<[<f16 as AccumulatorElement>::Storage; 8], [i16; 8]>(
957+
*(&c.data[..8] as *const [<f16 as AccumulatorElement>::Storage] as *const [<f16 as AccumulatorElement>::Storage; 8])
958+
)};
959+
960+
let result_vals = unsafe {
961+
wmma_mma_f16_f16_row_row_m16n16k16(
962+
a_vals[0], a_vals[1], a_vals[2], a_vals[3], a_vals[4], a_vals[5], a_vals[6], a_vals[7],
963+
a_vals[8], a_vals[9], a_vals[10], a_vals[11], a_vals[12], a_vals[13], a_vals[14], a_vals[15],
964+
b_vals[0], b_vals[1], b_vals[2], b_vals[3], b_vals[4], b_vals[5], b_vals[6], b_vals[7],
965+
b_vals[8], b_vals[9], b_vals[10], b_vals[11], b_vals[12], b_vals[13], b_vals[14], b_vals[15],
966+
c_vals[0], c_vals[1], c_vals[2], c_vals[3], c_vals[4], c_vals[5], c_vals[6], c_vals[7]
967+
)
968+
};
969+
970+
result.data[..8].copy_from_slice(&unsafe { core::mem::transmute::<[i16; 8], [<f16 as AccumulatorElement>::Storage; 8]>(result_vals) });
971+
result
972+
}
973+
}
974+
975+
impl MmaWithShapeAndLayout<f16, f16, f16, dims::Shape<16, 16, 16>, layout::Row, layout::Col> for f16 {
976+
type Output = f16;
977+
978+
#[gpu_only]
979+
fn mma(
980+
a: &MatrixA<f16, dims::Shape<16, 16, 16>, layout::Row>,
981+
b: &MatrixB<f16, dims::Shape<16, 16, 16>, layout::Col>,
982+
c: &Accumulator<f16, dims::Shape<16, 16, 16>>,
983+
) -> Accumulator<f16, dims::Shape<16, 16, 16>> {
984+
let mut result = Accumulator::new();
985+
986+
let a_vals = unsafe { core::mem::transmute::<[<f16 as MatrixElement>::Storage; 16], [i16; 16]>(
987+
*(&a.data[..16] as *const [<f16 as MatrixElement>::Storage] as *const [<f16 as MatrixElement>::Storage; 16])
988+
)};
989+
let b_vals = unsafe { core::mem::transmute::<[<f16 as MatrixElement>::Storage; 16], [i16; 16]>(
990+
*(&b.data[..16] as *const [<f16 as MatrixElement>::Storage] as *const [<f16 as MatrixElement>::Storage; 16])
991+
)};
992+
let c_vals = unsafe { core::mem::transmute::<[<f16 as AccumulatorElement>::Storage; 8], [i16; 8]>(
993+
*(&c.data[..8] as *const [<f16 as AccumulatorElement>::Storage] as *const [<f16 as AccumulatorElement>::Storage; 8])
994+
)};
995+
996+
let result_vals = unsafe {
997+
wmma_mma_f16_f16_row_col_m16n16k16(
998+
a_vals[0], a_vals[1], a_vals[2], a_vals[3], a_vals[4], a_vals[5], a_vals[6], a_vals[7],
999+
a_vals[8], a_vals[9], a_vals[10], a_vals[11], a_vals[12], a_vals[13], a_vals[14], a_vals[15],
1000+
b_vals[0], b_vals[1], b_vals[2], b_vals[3], b_vals[4], b_vals[5], b_vals[6], b_vals[7],
1001+
b_vals[8], b_vals[9], b_vals[10], b_vals[11], b_vals[12], b_vals[13], b_vals[14], b_vals[15],
1002+
c_vals[0], c_vals[1], c_vals[2], c_vals[3], c_vals[4], c_vals[5], c_vals[6], c_vals[7]
1003+
)
1004+
};
1005+
1006+
result.data[..8].copy_from_slice(&unsafe { core::mem::transmute::<[i16; 8], [<f16 as AccumulatorElement>::Storage; 8]>(result_vals) });
1007+
result
1008+
}
1009+
}
1010+
1011+
impl MmaWithShapeAndLayout<f16, f16, f16, dims::Shape<16, 16, 16>, layout::Col, layout::Row> for f16 {
1012+
type Output = f16;
1013+
1014+
#[gpu_only]
1015+
fn mma(
1016+
a: &MatrixA<f16, dims::Shape<16, 16, 16>, layout::Col>,
1017+
b: &MatrixB<f16, dims::Shape<16, 16, 16>, layout::Row>,
1018+
c: &Accumulator<f16, dims::Shape<16, 16, 16>>,
1019+
) -> Accumulator<f16, dims::Shape<16, 16, 16>> {
1020+
let mut result = Accumulator::new();
1021+
1022+
let a_vals = unsafe { core::mem::transmute::<[<f16 as MatrixElement>::Storage; 16], [i16; 16]>(
1023+
*(&a.data[..16] as *const [<f16 as MatrixElement>::Storage] as *const [<f16 as MatrixElement>::Storage; 16])
1024+
)};
1025+
let b_vals = unsafe { core::mem::transmute::<[<f16 as MatrixElement>::Storage; 16], [i16; 16]>(
1026+
*(&b.data[..16] as *const [<f16 as MatrixElement>::Storage] as *const [<f16 as MatrixElement>::Storage; 16])
1027+
)};
1028+
let c_vals = unsafe { core::mem::transmute::<[<f16 as AccumulatorElement>::Storage; 8], [i16; 8]>(
1029+
*(&c.data[..8] as *const [<f16 as AccumulatorElement>::Storage] as *const [<f16 as AccumulatorElement>::Storage; 8])
1030+
)};
1031+
1032+
let result_vals = unsafe {
1033+
wmma_mma_f16_f16_col_row_m16n16k16(
1034+
a_vals[0], a_vals[1], a_vals[2], a_vals[3], a_vals[4], a_vals[5], a_vals[6], a_vals[7],
1035+
a_vals[8], a_vals[9], a_vals[10], a_vals[11], a_vals[12], a_vals[13], a_vals[14], a_vals[15],
1036+
b_vals[0], b_vals[1], b_vals[2], b_vals[3], b_vals[4], b_vals[5], b_vals[6], b_vals[7],
1037+
b_vals[8], b_vals[9], b_vals[10], b_vals[11], b_vals[12], b_vals[13], b_vals[14], b_vals[15],
1038+
c_vals[0], c_vals[1], c_vals[2], c_vals[3], c_vals[4], c_vals[5], c_vals[6], c_vals[7]
1039+
)
1040+
};
1041+
1042+
result.data[..8].copy_from_slice(&unsafe { core::mem::transmute::<[i16; 8], [<f16 as AccumulatorElement>::Storage; 8]>(result_vals) });
1043+
result
1044+
}
1045+
}
1046+
1047+
impl MmaWithShapeAndLayout<f16, f16, f16, dims::Shape<16, 16, 16>, layout::Col, layout::Col> for f16 {
1048+
type Output = f16;
1049+
1050+
#[gpu_only]
1051+
fn mma(
1052+
a: &MatrixA<f16, dims::Shape<16, 16, 16>, layout::Col>,
1053+
b: &MatrixB<f16, dims::Shape<16, 16, 16>, layout::Col>,
1054+
c: &Accumulator<f16, dims::Shape<16, 16, 16>>,
1055+
) -> Accumulator<f16, dims::Shape<16, 16, 16>> {
1056+
let mut result = Accumulator::new();
1057+
1058+
let a_vals = unsafe { core::mem::transmute::<[<f16 as MatrixElement>::Storage; 16], [i16; 16]>(
1059+
*(&a.data[..16] as *const [<f16 as MatrixElement>::Storage] as *const [<f16 as MatrixElement>::Storage; 16])
1060+
)};
1061+
let b_vals = unsafe { core::mem::transmute::<[<f16 as MatrixElement>::Storage; 16], [i16; 16]>(
1062+
*(&b.data[..16] as *const [<f16 as MatrixElement>::Storage] as *const [<f16 as MatrixElement>::Storage; 16])
1063+
)};
1064+
let c_vals = unsafe { core::mem::transmute::<[<f16 as AccumulatorElement>::Storage; 8], [i16; 8]>(
1065+
*(&c.data[..8] as *const [<f16 as AccumulatorElement>::Storage] as *const [<f16 as AccumulatorElement>::Storage; 8])
1066+
)};
1067+
1068+
let result_vals = unsafe {
1069+
wmma_mma_f16_f16_col_col_m16n16k16(
1070+
a_vals[0], a_vals[1], a_vals[2], a_vals[3], a_vals[4], a_vals[5], a_vals[6], a_vals[7],
1071+
a_vals[8], a_vals[9], a_vals[10], a_vals[11], a_vals[12], a_vals[13], a_vals[14], a_vals[15],
1072+
b_vals[0], b_vals[1], b_vals[2], b_vals[3], b_vals[4], b_vals[5], b_vals[6], b_vals[7],
1073+
b_vals[8], b_vals[9], b_vals[10], b_vals[11], b_vals[12], b_vals[13], b_vals[14], b_vals[15],
1074+
c_vals[0], c_vals[1], c_vals[2], c_vals[3], c_vals[4], c_vals[5], c_vals[6], c_vals[7]
1075+
)
1076+
};
1077+
1078+
result.data[..8].copy_from_slice(&unsafe { core::mem::transmute::<[i16; 8], [<f16 as AccumulatorElement>::Storage; 8]>(result_vals) });
1079+
result
1080+
}
1081+
}
1082+
9011083
// f16 × f16 + f32 → f32 with 16x16x16, Row-Row
9021084
impl MmaWithShapeAndLayout<f16, f16, f32, dims::Shape<16, 16, 16>, layout::Row, layout::Row> for f32 {
9031085
type Output = f32;

crates/rustc_codegen_nvvm/libintrinsics.ll

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,12 @@ declare { double, double } @llvm.nvvm.wmma.m8n8k4.mma.sync.row.col.f64.f64(doubl
531531
declare { double, double } @llvm.nvvm.wmma.m8n8k4.mma.sync.col.row.f64.f64(double, double, double, double, double, double, double, double, double, double) #1
532532
declare { double, double } @llvm.nvvm.wmma.m8n8k4.mma.sync.col.col.f64.f64(double, double, double, double, double, double, double, double, double, double) #1
533533

534+
; f16 accumulator MMA operations (16x16x16 shape)
535+
declare { i16, i16, i16, i16, i16, i16, i16, i16 } @llvm.nvvm.wmma.m16n16k16.mma.sync.row.row.f16.f16(i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16) #1
536+
declare { i16, i16, i16, i16, i16, i16, i16, i16 } @llvm.nvvm.wmma.m16n16k16.mma.sync.row.col.f16.f16(i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16) #1
537+
declare { i16, i16, i16, i16, i16, i16, i16, i16 } @llvm.nvvm.wmma.m16n16k16.mma.sync.col.row.f16.f16(i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16) #1
538+
declare { i16, i16, i16, i16, i16, i16, i16, i16 } @llvm.nvvm.wmma.m16n16k16.mma.sync.col.col.f16.f16(i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16) #1
539+
534540
; TF32 tensor core operations (16x16x8 shape)
535541
; Note: TF32 uses float storage but with reduced precision during computation
536542
declare float @llvm.nvvm.f2tf32.rna.f32(float) #1

tests/compiletests/ui/warp/matrix/matrix_operations.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
use cuda_std::f16;
55
use cuda_std::kernel;
6-
use cuda_std::warp::matrix::{dims, layout, MmaExt, TensorCore};
6+
use cuda_std::warp::matrix::{dims, layout, TensorCore};
77

88
#[kernel]
99
pub unsafe fn test_mma_operations() {

0 commit comments

Comments
 (0)