Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions jax/_src/pallas/mosaic_gpu/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -3552,10 +3552,10 @@ def _multimem_load_reduce_lowering_rule(
"Failed to infer the output layout of multimem_load_reduce. Please apply"
" plgpu.layout_cast to its output right after its creation."
)
if not isinstance(layout, mgpu.TiledLayout):
if not isinstance(layout, (mgpu.TiledLayout, mgpu.WGStridedFragLayout)):
raise ValueError(
"Only tiled layouts are supported by multimem_load_reduce, but got"
f" {layout}"
"Only tiled and WG strided layouts are supported by"
f" multimem_load_reduce, but got {layout}"
)
dtype = ctx.avals_out[0].dtype
transforms = tree.unflatten(transforms_leaves)
Expand Down
21 changes: 16 additions & 5 deletions jax/experimental/mosaic/gpu/fragmented_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2571,13 +2571,14 @@ def store_untiled(
# All values are the same so swizzle does not affect anything here.
self._store_untiled_splat(ref)
case WGStridedFragLayout():
if isinstance(ref, utils.MultimemRef):
raise NotImplementedError("Strided layout does not support multimem")
if swizzle != 16:
raise ValueError("Only TiledLayouts support swizzling")
assert isinstance(self.layout, WGStridedFragLayout)
for get, _update, ref, idx in self.transfer_strided(ref, self.layout.vec_size):
vector.store(get(self.registers), ref, idx)
if isinstance(ref, utils.MultimemRef):
ref.store(get(self.registers), idx)
else:
vector.store(get(self.registers), ref, idx)
case TiledLayout():
ref_shape = ir.MemRefType(ref.type).shape
ref = utils.memref_reshape(ref, (*(1 for _ in ref_shape), *ref_shape))
Expand All @@ -2589,12 +2590,22 @@ def store_untiled(
def load_reduce_untiled(
cls,
ref: utils.MultimemRef,
layout: TiledLayout,
layout: TiledLayout | WGStridedFragLayout,
reduction: utils.MultimemReductionOp,
swizzle: int = 16,
is_signed: bool | None = None,
):
shape = ir.MemRefType(ref.type).shape
ref_ty = ir.MemRefType(ref.type)
shape = tuple(ref_ty.shape)
if isinstance(layout, WGStridedFragLayout):
if swizzle != 16:
raise ValueError("Only TiledLayouts support swizzling")
registers = np.empty(layout.registers_shape(shape), dtype=object)
vec_ty = ir.VectorType.get((layout.vec_size,), ref_ty.element_type)
for _get, update, ref, idx in cls.transfer_strided(ref, layout.vec_size):
ptr = utils.memref_ptr(utils.memref_slice(ref.ref, tuple(idx)))
update(registers, utils.multimem_load_reduce(vec_ty, ptr, reduction, is_signed))
return cls(_registers=registers, _layout=layout, _is_signed=is_signed)
ref = utils.memref_reshape(ref, (*(1 for _ in shape), *shape))
return cls.load_tiled(
ref.ref,
Expand Down
7 changes: 5 additions & 2 deletions jax/experimental/mosaic/gpu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ def fold_until(shape, off , target) -> tuple[int, int]:
return ref


def memref_reshape(ref: ir.Value | MultimemRef, shape: tuple[int, ...]) -> ir.Value:
def memref_reshape(ref: ir.Value | MultimemRef, shape: tuple[int, ...]) -> ir.Value | MultimemRef:
"""Reshape by means of folding and unfolding.

The use of memref fold/unfold may avoid some possible issues with
Expand Down Expand Up @@ -671,7 +671,10 @@ def memref_reshape(ref: ir.Value | MultimemRef, shape: tuple[int, ...]) -> ir.Va
return _reshape(ref, src_shape, dst_shape)


def memref_fold(ref: ir.Value, dim, fold_rank) -> ir.Value:
def memref_fold(ref: ir.Value | MultimemRef, dim, fold_rank) -> ir.Value | MultimemRef:
if isinstance(ref, MultimemRef):
return MultimemRef(memref_fold(ref.ref, dim, fold_rank))

ref_ty = ir.MemRefType(ref.type)
new_shape = list(ref_ty.shape)
if dim < 0:
Expand Down
13 changes: 2 additions & 11 deletions tests/mosaic/gpu_test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,7 @@ def kernel(ctx, inp, sem, out, _):
other_dst = ctx.to_remote(sem, other_device)
other_sem = mgpu.SemaphoreRef(mgpu.utils.memref_ptr(other_dst))
with mgpu.when(arith.cmpi(arith.CmpIPredicate.eq, my_device, arith.constant(i32, 0))):
arr = mgpu.FragmentedArray.load_untiled(
inp, layout=mgpu.WGMMA_LAYOUT, optimized=False, is_signed=True
)
arr = mgpu.FragmentedArray.load_strided(inp, is_signed=True)
arr.store_untiled(ctx.to_remote_multicast(out), optimized=False)
other_sem.signal(arith.constant(i32, 1))
my_sem.wait(1)
Expand Down Expand Up @@ -269,14 +267,7 @@ def kernel(ctx, inp, sem, out, _):
my_sem = mgpu.SemaphoreRef(mgpu.utils.memref_ptr(sem))
other_dst = ctx.to_remote(sem, other_device)
other_sem = mgpu.SemaphoreRef(mgpu.utils.memref_ptr(other_dst))
layout = fa.TiledLayout(
fa.Tiling((
(64, 2 * vector_length), (16, 2 * vector_length), (vector_length,)
)),
warp_dims=(-5,),
lane_dims=(-3, -2),
vector_dim=-1,
)
layout = fa.WGStridedFragLayout((64, 32), vec_size=vector_length)
arr = mgpu.FragmentedArray.load_reduce_untiled(
ctx.to_remote_multicast(inp),
layout=layout,
Expand Down
38 changes: 20 additions & 18 deletions tests/pallas/gpu_pallas_distributed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import jax.experimental.mosaic.gpu as mgpu
import jax.numpy as jnp
import numpy as np
import jax.experimental.mosaic.gpu.fragmented_array as fa


P = jax.sharding.PartitionSpec
Expand Down Expand Up @@ -268,22 +267,22 @@ def _store():
(jnp.int32, 1, "or"),
(jnp.int32, 1, "xor"),
(jnp.float32, 1, "add"),
(jnp.float32, 2, "add"),
(jnp.float32, 2, "add", True),
(jnp.float32, 4, "add"),
(jnp.float16, 2, "add"),
(jnp.float16, 2, "min"),
(jnp.float16, 4, "max"),
(jnp.float16, 8, "add"),
(jnp.float16, 8, "add", True),
(jnp.bfloat16, 2, "max"),
(jnp.bfloat16, 8, "add"),
(jnp.float8_e5m2, 4, "add"),
(jnp.float8_e5m2, 8, "min"),
(jnp.float8_e5m2, 16, "max"),
(jnp.float8_e4m3fn, 4, "min"),
(jnp.float8_e5m2, 16, "max", True),
(jnp.float8_e4m3fn, 4, "min", True),
(jnp.float8_e4m3fn, 8, "max"),
(jnp.float8_e4m3fn, 16, "add"),
)
def test_multimem_load_reduce(self, dtype, vector_length, reduction):
def test_multimem_load_reduce(self, dtype, vector_length, reduction, tiled_layout=False):
if dtype in (
jnp.float8_e5m2,
jnp.float8_e4m3fn,
Expand All @@ -294,18 +293,21 @@ def test_multimem_load_reduce(self, dtype, vector_length, reduction):
devices = jax.devices()[:2]

def kernel(x_ref, y_ref, _, sem_ref):
layout = plgpu.Layout.TILED(
fa.Tiling(
(
(64, 2 * vector_length),
(16, 2 * vector_length),
(vector_length,),
)
),
warp_dims=(-5,),
lane_dims=(-3, -2),
vector_dim=-1,
)
if tiled_layout:
layout = plgpu.Layout.TILED(
plgpu.Tiling(
(
(64, 2 * vector_length),
(16, 2 * vector_length),
(vector_length,),
)
),
warp_dims=(-5,),
lane_dims=(-3, -2),
vector_dim=-1,
)
else:
layout = plgpu.Layout.WG_STRIDED((64, 32), vec_size=vector_length)
y_ref[...] = plgpu.layout_cast(
plgpu.multimem_load_reduce(
x_ref.at[16:-16], collective_axes="x", reduction_op=reduction,
Expand Down
Loading