@@ -966,16 +966,11 @@ def load_strided(
966
966
)
967
967
else :
968
968
layout = WGStridedFragLayout (shape = shape , vec_size = vec_size )
969
+ registers = np .empty (layout .registers_shape (shape ), dtype = object )
969
970
vec_ty = ir .VectorType .get ((layout .vec_size ,), ref_ty .element_type )
970
- try :
971
- # Flattening the reference potentially produces simpler PTX but
972
- # if the ref is not already 1D and has strided dimensions
973
- # flattening won't work.
974
- ref_ = mgpu .memref_fold (ref , 0 , len (ref_ty .shape ))
975
- vecs = [vector .load (vec_ty , ref_ , [vec_idx ]) for vec_idx in layout .linear_thread_idxs ()]
976
- except NotImplementedError :
977
- vecs = [vector .load (vec_ty , ref , vec_idx ) for vec_idx in layout .thread_idxs (shape )]
978
- return cls (_registers = np .array (vecs ), _layout = layout , _is_signed = is_signed )
971
+ for _get , update , ref , idx in cls .transfer_strided (ref , layout .vec_size ):
972
+ update (registers , vector .load (vec_ty , ref , idx ))
973
+ return cls (_registers = registers , _layout = layout , _is_signed = is_signed )
979
974
980
975
@classmethod
981
976
def splat (
@@ -2579,8 +2574,10 @@ def store_untiled(
2579
2574
if isinstance (ref , utils .MultimemRef ):
2580
2575
raise NotImplementedError ("Strided layout does not support multimem" )
2581
2576
if swizzle != 16 :
2582
- raise NotImplementedError
2583
- self ._store_untiled_wg_strided (ref )
2577
+ raise ValueError ("Only TiledLayouts support swizzling" )
2578
+ assert isinstance (self .layout , WGStridedFragLayout )
2579
+ for get , _update , ref , idx in self .transfer_strided (ref , self .layout .vec_size ):
2580
+ vector .store (get (self .registers ), ref , idx )
2584
2581
case TiledLayout ():
2585
2582
ref_shape = ir .MemRefType (ref .type ).shape
2586
2583
ref = utils .memref_reshape (ref , (* (1 for _ in ref_shape ), * ref_shape ))
@@ -2621,8 +2618,8 @@ def load_untiled(
2621
2618
is_signed : bool | None = None ,
2622
2619
optimized : bool = True ,
2623
2620
) -> FragmentedArray :
2624
- ref_shape = ir .MemRefType (ref .type ). shape
2625
- ref = utils .memref_reshape (ref , (* (1 for _ in ref_shape ), * ref_shape ))
2621
+ ref_ty = ir .MemRefType (ref .type )
2622
+ ref = utils .memref_reshape (ref , (* (1 for _ in ref_ty . shape ), * ref_ty . shape ))
2626
2623
return cls .load_tiled (
2627
2624
ref , swizzle = swizzle , is_signed = is_signed , layout = layout , optimized = optimized
2628
2625
)
@@ -2653,27 +2650,6 @@ def _store_untiled_splat(self, ref: ir.Value):
2653
2650
)
2654
2651
fa .store_untiled (ref )
2655
2652
2656
- def _store_untiled_wg_strided (self , ref : ir .Value ):
2657
- assert isinstance (self .layout , WGStridedFragLayout )
2658
- ref_ty = ir .MemRefType (ref .type )
2659
- idxs : Iterable [Sequence [ir .Value ]]
2660
- try :
2661
- # Flattening the reference potentially produces simpler PTX but
2662
- # if the ref is not already 1D and has strided dimensions
2663
- # flattening won't work. We use a different variable for ref in
2664
- # case `NotImplementedError` is thrown by
2665
- # .linear_thread_idxs().
2666
- ref_ = mgpu .memref_fold (ref , 0 , len (ref_ty .shape ))
2667
- idxs = ((i ,) for i in self .layout .linear_thread_idxs ())
2668
- except NotImplementedError :
2669
- ref_ = ref
2670
- idxs = self .layout .thread_idxs (self .shape )
2671
- ref_shape = tuple (ref_ty .shape )
2672
- if ref_shape != self .shape :
2673
- raise ValueError ((ref_shape , self .shape ))
2674
- for idx , reg in zip (idxs , self .registers .flat ):
2675
- vector .store (reg , ref_ , idx )
2676
-
2677
2653
def store_tiled (self , ref : ir .Value | utils .MultimemRef , swizzle : int | None , optimized : bool = True ):
2678
2654
if not isinstance (self .layout , TiledLayout ):
2679
2655
raise NotImplementedError (self .layout )
@@ -2731,6 +2707,51 @@ def load_tiled(
2731
2707
update (registers , loaded_reg )
2732
2708
return cls (_registers = registers , _layout = layout , _is_signed = is_signed )
2733
2709
2710
+ @classmethod
2711
+ def transfer_strided (self , ref : ir .Value , vec_size : int ):
2712
+ ref_ty = ir .MemRefType (ref .type )
2713
+ layout = WGStridedFragLayout (shape = tuple (ref_ty .shape ), vec_size = vec_size )
2714
+ try :
2715
+ # Flattening the reference potentially produces simpler PTX but
2716
+ # if the ref is not already 1D and has strided dimensions
2717
+ # flattening won't work.
2718
+ ref = mgpu .memref_fold (ref , 0 , len (ref_ty .shape ))
2719
+ except ValueError :
2720
+ strides , _ = ref_ty .get_strides_and_offset ()
2721
+ if vec_size > 1 :
2722
+ # TODO(apaszke): We could fold all the pairs of dims that are contiguous
2723
+ # This check is a too strict if we don't do that.
2724
+ has_contiguous_dim = False
2725
+ for size , stride in zip (ref_ty .shape , strides ):
2726
+ if stride == 1 :
2727
+ has_contiguous_dim = True
2728
+ if size % vec_size != 0 :
2729
+ raise ValueError (
2730
+ "The contiguous dimension of the reference must be a"
2731
+ f" multiple of the layout's vector size (got { size } and"
2732
+ f" vector size { vec_size } )"
2733
+ ) from None
2734
+ elif size > 1 :
2735
+ if stride % vec_size != 0 :
2736
+ raise ValueError (
2737
+ "Non-contiguous dimension of the reference must have strides"
2738
+ " that are multiples of the layout's vector size (got"
2739
+ f" { stride } and vector size { vec_size } )"
2740
+ ) from None
2741
+ if not has_contiguous_dim :
2742
+ raise ValueError (
2743
+ "The reference must have a contiguous dimension when vec_size > 1"
2744
+ )
2745
+ idx_gen = layout .thread_idxs (tuple (ref_ty .shape ))
2746
+ else :
2747
+ idx_gen = map (lambda x : [x ], layout .linear_thread_idxs ())
2748
+ for i , vec_idx in enumerate (idx_gen ):
2749
+ def update (registers , reg , _i = i ):
2750
+ registers [_i ] = reg
2751
+ def get (registers , _i = i ):
2752
+ return registers [_i ]
2753
+ yield get , update , ref , vec_idx
2754
+
2734
2755
@staticmethod
2735
2756
def transfer_tiled (
2736
2757
ref : ir .Value ,
0 commit comments