@@ -469,58 +469,26 @@ const SHUFFLE_MODE_UP: u32 = 1;
469469const SHUFFLE_MODE_XOR : u32 = 2 ;
470470const SHUFFLE_MODE_IDX : u32 = 3 ;
471471
472+ // Generic macro to implement a single shuffle operation
473+ macro_rules! impl_shuffle_op {
474+ ( $fn_name: ident, $mode: expr) => {
475+ #[ gpu_only]
476+ unsafe fn $fn_name( mask: WarpMask , value: Self , param: u32 , width: u32 ) -> ShuffleResult <Self > {
477+ let result = warp_shuffle_32( $mode, mask. raw( ) , value as u32 , param, width) ;
478+ ShuffleResult :: new( result. value as Self , result. valid)
479+ }
480+ } ;
481+ }
482+
472483// Macro to implement shuffle for 32-bit types
473484macro_rules! impl_shuffle_32 {
474485 ( $( $ty: ty) ,* $( , ) ?) => {
475486 $(
476487 impl ShuffleValue for $ty {
477- #[ gpu_only]
478- unsafe fn shuffle_down( mask: WarpMask , value: Self , delta: u32 , width: u32 ) -> ShuffleResult <Self > {
479- let result = warp_shuffle_32(
480- SHUFFLE_MODE_DOWN ,
481- mask. raw( ) ,
482- value as u32 ,
483- delta,
484- width
485- ) ;
486- ShuffleResult :: new( result. value as Self , result. valid)
487- }
488-
489- #[ gpu_only]
490- unsafe fn shuffle_up( mask: WarpMask , value: Self , delta: u32 , width: u32 ) -> ShuffleResult <Self > {
491- let result = warp_shuffle_32(
492- SHUFFLE_MODE_UP ,
493- mask. raw( ) ,
494- value as u32 ,
495- delta,
496- width
497- ) ;
498- ShuffleResult :: new( result. value as Self , result. valid)
499- }
500-
501- #[ gpu_only]
502- unsafe fn shuffle_xor( mask: WarpMask , value: Self , lane_mask: u32 , width: u32 ) -> ShuffleResult <Self > {
503- let result = warp_shuffle_32(
504- SHUFFLE_MODE_XOR ,
505- mask. raw( ) ,
506- value as u32 ,
507- lane_mask,
508- width
509- ) ;
510- ShuffleResult :: new( result. value as Self , result. valid)
511- }
512-
513- #[ gpu_only]
514- unsafe fn shuffle_idx( mask: WarpMask , value: Self , src_lane: u32 , width: u32 ) -> ShuffleResult <Self > {
515- let result = warp_shuffle_32(
516- SHUFFLE_MODE_IDX ,
517- mask. raw( ) ,
518- value as u32 ,
519- src_lane,
520- width
521- ) ;
522- ShuffleResult :: new( result. value as Self , result. valid)
523- }
488+ impl_shuffle_op!( shuffle_down, SHUFFLE_MODE_DOWN ) ;
489+ impl_shuffle_op!( shuffle_up, SHUFFLE_MODE_UP ) ;
490+ impl_shuffle_op!( shuffle_xor, SHUFFLE_MODE_XOR ) ;
491+ impl_shuffle_op!( shuffle_idx, SHUFFLE_MODE_IDX ) ;
524492 }
525493 ) *
526494 } ;
@@ -530,106 +498,52 @@ impl_shuffle_32! {
530498 i32 , u32 ,
531499}
532500
501+ // Macro for floating-point shuffle operations using bit conversion
502+ macro_rules! impl_shuffle_float_op {
503+ ( $fn_name: ident, $bits_ty: ty, $to_bits: ident, $from_bits: ident) => {
504+ #[ gpu_only]
505+ unsafe fn $fn_name( mask: WarpMask , value: Self , param: u32 , width: u32 ) -> ShuffleResult <Self > {
506+ let bits = value. $to_bits( ) ;
507+ let result = <$bits_ty as ShuffleValue >:: $fn_name( mask, bits, param, width) ;
508+ ShuffleResult :: new( Self :: $from_bits( result. value) , result. valid)
509+ }
510+ } ;
511+ }
512+
533513// Special case for f32 to preserve bit pattern
534514impl ShuffleValue for f32 {
535- #[ gpu_only]
536- unsafe fn shuffle_down (
537- mask : WarpMask ,
538- value : Self ,
539- delta : u32 ,
540- width : u32 ,
541- ) -> ShuffleResult < Self > {
542- let bits = value. to_bits ( ) ;
543- let result = <u32 as ShuffleValue >:: shuffle_down ( mask, bits, delta, width) ;
544- ShuffleResult :: new ( f32:: from_bits ( result. value ) , result. valid )
545- }
546-
547- #[ gpu_only]
548- unsafe fn shuffle_up (
549- mask : WarpMask ,
550- value : Self ,
551- delta : u32 ,
552- width : u32 ,
553- ) -> ShuffleResult < Self > {
554- let bits = value. to_bits ( ) ;
555- let result = <u32 as ShuffleValue >:: shuffle_up ( mask, bits, delta, width) ;
556- ShuffleResult :: new ( f32:: from_bits ( result. value ) , result. valid )
557- }
558-
559- #[ gpu_only]
560- unsafe fn shuffle_xor (
561- mask : WarpMask ,
562- value : Self ,
563- lane_mask : u32 ,
564- width : u32 ,
565- ) -> ShuffleResult < Self > {
566- let bits = value. to_bits ( ) ;
567- let result = <u32 as ShuffleValue >:: shuffle_xor ( mask, bits, lane_mask, width) ;
568- ShuffleResult :: new ( f32:: from_bits ( result. value ) , result. valid )
569- }
515+ impl_shuffle_float_op ! ( shuffle_down, u32 , to_bits, from_bits) ;
516+ impl_shuffle_float_op ! ( shuffle_up, u32 , to_bits, from_bits) ;
517+ impl_shuffle_float_op ! ( shuffle_xor, u32 , to_bits, from_bits) ;
518+ impl_shuffle_float_op ! ( shuffle_idx, u32 , to_bits, from_bits) ;
519+ }
570520
571- #[ gpu_only]
572- unsafe fn shuffle_idx (
573- mask : WarpMask ,
574- value : Self ,
575- src_lane : u32 ,
576- width : u32 ,
577- ) -> ShuffleResult < Self > {
578- let bits = value. to_bits ( ) ;
579- let result = <u32 as ShuffleValue >:: shuffle_idx ( mask, bits, src_lane, width) ;
580- ShuffleResult :: new ( f32:: from_bits ( result. value ) , result. valid )
581- }
521+ // Generic macro for 64-bit shuffle operations
522+ macro_rules! impl_shuffle_64_op {
523+ ( $fn_name: ident) => {
524+ #[ gpu_only]
525+ unsafe fn $fn_name( mask: WarpMask , value: Self , param: u32 , width: u32 ) -> ShuffleResult <Self > {
526+ let lo = ( value & 0xFFFFFFFF ) as u32 ;
527+ let hi = ( value >> 32 ) as u32 ;
528+ let lo_result = <u32 as ShuffleValue >:: $fn_name( mask, lo, param, width) ;
529+ let hi_result = <u32 as ShuffleValue >:: $fn_name( mask, hi, param, width) ;
530+ // Both parts must be valid
531+ let valid = lo_result. valid && hi_result. valid;
532+ let value = ( ( hi_result. value as Self ) << 32 ) | ( lo_result. value as Self ) ;
533+ ShuffleResult :: new( value, valid)
534+ }
535+ } ;
582536}
583537
584538// For 64-bit types, we shuffle high and low parts separately
585539macro_rules! impl_shuffle_64 {
586540 ( $( $ty: ty) ,* $( , ) ?) => {
587541 $(
588542 impl ShuffleValue for $ty {
589- #[ gpu_only]
590- unsafe fn shuffle_down( mask: WarpMask , value: Self , delta: u32 , width: u32 ) -> ShuffleResult <Self > {
591- let lo = ( value & 0xFFFFFFFF ) as u32 ;
592- let hi = ( value >> 32 ) as u32 ;
593- let lo_result = <u32 as ShuffleValue >:: shuffle_down( mask, lo, delta, width) ;
594- let hi_result = <u32 as ShuffleValue >:: shuffle_down( mask, hi, delta, width) ;
595- // Both parts must be valid
596- let valid = lo_result. valid && hi_result. valid;
597- let value = ( ( hi_result. value as $ty) << 32 ) | ( lo_result. value as $ty) ;
598- ShuffleResult :: new( value, valid)
599- }
600-
601- #[ gpu_only]
602- unsafe fn shuffle_up( mask: WarpMask , value: Self , delta: u32 , width: u32 ) -> ShuffleResult <Self > {
603- let lo = ( value & 0xFFFFFFFF ) as u32 ;
604- let hi = ( value >> 32 ) as u32 ;
605- let lo_result = <u32 as ShuffleValue >:: shuffle_up( mask, lo, delta, width) ;
606- let hi_result = <u32 as ShuffleValue >:: shuffle_up( mask, hi, delta, width) ;
607- let valid = lo_result. valid && hi_result. valid;
608- let value = ( ( hi_result. value as $ty) << 32 ) | ( lo_result. value as $ty) ;
609- ShuffleResult :: new( value, valid)
610- }
611-
612- #[ gpu_only]
613- unsafe fn shuffle_xor( mask: WarpMask , value: Self , lane_mask: u32 , width: u32 ) -> ShuffleResult <Self > {
614- let lo = ( value & 0xFFFFFFFF ) as u32 ;
615- let hi = ( value >> 32 ) as u32 ;
616- let lo_result = <u32 as ShuffleValue >:: shuffle_xor( mask, lo, lane_mask, width) ;
617- let hi_result = <u32 as ShuffleValue >:: shuffle_xor( mask, hi, lane_mask, width) ;
618- let valid = lo_result. valid && hi_result. valid;
619- let value = ( ( hi_result. value as $ty) << 32 ) | ( lo_result. value as $ty) ;
620- ShuffleResult :: new( value, valid)
621- }
622-
623- #[ gpu_only]
624- unsafe fn shuffle_idx( mask: WarpMask , value: Self , src_lane: u32 , width: u32 ) -> ShuffleResult <Self > {
625- let lo = ( value & 0xFFFFFFFF ) as u32 ;
626- let hi = ( value >> 32 ) as u32 ;
627- let lo_result = <u32 as ShuffleValue >:: shuffle_idx( mask, lo, src_lane, width) ;
628- let hi_result = <u32 as ShuffleValue >:: shuffle_idx( mask, hi, src_lane, width) ;
629- let valid = lo_result. valid && hi_result. valid;
630- let value = ( ( hi_result. value as $ty) << 32 ) | ( lo_result. value as $ty) ;
631- ShuffleResult :: new( value, valid)
632- }
543+ impl_shuffle_64_op!( shuffle_down) ;
544+ impl_shuffle_64_op!( shuffle_up) ;
545+ impl_shuffle_64_op!( shuffle_xor) ;
546+ impl_shuffle_64_op!( shuffle_idx) ;
633547 }
634548 ) *
635549 } ;
@@ -641,83 +555,32 @@ impl_shuffle_64! {
641555
642556// For f64, we need to handle the bit pattern correctly
643557impl ShuffleValue for f64 {
644- #[ gpu_only]
645- unsafe fn shuffle_down (
646- mask : WarpMask ,
647- value : Self ,
648- delta : u32 ,
649- width : u32 ,
650- ) -> ShuffleResult < Self > {
651- let bits = value. to_bits ( ) ;
652- let result = <u64 as ShuffleValue >:: shuffle_down ( mask, bits, delta, width) ;
653- ShuffleResult :: new ( f64:: from_bits ( result. value ) , result. valid )
654- }
655-
656- #[ gpu_only]
657- unsafe fn shuffle_up (
658- mask : WarpMask ,
659- value : Self ,
660- delta : u32 ,
661- width : u32 ,
662- ) -> ShuffleResult < Self > {
663- let bits = value. to_bits ( ) ;
664- let result = <u64 as ShuffleValue >:: shuffle_up ( mask, bits, delta, width) ;
665- ShuffleResult :: new ( f64:: from_bits ( result. value ) , result. valid )
666- }
667-
668- #[ gpu_only]
669- unsafe fn shuffle_xor (
670- mask : WarpMask ,
671- value : Self ,
672- lane_mask : u32 ,
673- width : u32 ,
674- ) -> ShuffleResult < Self > {
675- let bits = value. to_bits ( ) ;
676- let result = <u64 as ShuffleValue >:: shuffle_xor ( mask, bits, lane_mask, width) ;
677- ShuffleResult :: new ( f64:: from_bits ( result. value ) , result. valid )
678- }
558+ impl_shuffle_float_op ! ( shuffle_down, u64 , to_bits, from_bits) ;
559+ impl_shuffle_float_op ! ( shuffle_up, u64 , to_bits, from_bits) ;
560+ impl_shuffle_float_op ! ( shuffle_xor, u64 , to_bits, from_bits) ;
561+ impl_shuffle_float_op ! ( shuffle_idx, u64 , to_bits, from_bits) ;
562+ }
679563
680- #[ gpu_only]
681- unsafe fn shuffle_idx (
682- mask : WarpMask ,
683- value : Self ,
684- src_lane : u32 ,
685- width : u32 ,
686- ) -> ShuffleResult < Self > {
687- let bits = value. to_bits ( ) ;
688- let result = <u64 as ShuffleValue >:: shuffle_idx ( mask, bits, src_lane, width) ;
689- ShuffleResult :: new ( f64:: from_bits ( result. value ) , result. valid )
690- }
564+ // Generic macro for small type shuffle operations
565+ macro_rules! impl_shuffle_small_op {
566+ ( $fn_name: ident) => {
567+ #[ gpu_only]
568+ unsafe fn $fn_name( mask: WarpMask , value: Self , param: u32 , width: u32 ) -> ShuffleResult <Self > {
569+ let result = <u32 as ShuffleValue >:: $fn_name( mask, value as u32 , param, width) ;
570+ ShuffleResult :: new( result. value as Self , result. valid)
571+ }
572+ } ;
691573}
692574
693575// For smaller types, we pack them into 32-bit values
694576macro_rules! impl_shuffle_small {
695577 ( $( $ty: ty) ,* $( , ) ?) => {
696578 $(
697579 impl ShuffleValue for $ty {
698- #[ gpu_only]
699- unsafe fn shuffle_down( mask: WarpMask , value: Self , delta: u32 , width: u32 ) -> ShuffleResult <Self > {
700- let result = <u32 as ShuffleValue >:: shuffle_down( mask, value as u32 , delta, width) ;
701- ShuffleResult :: new( result. value as Self , result. valid)
702- }
703-
704- #[ gpu_only]
705- unsafe fn shuffle_up( mask: WarpMask , value: Self , delta: u32 , width: u32 ) -> ShuffleResult <Self > {
706- let result = <u32 as ShuffleValue >:: shuffle_up( mask, value as u32 , delta, width) ;
707- ShuffleResult :: new( result. value as Self , result. valid)
708- }
709-
710- #[ gpu_only]
711- unsafe fn shuffle_xor( mask: WarpMask , value: Self , lane_mask: u32 , width: u32 ) -> ShuffleResult <Self > {
712- let result = <u32 as ShuffleValue >:: shuffle_xor( mask, value as u32 , lane_mask, width) ;
713- ShuffleResult :: new( result. value as Self , result. valid)
714- }
715-
716- #[ gpu_only]
717- unsafe fn shuffle_idx( mask: WarpMask , value: Self , src_lane: u32 , width: u32 ) -> ShuffleResult <Self > {
718- let result = <u32 as ShuffleValue >:: shuffle_idx( mask, value as u32 , src_lane, width) ;
719- ShuffleResult :: new( result. value as Self , result. valid)
720- }
580+ impl_shuffle_small_op!( shuffle_down) ;
581+ impl_shuffle_small_op!( shuffle_up) ;
582+ impl_shuffle_small_op!( shuffle_xor) ;
583+ impl_shuffle_small_op!( shuffle_idx) ;
721584 }
722585 ) *
723586 } ;
0 commit comments