Skip to content

Commit ea1d46c

Browse files
committed
Autogenerate a lot of strides
1 parent b3a82d5 commit ea1d46c

File tree

16 files changed

+225
-75
lines changed

16 files changed

+225
-75
lines changed

crates/cuda_std/build.rs

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
use std::env;
2+
use std::fs;
3+
use std::io::Write;
4+
use std::path::Path;
5+
6+
/// Maximum stride value to support
7+
const MAX_STRIDE: usize = 8192;
8+
9+
fn main() {
10+
// Tell cargo to rerun if build.rs changes
11+
println!("cargo:rerun-if-changed=build.rs");
12+
13+
// Generate stride implementations
14+
generate_stride_impls();
15+
}
16+
17+
fn generate_stride_impls() {
18+
let out_dir = env::var_os("OUT_DIR").unwrap();
19+
let out_path = Path::new(&out_dir);
20+
21+
// Create stride subdirectory in OUT_DIR
22+
let stride_dir = out_path.join("stride");
23+
fs::create_dir_all(&stride_dir).unwrap();
24+
25+
// Define stride requirements for each type
26+
// (type_name, file_name, start, step)
27+
let stride_configs = vec![
28+
// 16-byte alignment types (multiple of element count to fill 16 bytes)
29+
("f16", "f16", 8, 8), // f16: 2 bytes × 8 = 16 bytes
30+
("bf16", "bf16", 8, 8), // bf16: 2 bytes × 8 = 16 bytes
31+
("i16", "i16", 8, 8), // i16: 2 bytes × 8 = 16 bytes
32+
("u16", "u16", 8, 8), // u16: 2 bytes × 8 = 16 bytes
33+
("f32", "f32", 4, 4), // f32: 4 bytes × 4 = 16 bytes
34+
("i32", "i32", 4, 4), // i32: 4 bytes × 4 = 16 bytes
35+
("u32", "u32", 4, 4), // u32: 4 bytes × 4 = 16 bytes
36+
("f64", "f64", 2, 2), // f64: 8 bytes × 2 = 16 bytes
37+
("i64", "i64", 2, 2), // i64: 8 bytes × 2 = 16 bytes
38+
("u64", "u64", 2, 2), // u64: 8 bytes × 2 = 16 bytes
39+
("i8", "i8", 16, 16), // i8: 1 byte × 16 = 16 bytes
40+
("u8", "u8", 16, 16), // u8: 1 byte × 16 = 16 bytes
41+
("bool", "bool", 16, 16), // bool: 1 byte × 16 = 16 bytes
42+
];
43+
44+
for (type_name, file_name, start, step) in stride_configs {
45+
generate_type_strides(&stride_dir, type_name, file_name, start, step);
46+
}
47+
}
48+
49+
fn generate_type_strides(
50+
out_dir: &Path,
51+
type_name: &str,
52+
file_name: &str,
53+
start: usize,
54+
step: usize,
55+
) {
56+
let dest_path = out_dir.join(format!("{}_impls.rs", file_name));
57+
let mut file = fs::File::create(&dest_path).unwrap();
58+
59+
// Write header
60+
writeln!(file, "// @generated").unwrap();
61+
writeln!(file, "// Auto-generated stride implementations for {}", type_name).unwrap();
62+
writeln!(file, "// DO NOT EDIT THIS FILE MANUALLY").unwrap();
63+
writeln!(file).unwrap();
64+
writeln!(file, "// {} requires stride to be multiple of {} (up to {})",
65+
type_name, step, MAX_STRIDE).unwrap();
66+
writeln!(file).unwrap();
67+
68+
// Generate implementations
69+
// These files are included in modules that already have the necessary imports
70+
let mut stride = start;
71+
while stride <= MAX_STRIDE {
72+
writeln!(file, "impl ValidStride for StrideValidator<{}, {}> {{}}",
73+
type_name, stride).unwrap();
74+
stride += step;
75+
}
76+
}

crates/cuda_std/src/warp/matrix/mod.rs

Lines changed: 7 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ pub mod ops;
1717
mod intrinsics;
1818
use intrinsics::*;
1919

20+
// Stride validation module
21+
pub mod stride;
22+
// Re-export for convenience
23+
pub use stride::{StrideValidator, ValidStride};
24+
2025
// ============================================================================
2126
// Shape Types with Compile-Time Validation
2227
// ============================================================================
@@ -1749,80 +1754,7 @@ impl MmaWithShapeAndLayout<f64, f64, f64, dims::Shape<8, 8, 4>, layout::Row, lay
17491754
}
17501755
}
17511756

1752-
// ============================================================================
1753-
// Stride Validation
1754-
// ============================================================================
1755-
1756-
/// Compile-time stride validation
1757-
pub struct StrideValidator<T, const STRIDE: usize>(PhantomData<T>);
1758-
1759-
/// Trait for valid strides
1760-
#[diagnostic::on_unimplemented(
1761-
message = "`{Self}` is not a valid stride configuration",
1762-
label = "invalid stride for tensor core operations",
1763-
note = "f16/bf16 require stride to be a multiple of 8",
1764-
note = "f32/i32 require stride to be a multiple of 4",
1765-
note = "i8/u8/bool require stride to be a multiple of 16",
1766-
note = "f64 requires stride to be a multiple of 2"
1767-
)]
1768-
pub trait ValidStride: sealed::Sealed {}
1769-
1770-
// f16 requires stride to be multiple of 8
1771-
impl ValidStride for StrideValidator<f16, 8> {}
1772-
impl ValidStride for StrideValidator<f16, 16> {}
1773-
impl ValidStride for StrideValidator<f16, 24> {}
1774-
impl ValidStride for StrideValidator<f16, 32> {}
1775-
impl ValidStride for StrideValidator<f16, 40> {}
1776-
impl ValidStride for StrideValidator<f16, 48> {}
1777-
impl ValidStride for StrideValidator<f16, 56> {}
1778-
impl ValidStride for StrideValidator<f16, 64> {}
1779-
1780-
// f32 requires stride to be multiple of 4
1781-
impl ValidStride for StrideValidator<f32, 4> {}
1782-
impl ValidStride for StrideValidator<f32, 8> {}
1783-
impl ValidStride for StrideValidator<f32, 12> {}
1784-
impl ValidStride for StrideValidator<f32, 16> {}
1785-
impl ValidStride for StrideValidator<f32, 20> {}
1786-
impl ValidStride for StrideValidator<f32, 24> {}
1787-
impl ValidStride for StrideValidator<f32, 28> {}
1788-
impl ValidStride for StrideValidator<f32, 32> {}
1789-
1790-
// bf16 requires stride to be multiple of 8 (same as f16)
1791-
impl ValidStride for StrideValidator<bf16, 8> {}
1792-
impl ValidStride for StrideValidator<bf16, 16> {}
1793-
impl ValidStride for StrideValidator<bf16, 24> {}
1794-
impl ValidStride for StrideValidator<bf16, 32> {}
1795-
impl ValidStride for StrideValidator<bf16, 40> {}
1796-
impl ValidStride for StrideValidator<bf16, 48> {}
1797-
impl ValidStride for StrideValidator<bf16, 56> {}
1798-
impl ValidStride for StrideValidator<bf16, 64> {}
1799-
1800-
// i8/u8 require stride to be multiple of 16
1801-
impl ValidStride for StrideValidator<i8, 16> {}
1802-
impl ValidStride for StrideValidator<i8, 32> {}
1803-
impl ValidStride for StrideValidator<i8, 48> {}
1804-
impl ValidStride for StrideValidator<i8, 64> {}
1805-
1806-
impl ValidStride for StrideValidator<u8, 16> {}
1807-
impl ValidStride for StrideValidator<u8, 32> {}
1808-
impl ValidStride for StrideValidator<u8, 48> {}
1809-
impl ValidStride for StrideValidator<u8, 64> {}
1810-
1811-
// i32 requires stride to be multiple of 4 (same as f32)
1812-
impl ValidStride for StrideValidator<i32, 4> {}
1813-
impl ValidStride for StrideValidator<i32, 8> {}
1814-
impl ValidStride for StrideValidator<i32, 12> {}
1815-
impl ValidStride for StrideValidator<i32, 16> {}
1816-
impl ValidStride for StrideValidator<i32, 20> {}
1817-
impl ValidStride for StrideValidator<i32, 24> {}
1818-
impl ValidStride for StrideValidator<i32, 28> {}
1819-
impl ValidStride for StrideValidator<i32, 32> {}
1820-
1821-
// bool uses u8 stride requirements (multiple of 16)
1822-
impl ValidStride for StrideValidator<bool, 16> {}
1823-
impl ValidStride for StrideValidator<bool, 32> {}
1824-
impl ValidStride for StrideValidator<bool, 48> {}
1825-
impl ValidStride for StrideValidator<bool, 64> {}
1757+
// Stride validation types and implementations are now in the stride module
18261758

18271759
// ============================================================================
18281760
// Ergonomic Builder API
@@ -1990,5 +1922,5 @@ mod sealed {
19901922
impl Sealed for bool {}
19911923

19921924
// Seal stride validators
1993-
impl<T, const S: usize> Sealed for StrideValidator<T, S> {}
1925+
// StrideValidator sealing is now in stride module
19941926
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
//! Stride implementations for bf16 type
2+
3+
use super::{StrideValidator, ValidStride};
4+
use crate::bf16;
5+
6+
// Include the auto-generated implementations
7+
include!(concat!(env!("OUT_DIR"), "/stride/bf16_impls.rs"));
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
//! Stride implementations for bool type
2+
3+
use super::{StrideValidator, ValidStride};
4+
5+
// Include the auto-generated implementations
6+
include!(concat!(env!("OUT_DIR"), "/stride/bool_impls.rs"));
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
//! Stride implementations for f16 type
2+
3+
use super::{StrideValidator, ValidStride};
4+
use crate::f16;
5+
6+
// Include the auto-generated implementations
7+
include!(concat!(env!("OUT_DIR"), "/stride/f16_impls.rs"));
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
//! Stride implementations for f32 type
2+
3+
use super::{StrideValidator, ValidStride};
4+
5+
// Include the auto-generated implementations
6+
include!(concat!(env!("OUT_DIR"), "/stride/f32_impls.rs"));
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
//! Stride implementations for f64 type
2+
3+
use super::{StrideValidator, ValidStride};
4+
5+
// Include the auto-generated implementations
6+
include!(concat!(env!("OUT_DIR"), "/stride/f64_impls.rs"));
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
//! Stride implementations for i16 type
2+
3+
use super::{StrideValidator, ValidStride};
4+
5+
// Include the auto-generated implementations
6+
include!(concat!(env!("OUT_DIR"), "/stride/i16_impls.rs"));
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
//! Stride implementations for i32 type
2+
3+
use super::{StrideValidator, ValidStride};
4+
5+
// Include the auto-generated implementations
6+
include!(concat!(env!("OUT_DIR"), "/stride/i32_impls.rs"));
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
//! Stride implementations for i64 type
2+
3+
use super::{StrideValidator, ValidStride};
4+
5+
// Include the auto-generated implementations
6+
include!(concat!(env!("OUT_DIR"), "/stride/i64_impls.rs"));

0 commit comments

Comments
 (0)