@@ -1652,8 +1652,8 @@ def _reshape_lowering_rule(
1652
1652
)
1653
1653
1654
1654
1655
- def _compute_offsets_from_indices (
1656
- block_info : BlockInfo , nd_indexer : NDIndexer
1655
+ def _compute_pointers_from_indices (
1656
+ root_ptr : ir . Value , block_info : BlockInfo , nd_indexer : NDIndexer
1657
1657
) -> ir .Value :
1658
1658
full_shape = block_info .full_shape_dtype .shape
1659
1659
num_mapped_dims = sum (b is pallas_core .mapped for b in block_info .block_shape )
@@ -1732,14 +1732,7 @@ def _compute_offsets_from_indices(
1732
1732
dim_offsets = _mul (dim_offsets , _full (dim_offsets .type , dim_stride ))
1733
1733
offsets = _add (offsets , dim_offsets )
1734
1734
1735
- return offsets
1736
-
1737
-
1738
- def _compute_pointers_from_indices (
1739
- root_ptr : ir .Value , block_info : BlockInfo , nd_indexer : NDIndexer
1740
- ) -> ir .Value :
1741
- offsets = _compute_offsets_from_indices (block_info , nd_indexer )
1742
- return _add (_bcast_to (root_ptr , nd_indexer .get_indexer_shape ()), offsets )
1735
+ return _add (_bcast_to (root_ptr , indexer_shape ), offsets )
1743
1736
1744
1737
1745
1738
@register_lowering (sp .get_p )
@@ -1855,20 +1848,14 @@ def _masked_load_lowering_rule(
1855
1848
if not tt_dialect .PointerType .isinstance (ptr .type ):
1856
1849
assert len (ctx .avals_in ) == 1
1857
1850
return ptr
1858
-
1859
- offsets = _compute_offsets_from_indices (block_info , idx )
1860
- ptr_offsets = offsets
1861
-
1862
- if block_info .full_shape_dtype .dtype in (jnp .int4 , jnp .uint4 ):
1863
- ptr_offsets = _floordiv (offsets , _full (offsets .type , 2 ), signed = False )
1864
-
1865
- shape = idx .get_indexer_shape ()
1866
- ptr = _add (_bcast_to (ptr , shape ), ptr_offsets )
1851
+ ptr = _compute_pointers_from_indices (ptr , block_info , idx )
1867
1852
if mask is not None :
1868
- mask = _bcast_to (_ensure_ir_value (mask , mask_aval ), shape )
1853
+ mask = _bcast_to (_ensure_ir_value (mask , mask_aval ), idx . get_indexer_shape () )
1869
1854
if other is not None :
1870
- other = _bcast_to (_ensure_ir_value (other , other_aval ), shape )
1871
- values = _load (
1855
+ other = _bcast_to (
1856
+ _ensure_ir_value (other , other_aval ), idx .get_indexer_shape ()
1857
+ )
1858
+ return _load (
1872
1859
ptr ,
1873
1860
mask = mask ,
1874
1861
other = other ,
@@ -1877,19 +1864,6 @@ def _masked_load_lowering_rule(
1877
1864
eviction_policy = eviction_policy ,
1878
1865
)
1879
1866
1880
- if block_info .full_shape_dtype .dtype not in (jnp .int4 , jnp .uint4 ):
1881
- return values
1882
-
1883
- # XLA packs pairs of `[u]int4` values into a `uint8` value with the first
1884
- # in the most significant bits and the second in the least significant.
1885
- offsets = _ir_cast (offsets , ir .IntegerType .get_signless (32 ), signed = False )
1886
- in_lsb = _mod (offsets , _full (offsets .type , 2 ), signed = False )
1887
- in_msb = arith_dialect .xori (in_lsb , _full (in_lsb .type , 1 ))
1888
- shift = _mul (in_msb , _full (in_msb .type , 4 ))
1889
- shift = _ir_cast (shift , values .type , signed = False )
1890
- values = arith_dialect .shrui (values , shift )
1891
- return _ir_cast (values , ir .IntegerType .get_signless (4 ), signed = False )
1892
-
1893
1867
1894
1868
@register_lowering (sp .swap_p )
1895
1869
def _swap_lowering_rule (ctx : LoweringRuleContext , ptr , value , * idx , tree ):
0 commit comments