Skip to content

Commit 10ac7b3

Browse files
committed
Moar tests
1 parent 85ac4a4 commit 10ac7b3

File tree

13 files changed

+885
-9
lines changed

13 files changed

+885
-9
lines changed

crates/cuda_std/src/warp/shuffle.rs

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,12 @@ const SHUFFLE_MODE_IDX: u32 = 3;
473473
macro_rules! impl_shuffle_op {
474474
($fn_name:ident, $mode:expr) => {
475475
#[gpu_only]
476-
unsafe fn $fn_name(mask: WarpMask, value: Self, param: u32, width: u32) -> ShuffleResult<Self> {
476+
unsafe fn $fn_name(
477+
mask: WarpMask,
478+
value: Self,
479+
param: u32,
480+
width: u32,
481+
) -> ShuffleResult<Self> {
477482
let result = warp_shuffle_32($mode, mask.raw(), value as u32, param, width);
478483
ShuffleResult::new(result.value as Self, result.valid)
479484
}
@@ -502,7 +507,12 @@ impl_shuffle_32! {
502507
macro_rules! impl_shuffle_float_op {
503508
($fn_name:ident, $bits_ty:ty, $to_bits:ident, $from_bits:ident) => {
504509
#[gpu_only]
505-
unsafe fn $fn_name(mask: WarpMask, value: Self, param: u32, width: u32) -> ShuffleResult<Self> {
510+
unsafe fn $fn_name(
511+
mask: WarpMask,
512+
value: Self,
513+
param: u32,
514+
width: u32,
515+
) -> ShuffleResult<Self> {
506516
let bits = value.$to_bits();
507517
let result = <$bits_ty as ShuffleValue>::$fn_name(mask, bits, param, width);
508518
ShuffleResult::new(Self::$from_bits(result.value), result.valid)
@@ -522,7 +532,12 @@ impl ShuffleValue for f32 {
522532
macro_rules! impl_shuffle_64_op {
523533
($fn_name:ident) => {
524534
#[gpu_only]
525-
unsafe fn $fn_name(mask: WarpMask, value: Self, param: u32, width: u32) -> ShuffleResult<Self> {
535+
unsafe fn $fn_name(
536+
mask: WarpMask,
537+
value: Self,
538+
param: u32,
539+
width: u32,
540+
) -> ShuffleResult<Self> {
526541
let lo = (value & 0xFFFFFFFF) as u32;
527542
let hi = (value >> 32) as u32;
528543
let lo_result = <u32 as ShuffleValue>::$fn_name(mask, lo, param, width);
@@ -561,11 +576,16 @@ impl ShuffleValue for f64 {
561576
impl_shuffle_float_op!(shuffle_idx, u64, to_bits, from_bits);
562577
}
563578

564-
// Generic macro for small type shuffle operations
579+
// Generic macro for small type shuffle operations
565580
macro_rules! impl_shuffle_small_op {
566581
($fn_name:ident) => {
567582
#[gpu_only]
568-
unsafe fn $fn_name(mask: WarpMask, value: Self, param: u32, width: u32) -> ShuffleResult<Self> {
583+
unsafe fn $fn_name(
584+
mask: WarpMask,
585+
value: Self,
586+
param: u32,
587+
width: u32,
588+
) -> ShuffleResult<Self> {
569589
let result = <u32 as ShuffleValue>::$fn_name(mask, value as u32, param, width);
570590
ShuffleResult::new(result.value as Self, result.valid)
571591
}

crates/cuda_std/src/warp/sync.rs

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,17 +99,54 @@ impl WarpMask {
9999
Self(!self.0)
100100
}
101101

102-
/// Check if this mask is empty.
102+
/// Create a mask for even-numbered lanes (0, 2, 4, ...).
103+
#[inline]
104+
pub const fn even_lanes() -> Self {
105+
Self(0x55555555)
106+
}
107+
108+
/// Create a mask for odd-numbered lanes (1, 3, 5, ...).
109+
#[inline]
110+
pub const fn odd_lanes() -> Self {
111+
Self(0xAAAAAAAA)
112+
}
113+
114+
/// Create a mask for a specific quadrant (0-3).
115+
#[inline]
116+
pub const fn quadrant(quad: u32) -> Self {
117+
debug_assert!(quad < 4);
118+
Self(0xFF << (quad * 8))
119+
}
120+
121+
/// Create a mask for a range of lanes.
122+
#[inline]
123+
pub const fn range(start: u32, end: u32) -> Self {
124+
debug_assert!(start <= end);
125+
debug_assert!(end <= super::WARP_SIZE);
126+
if start == end {
127+
return Self(0);
128+
}
129+
let count = end - start;
130+
Self::lanes(start, count)
131+
}
132+
133+
/// Check if the mask is empty (no lanes set).
103134
#[inline]
104135
pub const fn is_empty(self) -> bool {
105136
self.0 == 0
106137
}
107138

108-
/// Check if this mask includes all lanes.
139+
/// Check if the mask is full (all lanes set).
109140
#[inline]
110141
pub const fn is_full(self) -> bool {
111142
self.0 == super::FULL_MASK
112143
}
144+
145+
/// Count the number of set lanes in the mask.
146+
#[inline]
147+
pub const fn count(self) -> u32 {
148+
self.0.count_ones()
149+
}
113150
}
114151

115152
impl Default for WarpMask {

crates/cuda_std/src/warp/vote.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,10 @@ macro_rules! impl_vote_equality_float {
410410
#[gpu_only]
411411
unsafe fn vote_all_equal(mask: WarpMask, value: Self) -> EqualityResult {
412412
let result = __nvvm_warp_match_all_32(mask.raw(), value.to_bits());
413-
EqualityResult::with_mask(result.all_matched != 0, WarpMask::new(result.matched_mask))
413+
EqualityResult::with_mask(
414+
result.all_matched != 0,
415+
WarpMask::new(result.matched_mask),
416+
)
414417
}
415418
}
416419
};
@@ -419,7 +422,10 @@ macro_rules! impl_vote_equality_float {
419422
#[gpu_only]
420423
unsafe fn vote_all_equal(mask: WarpMask, value: Self) -> EqualityResult {
421424
let result = __nvvm_warp_match_all_64(mask.raw(), value.to_bits());
422-
EqualityResult::with_mask(result.all_matched != 0, WarpMask::new(result.matched_mask))
425+
EqualityResult::with_mask(
426+
result.all_matched != 0,
427+
WarpMask::new(result.matched_mask),
428+
)
423429
}
424430
}
425431
};
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
// Test warp sync API and mask operations compile
2+
// build-pass
3+
4+
use cuda_std::kernel;
5+
use cuda_std::warp::{self, WarpMask};
6+
7+
#[kernel]
8+
pub unsafe fn test_warp_mask_creation() {
9+
// Test different ways to create WarpMask
10+
let _all_mask = WarpMask::all();
11+
let _none_mask = WarpMask::none();
12+
let _custom_mask = WarpMask::new(0x12345678);
13+
14+
// Test single lane masks
15+
let _lane_mask = WarpMask::lane(5);
16+
let _lane_mask_2 = WarpMask::lane(31);
17+
18+
// Test range masks
19+
let _lower_16 = WarpMask::range(0, 16);
20+
let _middle_8 = WarpMask::range(12, 20);
21+
}
22+
23+
#[kernel]
24+
pub unsafe fn test_warp_mask_operations() {
25+
let mask1 = WarpMask::new(0x0F0F0F0F);
26+
let mask2 = WarpMask::new(0xF0F0F0F0);
27+
28+
// Test mask predicates
29+
let _empty_check = mask1.is_empty();
30+
let _full_check = mask1.is_full();
31+
32+
// Test popcount
33+
let _count1 = mask1.count();
34+
let _count2 = mask2.count();
35+
36+
// Test contains_lane
37+
let _contains = mask1.contains_lane(0);
38+
let _contains2 = mask2.contains_lane(4);
39+
40+
// Test raw access
41+
let _raw = mask1.raw();
42+
}
43+
44+
#[kernel]
45+
pub unsafe fn test_active_mask() {
46+
// Get current active mask
47+
let active = warp::active_mask();
48+
49+
// Sync with active mask
50+
warp::sync(active);
51+
52+
// Use active mask in operations
53+
let _count = active.count();
54+
let _contains = active.contains_lane(0);
55+
}
56+
57+
#[kernel]
58+
pub unsafe fn test_lane_operations() {
59+
// Get lane ID
60+
let lane_id = warp::lane_id();
61+
62+
// Create mask for current lane
63+
let my_mask = WarpMask::lane(lane_id);
64+
65+
// Use mask in sync
66+
warp::sync(my_mask);
67+
}
68+
69+
#[kernel]
70+
pub unsafe fn test_mask_builders() {
71+
// Test even/odd lane masks
72+
let even_lanes = WarpMask::even_lanes();
73+
let odd_lanes = WarpMask::odd_lanes();
74+
75+
// Test quadrant masks
76+
let _q1 = WarpMask::quadrant(0);
77+
let _q2 = WarpMask::quadrant(1);
78+
let _q3 = WarpMask::quadrant(2);
79+
let _q4 = WarpMask::quadrant(3);
80+
81+
// Use masks in sync operations
82+
warp::sync(even_lanes);
83+
warp::sync(odd_lanes);
84+
}
85+
86+
#[kernel]
87+
pub unsafe fn test_mask_combinations() {
88+
let mask1 = WarpMask::new(0x0F0F0F0F);
89+
let mask2 = WarpMask::new(0xF0F0F0F0);
90+
91+
// Combine masks using bitwise operations
92+
let combined = WarpMask::new(mask1.raw() | mask2.raw());
93+
let intersection = WarpMask::new(mask1.raw() & mask2.raw());
94+
let xor = WarpMask::new(mask1.raw() ^ mask2.raw());
95+
96+
// Use combined masks
97+
warp::sync(combined);
98+
warp::sync(intersection);
99+
warp::sync(xor);
100+
}
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// Test sync operations with generic functions
2+
// build-pass
3+
4+
use cuda_std::kernel;
5+
use cuda_std::warp::{self, WarpMask};
6+
7+
#[kernel]
8+
pub unsafe fn test_sync_with_generics() {
9+
// Generic function that performs sync
10+
fn generic_sync<T>() {
11+
unsafe {
12+
warp::sync(WarpMask::all());
13+
let mask = WarpMask::new(0xFFFFFFFF);
14+
warp::sync(mask);
15+
}
16+
}
17+
18+
generic_sync::<i32>();
19+
generic_sync::<f32>();
20+
generic_sync::<()>();
21+
}
22+
23+
#[kernel]
24+
pub unsafe fn test_generic_mask_ops() {
25+
let mask1 = WarpMask::all();
26+
let mask2 = WarpMask::none();
27+
28+
// Generic function using masks
29+
fn sync_with_mask<T>(_phantom: T, mask: WarpMask) {
30+
unsafe {
31+
warp::sync(mask);
32+
}
33+
}
34+
35+
sync_with_mask(42i32, mask1);
36+
sync_with_mask(3.14f32, mask2);
37+
sync_with_mask((), mask1);
38+
}
39+
40+
#[kernel]
41+
pub unsafe fn test_generic_lane_ops() {
42+
// Generic function that uses lane operations
43+
fn get_lane_mask<T>() -> WarpMask {
44+
let lane_id = unsafe { warp::lane_id() };
45+
WarpMask::lane(lane_id)
46+
}
47+
48+
let mask_i32 = get_lane_mask::<i32>();
49+
warp::sync(mask_i32);
50+
51+
let mask_f64 = get_lane_mask::<f64>();
52+
warp::sync(mask_f64);
53+
54+
let mask_unit = get_lane_mask::<()>();
55+
warp::sync(mask_unit);
56+
}
57+
58+
#[kernel]
59+
pub unsafe fn test_generic_active_mask() {
60+
// Get active mask in generic context
61+
let active = warp::active_mask();
62+
63+
// Use in generic sync
64+
fn sync_active<T>() {
65+
unsafe {
66+
let mask = warp::active_mask();
67+
warp::sync(mask);
68+
}
69+
}
70+
71+
sync_active::<i32>();
72+
sync_active::<u64>();
73+
sync_active::<()>();
74+
75+
// Generic function with multiple type params
76+
fn sync_with_types<T, U>() {
77+
unsafe {
78+
warp::sync(WarpMask::all());
79+
}
80+
}
81+
82+
sync_with_types::<i32, f32>();
83+
sync_with_types::<u8, u16>();
84+
}

0 commit comments

Comments
 (0)