Skip to content

Commit 85ac4a4

Browse files
committed
Macros!
1 parent a614ef7 commit 85ac4a4

File tree

8 files changed

+222
-663
lines changed

8 files changed

+222
-663
lines changed

crates/cuda_std/src/warp/reduce.rs

Lines changed: 81 additions & 376 deletions
Large diffs are not rendered by default.

crates/cuda_std/src/warp/shuffle.rs

Lines changed: 69 additions & 206 deletions
Original file line numberDiff line numberDiff line change
@@ -469,58 +469,26 @@ const SHUFFLE_MODE_UP: u32 = 1;
469469
const SHUFFLE_MODE_XOR: u32 = 2;
470470
const SHUFFLE_MODE_IDX: u32 = 3;
471471

472+
// Generic macro to implement a single shuffle operation
473+
macro_rules! impl_shuffle_op {
474+
($fn_name:ident, $mode:expr) => {
475+
#[gpu_only]
476+
unsafe fn $fn_name(mask: WarpMask, value: Self, param: u32, width: u32) -> ShuffleResult<Self> {
477+
let result = warp_shuffle_32($mode, mask.raw(), value as u32, param, width);
478+
ShuffleResult::new(result.value as Self, result.valid)
479+
}
480+
};
481+
}
482+
472483
// Macro to implement shuffle for 32-bit types
473484
macro_rules! impl_shuffle_32 {
474485
($($ty:ty),* $(,)?) => {
475486
$(
476487
impl ShuffleValue for $ty {
477-
#[gpu_only]
478-
unsafe fn shuffle_down(mask: WarpMask, value: Self, delta: u32, width: u32) -> ShuffleResult<Self> {
479-
let result = warp_shuffle_32(
480-
SHUFFLE_MODE_DOWN,
481-
mask.raw(),
482-
value as u32,
483-
delta,
484-
width
485-
);
486-
ShuffleResult::new(result.value as Self, result.valid)
487-
}
488-
489-
#[gpu_only]
490-
unsafe fn shuffle_up(mask: WarpMask, value: Self, delta: u32, width: u32) -> ShuffleResult<Self> {
491-
let result = warp_shuffle_32(
492-
SHUFFLE_MODE_UP,
493-
mask.raw(),
494-
value as u32,
495-
delta,
496-
width
497-
);
498-
ShuffleResult::new(result.value as Self, result.valid)
499-
}
500-
501-
#[gpu_only]
502-
unsafe fn shuffle_xor(mask: WarpMask, value: Self, lane_mask: u32, width: u32) -> ShuffleResult<Self> {
503-
let result = warp_shuffle_32(
504-
SHUFFLE_MODE_XOR,
505-
mask.raw(),
506-
value as u32,
507-
lane_mask,
508-
width
509-
);
510-
ShuffleResult::new(result.value as Self, result.valid)
511-
}
512-
513-
#[gpu_only]
514-
unsafe fn shuffle_idx(mask: WarpMask, value: Self, src_lane: u32, width: u32) -> ShuffleResult<Self> {
515-
let result = warp_shuffle_32(
516-
SHUFFLE_MODE_IDX,
517-
mask.raw(),
518-
value as u32,
519-
src_lane,
520-
width
521-
);
522-
ShuffleResult::new(result.value as Self, result.valid)
523-
}
488+
impl_shuffle_op!(shuffle_down, SHUFFLE_MODE_DOWN);
489+
impl_shuffle_op!(shuffle_up, SHUFFLE_MODE_UP);
490+
impl_shuffle_op!(shuffle_xor, SHUFFLE_MODE_XOR);
491+
impl_shuffle_op!(shuffle_idx, SHUFFLE_MODE_IDX);
524492
}
525493
)*
526494
};
@@ -530,106 +498,52 @@ impl_shuffle_32! {
530498
i32, u32,
531499
}
532500

501+
// Macro for floating-point shuffle operations using bit conversion
502+
macro_rules! impl_shuffle_float_op {
503+
($fn_name:ident, $bits_ty:ty, $to_bits:ident, $from_bits:ident) => {
504+
#[gpu_only]
505+
unsafe fn $fn_name(mask: WarpMask, value: Self, param: u32, width: u32) -> ShuffleResult<Self> {
506+
let bits = value.$to_bits();
507+
let result = <$bits_ty as ShuffleValue>::$fn_name(mask, bits, param, width);
508+
ShuffleResult::new(Self::$from_bits(result.value), result.valid)
509+
}
510+
};
511+
}
512+
533513
// Special case for f32 to preserve bit pattern
534514
impl ShuffleValue for f32 {
535-
#[gpu_only]
536-
unsafe fn shuffle_down(
537-
mask: WarpMask,
538-
value: Self,
539-
delta: u32,
540-
width: u32,
541-
) -> ShuffleResult<Self> {
542-
let bits = value.to_bits();
543-
let result = <u32 as ShuffleValue>::shuffle_down(mask, bits, delta, width);
544-
ShuffleResult::new(f32::from_bits(result.value), result.valid)
545-
}
546-
547-
#[gpu_only]
548-
unsafe fn shuffle_up(
549-
mask: WarpMask,
550-
value: Self,
551-
delta: u32,
552-
width: u32,
553-
) -> ShuffleResult<Self> {
554-
let bits = value.to_bits();
555-
let result = <u32 as ShuffleValue>::shuffle_up(mask, bits, delta, width);
556-
ShuffleResult::new(f32::from_bits(result.value), result.valid)
557-
}
558-
559-
#[gpu_only]
560-
unsafe fn shuffle_xor(
561-
mask: WarpMask,
562-
value: Self,
563-
lane_mask: u32,
564-
width: u32,
565-
) -> ShuffleResult<Self> {
566-
let bits = value.to_bits();
567-
let result = <u32 as ShuffleValue>::shuffle_xor(mask, bits, lane_mask, width);
568-
ShuffleResult::new(f32::from_bits(result.value), result.valid)
569-
}
515+
impl_shuffle_float_op!(shuffle_down, u32, to_bits, from_bits);
516+
impl_shuffle_float_op!(shuffle_up, u32, to_bits, from_bits);
517+
impl_shuffle_float_op!(shuffle_xor, u32, to_bits, from_bits);
518+
impl_shuffle_float_op!(shuffle_idx, u32, to_bits, from_bits);
519+
}
570520

571-
#[gpu_only]
572-
unsafe fn shuffle_idx(
573-
mask: WarpMask,
574-
value: Self,
575-
src_lane: u32,
576-
width: u32,
577-
) -> ShuffleResult<Self> {
578-
let bits = value.to_bits();
579-
let result = <u32 as ShuffleValue>::shuffle_idx(mask, bits, src_lane, width);
580-
ShuffleResult::new(f32::from_bits(result.value), result.valid)
581-
}
521+
// Generic macro for 64-bit shuffle operations
522+
macro_rules! impl_shuffle_64_op {
523+
($fn_name:ident) => {
524+
#[gpu_only]
525+
unsafe fn $fn_name(mask: WarpMask, value: Self, param: u32, width: u32) -> ShuffleResult<Self> {
526+
let lo = (value & 0xFFFFFFFF) as u32;
527+
let hi = (value >> 32) as u32;
528+
let lo_result = <u32 as ShuffleValue>::$fn_name(mask, lo, param, width);
529+
let hi_result = <u32 as ShuffleValue>::$fn_name(mask, hi, param, width);
530+
// Both parts must be valid
531+
let valid = lo_result.valid && hi_result.valid;
532+
let value = ((hi_result.value as Self) << 32) | (lo_result.value as Self);
533+
ShuffleResult::new(value, valid)
534+
}
535+
};
582536
}
583537

584538
// For 64-bit types, we shuffle high and low parts separately
585539
macro_rules! impl_shuffle_64 {
586540
($($ty:ty),* $(,)?) => {
587541
$(
588542
impl ShuffleValue for $ty {
589-
#[gpu_only]
590-
unsafe fn shuffle_down(mask: WarpMask, value: Self, delta: u32, width: u32) -> ShuffleResult<Self> {
591-
let lo = (value & 0xFFFFFFFF) as u32;
592-
let hi = (value >> 32) as u32;
593-
let lo_result = <u32 as ShuffleValue>::shuffle_down(mask, lo, delta, width);
594-
let hi_result = <u32 as ShuffleValue>::shuffle_down(mask, hi, delta, width);
595-
// Both parts must be valid
596-
let valid = lo_result.valid && hi_result.valid;
597-
let value = ((hi_result.value as $ty) << 32) | (lo_result.value as $ty);
598-
ShuffleResult::new(value, valid)
599-
}
600-
601-
#[gpu_only]
602-
unsafe fn shuffle_up(mask: WarpMask, value: Self, delta: u32, width: u32) -> ShuffleResult<Self> {
603-
let lo = (value & 0xFFFFFFFF) as u32;
604-
let hi = (value >> 32) as u32;
605-
let lo_result = <u32 as ShuffleValue>::shuffle_up(mask, lo, delta, width);
606-
let hi_result = <u32 as ShuffleValue>::shuffle_up(mask, hi, delta, width);
607-
let valid = lo_result.valid && hi_result.valid;
608-
let value = ((hi_result.value as $ty) << 32) | (lo_result.value as $ty);
609-
ShuffleResult::new(value, valid)
610-
}
611-
612-
#[gpu_only]
613-
unsafe fn shuffle_xor(mask: WarpMask, value: Self, lane_mask: u32, width: u32) -> ShuffleResult<Self> {
614-
let lo = (value & 0xFFFFFFFF) as u32;
615-
let hi = (value >> 32) as u32;
616-
let lo_result = <u32 as ShuffleValue>::shuffle_xor(mask, lo, lane_mask, width);
617-
let hi_result = <u32 as ShuffleValue>::shuffle_xor(mask, hi, lane_mask, width);
618-
let valid = lo_result.valid && hi_result.valid;
619-
let value = ((hi_result.value as $ty) << 32) | (lo_result.value as $ty);
620-
ShuffleResult::new(value, valid)
621-
}
622-
623-
#[gpu_only]
624-
unsafe fn shuffle_idx(mask: WarpMask, value: Self, src_lane: u32, width: u32) -> ShuffleResult<Self> {
625-
let lo = (value & 0xFFFFFFFF) as u32;
626-
let hi = (value >> 32) as u32;
627-
let lo_result = <u32 as ShuffleValue>::shuffle_idx(mask, lo, src_lane, width);
628-
let hi_result = <u32 as ShuffleValue>::shuffle_idx(mask, hi, src_lane, width);
629-
let valid = lo_result.valid && hi_result.valid;
630-
let value = ((hi_result.value as $ty) << 32) | (lo_result.value as $ty);
631-
ShuffleResult::new(value, valid)
632-
}
543+
impl_shuffle_64_op!(shuffle_down);
544+
impl_shuffle_64_op!(shuffle_up);
545+
impl_shuffle_64_op!(shuffle_xor);
546+
impl_shuffle_64_op!(shuffle_idx);
633547
}
634548
)*
635549
};
@@ -641,83 +555,32 @@ impl_shuffle_64! {
641555

642556
// For f64, we need to handle the bit pattern correctly
643557
impl ShuffleValue for f64 {
644-
#[gpu_only]
645-
unsafe fn shuffle_down(
646-
mask: WarpMask,
647-
value: Self,
648-
delta: u32,
649-
width: u32,
650-
) -> ShuffleResult<Self> {
651-
let bits = value.to_bits();
652-
let result = <u64 as ShuffleValue>::shuffle_down(mask, bits, delta, width);
653-
ShuffleResult::new(f64::from_bits(result.value), result.valid)
654-
}
655-
656-
#[gpu_only]
657-
unsafe fn shuffle_up(
658-
mask: WarpMask,
659-
value: Self,
660-
delta: u32,
661-
width: u32,
662-
) -> ShuffleResult<Self> {
663-
let bits = value.to_bits();
664-
let result = <u64 as ShuffleValue>::shuffle_up(mask, bits, delta, width);
665-
ShuffleResult::new(f64::from_bits(result.value), result.valid)
666-
}
667-
668-
#[gpu_only]
669-
unsafe fn shuffle_xor(
670-
mask: WarpMask,
671-
value: Self,
672-
lane_mask: u32,
673-
width: u32,
674-
) -> ShuffleResult<Self> {
675-
let bits = value.to_bits();
676-
let result = <u64 as ShuffleValue>::shuffle_xor(mask, bits, lane_mask, width);
677-
ShuffleResult::new(f64::from_bits(result.value), result.valid)
678-
}
558+
impl_shuffle_float_op!(shuffle_down, u64, to_bits, from_bits);
559+
impl_shuffle_float_op!(shuffle_up, u64, to_bits, from_bits);
560+
impl_shuffle_float_op!(shuffle_xor, u64, to_bits, from_bits);
561+
impl_shuffle_float_op!(shuffle_idx, u64, to_bits, from_bits);
562+
}
679563

680-
#[gpu_only]
681-
unsafe fn shuffle_idx(
682-
mask: WarpMask,
683-
value: Self,
684-
src_lane: u32,
685-
width: u32,
686-
) -> ShuffleResult<Self> {
687-
let bits = value.to_bits();
688-
let result = <u64 as ShuffleValue>::shuffle_idx(mask, bits, src_lane, width);
689-
ShuffleResult::new(f64::from_bits(result.value), result.valid)
690-
}
564+
// Generic macro for small type shuffle operations
565+
macro_rules! impl_shuffle_small_op {
566+
($fn_name:ident) => {
567+
#[gpu_only]
568+
unsafe fn $fn_name(mask: WarpMask, value: Self, param: u32, width: u32) -> ShuffleResult<Self> {
569+
let result = <u32 as ShuffleValue>::$fn_name(mask, value as u32, param, width);
570+
ShuffleResult::new(result.value as Self, result.valid)
571+
}
572+
};
691573
}
692574

693575
// For smaller types, we pack them into 32-bit values
694576
macro_rules! impl_shuffle_small {
695577
($($ty:ty),* $(,)?) => {
696578
$(
697579
impl ShuffleValue for $ty {
698-
#[gpu_only]
699-
unsafe fn shuffle_down(mask: WarpMask, value: Self, delta: u32, width: u32) -> ShuffleResult<Self> {
700-
let result = <u32 as ShuffleValue>::shuffle_down(mask, value as u32, delta, width);
701-
ShuffleResult::new(result.value as Self, result.valid)
702-
}
703-
704-
#[gpu_only]
705-
unsafe fn shuffle_up(mask: WarpMask, value: Self, delta: u32, width: u32) -> ShuffleResult<Self> {
706-
let result = <u32 as ShuffleValue>::shuffle_up(mask, value as u32, delta, width);
707-
ShuffleResult::new(result.value as Self, result.valid)
708-
}
709-
710-
#[gpu_only]
711-
unsafe fn shuffle_xor(mask: WarpMask, value: Self, lane_mask: u32, width: u32) -> ShuffleResult<Self> {
712-
let result = <u32 as ShuffleValue>::shuffle_xor(mask, value as u32, lane_mask, width);
713-
ShuffleResult::new(result.value as Self, result.valid)
714-
}
715-
716-
#[gpu_only]
717-
unsafe fn shuffle_idx(mask: WarpMask, value: Self, src_lane: u32, width: u32) -> ShuffleResult<Self> {
718-
let result = <u32 as ShuffleValue>::shuffle_idx(mask, value as u32, src_lane, width);
719-
ShuffleResult::new(result.value as Self, result.valid)
720-
}
580+
impl_shuffle_small_op!(shuffle_down);
581+
impl_shuffle_small_op!(shuffle_up);
582+
impl_shuffle_small_op!(shuffle_xor);
583+
impl_shuffle_small_op!(shuffle_idx);
721584
}
722585
)*
723586
};

0 commit comments

Comments
 (0)