@@ -1334,8 +1334,10 @@ impl MmaWithShapeAndLayout<f32, f32, f32, dims::Shape<16, 16, 8>, layout::Row, l
13341334
13351335 let result_vals = unsafe {
13361336 wmma_mma_tf32_f32_row_row_m16n16k8 (
1337- a_vals[ 0 ] , a_vals[ 1 ] , a_vals[ 2 ] , a_vals[ 3 ] , a_vals[ 4 ] , a_vals[ 5 ] , a_vals[ 6 ] , a_vals[ 7 ] ,
1338- b_vals[ 0 ] , b_vals[ 1 ] , b_vals[ 2 ] , b_vals[ 3 ] , b_vals[ 4 ] , b_vals[ 5 ] , b_vals[ 6 ] , b_vals[ 7 ] ,
1337+ float_to_tf32 ( a_vals[ 0 ] ) , float_to_tf32 ( a_vals[ 1 ] ) , float_to_tf32 ( a_vals[ 2 ] ) , float_to_tf32 ( a_vals[ 3 ] ) ,
1338+ float_to_tf32 ( a_vals[ 4 ] ) , float_to_tf32 ( a_vals[ 5 ] ) , float_to_tf32 ( a_vals[ 6 ] ) , float_to_tf32 ( a_vals[ 7 ] ) ,
1339+ float_to_tf32 ( b_vals[ 0 ] ) , float_to_tf32 ( b_vals[ 1 ] ) , float_to_tf32 ( b_vals[ 2 ] ) , float_to_tf32 ( b_vals[ 3 ] ) ,
1340+ float_to_tf32 ( b_vals[ 4 ] ) , float_to_tf32 ( b_vals[ 5 ] ) , float_to_tf32 ( b_vals[ 6 ] ) , float_to_tf32 ( b_vals[ 7 ] ) ,
13391341 c_vals[ 0 ] , c_vals[ 1 ] , c_vals[ 2 ] , c_vals[ 3 ] , c_vals[ 4 ] , c_vals[ 5 ] , c_vals[ 6 ] , c_vals[ 7 ]
13401342 )
13411343 } ;
@@ -1362,8 +1364,10 @@ impl MmaWithShapeAndLayout<f32, f32, f32, dims::Shape<16, 16, 8>, layout::Row, l
13621364
13631365 let result_vals = unsafe {
13641366 wmma_mma_tf32_f32_row_col_m16n16k8 (
1365- a_vals[ 0 ] , a_vals[ 1 ] , a_vals[ 2 ] , a_vals[ 3 ] , a_vals[ 4 ] , a_vals[ 5 ] , a_vals[ 6 ] , a_vals[ 7 ] ,
1366- b_vals[ 0 ] , b_vals[ 1 ] , b_vals[ 2 ] , b_vals[ 3 ] , b_vals[ 4 ] , b_vals[ 5 ] , b_vals[ 6 ] , b_vals[ 7 ] ,
1367+ float_to_tf32 ( a_vals[ 0 ] ) , float_to_tf32 ( a_vals[ 1 ] ) , float_to_tf32 ( a_vals[ 2 ] ) , float_to_tf32 ( a_vals[ 3 ] ) ,
1368+ float_to_tf32 ( a_vals[ 4 ] ) , float_to_tf32 ( a_vals[ 5 ] ) , float_to_tf32 ( a_vals[ 6 ] ) , float_to_tf32 ( a_vals[ 7 ] ) ,
1369+ float_to_tf32 ( b_vals[ 0 ] ) , float_to_tf32 ( b_vals[ 1 ] ) , float_to_tf32 ( b_vals[ 2 ] ) , float_to_tf32 ( b_vals[ 3 ] ) ,
1370+ float_to_tf32 ( b_vals[ 4 ] ) , float_to_tf32 ( b_vals[ 5 ] ) , float_to_tf32 ( b_vals[ 6 ] ) , float_to_tf32 ( b_vals[ 7 ] ) ,
13671371 c_vals[ 0 ] , c_vals[ 1 ] , c_vals[ 2 ] , c_vals[ 3 ] , c_vals[ 4 ] , c_vals[ 5 ] , c_vals[ 6 ] , c_vals[ 7 ]
13681372 )
13691373 } ;
@@ -1390,8 +1394,10 @@ impl MmaWithShapeAndLayout<f32, f32, f32, dims::Shape<16, 16, 8>, layout::Col, l
13901394
13911395 let result_vals = unsafe {
13921396 wmma_mma_tf32_f32_col_row_m16n16k8 (
1393- a_vals[ 0 ] , a_vals[ 1 ] , a_vals[ 2 ] , a_vals[ 3 ] , a_vals[ 4 ] , a_vals[ 5 ] , a_vals[ 6 ] , a_vals[ 7 ] ,
1394- b_vals[ 0 ] , b_vals[ 1 ] , b_vals[ 2 ] , b_vals[ 3 ] , b_vals[ 4 ] , b_vals[ 5 ] , b_vals[ 6 ] , b_vals[ 7 ] ,
1397+ float_to_tf32 ( a_vals[ 0 ] ) , float_to_tf32 ( a_vals[ 1 ] ) , float_to_tf32 ( a_vals[ 2 ] ) , float_to_tf32 ( a_vals[ 3 ] ) ,
1398+ float_to_tf32 ( a_vals[ 4 ] ) , float_to_tf32 ( a_vals[ 5 ] ) , float_to_tf32 ( a_vals[ 6 ] ) , float_to_tf32 ( a_vals[ 7 ] ) ,
1399+ float_to_tf32 ( b_vals[ 0 ] ) , float_to_tf32 ( b_vals[ 1 ] ) , float_to_tf32 ( b_vals[ 2 ] ) , float_to_tf32 ( b_vals[ 3 ] ) ,
1400+ float_to_tf32 ( b_vals[ 4 ] ) , float_to_tf32 ( b_vals[ 5 ] ) , float_to_tf32 ( b_vals[ 6 ] ) , float_to_tf32 ( b_vals[ 7 ] ) ,
13951401 c_vals[ 0 ] , c_vals[ 1 ] , c_vals[ 2 ] , c_vals[ 3 ] , c_vals[ 4 ] , c_vals[ 5 ] , c_vals[ 6 ] , c_vals[ 7 ]
13961402 )
13971403 } ;
@@ -1418,8 +1424,10 @@ impl MmaWithShapeAndLayout<f32, f32, f32, dims::Shape<16, 16, 8>, layout::Col, l
14181424
14191425 let result_vals = unsafe {
14201426 wmma_mma_tf32_f32_col_col_m16n16k8 (
1421- a_vals[ 0 ] , a_vals[ 1 ] , a_vals[ 2 ] , a_vals[ 3 ] , a_vals[ 4 ] , a_vals[ 5 ] , a_vals[ 6 ] , a_vals[ 7 ] ,
1422- b_vals[ 0 ] , b_vals[ 1 ] , b_vals[ 2 ] , b_vals[ 3 ] , b_vals[ 4 ] , b_vals[ 5 ] , b_vals[ 6 ] , b_vals[ 7 ] ,
1427+ float_to_tf32 ( a_vals[ 0 ] ) , float_to_tf32 ( a_vals[ 1 ] ) , float_to_tf32 ( a_vals[ 2 ] ) , float_to_tf32 ( a_vals[ 3 ] ) ,
1428+ float_to_tf32 ( a_vals[ 4 ] ) , float_to_tf32 ( a_vals[ 5 ] ) , float_to_tf32 ( a_vals[ 6 ] ) , float_to_tf32 ( a_vals[ 7 ] ) ,
1429+ float_to_tf32 ( b_vals[ 0 ] ) , float_to_tf32 ( b_vals[ 1 ] ) , float_to_tf32 ( b_vals[ 2 ] ) , float_to_tf32 ( b_vals[ 3 ] ) ,
1430+ float_to_tf32 ( b_vals[ 4 ] ) , float_to_tf32 ( b_vals[ 5 ] ) , float_to_tf32 ( b_vals[ 6 ] ) , float_to_tf32 ( b_vals[ 7 ] ) ,
14231431 c_vals[ 0 ] , c_vals[ 1 ] , c_vals[ 2 ] , c_vals[ 3 ] , c_vals[ 4 ] , c_vals[ 5 ] , c_vals[ 6 ] , c_vals[ 7 ]
14241432 )
14251433 } ;
@@ -1429,6 +1437,175 @@ impl MmaWithShapeAndLayout<f32, f32, f32, dims::Shape<16, 16, 8>, layout::Col, l
14291437 }
14301438}
14311439
1440+ // i8 × i8 + i32 → i32 implementations for 32x8x16 (Row-Row only)
1441+ impl MmaWithShapeAndLayout < i8 , i8 , i32 , dims:: Shape < 32 , 8 , 16 > , layout:: Row , layout:: Row > for i32 {
1442+ type Output = i32 ;
1443+
1444+ #[ gpu_only]
1445+ fn mma (
1446+ a : & MatrixA < i8 , dims:: Shape < 32 , 8 , 16 > , layout:: Row > ,
1447+ b : & MatrixB < i8 , dims:: Shape < 32 , 8 , 16 > , layout:: Row > ,
1448+ c : & Accumulator < i32 , dims:: Shape < 32 , 8 , 16 > > ,
1449+ ) -> Accumulator < i32 , dims:: Shape < 32 , 8 , 16 > > {
1450+ let mut result = Accumulator :: new ( ) ;
1451+
1452+ let a_vals = unsafe { core:: mem:: transmute :: < [ <i8 as MatrixElement >:: Storage ; 4 ] , [ i32 ; 4 ] > (
1453+ * ( & a. data [ ..4 ] as * const [ <i8 as MatrixElement >:: Storage ] as * const [ <i8 as MatrixElement >:: Storage ; 4 ] )
1454+ ) } ;
1455+ let b_vals = unsafe { core:: mem:: transmute :: < [ <i8 as MatrixElement >:: Storage ; 2 ] , [ i32 ; 2 ] > (
1456+ * ( & b. data [ ..2 ] as * const [ <i8 as MatrixElement >:: Storage ] as * const [ <i8 as MatrixElement >:: Storage ; 2 ] )
1457+ ) } ;
1458+ let c_vals = unsafe { core:: mem:: transmute :: < [ <i32 as AccumulatorElement >:: Storage ; 8 ] , [ i32 ; 8 ] > (
1459+ * ( & c. data [ ..8 ] as * const [ <i32 as AccumulatorElement >:: Storage ] as * const [ <i32 as AccumulatorElement >:: Storage ; 8 ] )
1460+ ) } ;
1461+
1462+ let result_vals = unsafe {
1463+ wmma_mma_s8_s32_row_row_m32n8k16 (
1464+ a_vals[ 0 ] , a_vals[ 1 ] , a_vals[ 2 ] , a_vals[ 3 ] ,
1465+ b_vals[ 0 ] , b_vals[ 1 ] ,
1466+ c_vals[ 0 ] , c_vals[ 1 ] , c_vals[ 2 ] , c_vals[ 3 ] , c_vals[ 4 ] , c_vals[ 5 ] , c_vals[ 6 ] , c_vals[ 7 ]
1467+ )
1468+ } ;
1469+
1470+ result. data [ ..8 ] . copy_from_slice ( & unsafe { core:: mem:: transmute :: < [ i32 ; 8 ] , [ <i32 as AccumulatorElement >:: Storage ; 8 ] > ( result_vals) } ) ;
1471+ result
1472+ }
1473+ }
1474+
1475+ // u8 × u8 + i32 → i32 implementations for 32x8x16 (Row-Row only)
1476+ impl MmaWithShapeAndLayout < u8 , u8 , i32 , dims:: Shape < 32 , 8 , 16 > , layout:: Row , layout:: Row > for i32 {
1477+ type Output = i32 ;
1478+
1479+ #[ gpu_only]
1480+ fn mma (
1481+ a : & MatrixA < u8 , dims:: Shape < 32 , 8 , 16 > , layout:: Row > ,
1482+ b : & MatrixB < u8 , dims:: Shape < 32 , 8 , 16 > , layout:: Row > ,
1483+ c : & Accumulator < i32 , dims:: Shape < 32 , 8 , 16 > > ,
1484+ ) -> Accumulator < i32 , dims:: Shape < 32 , 8 , 16 > > {
1485+ let mut result = Accumulator :: new ( ) ;
1486+
1487+ let a_vals = unsafe { core:: mem:: transmute :: < [ <u8 as MatrixElement >:: Storage ; 4 ] , [ i32 ; 4 ] > (
1488+ * ( & a. data [ ..4 ] as * const [ <u8 as MatrixElement >:: Storage ] as * const [ <u8 as MatrixElement >:: Storage ; 4 ] )
1489+ ) } ;
1490+ let b_vals = unsafe { core:: mem:: transmute :: < [ <u8 as MatrixElement >:: Storage ; 2 ] , [ i32 ; 2 ] > (
1491+ * ( & b. data [ ..2 ] as * const [ <u8 as MatrixElement >:: Storage ] as * const [ <u8 as MatrixElement >:: Storage ; 2 ] )
1492+ ) } ;
1493+ let c_vals = unsafe { core:: mem:: transmute :: < [ <i32 as AccumulatorElement >:: Storage ; 8 ] , [ i32 ; 8 ] > (
1494+ * ( & c. data [ ..8 ] as * const [ <i32 as AccumulatorElement >:: Storage ] as * const [ <i32 as AccumulatorElement >:: Storage ; 8 ] )
1495+ ) } ;
1496+
1497+ let result_vals = unsafe {
1498+ wmma_mma_u8_s32_row_row_m32n8k16 (
1499+ a_vals[ 0 ] , a_vals[ 1 ] , a_vals[ 2 ] , a_vals[ 3 ] ,
1500+ b_vals[ 0 ] , b_vals[ 1 ] ,
1501+ c_vals[ 0 ] , c_vals[ 1 ] , c_vals[ 2 ] , c_vals[ 3 ] , c_vals[ 4 ] , c_vals[ 5 ] , c_vals[ 6 ] , c_vals[ 7 ]
1502+ )
1503+ } ;
1504+
1505+ result. data [ ..8 ] . copy_from_slice ( & unsafe { core:: mem:: transmute :: < [ i32 ; 8 ] , [ <i32 as AccumulatorElement >:: Storage ; 8 ] > ( result_vals) } ) ;
1506+ result
1507+ }
1508+ }
1509+
1510+ // i8 × i8 + i32 → i32 implementations for 8x32x16 (Row-Row only)
1511+ impl MmaWithShapeAndLayout < i8 , i8 , i32 , dims:: Shape < 8 , 32 , 16 > , layout:: Row , layout:: Row > for i32 {
1512+ type Output = i32 ;
1513+
1514+ #[ gpu_only]
1515+ fn mma (
1516+ a : & MatrixA < i8 , dims:: Shape < 8 , 32 , 16 > , layout:: Row > ,
1517+ b : & MatrixB < i8 , dims:: Shape < 8 , 32 , 16 > , layout:: Row > ,
1518+ c : & Accumulator < i32 , dims:: Shape < 8 , 32 , 16 > > ,
1519+ ) -> Accumulator < i32 , dims:: Shape < 8 , 32 , 16 > > {
1520+ let mut result = Accumulator :: new ( ) ;
1521+
1522+ let a_vals = unsafe { core:: mem:: transmute :: < [ <i8 as MatrixElement >:: Storage ; 2 ] , [ i32 ; 2 ] > (
1523+ * ( & a. data [ ..2 ] as * const [ <i8 as MatrixElement >:: Storage ] as * const [ <i8 as MatrixElement >:: Storage ; 2 ] )
1524+ ) } ;
1525+ let b_vals = unsafe { core:: mem:: transmute :: < [ <i8 as MatrixElement >:: Storage ; 4 ] , [ i32 ; 4 ] > (
1526+ * ( & b. data [ ..4 ] as * const [ <i8 as MatrixElement >:: Storage ] as * const [ <i8 as MatrixElement >:: Storage ; 4 ] )
1527+ ) } ;
1528+ let c_vals = unsafe { core:: mem:: transmute :: < [ <i32 as AccumulatorElement >:: Storage ; 8 ] , [ i32 ; 8 ] > (
1529+ * ( & c. data [ ..8 ] as * const [ <i32 as AccumulatorElement >:: Storage ] as * const [ <i32 as AccumulatorElement >:: Storage ; 8 ] )
1530+ ) } ;
1531+
1532+ let result_vals = unsafe {
1533+ wmma_mma_s8_s32_row_row_m8n32k16 (
1534+ a_vals[ 0 ] , a_vals[ 1 ] ,
1535+ b_vals[ 0 ] , b_vals[ 1 ] , b_vals[ 2 ] , b_vals[ 3 ] ,
1536+ c_vals[ 0 ] , c_vals[ 1 ] , c_vals[ 2 ] , c_vals[ 3 ] , c_vals[ 4 ] , c_vals[ 5 ] , c_vals[ 6 ] , c_vals[ 7 ]
1537+ )
1538+ } ;
1539+
1540+ result. data [ ..8 ] . copy_from_slice ( & unsafe { core:: mem:: transmute :: < [ i32 ; 8 ] , [ <i32 as AccumulatorElement >:: Storage ; 8 ] > ( result_vals) } ) ;
1541+ result
1542+ }
1543+ }
1544+
1545+ // u8 × u8 + i32 → i32 implementations for 8x32x16 (Row-Row only)
1546+ impl MmaWithShapeAndLayout < u8 , u8 , i32 , dims:: Shape < 8 , 32 , 16 > , layout:: Row , layout:: Row > for i32 {
1547+ type Output = i32 ;
1548+
1549+ #[ gpu_only]
1550+ fn mma (
1551+ a : & MatrixA < u8 , dims:: Shape < 8 , 32 , 16 > , layout:: Row > ,
1552+ b : & MatrixB < u8 , dims:: Shape < 8 , 32 , 16 > , layout:: Row > ,
1553+ c : & Accumulator < i32 , dims:: Shape < 8 , 32 , 16 > > ,
1554+ ) -> Accumulator < i32 , dims:: Shape < 8 , 32 , 16 > > {
1555+ let mut result = Accumulator :: new ( ) ;
1556+
1557+ let a_vals = unsafe { core:: mem:: transmute :: < [ <u8 as MatrixElement >:: Storage ; 2 ] , [ i32 ; 2 ] > (
1558+ * ( & a. data [ ..2 ] as * const [ <u8 as MatrixElement >:: Storage ] as * const [ <u8 as MatrixElement >:: Storage ; 2 ] )
1559+ ) } ;
1560+ let b_vals = unsafe { core:: mem:: transmute :: < [ <u8 as MatrixElement >:: Storage ; 4 ] , [ i32 ; 4 ] > (
1561+ * ( & b. data [ ..4 ] as * const [ <u8 as MatrixElement >:: Storage ] as * const [ <u8 as MatrixElement >:: Storage ; 4 ] )
1562+ ) } ;
1563+ let c_vals = unsafe { core:: mem:: transmute :: < [ <i32 as AccumulatorElement >:: Storage ; 8 ] , [ i32 ; 8 ] > (
1564+ * ( & c. data [ ..8 ] as * const [ <i32 as AccumulatorElement >:: Storage ] as * const [ <i32 as AccumulatorElement >:: Storage ; 8 ] )
1565+ ) } ;
1566+
1567+ let result_vals = unsafe {
1568+ wmma_mma_u8_s32_row_row_m8n32k16 (
1569+ a_vals[ 0 ] , a_vals[ 1 ] ,
1570+ b_vals[ 0 ] , b_vals[ 1 ] , b_vals[ 2 ] , b_vals[ 3 ] ,
1571+ c_vals[ 0 ] , c_vals[ 1 ] , c_vals[ 2 ] , c_vals[ 3 ] , c_vals[ 4 ] , c_vals[ 5 ] , c_vals[ 6 ] , c_vals[ 7 ]
1572+ )
1573+ } ;
1574+
1575+ result. data [ ..8 ] . copy_from_slice ( & unsafe { core:: mem:: transmute :: < [ i32 ; 8 ] , [ <i32 as AccumulatorElement >:: Storage ; 8 ] > ( result_vals) } ) ;
1576+ result
1577+ }
1578+ }
1579+
1580+ // f64 × f64 + f64 → f64 implementations for 8x8x4 (Row-Row only)
1581+ impl MmaWithShapeAndLayout < f64 , f64 , f64 , dims:: Shape < 8 , 8 , 4 > , layout:: Row , layout:: Row > for f64 {
1582+ type Output = f64 ;
1583+
1584+ #[ gpu_only]
1585+ fn mma (
1586+ a : & MatrixA < f64 , dims:: Shape < 8 , 8 , 4 > , layout:: Row > ,
1587+ b : & MatrixB < f64 , dims:: Shape < 8 , 8 , 4 > , layout:: Row > ,
1588+ c : & Accumulator < f64 , dims:: Shape < 8 , 8 , 4 > > ,
1589+ ) -> Accumulator < f64 , dims:: Shape < 8 , 8 , 4 > > {
1590+ let mut result = Accumulator :: new ( ) ;
1591+
1592+ let a_vals = unsafe { * ( & a. data [ ..2 ] as * const [ f64 ] as * const [ f64 ; 2 ] ) } ;
1593+ let b_vals = unsafe { * ( & b. data [ ..2 ] as * const [ f64 ] as * const [ f64 ; 2 ] ) } ;
1594+ let c_vals = unsafe { * ( & c. data [ ..2 ] as * const [ f64 ] as * const [ f64 ; 2 ] ) } ;
1595+
1596+ let result_vals = unsafe {
1597+ wmma_mma_f64_row_row_m8n8k4 (
1598+ a_vals[ 0 ] , a_vals[ 1 ] ,
1599+ b_vals[ 0 ] , b_vals[ 1 ] ,
1600+ c_vals[ 0 ] , c_vals[ 1 ]
1601+ )
1602+ } ;
1603+
1604+ result. data [ ..2 ] . copy_from_slice ( & result_vals) ;
1605+ result
1606+ }
1607+ }
1608+
14321609// ============================================================================
14331610// Stride Validation
14341611// ============================================================================
0 commit comments