@@ -9,38 +9,38 @@ const MAX_STRIDE: usize = 8192;
99fn main ( ) {
1010 // Tell cargo to rerun if build.rs changes
1111 println ! ( "cargo:rerun-if-changed=build.rs" ) ;
12-
12+
1313 // Generate stride implementations
1414 generate_stride_impls ( ) ;
1515}
1616
1717fn generate_stride_impls ( ) {
1818 let out_dir = env:: var_os ( "OUT_DIR" ) . unwrap ( ) ;
1919 let out_path = Path :: new ( & out_dir) ;
20-
20+
2121 // Create stride subdirectory in OUT_DIR
2222 let stride_dir = out_path. join ( "stride" ) ;
2323 fs:: create_dir_all ( & stride_dir) . unwrap ( ) ;
24-
24+
2525 // Define stride requirements for each type
2626 // (type_name, file_name, start, step)
2727 let stride_configs = vec ! [
2828 // 16-byte alignment types (multiple of element count to fill 16 bytes)
29- ( "f16" , "f16" , 8 , 8 ) , // f16: 2 bytes × 8 = 16 bytes
30- ( "bf16" , "bf16" , 8 , 8 ) , // bf16: 2 bytes × 8 = 16 bytes
31- ( "i16" , "i16" , 8 , 8 ) , // i16: 2 bytes × 8 = 16 bytes
32- ( "u16" , "u16" , 8 , 8 ) , // u16: 2 bytes × 8 = 16 bytes
33- ( "f32" , "f32" , 4 , 4 ) , // f32: 4 bytes × 4 = 16 bytes
34- ( "i32" , "i32" , 4 , 4 ) , // i32: 4 bytes × 4 = 16 bytes
35- ( "u32" , "u32" , 4 , 4 ) , // u32: 4 bytes × 4 = 16 bytes
36- ( "f64" , "f64" , 2 , 2 ) , // f64: 8 bytes × 2 = 16 bytes
37- ( "i64" , "i64" , 2 , 2 ) , // i64: 8 bytes × 2 = 16 bytes
38- ( "u64" , "u64" , 2 , 2 ) , // u64: 8 bytes × 2 = 16 bytes
39- ( "i8" , "i8" , 16 , 16 ) , // i8: 1 byte × 16 = 16 bytes
40- ( "u8" , "u8" , 16 , 16 ) , // u8: 1 byte × 16 = 16 bytes
41- ( "bool" , "bool" , 16 , 16 ) , // bool: 1 byte × 16 = 16 bytes
29+ ( "f16" , "f16" , 8 , 8 ) , // f16: 2 bytes × 8 = 16 bytes
30+ ( "bf16" , "bf16" , 8 , 8 ) , // bf16: 2 bytes × 8 = 16 bytes
31+ ( "i16" , "i16" , 8 , 8 ) , // i16: 2 bytes × 8 = 16 bytes
32+ ( "u16" , "u16" , 8 , 8 ) , // u16: 2 bytes × 8 = 16 bytes
33+ ( "f32" , "f32" , 4 , 4 ) , // f32: 4 bytes × 4 = 16 bytes
34+ ( "i32" , "i32" , 4 , 4 ) , // i32: 4 bytes × 4 = 16 bytes
35+ ( "u32" , "u32" , 4 , 4 ) , // u32: 4 bytes × 4 = 16 bytes
36+ ( "f64" , "f64" , 2 , 2 ) , // f64: 8 bytes × 2 = 16 bytes
37+ ( "i64" , "i64" , 2 , 2 ) , // i64: 8 bytes × 2 = 16 bytes
38+ ( "u64" , "u64" , 2 , 2 ) , // u64: 8 bytes × 2 = 16 bytes
39+ ( "i8" , "i8" , 16 , 16 ) , // i8: 1 byte × 16 = 16 bytes
40+ ( "u8" , "u8" , 16 , 16 ) , // u8: 1 byte × 16 = 16 bytes
41+ ( "bool" , "bool" , 16 , 16 ) , // bool: 1 byte × 16 = 16 bytes
4242 ] ;
43-
43+
4444 for ( type_name, file_name, start, step) in stride_configs {
4545 generate_type_strides ( & stride_dir, type_name, file_name, start, step) ;
4646 }
@@ -55,22 +55,35 @@ fn generate_type_strides(
5555) {
5656 let dest_path = out_dir. join ( format ! ( "{}_impls.rs" , file_name) ) ;
5757 let mut file = fs:: File :: create ( & dest_path) . unwrap ( ) ;
58-
58+
5959 // Write header
6060 writeln ! ( file, "// @generated" ) . unwrap ( ) ;
61- writeln ! ( file, "// Auto-generated stride implementations for {}" , type_name) . unwrap ( ) ;
61+ writeln ! (
62+ file,
63+ "// Auto-generated stride implementations for {}" ,
64+ type_name
65+ )
66+ . unwrap ( ) ;
6267 writeln ! ( file, "// DO NOT EDIT THIS FILE MANUALLY" ) . unwrap ( ) ;
6368 writeln ! ( file) . unwrap ( ) ;
64- writeln ! ( file, "// {} requires stride to be multiple of {} (up to {})" ,
65- type_name, step, MAX_STRIDE ) . unwrap ( ) ;
69+ writeln ! (
70+ file,
71+ "// {} requires stride to be multiple of {} (up to {})" ,
72+ type_name, step, MAX_STRIDE
73+ )
74+ . unwrap ( ) ;
6675 writeln ! ( file) . unwrap ( ) ;
67-
76+
6877 // Generate implementations
6978 // These files are included in modules that already have the necessary imports
7079 let mut stride = start;
7180 while stride <= MAX_STRIDE {
72- writeln ! ( file, "impl ValidStride for StrideValidator<{}, {}> {{}}" ,
73- type_name, stride) . unwrap ( ) ;
81+ writeln ! (
82+ file,
83+ "impl ValidStride for StrideValidator<{}, {}> {{}}" ,
84+ type_name, stride
85+ )
86+ . unwrap ( ) ;
7487 stride += step;
7588 }
76- }
89+ }
0 commit comments