Skip to content

Commit a614ef7

Browse files
committed
Working with full tests
1 parent e697cf4 commit a614ef7

33 files changed

+1086
-75
lines changed

crates/cuda_std/src/warp/reduce.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
//! making invalid states unrepresentable and ensuring all validation happens
55
//! at compile time.
66
7-
use super::shuffle::ShuffleValue;
87
use super::sync::WarpMask;
98
use crate::gpu_only;
109
#[cfg(target_os = "cuda")]
@@ -416,6 +415,7 @@ impl BitwiseReduceValue for u32 {
416415
impl ReduceValue for i64 {
417416
#[gpu_only]
418417
unsafe fn reduce_add(mask: WarpMask, mut value: Self) -> Self {
418+
use super::shuffle::ShuffleValue;
419419
// Implement using shuffle operations in a tree reduction pattern
420420
for offset in [16, 8, 4, 2, 1] {
421421
let shuffled =
@@ -427,6 +427,7 @@ impl ReduceValue for i64 {
427427

428428
#[gpu_only]
429429
unsafe fn reduce_min(mask: WarpMask, mut value: Self) -> Self {
430+
use super::shuffle::ShuffleValue;
430431
for offset in [16, 8, 4, 2, 1] {
431432
let shuffled =
432433
<Self as ShuffleValue>::shuffle_down(mask, value, offset, 32).unwrap_or(value);
@@ -437,6 +438,7 @@ impl ReduceValue for i64 {
437438

438439
#[gpu_only]
439440
unsafe fn reduce_max(mask: WarpMask, mut value: Self) -> Self {
441+
use super::shuffle::ShuffleValue;
440442
for offset in [16, 8, 4, 2, 1] {
441443
let shuffled =
442444
<Self as ShuffleValue>::shuffle_down(mask, value, offset, 32).unwrap_or(value);
@@ -449,6 +451,7 @@ impl ReduceValue for i64 {
449451
impl BitwiseReduceValue for i64 {
450452
#[gpu_only]
451453
unsafe fn reduce_and(mask: WarpMask, mut value: Self) -> Self {
454+
use super::shuffle::ShuffleValue;
452455
for offset in [16, 8, 4, 2, 1] {
453456
let shuffled =
454457
<Self as ShuffleValue>::shuffle_down(mask, value, offset, 32).unwrap_or(value);
@@ -459,6 +462,7 @@ impl BitwiseReduceValue for i64 {
459462

460463
#[gpu_only]
461464
unsafe fn reduce_or(mask: WarpMask, mut value: Self) -> Self {
465+
use super::shuffle::ShuffleValue;
462466
for offset in [16, 8, 4, 2, 1] {
463467
let shuffled =
464468
<Self as ShuffleValue>::shuffle_down(mask, value, offset, 32).unwrap_or(value);
@@ -469,6 +473,7 @@ impl BitwiseReduceValue for i64 {
469473

470474
#[gpu_only]
471475
unsafe fn reduce_xor(mask: WarpMask, mut value: Self) -> Self {
476+
use super::shuffle::ShuffleValue;
472477
for offset in [16, 8, 4, 2, 1] {
473478
let shuffled =
474479
<Self as ShuffleValue>::shuffle_down(mask, value, offset, 32).unwrap_or(value);
@@ -481,6 +486,7 @@ impl BitwiseReduceValue for i64 {
481486
impl ReduceValue for u64 {
482487
#[gpu_only]
483488
unsafe fn reduce_add(mask: WarpMask, mut value: Self) -> Self {
489+
use super::shuffle::ShuffleValue;
484490
for offset in [16, 8, 4, 2, 1] {
485491
let shuffled =
486492
<Self as ShuffleValue>::shuffle_down(mask, value, offset, 32).unwrap_or(value);
@@ -491,6 +497,7 @@ impl ReduceValue for u64 {
491497

492498
#[gpu_only]
493499
unsafe fn reduce_min(mask: WarpMask, mut value: Self) -> Self {
500+
use super::shuffle::ShuffleValue;
494501
for offset in [16, 8, 4, 2, 1] {
495502
let shuffled =
496503
<Self as ShuffleValue>::shuffle_down(mask, value, offset, 32).unwrap_or(value);
@@ -501,6 +508,7 @@ impl ReduceValue for u64 {
501508

502509
#[gpu_only]
503510
unsafe fn reduce_max(mask: WarpMask, mut value: Self) -> Self {
511+
use super::shuffle::ShuffleValue;
504512
for offset in [16, 8, 4, 2, 1] {
505513
let shuffled =
506514
<Self as ShuffleValue>::shuffle_down(mask, value, offset, 32).unwrap_or(value);
@@ -513,6 +521,7 @@ impl ReduceValue for u64 {
513521
impl BitwiseReduceValue for u64 {
514522
#[gpu_only]
515523
unsafe fn reduce_and(mask: WarpMask, mut value: Self) -> Self {
524+
use super::shuffle::ShuffleValue;
516525
for offset in [16, 8, 4, 2, 1] {
517526
let shuffled =
518527
<Self as ShuffleValue>::shuffle_down(mask, value, offset, 32).unwrap_or(value);
@@ -523,6 +532,7 @@ impl BitwiseReduceValue for u64 {
523532

524533
#[gpu_only]
525534
unsafe fn reduce_or(mask: WarpMask, mut value: Self) -> Self {
535+
use super::shuffle::ShuffleValue;
526536
for offset in [16, 8, 4, 2, 1] {
527537
let shuffled =
528538
<Self as ShuffleValue>::shuffle_down(mask, value, offset, 32).unwrap_or(value);
@@ -533,6 +543,7 @@ impl BitwiseReduceValue for u64 {
533543

534544
#[gpu_only]
535545
unsafe fn reduce_xor(mask: WarpMask, mut value: Self) -> Self {
546+
use super::shuffle::ShuffleValue;
536547
for offset in [16, 8, 4, 2, 1] {
537548
let shuffled =
538549
<Self as ShuffleValue>::shuffle_down(mask, value, offset, 32).unwrap_or(value);
@@ -549,6 +560,7 @@ impl BitwiseReduceValue for u64 {
549560
impl ReduceValue for f32 {
550561
#[gpu_only]
551562
unsafe fn reduce_add(mask: WarpMask, mut value: Self) -> Self {
563+
use super::shuffle::ShuffleValue;
552564
for offset in [16, 8, 4, 2, 1] {
553565
let shuffled =
554566
<Self as ShuffleValue>::shuffle_down(mask, value, offset, 32).unwrap_or(value);
@@ -559,6 +571,7 @@ impl ReduceValue for f32 {
559571

560572
#[gpu_only]
561573
unsafe fn reduce_min(mask: WarpMask, mut value: Self) -> Self {
574+
use super::shuffle::ShuffleValue;
562575
for offset in [16, 8, 4, 2, 1] {
563576
let shuffled =
564577
<Self as ShuffleValue>::shuffle_down(mask, value, offset, 32).unwrap_or(value);
@@ -569,6 +582,7 @@ impl ReduceValue for f32 {
569582

570583
#[gpu_only]
571584
unsafe fn reduce_max(mask: WarpMask, mut value: Self) -> Self {
585+
use super::shuffle::ShuffleValue;
572586
for offset in [16, 8, 4, 2, 1] {
573587
let shuffled =
574588
<Self as ShuffleValue>::shuffle_down(mask, value, offset, 32).unwrap_or(value);
@@ -581,6 +595,7 @@ impl ReduceValue for f32 {
581595
impl ReduceValue for f64 {
582596
#[gpu_only]
583597
unsafe fn reduce_add(mask: WarpMask, mut value: Self) -> Self {
598+
use super::shuffle::ShuffleValue;
584599
for offset in [16, 8, 4, 2, 1] {
585600
let shuffled =
586601
<Self as ShuffleValue>::shuffle_down(mask, value, offset, 32).unwrap_or(value);
@@ -591,6 +606,7 @@ impl ReduceValue for f64 {
591606

592607
#[gpu_only]
593608
unsafe fn reduce_min(mask: WarpMask, mut value: Self) -> Self {
609+
use super::shuffle::ShuffleValue;
594610
for offset in [16, 8, 4, 2, 1] {
595611
let shuffled =
596612
<Self as ShuffleValue>::shuffle_down(mask, value, offset, 32).unwrap_or(value);
@@ -601,6 +617,7 @@ impl ReduceValue for f64 {
601617

602618
#[gpu_only]
603619
unsafe fn reduce_max(mask: WarpMask, mut value: Self) -> Self {
620+
use super::shuffle::ShuffleValue;
604621
for offset in [16, 8, 4, 2, 1] {
605622
let shuffled =
606623
<Self as ShuffleValue>::shuffle_down(mask, value, offset, 32).unwrap_or(value);

crates/cuda_std/src/warp/shuffle.rs

Lines changed: 8 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
//! Warp shuffle operations with extreme type safety.
22
//!
3-
//! This module provides type-safe abstractions for CUDA warp shuffle operations,
4-
//! making invalid states unrepresentable and ensuring all validation happens
5-
//! at compile time.
3+
//! This module provides type-safe abstractions for CUDA warp shuffle operations.
64
75
use super::sync::WarpMask;
86
use crate::gpu_only;
9-
#[cfg(target_os = "cuda")]
10-
use core::arch::asm;
117
use core::marker::PhantomData;
128

139
// ============================================================================
@@ -30,12 +26,6 @@ impl ShuffleWidth {
3026
}
3127
}
3228

33-
/// Create from raw value (internal use only)
34-
#[inline(always)]
35-
const fn from_raw(width: u32) -> Self {
36-
Self(width)
37-
}
38-
3929
/// Full warp width (32 threads)
4030
#[inline(always)]
4131
pub const fn full_warp() -> Self {
@@ -165,7 +155,7 @@ impl<T> ShuffleResult<T> {
165155
f()
166156
}
167157
}
168-
158+
169159
#[inline(always)]
170160
pub fn unwrap(self) -> T {
171161
if self.valid {
@@ -174,7 +164,7 @@ impl<T> ShuffleResult<T> {
174164
panic!("called `ShuffleResult::unwrap()` on an invalid shuffle result")
175165
}
176166
}
177-
167+
178168
#[inline(always)]
179169
pub fn unwrap_or_default(self) -> T
180170
where
@@ -748,16 +738,10 @@ pub trait ShuffleExt: ShuffleValue {
748738
fn shuffle(mask: WarpMask, width: ShuffleWidth) -> Shuffle<Self>;
749739

750740
/// Shuffle this value down
751-
unsafe fn shuffle_down(
752-
self,
753-
mask: WarpMask,
754-
delta: u32,
755-
width: u32,
756-
) -> ShuffleResult<Self>;
741+
unsafe fn shuffle_down(self, mask: WarpMask, delta: u32, width: u32) -> ShuffleResult<Self>;
757742

758743
/// Shuffle this value up
759-
unsafe fn shuffle_up(self, mask: WarpMask, delta: u32, width: u32)
760-
-> ShuffleResult<Self>;
744+
unsafe fn shuffle_up(self, mask: WarpMask, delta: u32, width: u32) -> ShuffleResult<Self>;
761745
}
762746

763747
impl<T: ShuffleValue> ShuffleExt for T {
@@ -768,23 +752,13 @@ impl<T: ShuffleValue> ShuffleExt for T {
768752

769753
#[gpu_only]
770754
#[inline(always)]
771-
unsafe fn shuffle_down(
772-
self,
773-
mask: WarpMask,
774-
delta: u32,
775-
width: u32,
776-
) -> ShuffleResult<Self> {
755+
unsafe fn shuffle_down(self, mask: WarpMask, delta: u32, width: u32) -> ShuffleResult<Self> {
777756
ShuffleValue::shuffle_down(mask, self, delta, width)
778757
}
779758

780759
#[gpu_only]
781760
#[inline(always)]
782-
unsafe fn shuffle_up(
783-
self,
784-
mask: WarpMask,
785-
delta: u32,
786-
width: u32,
787-
) -> ShuffleResult<Self> {
761+
unsafe fn shuffle_up(self, mask: WarpMask, delta: u32, width: u32) -> ShuffleResult<Self> {
788762
ShuffleValue::shuffle_up(mask, self, delta, width)
789763
}
790-
}
764+
}

crates/cuda_std/src/warp/vote.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,15 +295,15 @@ pub struct EqualityResult {
295295

296296
impl EqualityResult {
297297
#[inline(always)]
298-
const fn new(all_equal: bool) -> Self {
298+
pub const fn new(all_equal: bool) -> Self {
299299
Self {
300300
all_equal,
301301
match_mask: None,
302302
}
303303
}
304304

305305
#[inline(always)]
306-
const fn with_mask(all_equal: bool, mask: WarpMask) -> Self {
306+
pub const fn with_mask(all_equal: bool, mask: WarpMask) -> Self {
307307
Self {
308308
all_equal,
309309
match_mask: Some(mask),

tests/compiletests/ui/dis/shuffle_crashing.rs

Lines changed: 0 additions & 20 deletions
This file was deleted.

tests/compiletests/ui/dis/shuffle_working.rs

Lines changed: 0 additions & 14 deletions
This file was deleted.
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// build-fail
2+
3+
use cuda_std::kernel;
4+
use cuda_std::warp::reduce::Reduction;
5+
6+
#[kernel]
7+
pub unsafe fn test_float_bitwise_and() {
8+
let reduction = Reduction::<f32>::all_threads();
9+
10+
// f32 doesn't implement BitwiseReduceValue
11+
let _result = reduction.and(3.14f32);
12+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
error[E0277]: the trait bound `f32: BitwiseReduceValue` is not satisfied
2+
--> $DIR/float_bitwise_and.rs:11:29
3+
|
4+
11 | let _result = reduction.and(3.14f32);
5+
| ^^^ the trait `BitwiseReduceValue` is not implemented for `f32`
6+
|
7+
= help: the following other types implement trait `BitwiseReduceValue`:
8+
i32
9+
i64
10+
u32
11+
u64
12+
note: required by a bound in `Reduction::<T>::and`
13+
--> /workspace/crates/cuda_std/src/warp/reduce.rs:201:12
14+
|
15+
199 | pub unsafe fn and(&self, value: T) -> ReductionResult<T, And>
16+
| --- required by a bound in this associated function
17+
200 | where
18+
201 | T: BitwiseReduceValue,
19+
| ^^^^^^^^^^^^^^^^^^ required by this bound in `Reduction::<T>::and`
20+
21+
error: aborting due to 1 previous error
22+
23+
For more information about this error, try `rustc --explain E0277`.
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// build-fail
2+
3+
use cuda_std::kernel;
4+
use cuda_std::warp::reduce::Reduction;
5+
6+
#[kernel]
7+
pub unsafe fn test_float_bitwise_or() {
8+
let reduction = Reduction::<f64>::all_threads();
9+
10+
// f64 doesn't implement BitwiseReduceValue
11+
let _result = reduction.or(2.718f64);
12+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
error[E0277]: the trait bound `f64: BitwiseReduceValue` is not satisfied
2+
--> $DIR/float_bitwise_or.rs:11:29
3+
|
4+
11 | let _result = reduction.or(2.718f64);
5+
| ^^ the trait `BitwiseReduceValue` is not implemented for `f64`
6+
|
7+
= help: the following other types implement trait `BitwiseReduceValue`:
8+
i32
9+
i64
10+
u32
11+
u64
12+
note: required by a bound in `Reduction::<T>::or`
13+
--> /workspace/crates/cuda_std/src/warp/reduce.rs:211:12
14+
|
15+
209 | pub unsafe fn or(&self, value: T) -> ReductionResult<T, Or>
16+
| -- required by a bound in this associated function
17+
210 | where
18+
211 | T: BitwiseReduceValue,
19+
| ^^^^^^^^^^^^^^^^^^ required by this bound in `Reduction::<T>::or`
20+
21+
error: aborting due to 1 previous error
22+
23+
For more information about this error, try `rustc --explain E0277`.
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// build-fail
2+
3+
use cuda_std::kernel;
4+
use cuda_std::warp::reduce::Reduction;
5+
6+
#[kernel]
7+
pub unsafe fn test_float_bitwise_xor() {
8+
let reduction = Reduction::<f32>::all_threads();
9+
10+
// f32 doesn't implement BitwiseReduceValue
11+
let _result = reduction.xor(1.414f32);
12+
}

0 commit comments

Comments
 (0)