Skip to content

Commit 7b6238e

Browse files
committed
Formatting
1 parent 27251bd commit 7b6238e

File tree

16 files changed

+60
-47
lines changed

16 files changed

+60
-47
lines changed

crates/cuda_std/build.rs

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,38 +9,38 @@ const MAX_STRIDE: usize = 8192;
99
fn main() {
1010
// Tell cargo to rerun if build.rs changes
1111
println!("cargo:rerun-if-changed=build.rs");
12-
12+
1313
// Generate stride implementations
1414
generate_stride_impls();
1515
}
1616

1717
fn generate_stride_impls() {
1818
let out_dir = env::var_os("OUT_DIR").unwrap();
1919
let out_path = Path::new(&out_dir);
20-
20+
2121
// Create stride subdirectory in OUT_DIR
2222
let stride_dir = out_path.join("stride");
2323
fs::create_dir_all(&stride_dir).unwrap();
24-
24+
2525
// Define stride requirements for each type
2626
// (type_name, file_name, start, step)
2727
let stride_configs = vec![
2828
// 16-byte alignment types (multiple of element count to fill 16 bytes)
29-
("f16", "f16", 8, 8), // f16: 2 bytes × 8 = 16 bytes
30-
("bf16", "bf16", 8, 8), // bf16: 2 bytes × 8 = 16 bytes
31-
("i16", "i16", 8, 8), // i16: 2 bytes × 8 = 16 bytes
32-
("u16", "u16", 8, 8), // u16: 2 bytes × 8 = 16 bytes
33-
("f32", "f32", 4, 4), // f32: 4 bytes × 4 = 16 bytes
34-
("i32", "i32", 4, 4), // i32: 4 bytes × 4 = 16 bytes
35-
("u32", "u32", 4, 4), // u32: 4 bytes × 4 = 16 bytes
36-
("f64", "f64", 2, 2), // f64: 8 bytes × 2 = 16 bytes
37-
("i64", "i64", 2, 2), // i64: 8 bytes × 2 = 16 bytes
38-
("u64", "u64", 2, 2), // u64: 8 bytes × 2 = 16 bytes
39-
("i8", "i8", 16, 16), // i8: 1 byte × 16 = 16 bytes
40-
("u8", "u8", 16, 16), // u8: 1 byte × 16 = 16 bytes
41-
("bool", "bool", 16, 16), // bool: 1 byte × 16 = 16 bytes
29+
("f16", "f16", 8, 8), // f16: 2 bytes × 8 = 16 bytes
30+
("bf16", "bf16", 8, 8), // bf16: 2 bytes × 8 = 16 bytes
31+
("i16", "i16", 8, 8), // i16: 2 bytes × 8 = 16 bytes
32+
("u16", "u16", 8, 8), // u16: 2 bytes × 8 = 16 bytes
33+
("f32", "f32", 4, 4), // f32: 4 bytes × 4 = 16 bytes
34+
("i32", "i32", 4, 4), // i32: 4 bytes × 4 = 16 bytes
35+
("u32", "u32", 4, 4), // u32: 4 bytes × 4 = 16 bytes
36+
("f64", "f64", 2, 2), // f64: 8 bytes × 2 = 16 bytes
37+
("i64", "i64", 2, 2), // i64: 8 bytes × 2 = 16 bytes
38+
("u64", "u64", 2, 2), // u64: 8 bytes × 2 = 16 bytes
39+
("i8", "i8", 16, 16), // i8: 1 byte × 16 = 16 bytes
40+
("u8", "u8", 16, 16), // u8: 1 byte × 16 = 16 bytes
41+
("bool", "bool", 16, 16), // bool: 1 byte × 16 = 16 bytes
4242
];
43-
43+
4444
for (type_name, file_name, start, step) in stride_configs {
4545
generate_type_strides(&stride_dir, type_name, file_name, start, step);
4646
}
@@ -55,22 +55,35 @@ fn generate_type_strides(
5555
) {
5656
let dest_path = out_dir.join(format!("{}_impls.rs", file_name));
5757
let mut file = fs::File::create(&dest_path).unwrap();
58-
58+
5959
// Write header
6060
writeln!(file, "// @generated").unwrap();
61-
writeln!(file, "// Auto-generated stride implementations for {}", type_name).unwrap();
61+
writeln!(
62+
file,
63+
"// Auto-generated stride implementations for {}",
64+
type_name
65+
)
66+
.unwrap();
6267
writeln!(file, "// DO NOT EDIT THIS FILE MANUALLY").unwrap();
6368
writeln!(file).unwrap();
64-
writeln!(file, "// {} requires stride to be multiple of {} (up to {})",
65-
type_name, step, MAX_STRIDE).unwrap();
69+
writeln!(
70+
file,
71+
"// {} requires stride to be multiple of {} (up to {})",
72+
type_name, step, MAX_STRIDE
73+
)
74+
.unwrap();
6675
writeln!(file).unwrap();
67-
76+
6877
// Generate implementations
6978
// These files are included in modules that already have the necessary imports
7079
let mut stride = start;
7180
while stride <= MAX_STRIDE {
72-
writeln!(file, "impl ValidStride for StrideValidator<{}, {}> {{}}",
73-
type_name, stride).unwrap();
81+
writeln!(
82+
file,
83+
"impl ValidStride for StrideValidator<{}, {}> {{}}",
84+
type_name, stride
85+
)
86+
.unwrap();
7487
stride += step;
7588
}
76-
}
89+
}

crates/cuda_std/src/warp/matrix/intrinsics.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// WMMA intrinsic declarations
2-
// rustfmt::skip
2+
#![cfg_attr(rustfmt, rustfmt_skip)]
33
#![allow(improper_ctypes)]
44

55
// ============= 16x16x16 shape =============
@@ -803,4 +803,4 @@ pub(crate) use m8n8k4::mma::*;
803803
pub(crate) use m16n16k8::convert::*;
804804
pub(crate) use m16n16k8::load::*;
805805
pub(crate) use m16n16k8::store::*;
806-
pub(crate) use m16n16k8::mma::*;
806+
pub(crate) use m16n16k8::mma::*;

crates/cuda_std/src/warp/matrix/stride/bf16.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ use super::{StrideValidator, ValidStride};
44
use crate::bf16;
55

66
// Include the auto-generated implementations
7-
include!(concat!(env!("OUT_DIR"), "/stride/bf16_impls.rs"));
7+
include!(concat!(env!("OUT_DIR"), "/stride/bf16_impls.rs"));

crates/cuda_std/src/warp/matrix/stride/bool.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
use super::{StrideValidator, ValidStride};
44

55
// Include the auto-generated implementations
6-
include!(concat!(env!("OUT_DIR"), "/stride/bool_impls.rs"));
6+
include!(concat!(env!("OUT_DIR"), "/stride/bool_impls.rs"));

crates/cuda_std/src/warp/matrix/stride/f16.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ use super::{StrideValidator, ValidStride};
44
use crate::f16;
55

66
// Include the auto-generated implementations
7-
include!(concat!(env!("OUT_DIR"), "/stride/f16_impls.rs"));
7+
include!(concat!(env!("OUT_DIR"), "/stride/f16_impls.rs"));

crates/cuda_std/src/warp/matrix/stride/f32.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
use super::{StrideValidator, ValidStride};
44

55
// Include the auto-generated implementations
6-
include!(concat!(env!("OUT_DIR"), "/stride/f32_impls.rs"));
6+
include!(concat!(env!("OUT_DIR"), "/stride/f32_impls.rs"));

crates/cuda_std/src/warp/matrix/stride/f64.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
use super::{StrideValidator, ValidStride};
44

55
// Include the auto-generated implementations
6-
include!(concat!(env!("OUT_DIR"), "/stride/f64_impls.rs"));
6+
include!(concat!(env!("OUT_DIR"), "/stride/f64_impls.rs"));

crates/cuda_std/src/warp/matrix/stride/i16.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
use super::{StrideValidator, ValidStride};
44

55
// Include the auto-generated implementations
6-
include!(concat!(env!("OUT_DIR"), "/stride/i16_impls.rs"));
6+
include!(concat!(env!("OUT_DIR"), "/stride/i16_impls.rs"));

crates/cuda_std/src/warp/matrix/stride/i32.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
use super::{StrideValidator, ValidStride};
44

55
// Include the auto-generated implementations
6-
include!(concat!(env!("OUT_DIR"), "/stride/i32_impls.rs"));
6+
include!(concat!(env!("OUT_DIR"), "/stride/i32_impls.rs"));

crates/cuda_std/src/warp/matrix/stride/i64.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
use super::{StrideValidator, ValidStride};
44

55
// Include the auto-generated implementations
6-
include!(concat!(env!("OUT_DIR"), "/stride/i64_impls.rs"));
6+
include!(concat!(env!("OUT_DIR"), "/stride/i64_impls.rs"));

0 commit comments

Comments
 (0)