Skip to content

Commit 8075a6a

Browse files
committed
Working! untested runtime
1 parent fe06e38 commit 8075a6a

File tree

1 file changed

+185
-8
lines changed

1 file changed

+185
-8
lines changed

crates/cuda_std/src/warp/matrix.rs

Lines changed: 185 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,8 +1334,10 @@ impl MmaWithShapeAndLayout<f32, f32, f32, dims::Shape<16, 16, 8>, layout::Row, l
13341334

13351335
let result_vals = unsafe {
13361336
wmma_mma_tf32_f32_row_row_m16n16k8(
1337-
a_vals[0], a_vals[1], a_vals[2], a_vals[3], a_vals[4], a_vals[5], a_vals[6], a_vals[7],
1338-
b_vals[0], b_vals[1], b_vals[2], b_vals[3], b_vals[4], b_vals[5], b_vals[6], b_vals[7],
1337+
float_to_tf32(a_vals[0]), float_to_tf32(a_vals[1]), float_to_tf32(a_vals[2]), float_to_tf32(a_vals[3]),
1338+
float_to_tf32(a_vals[4]), float_to_tf32(a_vals[5]), float_to_tf32(a_vals[6]), float_to_tf32(a_vals[7]),
1339+
float_to_tf32(b_vals[0]), float_to_tf32(b_vals[1]), float_to_tf32(b_vals[2]), float_to_tf32(b_vals[3]),
1340+
float_to_tf32(b_vals[4]), float_to_tf32(b_vals[5]), float_to_tf32(b_vals[6]), float_to_tf32(b_vals[7]),
13391341
c_vals[0], c_vals[1], c_vals[2], c_vals[3], c_vals[4], c_vals[5], c_vals[6], c_vals[7]
13401342
)
13411343
};
@@ -1362,8 +1364,10 @@ impl MmaWithShapeAndLayout<f32, f32, f32, dims::Shape<16, 16, 8>, layout::Row, l
13621364

13631365
let result_vals = unsafe {
13641366
wmma_mma_tf32_f32_row_col_m16n16k8(
1365-
a_vals[0], a_vals[1], a_vals[2], a_vals[3], a_vals[4], a_vals[5], a_vals[6], a_vals[7],
1366-
b_vals[0], b_vals[1], b_vals[2], b_vals[3], b_vals[4], b_vals[5], b_vals[6], b_vals[7],
1367+
float_to_tf32(a_vals[0]), float_to_tf32(a_vals[1]), float_to_tf32(a_vals[2]), float_to_tf32(a_vals[3]),
1368+
float_to_tf32(a_vals[4]), float_to_tf32(a_vals[5]), float_to_tf32(a_vals[6]), float_to_tf32(a_vals[7]),
1369+
float_to_tf32(b_vals[0]), float_to_tf32(b_vals[1]), float_to_tf32(b_vals[2]), float_to_tf32(b_vals[3]),
1370+
float_to_tf32(b_vals[4]), float_to_tf32(b_vals[5]), float_to_tf32(b_vals[6]), float_to_tf32(b_vals[7]),
13671371
c_vals[0], c_vals[1], c_vals[2], c_vals[3], c_vals[4], c_vals[5], c_vals[6], c_vals[7]
13681372
)
13691373
};
@@ -1390,8 +1394,10 @@ impl MmaWithShapeAndLayout<f32, f32, f32, dims::Shape<16, 16, 8>, layout::Col, l
13901394

13911395
let result_vals = unsafe {
13921396
wmma_mma_tf32_f32_col_row_m16n16k8(
1393-
a_vals[0], a_vals[1], a_vals[2], a_vals[3], a_vals[4], a_vals[5], a_vals[6], a_vals[7],
1394-
b_vals[0], b_vals[1], b_vals[2], b_vals[3], b_vals[4], b_vals[5], b_vals[6], b_vals[7],
1397+
float_to_tf32(a_vals[0]), float_to_tf32(a_vals[1]), float_to_tf32(a_vals[2]), float_to_tf32(a_vals[3]),
1398+
float_to_tf32(a_vals[4]), float_to_tf32(a_vals[5]), float_to_tf32(a_vals[6]), float_to_tf32(a_vals[7]),
1399+
float_to_tf32(b_vals[0]), float_to_tf32(b_vals[1]), float_to_tf32(b_vals[2]), float_to_tf32(b_vals[3]),
1400+
float_to_tf32(b_vals[4]), float_to_tf32(b_vals[5]), float_to_tf32(b_vals[6]), float_to_tf32(b_vals[7]),
13951401
c_vals[0], c_vals[1], c_vals[2], c_vals[3], c_vals[4], c_vals[5], c_vals[6], c_vals[7]
13961402
)
13971403
};
@@ -1418,8 +1424,10 @@ impl MmaWithShapeAndLayout<f32, f32, f32, dims::Shape<16, 16, 8>, layout::Col, l
14181424

14191425
let result_vals = unsafe {
14201426
wmma_mma_tf32_f32_col_col_m16n16k8(
1421-
a_vals[0], a_vals[1], a_vals[2], a_vals[3], a_vals[4], a_vals[5], a_vals[6], a_vals[7],
1422-
b_vals[0], b_vals[1], b_vals[2], b_vals[3], b_vals[4], b_vals[5], b_vals[6], b_vals[7],
1427+
float_to_tf32(a_vals[0]), float_to_tf32(a_vals[1]), float_to_tf32(a_vals[2]), float_to_tf32(a_vals[3]),
1428+
float_to_tf32(a_vals[4]), float_to_tf32(a_vals[5]), float_to_tf32(a_vals[6]), float_to_tf32(a_vals[7]),
1429+
float_to_tf32(b_vals[0]), float_to_tf32(b_vals[1]), float_to_tf32(b_vals[2]), float_to_tf32(b_vals[3]),
1430+
float_to_tf32(b_vals[4]), float_to_tf32(b_vals[5]), float_to_tf32(b_vals[6]), float_to_tf32(b_vals[7]),
14231431
c_vals[0], c_vals[1], c_vals[2], c_vals[3], c_vals[4], c_vals[5], c_vals[6], c_vals[7]
14241432
)
14251433
};
@@ -1429,6 +1437,175 @@ impl MmaWithShapeAndLayout<f32, f32, f32, dims::Shape<16, 16, 8>, layout::Col, l
14291437
}
14301438
}
14311439

1440+
// i8 × i8 + i32 → i32 implementations for 32x8x16 (Row-Row only)
1441+
impl MmaWithShapeAndLayout<i8, i8, i32, dims::Shape<32, 8, 16>, layout::Row, layout::Row> for i32 {
1442+
type Output = i32;
1443+
1444+
#[gpu_only]
1445+
fn mma(
1446+
a: &MatrixA<i8, dims::Shape<32, 8, 16>, layout::Row>,
1447+
b: &MatrixB<i8, dims::Shape<32, 8, 16>, layout::Row>,
1448+
c: &Accumulator<i32, dims::Shape<32, 8, 16>>,
1449+
) -> Accumulator<i32, dims::Shape<32, 8, 16>> {
1450+
let mut result = Accumulator::new();
1451+
1452+
let a_vals = unsafe { core::mem::transmute::<[<i8 as MatrixElement>::Storage; 4], [i32; 4]>(
1453+
*(&a.data[..4] as *const [<i8 as MatrixElement>::Storage] as *const [<i8 as MatrixElement>::Storage; 4])
1454+
)};
1455+
let b_vals = unsafe { core::mem::transmute::<[<i8 as MatrixElement>::Storage; 2], [i32; 2]>(
1456+
*(&b.data[..2] as *const [<i8 as MatrixElement>::Storage] as *const [<i8 as MatrixElement>::Storage; 2])
1457+
)};
1458+
let c_vals = unsafe { core::mem::transmute::<[<i32 as AccumulatorElement>::Storage; 8], [i32; 8]>(
1459+
*(&c.data[..8] as *const [<i32 as AccumulatorElement>::Storage] as *const [<i32 as AccumulatorElement>::Storage; 8])
1460+
)};
1461+
1462+
let result_vals = unsafe {
1463+
wmma_mma_s8_s32_row_row_m32n8k16(
1464+
a_vals[0], a_vals[1], a_vals[2], a_vals[3],
1465+
b_vals[0], b_vals[1],
1466+
c_vals[0], c_vals[1], c_vals[2], c_vals[3], c_vals[4], c_vals[5], c_vals[6], c_vals[7]
1467+
)
1468+
};
1469+
1470+
result.data[..8].copy_from_slice(&unsafe { core::mem::transmute::<[i32; 8], [<i32 as AccumulatorElement>::Storage; 8]>(result_vals) });
1471+
result
1472+
}
1473+
}
1474+
1475+
// u8 × u8 + i32 → i32 implementations for 32x8x16 (Row-Row only)
1476+
impl MmaWithShapeAndLayout<u8, u8, i32, dims::Shape<32, 8, 16>, layout::Row, layout::Row> for i32 {
1477+
type Output = i32;
1478+
1479+
#[gpu_only]
1480+
fn mma(
1481+
a: &MatrixA<u8, dims::Shape<32, 8, 16>, layout::Row>,
1482+
b: &MatrixB<u8, dims::Shape<32, 8, 16>, layout::Row>,
1483+
c: &Accumulator<i32, dims::Shape<32, 8, 16>>,
1484+
) -> Accumulator<i32, dims::Shape<32, 8, 16>> {
1485+
let mut result = Accumulator::new();
1486+
1487+
let a_vals = unsafe { core::mem::transmute::<[<u8 as MatrixElement>::Storage; 4], [i32; 4]>(
1488+
*(&a.data[..4] as *const [<u8 as MatrixElement>::Storage] as *const [<u8 as MatrixElement>::Storage; 4])
1489+
)};
1490+
let b_vals = unsafe { core::mem::transmute::<[<u8 as MatrixElement>::Storage; 2], [i32; 2]>(
1491+
*(&b.data[..2] as *const [<u8 as MatrixElement>::Storage] as *const [<u8 as MatrixElement>::Storage; 2])
1492+
)};
1493+
let c_vals = unsafe { core::mem::transmute::<[<i32 as AccumulatorElement>::Storage; 8], [i32; 8]>(
1494+
*(&c.data[..8] as *const [<i32 as AccumulatorElement>::Storage] as *const [<i32 as AccumulatorElement>::Storage; 8])
1495+
)};
1496+
1497+
let result_vals = unsafe {
1498+
wmma_mma_u8_s32_row_row_m32n8k16(
1499+
a_vals[0], a_vals[1], a_vals[2], a_vals[3],
1500+
b_vals[0], b_vals[1],
1501+
c_vals[0], c_vals[1], c_vals[2], c_vals[3], c_vals[4], c_vals[5], c_vals[6], c_vals[7]
1502+
)
1503+
};
1504+
1505+
result.data[..8].copy_from_slice(&unsafe { core::mem::transmute::<[i32; 8], [<i32 as AccumulatorElement>::Storage; 8]>(result_vals) });
1506+
result
1507+
}
1508+
}
1509+
1510+
// i8 × i8 + i32 → i32 implementations for 8x32x16 (Row-Row only)
1511+
impl MmaWithShapeAndLayout<i8, i8, i32, dims::Shape<8, 32, 16>, layout::Row, layout::Row> for i32 {
1512+
type Output = i32;
1513+
1514+
#[gpu_only]
1515+
fn mma(
1516+
a: &MatrixA<i8, dims::Shape<8, 32, 16>, layout::Row>,
1517+
b: &MatrixB<i8, dims::Shape<8, 32, 16>, layout::Row>,
1518+
c: &Accumulator<i32, dims::Shape<8, 32, 16>>,
1519+
) -> Accumulator<i32, dims::Shape<8, 32, 16>> {
1520+
let mut result = Accumulator::new();
1521+
1522+
let a_vals = unsafe { core::mem::transmute::<[<i8 as MatrixElement>::Storage; 2], [i32; 2]>(
1523+
*(&a.data[..2] as *const [<i8 as MatrixElement>::Storage] as *const [<i8 as MatrixElement>::Storage; 2])
1524+
)};
1525+
let b_vals = unsafe { core::mem::transmute::<[<i8 as MatrixElement>::Storage; 4], [i32; 4]>(
1526+
*(&b.data[..4] as *const [<i8 as MatrixElement>::Storage] as *const [<i8 as MatrixElement>::Storage; 4])
1527+
)};
1528+
let c_vals = unsafe { core::mem::transmute::<[<i32 as AccumulatorElement>::Storage; 8], [i32; 8]>(
1529+
*(&c.data[..8] as *const [<i32 as AccumulatorElement>::Storage] as *const [<i32 as AccumulatorElement>::Storage; 8])
1530+
)};
1531+
1532+
let result_vals = unsafe {
1533+
wmma_mma_s8_s32_row_row_m8n32k16(
1534+
a_vals[0], a_vals[1],
1535+
b_vals[0], b_vals[1], b_vals[2], b_vals[3],
1536+
c_vals[0], c_vals[1], c_vals[2], c_vals[3], c_vals[4], c_vals[5], c_vals[6], c_vals[7]
1537+
)
1538+
};
1539+
1540+
result.data[..8].copy_from_slice(&unsafe { core::mem::transmute::<[i32; 8], [<i32 as AccumulatorElement>::Storage; 8]>(result_vals) });
1541+
result
1542+
}
1543+
}
1544+
1545+
// u8 × u8 + i32 → i32 implementations for 8x32x16 (Row-Row only)
1546+
impl MmaWithShapeAndLayout<u8, u8, i32, dims::Shape<8, 32, 16>, layout::Row, layout::Row> for i32 {
1547+
type Output = i32;
1548+
1549+
#[gpu_only]
1550+
fn mma(
1551+
a: &MatrixA<u8, dims::Shape<8, 32, 16>, layout::Row>,
1552+
b: &MatrixB<u8, dims::Shape<8, 32, 16>, layout::Row>,
1553+
c: &Accumulator<i32, dims::Shape<8, 32, 16>>,
1554+
) -> Accumulator<i32, dims::Shape<8, 32, 16>> {
1555+
let mut result = Accumulator::new();
1556+
1557+
let a_vals = unsafe { core::mem::transmute::<[<u8 as MatrixElement>::Storage; 2], [i32; 2]>(
1558+
*(&a.data[..2] as *const [<u8 as MatrixElement>::Storage] as *const [<u8 as MatrixElement>::Storage; 2])
1559+
)};
1560+
let b_vals = unsafe { core::mem::transmute::<[<u8 as MatrixElement>::Storage; 4], [i32; 4]>(
1561+
*(&b.data[..4] as *const [<u8 as MatrixElement>::Storage] as *const [<u8 as MatrixElement>::Storage; 4])
1562+
)};
1563+
let c_vals = unsafe { core::mem::transmute::<[<i32 as AccumulatorElement>::Storage; 8], [i32; 8]>(
1564+
*(&c.data[..8] as *const [<i32 as AccumulatorElement>::Storage] as *const [<i32 as AccumulatorElement>::Storage; 8])
1565+
)};
1566+
1567+
let result_vals = unsafe {
1568+
wmma_mma_u8_s32_row_row_m8n32k16(
1569+
a_vals[0], a_vals[1],
1570+
b_vals[0], b_vals[1], b_vals[2], b_vals[3],
1571+
c_vals[0], c_vals[1], c_vals[2], c_vals[3], c_vals[4], c_vals[5], c_vals[6], c_vals[7]
1572+
)
1573+
};
1574+
1575+
result.data[..8].copy_from_slice(&unsafe { core::mem::transmute::<[i32; 8], [<i32 as AccumulatorElement>::Storage; 8]>(result_vals) });
1576+
result
1577+
}
1578+
}
1579+
1580+
// f64 × f64 + f64 → f64 implementations for 8x8x4 (Row-Row only)
1581+
impl MmaWithShapeAndLayout<f64, f64, f64, dims::Shape<8, 8, 4>, layout::Row, layout::Row> for f64 {
1582+
type Output = f64;
1583+
1584+
#[gpu_only]
1585+
fn mma(
1586+
a: &MatrixA<f64, dims::Shape<8, 8, 4>, layout::Row>,
1587+
b: &MatrixB<f64, dims::Shape<8, 8, 4>, layout::Row>,
1588+
c: &Accumulator<f64, dims::Shape<8, 8, 4>>,
1589+
) -> Accumulator<f64, dims::Shape<8, 8, 4>> {
1590+
let mut result = Accumulator::new();
1591+
1592+
let a_vals = unsafe { *(&a.data[..2] as *const [f64] as *const [f64; 2]) };
1593+
let b_vals = unsafe { *(&b.data[..2] as *const [f64] as *const [f64; 2]) };
1594+
let c_vals = unsafe { *(&c.data[..2] as *const [f64] as *const [f64; 2]) };
1595+
1596+
let result_vals = unsafe {
1597+
wmma_mma_f64_row_row_m8n8k4(
1598+
a_vals[0], a_vals[1],
1599+
b_vals[0], b_vals[1],
1600+
c_vals[0], c_vals[1]
1601+
)
1602+
};
1603+
1604+
result.data[..2].copy_from_slice(&result_vals);
1605+
result
1606+
}
1607+
}
1608+
14321609
// ============================================================================
14331610
// Stride Validation
14341611
// ============================================================================

0 commit comments

Comments
 (0)