@@ -108,6 +108,43 @@ extern "C" {
108108 stride : i32
109109 ) ;
110110
111+ // MMA intrinsics for f16 -> f16
112+ #[ link_name = "llvm.nvvm.wmma.m16n16k16.mma.sync.row.row.f16.f16" ]
113+ pub ( crate ) fn wmma_mma_f16_f16_row_row_m16n16k16 (
114+ a0 : i16 , a1 : i16 , a2 : i16 , a3 : i16 , a4 : i16 , a5 : i16 , a6 : i16 , a7 : i16 ,
115+ a8 : i16 , a9 : i16 , a10 : i16 , a11 : i16 , a12 : i16 , a13 : i16 , a14 : i16 , a15 : i16 ,
116+ b0 : i16 , b1 : i16 , b2 : i16 , b3 : i16 , b4 : i16 , b5 : i16 , b6 : i16 , b7 : i16 ,
117+ b8 : i16 , b9 : i16 , b10 : i16 , b11 : i16 , b12 : i16 , b13 : i16 , b14 : i16 , b15 : i16 ,
118+ c0 : i16 , c1 : i16 , c2 : i16 , c3 : i16 , c4 : i16 , c5 : i16 , c6 : i16 , c7 : i16
119+ ) -> [ i16 ; 8 ] ;
120+
121+ #[ link_name = "llvm.nvvm.wmma.m16n16k16.mma.sync.row.col.f16.f16" ]
122+ pub ( crate ) fn wmma_mma_f16_f16_row_col_m16n16k16 (
123+ a0 : i16 , a1 : i16 , a2 : i16 , a3 : i16 , a4 : i16 , a5 : i16 , a6 : i16 , a7 : i16 ,
124+ a8 : i16 , a9 : i16 , a10 : i16 , a11 : i16 , a12 : i16 , a13 : i16 , a14 : i16 , a15 : i16 ,
125+ b0 : i16 , b1 : i16 , b2 : i16 , b3 : i16 , b4 : i16 , b5 : i16 , b6 : i16 , b7 : i16 ,
126+ b8 : i16 , b9 : i16 , b10 : i16 , b11 : i16 , b12 : i16 , b13 : i16 , b14 : i16 , b15 : i16 ,
127+ c0 : i16 , c1 : i16 , c2 : i16 , c3 : i16 , c4 : i16 , c5 : i16 , c6 : i16 , c7 : i16
128+ ) -> [ i16 ; 8 ] ;
129+
130+ #[ link_name = "llvm.nvvm.wmma.m16n16k16.mma.sync.col.row.f16.f16" ]
131+ pub ( crate ) fn wmma_mma_f16_f16_col_row_m16n16k16 (
132+ a0 : i16 , a1 : i16 , a2 : i16 , a3 : i16 , a4 : i16 , a5 : i16 , a6 : i16 , a7 : i16 ,
133+ a8 : i16 , a9 : i16 , a10 : i16 , a11 : i16 , a12 : i16 , a13 : i16 , a14 : i16 , a15 : i16 ,
134+ b0 : i16 , b1 : i16 , b2 : i16 , b3 : i16 , b4 : i16 , b5 : i16 , b6 : i16 , b7 : i16 ,
135+ b8 : i16 , b9 : i16 , b10 : i16 , b11 : i16 , b12 : i16 , b13 : i16 , b14 : i16 , b15 : i16 ,
136+ c0 : i16 , c1 : i16 , c2 : i16 , c3 : i16 , c4 : i16 , c5 : i16 , c6 : i16 , c7 : i16
137+ ) -> [ i16 ; 8 ] ;
138+
139+ #[ link_name = "llvm.nvvm.wmma.m16n16k16.mma.sync.col.col.f16.f16" ]
140+ pub ( crate ) fn wmma_mma_f16_f16_col_col_m16n16k16 (
141+ a0 : i16 , a1 : i16 , a2 : i16 , a3 : i16 , a4 : i16 , a5 : i16 , a6 : i16 , a7 : i16 ,
142+ a8 : i16 , a9 : i16 , a10 : i16 , a11 : i16 , a12 : i16 , a13 : i16 , a14 : i16 , a15 : i16 ,
143+ b0 : i16 , b1 : i16 , b2 : i16 , b3 : i16 , b4 : i16 , b5 : i16 , b6 : i16 , b7 : i16 ,
144+ b8 : i16 , b9 : i16 , b10 : i16 , b11 : i16 , b12 : i16 , b13 : i16 , b14 : i16 , b15 : i16 ,
145+ c0 : i16 , c1 : i16 , c2 : i16 , c3 : i16 , c4 : i16 , c5 : i16 , c6 : i16 , c7 : i16
146+ ) -> [ i16 ; 8 ] ;
147+
111148 // MMA intrinsics for f16 -> f32
112149 #[ link_name = "llvm.nvvm.wmma.m16n16k16.mma.sync.row.row.f16.f32" ]
113150 pub ( crate ) fn wmma_mma_f16_f32_row_row_m16n16k16 (
@@ -898,6 +935,151 @@ where
898935 ) -> Accumulator < Self :: Output , Shape > ;
899936}
900937
938+ // f16 × f16 + f16 → f16 with 16x16x16 (all layout combinations)
939+ impl MmaWithShapeAndLayout < f16 , f16 , f16 , dims:: Shape < 16 , 16 , 16 > , layout:: Row , layout:: Row > for f16 {
940+ type Output = f16 ;
941+
942+ #[ gpu_only]
943+ fn mma (
944+ a : & MatrixA < f16 , dims:: Shape < 16 , 16 , 16 > , layout:: Row > ,
945+ b : & MatrixB < f16 , dims:: Shape < 16 , 16 , 16 > , layout:: Row > ,
946+ c : & Accumulator < f16 , dims:: Shape < 16 , 16 , 16 > > ,
947+ ) -> Accumulator < f16 , dims:: Shape < 16 , 16 , 16 > > {
948+ let mut result = Accumulator :: new ( ) ;
949+
950+ let a_vals = unsafe { core:: mem:: transmute :: < [ <f16 as MatrixElement >:: Storage ; 16 ] , [ i16 ; 16 ] > (
951+ * ( & a. data [ ..16 ] as * const [ <f16 as MatrixElement >:: Storage ] as * const [ <f16 as MatrixElement >:: Storage ; 16 ] )
952+ ) } ;
953+ let b_vals = unsafe { core:: mem:: transmute :: < [ <f16 as MatrixElement >:: Storage ; 16 ] , [ i16 ; 16 ] > (
954+ * ( & b. data [ ..16 ] as * const [ <f16 as MatrixElement >:: Storage ] as * const [ <f16 as MatrixElement >:: Storage ; 16 ] )
955+ ) } ;
956+ let c_vals = unsafe { core:: mem:: transmute :: < [ <f16 as AccumulatorElement >:: Storage ; 8 ] , [ i16 ; 8 ] > (
957+ * ( & c. data [ ..8 ] as * const [ <f16 as AccumulatorElement >:: Storage ] as * const [ <f16 as AccumulatorElement >:: Storage ; 8 ] )
958+ ) } ;
959+
960+ let result_vals = unsafe {
961+ wmma_mma_f16_f16_row_row_m16n16k16 (
962+ a_vals[ 0 ] , a_vals[ 1 ] , a_vals[ 2 ] , a_vals[ 3 ] , a_vals[ 4 ] , a_vals[ 5 ] , a_vals[ 6 ] , a_vals[ 7 ] ,
963+ a_vals[ 8 ] , a_vals[ 9 ] , a_vals[ 10 ] , a_vals[ 11 ] , a_vals[ 12 ] , a_vals[ 13 ] , a_vals[ 14 ] , a_vals[ 15 ] ,
964+ b_vals[ 0 ] , b_vals[ 1 ] , b_vals[ 2 ] , b_vals[ 3 ] , b_vals[ 4 ] , b_vals[ 5 ] , b_vals[ 6 ] , b_vals[ 7 ] ,
965+ b_vals[ 8 ] , b_vals[ 9 ] , b_vals[ 10 ] , b_vals[ 11 ] , b_vals[ 12 ] , b_vals[ 13 ] , b_vals[ 14 ] , b_vals[ 15 ] ,
966+ c_vals[ 0 ] , c_vals[ 1 ] , c_vals[ 2 ] , c_vals[ 3 ] , c_vals[ 4 ] , c_vals[ 5 ] , c_vals[ 6 ] , c_vals[ 7 ]
967+ )
968+ } ;
969+
970+ result. data [ ..8 ] . copy_from_slice ( & unsafe { core:: mem:: transmute :: < [ i16 ; 8 ] , [ <f16 as AccumulatorElement >:: Storage ; 8 ] > ( result_vals) } ) ;
971+ result
972+ }
973+ }
974+
975+ impl MmaWithShapeAndLayout < f16 , f16 , f16 , dims:: Shape < 16 , 16 , 16 > , layout:: Row , layout:: Col > for f16 {
976+ type Output = f16 ;
977+
978+ #[ gpu_only]
979+ fn mma (
980+ a : & MatrixA < f16 , dims:: Shape < 16 , 16 , 16 > , layout:: Row > ,
981+ b : & MatrixB < f16 , dims:: Shape < 16 , 16 , 16 > , layout:: Col > ,
982+ c : & Accumulator < f16 , dims:: Shape < 16 , 16 , 16 > > ,
983+ ) -> Accumulator < f16 , dims:: Shape < 16 , 16 , 16 > > {
984+ let mut result = Accumulator :: new ( ) ;
985+
986+ let a_vals = unsafe { core:: mem:: transmute :: < [ <f16 as MatrixElement >:: Storage ; 16 ] , [ i16 ; 16 ] > (
987+ * ( & a. data [ ..16 ] as * const [ <f16 as MatrixElement >:: Storage ] as * const [ <f16 as MatrixElement >:: Storage ; 16 ] )
988+ ) } ;
989+ let b_vals = unsafe { core:: mem:: transmute :: < [ <f16 as MatrixElement >:: Storage ; 16 ] , [ i16 ; 16 ] > (
990+ * ( & b. data [ ..16 ] as * const [ <f16 as MatrixElement >:: Storage ] as * const [ <f16 as MatrixElement >:: Storage ; 16 ] )
991+ ) } ;
992+ let c_vals = unsafe { core:: mem:: transmute :: < [ <f16 as AccumulatorElement >:: Storage ; 8 ] , [ i16 ; 8 ] > (
993+ * ( & c. data [ ..8 ] as * const [ <f16 as AccumulatorElement >:: Storage ] as * const [ <f16 as AccumulatorElement >:: Storage ; 8 ] )
994+ ) } ;
995+
996+ let result_vals = unsafe {
997+ wmma_mma_f16_f16_row_col_m16n16k16 (
998+ a_vals[ 0 ] , a_vals[ 1 ] , a_vals[ 2 ] , a_vals[ 3 ] , a_vals[ 4 ] , a_vals[ 5 ] , a_vals[ 6 ] , a_vals[ 7 ] ,
999+ a_vals[ 8 ] , a_vals[ 9 ] , a_vals[ 10 ] , a_vals[ 11 ] , a_vals[ 12 ] , a_vals[ 13 ] , a_vals[ 14 ] , a_vals[ 15 ] ,
1000+ b_vals[ 0 ] , b_vals[ 1 ] , b_vals[ 2 ] , b_vals[ 3 ] , b_vals[ 4 ] , b_vals[ 5 ] , b_vals[ 6 ] , b_vals[ 7 ] ,
1001+ b_vals[ 8 ] , b_vals[ 9 ] , b_vals[ 10 ] , b_vals[ 11 ] , b_vals[ 12 ] , b_vals[ 13 ] , b_vals[ 14 ] , b_vals[ 15 ] ,
1002+ c_vals[ 0 ] , c_vals[ 1 ] , c_vals[ 2 ] , c_vals[ 3 ] , c_vals[ 4 ] , c_vals[ 5 ] , c_vals[ 6 ] , c_vals[ 7 ]
1003+ )
1004+ } ;
1005+
1006+ result. data [ ..8 ] . copy_from_slice ( & unsafe { core:: mem:: transmute :: < [ i16 ; 8 ] , [ <f16 as AccumulatorElement >:: Storage ; 8 ] > ( result_vals) } ) ;
1007+ result
1008+ }
1009+ }
1010+
1011+ impl MmaWithShapeAndLayout < f16 , f16 , f16 , dims:: Shape < 16 , 16 , 16 > , layout:: Col , layout:: Row > for f16 {
1012+ type Output = f16 ;
1013+
1014+ #[ gpu_only]
1015+ fn mma (
1016+ a : & MatrixA < f16 , dims:: Shape < 16 , 16 , 16 > , layout:: Col > ,
1017+ b : & MatrixB < f16 , dims:: Shape < 16 , 16 , 16 > , layout:: Row > ,
1018+ c : & Accumulator < f16 , dims:: Shape < 16 , 16 , 16 > > ,
1019+ ) -> Accumulator < f16 , dims:: Shape < 16 , 16 , 16 > > {
1020+ let mut result = Accumulator :: new ( ) ;
1021+
1022+ let a_vals = unsafe { core:: mem:: transmute :: < [ <f16 as MatrixElement >:: Storage ; 16 ] , [ i16 ; 16 ] > (
1023+ * ( & a. data [ ..16 ] as * const [ <f16 as MatrixElement >:: Storage ] as * const [ <f16 as MatrixElement >:: Storage ; 16 ] )
1024+ ) } ;
1025+ let b_vals = unsafe { core:: mem:: transmute :: < [ <f16 as MatrixElement >:: Storage ; 16 ] , [ i16 ; 16 ] > (
1026+ * ( & b. data [ ..16 ] as * const [ <f16 as MatrixElement >:: Storage ] as * const [ <f16 as MatrixElement >:: Storage ; 16 ] )
1027+ ) } ;
1028+ let c_vals = unsafe { core:: mem:: transmute :: < [ <f16 as AccumulatorElement >:: Storage ; 8 ] , [ i16 ; 8 ] > (
1029+ * ( & c. data [ ..8 ] as * const [ <f16 as AccumulatorElement >:: Storage ] as * const [ <f16 as AccumulatorElement >:: Storage ; 8 ] )
1030+ ) } ;
1031+
1032+ let result_vals = unsafe {
1033+ wmma_mma_f16_f16_col_row_m16n16k16 (
1034+ a_vals[ 0 ] , a_vals[ 1 ] , a_vals[ 2 ] , a_vals[ 3 ] , a_vals[ 4 ] , a_vals[ 5 ] , a_vals[ 6 ] , a_vals[ 7 ] ,
1035+ a_vals[ 8 ] , a_vals[ 9 ] , a_vals[ 10 ] , a_vals[ 11 ] , a_vals[ 12 ] , a_vals[ 13 ] , a_vals[ 14 ] , a_vals[ 15 ] ,
1036+ b_vals[ 0 ] , b_vals[ 1 ] , b_vals[ 2 ] , b_vals[ 3 ] , b_vals[ 4 ] , b_vals[ 5 ] , b_vals[ 6 ] , b_vals[ 7 ] ,
1037+ b_vals[ 8 ] , b_vals[ 9 ] , b_vals[ 10 ] , b_vals[ 11 ] , b_vals[ 12 ] , b_vals[ 13 ] , b_vals[ 14 ] , b_vals[ 15 ] ,
1038+ c_vals[ 0 ] , c_vals[ 1 ] , c_vals[ 2 ] , c_vals[ 3 ] , c_vals[ 4 ] , c_vals[ 5 ] , c_vals[ 6 ] , c_vals[ 7 ]
1039+ )
1040+ } ;
1041+
1042+ result. data [ ..8 ] . copy_from_slice ( & unsafe { core:: mem:: transmute :: < [ i16 ; 8 ] , [ <f16 as AccumulatorElement >:: Storage ; 8 ] > ( result_vals) } ) ;
1043+ result
1044+ }
1045+ }
1046+
1047+ impl MmaWithShapeAndLayout < f16 , f16 , f16 , dims:: Shape < 16 , 16 , 16 > , layout:: Col , layout:: Col > for f16 {
1048+ type Output = f16 ;
1049+
1050+ #[ gpu_only]
1051+ fn mma (
1052+ a : & MatrixA < f16 , dims:: Shape < 16 , 16 , 16 > , layout:: Col > ,
1053+ b : & MatrixB < f16 , dims:: Shape < 16 , 16 , 16 > , layout:: Col > ,
1054+ c : & Accumulator < f16 , dims:: Shape < 16 , 16 , 16 > > ,
1055+ ) -> Accumulator < f16 , dims:: Shape < 16 , 16 , 16 > > {
1056+ let mut result = Accumulator :: new ( ) ;
1057+
1058+ let a_vals = unsafe { core:: mem:: transmute :: < [ <f16 as MatrixElement >:: Storage ; 16 ] , [ i16 ; 16 ] > (
1059+ * ( & a. data [ ..16 ] as * const [ <f16 as MatrixElement >:: Storage ] as * const [ <f16 as MatrixElement >:: Storage ; 16 ] )
1060+ ) } ;
1061+ let b_vals = unsafe { core:: mem:: transmute :: < [ <f16 as MatrixElement >:: Storage ; 16 ] , [ i16 ; 16 ] > (
1062+ * ( & b. data [ ..16 ] as * const [ <f16 as MatrixElement >:: Storage ] as * const [ <f16 as MatrixElement >:: Storage ; 16 ] )
1063+ ) } ;
1064+ let c_vals = unsafe { core:: mem:: transmute :: < [ <f16 as AccumulatorElement >:: Storage ; 8 ] , [ i16 ; 8 ] > (
1065+ * ( & c. data [ ..8 ] as * const [ <f16 as AccumulatorElement >:: Storage ] as * const [ <f16 as AccumulatorElement >:: Storage ; 8 ] )
1066+ ) } ;
1067+
1068+ let result_vals = unsafe {
1069+ wmma_mma_f16_f16_col_col_m16n16k16 (
1070+ a_vals[ 0 ] , a_vals[ 1 ] , a_vals[ 2 ] , a_vals[ 3 ] , a_vals[ 4 ] , a_vals[ 5 ] , a_vals[ 6 ] , a_vals[ 7 ] ,
1071+ a_vals[ 8 ] , a_vals[ 9 ] , a_vals[ 10 ] , a_vals[ 11 ] , a_vals[ 12 ] , a_vals[ 13 ] , a_vals[ 14 ] , a_vals[ 15 ] ,
1072+ b_vals[ 0 ] , b_vals[ 1 ] , b_vals[ 2 ] , b_vals[ 3 ] , b_vals[ 4 ] , b_vals[ 5 ] , b_vals[ 6 ] , b_vals[ 7 ] ,
1073+ b_vals[ 8 ] , b_vals[ 9 ] , b_vals[ 10 ] , b_vals[ 11 ] , b_vals[ 12 ] , b_vals[ 13 ] , b_vals[ 14 ] , b_vals[ 15 ] ,
1074+ c_vals[ 0 ] , c_vals[ 1 ] , c_vals[ 2 ] , c_vals[ 3 ] , c_vals[ 4 ] , c_vals[ 5 ] , c_vals[ 6 ] , c_vals[ 7 ]
1075+ )
1076+ } ;
1077+
1078+ result. data [ ..8 ] . copy_from_slice ( & unsafe { core:: mem:: transmute :: < [ i16 ; 8 ] , [ <f16 as AccumulatorElement >:: Storage ; 8 ] > ( result_vals) } ) ;
1079+ result
1080+ }
1081+ }
1082+
9011083// f16 × f16 + f32 → f32 with 16x16x16, Row-Row
9021084impl MmaWithShapeAndLayout < f16 , f16 , f32 , dims:: Shape < 16 , 16 , 16 > , layout:: Row , layout:: Row > for f32 {
9031085 type Output = f32 ;
0 commit comments