Skip to content

Commit 5375748

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Improve checks on the WG strided transfers
Our previous code didn't actually verify that the vector loads/stores are safe, so you could pass in a non-contiguous reference with a weird shape, but the right number of elements and get nonsensical results. The current check is a bit too conservative, but it's better to lean this way. This also factors the address calculation into a common `transfer_strided` class method, following the same pattern we use for tiled layouts. PiperOrigin-RevId: 813722484
1 parent 47d933c commit 5375748

File tree

9 files changed

+272
-36
lines changed

9 files changed

+272
-36
lines changed

docs/jax.experimental.pallas.mosaic_gpu.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,15 @@ Blackwell-specific functions
8686
try_cluster_cancel
8787
query_cluster_cancel
8888

89+
Multimem operations
90+
-------------------
91+
92+
.. autosummary::
93+
:toctree: _autosummary
94+
95+
multimem_store
96+
multimem_load_reduce
97+
8998
Aliases
9099
-------
91100

jax/_src/pallas/mosaic_gpu/primitives.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3517,3 +3517,88 @@ def _multimem_store_lowering_rule(
35173517
if ctx.module_ctx.auto_barriers:
35183518
mgpu.warpgroup_barrier() # Make sure the writes have completed.
35193519
return ()
3520+
3521+
3522+
multimem_load_reduce_p = jax_core.Primitive("multimem_load_reduce")
3523+
3524+
@multimem_load_reduce_p.def_effectful_abstract_eval
3525+
def _multimem_load_reduce_abstract_eval(ref, *avals_flat, tree, collective_axes, reduction_op):
3526+
del collective_axes, reduction_op
3527+
_check_ref(ref, "ref", gpu_core.GMEM)
3528+
shape, dtype = ref.shape, ref.dtype
3529+
if tree is not None:
3530+
transforms = jax.tree.unflatten(tree, avals_flat)
3531+
for t in transforms:
3532+
shape = t.transform_shape(shape)
3533+
dtype = t.transform_dtype(dtype)
3534+
return jax_core.ShapedArray(shape, dtype), {pallas_core.comms_effect}
3535+
3536+
@lowering.register_lowering_rule(multimem_load_reduce_p, mgpu.LoweringSemantics.Lane)
3537+
def _multimem_load_reduce_lowering_rule(
3538+
ctx: lowering.LoweringRuleContext, ref, *transforms_leaves, tree, collective_axes, reduction_op,
3539+
):
3540+
if (mesh_info := ctx.module_ctx.mesh_info) is None:
3541+
raise ValueError(
3542+
"JAX device mesh is required by multimem_load_reduce, but not defined."
3543+
)
3544+
if set(collective_axes) != set(mesh_info.axis_names):
3545+
raise NotImplementedError(
3546+
"Only collective_axes that include all JAX device mesh"
3547+
f" ({mesh_info.axis_names}) axes are supported, but got"
3548+
f" {collective_axes}"
3549+
)
3550+
if ctx.out_layout_hint is None:
3551+
raise RuntimeError(
3552+
"Failed to infer the output layout of multimem_load_reduce. Please apply"
3553+
" plgpu.layout_cast to its output right after its creation."
3554+
)
3555+
dtype = ctx.avals_out[0].dtype
3556+
transforms = tree.unflatten(transforms_leaves)
3557+
ref, transforms = lowering._handle_transforms(ctx, ref, transforms, allow_peer_refs=False)
3558+
if transforms:
3559+
raise NotImplementedError(
3560+
f"Unhandled transforms for multimem_load_reduce: {transforms}"
3561+
)
3562+
multi_ref = ctx.launch_ctx.to_remote_multicast(ref)
3563+
is_signed = mgpu_utils.is_signed(dtype)
3564+
arr = mgpu.FragmentedArray.load_reduce_untiled(
3565+
multi_ref,
3566+
layout=ctx.out_layout_hint,
3567+
is_signed=is_signed,
3568+
reduction=reduction_op,
3569+
)
3570+
return arr
3571+
3572+
def multimem_load_reduce(
3573+
ref: _Ref,
3574+
*,
3575+
collective_axes: Hashable | tuple[Hashable, ...],
3576+
reduction_op: mgpu.MultimemReductionOp,
3577+
) -> jax.Array:
3578+
"""Loads from a GMEM reference on all devices present in collective_axes and reduces the loaded values.
3579+
3580+
The supported dtypes are: ``jnp.float32``, ``jnp.float16``, ``jnp.bfloat16``,
3581+
``jnp.float8_e5m2``, ``jnp.float8_e4m3fn``, ``jnp.int32`` and ``jnp.int64``.
3582+
3583+
8-bit floating point dtypes are only supported on Blackwell GPUs.
3584+
3585+
Args:
3586+
ref: The GMEM reference to load from.
3587+
collective_axes: The JAX mesh axes indicating the devices to load from.
3588+
reduction_op: The reduction operation to perform on the loaded values. The
3589+
allowed values are add (all dtypes), min, max (all dtypes but f32), as
3590+
well as and, or and xor (integer types only).
3591+
"""
3592+
ref, ref_transforms = state_primitives.get_ref_and_transforms(
3593+
ref, None, "multimem_load_reduce"
3594+
)
3595+
flat_ref_transforms, ref_transforms_treedef = tree_util.tree_flatten(
3596+
ref_transforms
3597+
)
3598+
return multimem_load_reduce_p.bind(
3599+
ref,
3600+
*flat_ref_transforms,
3601+
tree=ref_transforms_treedef,
3602+
collective_axes=collective_axes,
3603+
reduction_op=reduction_op,
3604+
)

jax/experimental/mosaic/gpu/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
Partition1D as Partition1D,
8787
SemaphoreRef as SemaphoreRef,
8888
ThreadSubset as ThreadSubset,
89+
MultimemReductionOp as MultimemReductionOp,
8990
bitwidth as bitwidth,
9091
bytewidth as bytewidth,
9192
c as c,

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 55 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -966,16 +966,11 @@ def load_strided(
966966
)
967967
else:
968968
layout = WGStridedFragLayout(shape=shape, vec_size=vec_size)
969+
registers = np.empty(layout.registers_shape(shape), dtype=object)
969970
vec_ty = ir.VectorType.get((layout.vec_size,), ref_ty.element_type)
970-
try:
971-
# Flattening the reference potentially produces simpler PTX but
972-
# if the ref is not already 1D and has strided dimensions
973-
# flattening won't work.
974-
ref_ = mgpu.memref_fold(ref, 0, len(ref_ty.shape))
975-
vecs = [vector.load(vec_ty, ref_, [vec_idx]) for vec_idx in layout.linear_thread_idxs()]
976-
except NotImplementedError:
977-
vecs = [vector.load(vec_ty, ref, vec_idx) for vec_idx in layout.thread_idxs(shape)]
978-
return cls(_registers=np.array(vecs), _layout=layout, _is_signed=is_signed)
971+
for _get, update, ref, idx in cls.transfer_strided(ref, layout.vec_size):
972+
update(registers, vector.load(vec_ty, ref, idx))
973+
return cls(_registers=registers, _layout=layout, _is_signed=is_signed)
979974

980975
@classmethod
981976
def splat(
@@ -2579,8 +2574,10 @@ def store_untiled(
25792574
if isinstance(ref, utils.MultimemRef):
25802575
raise NotImplementedError("Strided layout does not support multimem")
25812576
if swizzle != 16:
2582-
raise NotImplementedError
2583-
self._store_untiled_wg_strided(ref)
2577+
raise ValueError("Only TiledLayouts support swizzling")
2578+
assert isinstance(self.layout, WGStridedFragLayout)
2579+
for get, _update, ref, idx in self.transfer_strided(ref, self.layout.vec_size):
2580+
vector.store(get(self.registers), ref, idx)
25842581
case TiledLayout():
25852582
ref_shape = ir.MemRefType(ref.type).shape
25862583
ref = utils.memref_reshape(ref, (*(1 for _ in ref_shape), *ref_shape))
@@ -2621,8 +2618,8 @@ def load_untiled(
26212618
is_signed: bool | None = None,
26222619
optimized: bool = True,
26232620
) -> FragmentedArray:
2624-
ref_shape = ir.MemRefType(ref.type).shape
2625-
ref = utils.memref_reshape(ref, (*(1 for _ in ref_shape), *ref_shape))
2621+
ref_ty = ir.MemRefType(ref.type)
2622+
ref = utils.memref_reshape(ref, (*(1 for _ in ref_ty.shape), *ref_ty.shape))
26262623
return cls.load_tiled(
26272624
ref, swizzle=swizzle, is_signed=is_signed, layout=layout, optimized=optimized
26282625
)
@@ -2653,27 +2650,6 @@ def _store_untiled_splat(self, ref: ir.Value):
26532650
)
26542651
fa.store_untiled(ref)
26552652

2656-
def _store_untiled_wg_strided(self, ref: ir.Value):
2657-
assert isinstance(self.layout, WGStridedFragLayout)
2658-
ref_ty = ir.MemRefType(ref.type)
2659-
idxs: Iterable[Sequence[ir.Value]]
2660-
try:
2661-
# Flattening the reference potentially produces simpler PTX but
2662-
# if the ref is not already 1D and has strided dimensions
2663-
# flattening won't work. We use a different variable for ref in
2664-
# case `NotImplementedError` is thrown by
2665-
# .linear_thread_idxs().
2666-
ref_ = mgpu.memref_fold(ref, 0, len(ref_ty.shape))
2667-
idxs = ((i,) for i in self.layout.linear_thread_idxs())
2668-
except NotImplementedError:
2669-
ref_ = ref
2670-
idxs = self.layout.thread_idxs(self.shape)
2671-
ref_shape = tuple(ref_ty.shape)
2672-
if ref_shape != self.shape:
2673-
raise ValueError((ref_shape, self.shape))
2674-
for idx, reg in zip(idxs, self.registers.flat):
2675-
vector.store(reg, ref_, idx)
2676-
26772653
def store_tiled(self, ref: ir.Value | utils.MultimemRef, swizzle: int | None, optimized: bool = True):
26782654
if not isinstance(self.layout, TiledLayout):
26792655
raise NotImplementedError(self.layout)
@@ -2731,6 +2707,51 @@ def load_tiled(
27312707
update(registers, loaded_reg)
27322708
return cls(_registers=registers, _layout=layout, _is_signed=is_signed)
27332709

2710+
@classmethod
2711+
def transfer_strided(self, ref: ir.Value, vec_size: int):
2712+
ref_ty = ir.MemRefType(ref.type)
2713+
layout = WGStridedFragLayout(shape=tuple(ref_ty.shape), vec_size=vec_size)
2714+
try:
2715+
# Flattening the reference potentially produces simpler PTX but
2716+
# if the ref is not already 1D and has strided dimensions
2717+
# flattening won't work.
2718+
ref = mgpu.memref_fold(ref, 0, len(ref_ty.shape))
2719+
except ValueError:
2720+
strides, _ = ref_ty.get_strides_and_offset()
2721+
if vec_size > 1:
2722+
# TODO(apaszke): We could fold all the pairs of dims that are contiguous
2723+
# This check is a too strict if we don't do that.
2724+
has_contiguous_dim = False
2725+
for size, stride in zip(ref_ty.shape, strides):
2726+
if stride == 1:
2727+
has_contiguous_dim = True
2728+
if size % vec_size != 0:
2729+
raise ValueError(
2730+
"The contiguous dimension of the reference must be a"
2731+
f" multiple of the layout's vector size (got {size} and"
2732+
f" vector size {vec_size})"
2733+
) from None
2734+
elif size > 1:
2735+
if stride % vec_size != 0:
2736+
raise ValueError(
2737+
"Non-contiguous dimension of the reference must have strides"
2738+
" that are multiples of the layout's vector size (got"
2739+
f" {stride} and vector size {vec_size})"
2740+
) from None
2741+
if not has_contiguous_dim:
2742+
raise ValueError(
2743+
"The reference must have a contiguous dimension when vec_size > 1"
2744+
)
2745+
idx_gen = layout.thread_idxs(tuple(ref_ty.shape))
2746+
else:
2747+
idx_gen = map(lambda x: [x], layout.linear_thread_idxs())
2748+
for i, vec_idx in enumerate(idx_gen):
2749+
def update(registers, reg, _i=i):
2750+
registers[_i] = reg
2751+
def get(registers, _i=i):
2752+
return registers[_i]
2753+
yield get, update, ref, vec_idx
2754+
27342755
@staticmethod
27352756
def transfer_tiled(
27362757
ref: ir.Value,

jax/experimental/mosaic/gpu/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,7 @@ def memref_fold(ref: ir.Value, dim, fold_rank) -> ir.Value:
694694
new_strides[dim : dim + fold_rank] = [new_strides[dim + fold_rank - 1]]
695695
new_layout = ir.StridedLayoutAttr.get(offset, new_strides)
696696
else:
697-
raise NotImplementedError(
697+
raise ValueError(
698698
f"strides={ref_ty.get_strides_and_offset()[0]}, {ref_ty.shape=},"
699699
f" {dim=}, {fold_rank=}"
700700
)

jax/experimental/pallas/mosaic_gpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
from jax._src.pallas.mosaic_gpu.primitives import inline_mgpu as inline_mgpu
6969
from jax._src.pallas.mosaic_gpu.primitives import load as load
7070
from jax._src.pallas.mosaic_gpu.primitives import multimem_store as multimem_store
71+
from jax._src.pallas.mosaic_gpu.primitives import multimem_load_reduce as multimem_load_reduce
7172
from jax._src.pallas.mosaic_gpu.primitives import print_layout as print_layout
7273
from jax._src.pallas.mosaic_gpu.primitives import query_cluster_cancel as query_cluster_cancel
7374
from jax._src.pallas.mosaic_gpu.primitives import RefType as RefType

tests/mosaic/gpu_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def kernel(ctx, inp, out, _):
447447
np.testing.assert_array_equal(y, out)
448448

449449
if throws_not_impl:
450-
with self.assertRaises(NotImplementedError):
450+
with self.assertRaises(ValueError):
451451
do_test()
452452
else:
453453
do_test()

tests/pallas/gpu_pallas_distributed_test.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import jax.experimental.mosaic.gpu as mgpu
3030
import jax.numpy as jnp
3131
import numpy as np
32+
import jax.experimental.mosaic.gpu.fragmented_array as fa
3233

3334

3435
P = jax.sharding.PartitionSpec
@@ -259,6 +260,123 @@ def _store():
259260
ref = lax.broadcasted_iota(jnp.int32, (128, 128), 1)
260261
np.testing.assert_array_equal(y, np.concat([ref, ref], axis=0))
261262

263+
@parameterized.parameters(
264+
(jnp.int32, 1, "add"),
265+
(jnp.int32, 1, "min"),
266+
(jnp.int32, 1, "max"),
267+
(jnp.int32, 1, "and"),
268+
(jnp.int32, 1, "or"),
269+
(jnp.int32, 1, "xor"),
270+
(jnp.float32, 1, "add"),
271+
(jnp.float32, 2, "add"),
272+
(jnp.float32, 4, "add"),
273+
(jnp.float16, 2, "add"),
274+
(jnp.float16, 2, "min"),
275+
(jnp.float16, 4, "max"),
276+
(jnp.float16, 8, "add"),
277+
(jnp.bfloat16, 2, "max"),
278+
(jnp.bfloat16, 8, "add"),
279+
(jnp.float8_e5m2, 4, "add"),
280+
(jnp.float8_e5m2, 8, "min"),
281+
(jnp.float8_e5m2, 16, "max"),
282+
(jnp.float8_e4m3fn, 4, "min"),
283+
(jnp.float8_e4m3fn, 8, "max"),
284+
(jnp.float8_e4m3fn, 16, "add"),
285+
)
286+
def test_multimem_load_reduce(self, dtype, vector_length, reduction):
287+
if dtype in (
288+
jnp.float8_e5m2,
289+
jnp.float8_e4m3fn,
290+
) and not jtu.is_cuda_compute_capability_at_least("10.0"):
291+
self.skipTest("Only works on GPU with capability >= sm100")
292+
if jax.process_index() > 2:
293+
return # Only 2 processes needed.
294+
devices = jax.devices()[:2]
295+
296+
def kernel(x_ref, y_ref, _, sem_ref):
297+
layout = plgpu.Layout.TILED(
298+
fa.Tiling(
299+
(
300+
(64, 2 * vector_length),
301+
(16, 2 * vector_length),
302+
(vector_length,),
303+
)
304+
),
305+
warp_dims=(-5,),
306+
lane_dims=(-3, -2),
307+
vector_dim=-1,
308+
)
309+
y_ref[...] = plgpu.layout_cast(
310+
plgpu.multimem_load_reduce(
311+
x_ref.at[16:-16], collective_axes="x", reduction_op=reduction,
312+
),
313+
layout
314+
)
315+
my_device = lax.axis_index("x")
316+
other_device = 1 - my_device
317+
pl.semaphore_signal(sem_ref, 1, device_id=other_device)
318+
pl.semaphore_wait(sem_ref)
319+
320+
# The rounding we see in low precision types seems to be different from
321+
# what JAX/XLA use.
322+
match jnp.dtype(dtype).itemsize:
323+
case 4:
324+
bound = 800000
325+
case 2:
326+
bound = 128
327+
case 1:
328+
bound = 4
329+
case _:
330+
raise ValueError(f"Unsupported dtype: {dtype}")
331+
x_local = jax.random.randint(
332+
jax.random.key(1234), (128 + 64, 32), dtype=jnp.int32, minval=-bound, maxval=bound,
333+
).astype(dtype)
334+
mesh = jax.sharding.Mesh(devices, ("x",))
335+
x_shard = jax.ShapeDtypeStruct((64 + 32, 32), dtype)
336+
y_shape = jax.ShapeDtypeStruct((64, 32), dtype)
337+
y, _ = jax.jit(
338+
shard_map.shard_map(
339+
pl.pallas_call(
340+
kernel,
341+
in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
342+
out_specs=[
343+
pl.BlockSpec(memory_space=plgpu.SMEM),
344+
pl.BlockSpec(memory_space=plgpu.GMEM),
345+
],
346+
out_shape=(y_shape, x_shard),
347+
scratch_shapes=[plgpu.SemaphoreType.REGULAR],
348+
# TODO(b/448323639): Without aliasing XLA doesn't actually
349+
# insert the copy that puts the operand in symmetric memory,
350+
# which causes the kernel to crash.
351+
input_output_aliases={0: 1},
352+
),
353+
mesh=mesh,
354+
in_specs=P("x"),
355+
out_specs=P("x"), # Not really, but lets us test.
356+
check_rep=False,
357+
)
358+
)(x_local)
359+
y = multihost_utils.process_allgather(y, tiled=True)
360+
match reduction:
361+
case "add":
362+
np_reduction = jnp.add
363+
case "min":
364+
np_reduction = jnp.minimum
365+
case "max":
366+
np_reduction = jnp.maximum
367+
case "and":
368+
np_reduction = jnp.bitwise_and
369+
case "or":
370+
np_reduction = jnp.bitwise_or
371+
case "xor":
372+
np_reduction = jnp.bitwise_xor
373+
case _:
374+
raise ValueError(reduction)
375+
np.testing.assert_array_equal(
376+
y.astype(jnp.float32),
377+
np.tile(np_reduction(x_local[16:64+16], x_local[64+48:128+48]), (2, 1)),
378+
)
379+
262380

263381
if __name__ == '__main__':
264382
# This test doesn't work with the platform allocator, so we override it

0 commit comments

Comments
 (0)