Skip to content

Commit dbe9ccd

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
Reverts 83e60a9
PiperOrigin-RevId: 711403091
1 parent 4a6cfeb commit dbe9ccd

File tree

2 files changed

+9
-49
lines changed

2 files changed

+9
-49
lines changed

jax/_src/pallas/triton/lowering.py

+9-35
Original file line numberDiff line numberDiff line change
@@ -1652,8 +1652,8 @@ def _reshape_lowering_rule(
16521652
)
16531653

16541654

1655-
def _compute_offsets_from_indices(
1656-
block_info: BlockInfo, nd_indexer: NDIndexer
1655+
def _compute_pointers_from_indices(
1656+
root_ptr: ir.Value, block_info: BlockInfo, nd_indexer: NDIndexer
16571657
) -> ir.Value:
16581658
full_shape = block_info.full_shape_dtype.shape
16591659
num_mapped_dims = sum(b is pallas_core.mapped for b in block_info.block_shape)
@@ -1732,14 +1732,7 @@ def _compute_offsets_from_indices(
17321732
dim_offsets = _mul(dim_offsets, _full(dim_offsets.type, dim_stride))
17331733
offsets = _add(offsets, dim_offsets)
17341734

1735-
return offsets
1736-
1737-
1738-
def _compute_pointers_from_indices(
1739-
root_ptr: ir.Value, block_info: BlockInfo, nd_indexer: NDIndexer
1740-
) -> ir.Value:
1741-
offsets = _compute_offsets_from_indices(block_info, nd_indexer)
1742-
return _add(_bcast_to(root_ptr, nd_indexer.get_indexer_shape()), offsets)
1735+
return _add(_bcast_to(root_ptr, indexer_shape), offsets)
17431736

17441737

17451738
@register_lowering(sp.get_p)
@@ -1855,20 +1848,14 @@ def _masked_load_lowering_rule(
18551848
if not tt_dialect.PointerType.isinstance(ptr.type):
18561849
assert len(ctx.avals_in) == 1
18571850
return ptr
1858-
1859-
offsets = _compute_offsets_from_indices(block_info, idx)
1860-
ptr_offsets = offsets
1861-
1862-
if block_info.full_shape_dtype.dtype in (jnp.int4, jnp.uint4):
1863-
ptr_offsets = _floordiv(offsets, _full(offsets.type, 2), signed=False)
1864-
1865-
shape = idx.get_indexer_shape()
1866-
ptr = _add(_bcast_to(ptr, shape), ptr_offsets)
1851+
ptr = _compute_pointers_from_indices(ptr, block_info, idx)
18671852
if mask is not None:
1868-
mask = _bcast_to(_ensure_ir_value(mask, mask_aval), shape)
1853+
mask = _bcast_to(_ensure_ir_value(mask, mask_aval), idx.get_indexer_shape())
18691854
if other is not None:
1870-
other = _bcast_to(_ensure_ir_value(other, other_aval), shape)
1871-
values = _load(
1855+
other = _bcast_to(
1856+
_ensure_ir_value(other, other_aval), idx.get_indexer_shape()
1857+
)
1858+
return _load(
18721859
ptr,
18731860
mask=mask,
18741861
other=other,
@@ -1877,19 +1864,6 @@ def _masked_load_lowering_rule(
18771864
eviction_policy=eviction_policy,
18781865
)
18791866

1880-
if block_info.full_shape_dtype.dtype not in (jnp.int4, jnp.uint4):
1881-
return values
1882-
1883-
# XLA packs pairs of `[u]int4` values into a `uint8` value with the first
1884-
# in the most significant bits and the second in the least significant.
1885-
offsets = _ir_cast(offsets, ir.IntegerType.get_signless(32), signed=False)
1886-
in_lsb = _mod(offsets, _full(offsets.type, 2), signed=False)
1887-
in_msb = arith_dialect.xori(in_lsb, _full(in_lsb.type, 1))
1888-
shift = _mul(in_msb, _full(in_msb.type, 4))
1889-
shift = _ir_cast(shift, values.type, signed=False)
1890-
values = arith_dialect.shrui(values, shift)
1891-
return _ir_cast(values, ir.IntegerType.get_signless(4), signed=False)
1892-
18931867

18941868
@register_lowering(sp.swap_p)
18951869
def _swap_lowering_rule(ctx: LoweringRuleContext, ptr, value, *idx, tree):

tests/pallas/pallas_test.py

-14
Original file line numberDiff line numberDiff line change
@@ -725,20 +725,6 @@ def dot_kernel(x_ref, y_ref, o_ref):
725725
)
726726
self.assertAllClose(dot_kernel(x, y), expected, atol=5e-2, rtol=5e-3)
727727

728-
@parameterized.parameters(jnp.int4, jnp.uint4)
729-
def test_subbyte_load(self, dtype):
730-
if not jtu.test_device_matches(["gpu"]):
731-
self.skipTest("`[u]int4` loads only supported on GPU.")
732-
733-
x = jnp.arange(-128, 128, dtype=jnp.int8)
734-
735-
@functools.partial(self.pallas_call, out_shape=x)
736-
def copy_kernel(x_ref, o_ref):
737-
o_ref[()] = x_ref[()].astype(jnp.int8)
738-
739-
expected = x.astype(dtype).astype(jnp.int8)
740-
self.assertAllClose(copy_kernel(x.astype(dtype)), expected)
741-
742728

743729
class PallasCallInterpretTest(PallasCallTest):
744730
INTERPRET = True

0 commit comments

Comments
 (0)