Skip to content

Commit 47d933c

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Pallas:MGPU] Add support for multimem stores in Pallas
PiperOrigin-RevId: 813711270
1 parent 036d7a1 commit 47d933c

File tree

4 files changed

+115
-0
lines changed

4 files changed

+115
-0
lines changed

jax/_src/pallas/mosaic_gpu/primitives.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3434,3 +3434,86 @@ def query_cluster_cancel(
34343434
grid_names=grid_names,
34353435
transforms_tree=result_transforms_tree)
34363436
return tuple(result[:-1]), result[-1]
3437+
3438+
3439+
multimem_store_p = jax_core.Primitive("multimem_store")
3440+
multimem_store_p.multiple_results = True
3441+
3442+
3443+
def multimem_store(source: jax.Array, ref: _Ref, collective_axes: Hashable | tuple[Hashable, ...]):
3444+
"""Stores the value to ref on all devices present in collective_axes.
3445+
3446+
The stores is done using the multimem instructions, meaning that the data is
3447+
only transferred to the switch once, and broadcasted to all other devices
3448+
there.
3449+
3450+
Args:
3451+
source: The value to store.
3452+
ref: The GMEM reference to store the value to.
3453+
collective_axes: The JAX mesh axes indicating the devices to store to.
3454+
"""
3455+
if isinstance(ref, pallas_core.TransformedRef):
3456+
transforms_leaves, transforms_tree = jax.tree.flatten(
3457+
ref.transforms
3458+
)
3459+
ref = ref.ref
3460+
else:
3461+
transforms_leaves, transforms_tree = [], None
3462+
multimem_store_p.bind(
3463+
source,
3464+
ref,
3465+
*transforms_leaves,
3466+
collective_axes=collective_axes,
3467+
transforms_tree=transforms_tree,
3468+
)
3469+
3470+
3471+
@multimem_store_p.def_effectful_abstract_eval
3472+
def _multimem_store_abstract_eval(source, ref, *transforms_leaves, transforms_tree, **_):
3473+
_check_ref(ref, "ref", gpu_core.GMEM)
3474+
shape, dtype = ref.shape, ref.dtype
3475+
if transforms_tree is not None:
3476+
transforms = jax.tree.unflatten(transforms_tree, transforms_leaves)
3477+
for t in transforms:
3478+
shape = t.transform_shape(shape)
3479+
dtype = t.transform_dtype(dtype)
3480+
if source.dtype != dtype:
3481+
raise ValueError(f"Value dtype {source.dtype} does not match ref dtype {dtype}")
3482+
if source.shape != shape:
3483+
raise ValueError(f"Value shape {source.shape} does not match ref shape {shape}")
3484+
return [], {pallas_core.comms_effect}
3485+
3486+
3487+
@lowering.register_lowering_rule(multimem_store_p, mgpu.LoweringSemantics.Lane)
3488+
def _multimem_store_lowering_rule(
3489+
ctx: lowering.LoweringRuleContext, value, local_ref, *transforms_leaves, transforms_tree, collective_axes,
3490+
):
3491+
if (mesh_info := ctx.module_ctx.mesh_info) is None:
3492+
raise ValueError(
3493+
"JAX device mesh is required by multimem_store, but not defined."
3494+
)
3495+
if set(collective_axes) != set(mesh_info.axis_names):
3496+
raise NotImplementedError(
3497+
"Only collective_axes that include all JAX device mesh"
3498+
f" ({mesh_info.axis_names}) axes are supported, but got"
3499+
f" {collective_axes}"
3500+
)
3501+
if not isinstance(value, mgpu.FragmentedArray):
3502+
raise TypeError(f"Can only store arrays (got {value}).")
3503+
if transforms_tree is not None:
3504+
transforms = tree_util.tree_unflatten(transforms_tree, transforms_leaves)
3505+
local_ref, transforms = lowering._handle_transforms(
3506+
ctx, local_ref, transforms, allow_peer_refs=False
3507+
)
3508+
if transforms:
3509+
raise NotImplementedError(
3510+
f"Unhandled transforms for multimem_store: {transforms}"
3511+
)
3512+
multi_ref = ctx.launch_ctx.to_remote_multicast(local_ref)
3513+
if not ctx.avals_in[0].shape:
3514+
multi_ref.store(lowering._ensure_ir_value(value, ctx.avals_out[0].dtype), [])
3515+
else:
3516+
value.store_untiled(multi_ref, optimized=False)
3517+
if ctx.module_ctx.auto_barriers:
3518+
mgpu.warpgroup_barrier() # Make sure the writes have completed.
3519+
return ()

jax/experimental/pallas/mosaic_gpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem as copy_smem_to_gmem
6868
from jax._src.pallas.mosaic_gpu.primitives import inline_mgpu as inline_mgpu
6969
from jax._src.pallas.mosaic_gpu.primitives import load as load
70+
from jax._src.pallas.mosaic_gpu.primitives import multimem_store as multimem_store
7071
from jax._src.pallas.mosaic_gpu.primitives import print_layout as print_layout
7172
from jax._src.pallas.mosaic_gpu.primitives import query_cluster_cancel as query_cluster_cancel
7273
from jax._src.pallas.mosaic_gpu.primitives import RefType as RefType

tests/pallas/gpu_pallas_distributed_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from jax import lax
2323
from jax._src import test_util as jtu
2424
from jax._src import test_multiprocess as jt_multiprocess
25+
from jax.experimental import multihost_utils
2526
from jax.experimental import pallas as pl
2627
from jax.experimental import shard_map
2728
from jax.experimental.pallas import mosaic_gpu as plgpu
@@ -229,6 +230,35 @@ def kernel(y_ref, sem):
229230
with self.assertRaisesRegex(NotImplementedError, msg):
230231
f()
231232

233+
def test_multimem_store(self):
234+
if jax.process_index() > 2:
235+
return # Only 2 processes needed.
236+
237+
def kernel(y_ref, sem):
238+
@pl.when(lax.axis_index('x') == 0)
239+
def _store():
240+
output = plgpu.layout_cast(lax.broadcasted_iota(jnp.int32, (128, 128), 1), plgpu.Layout.WGMMA)
241+
plgpu.multimem_store(output, y_ref, 'x')
242+
other_dev_id = 1 - lax.axis_index('x')
243+
pl.semaphore_signal(sem, 1, device_id=other_dev_id)
244+
pl.semaphore_wait(sem)
245+
246+
kernel_call = pl.pallas_call(
247+
kernel,
248+
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
249+
out_shape=jax.ShapeDtypeStruct((128, 128), jnp.int32),
250+
scratch_shapes=[plgpu.SemaphoreType.REGULAR],
251+
)
252+
mesh = jax.sharding.Mesh(jax.devices(), ['x'])
253+
y = jax.jit(
254+
shard_map.shard_map(
255+
kernel_call, mesh, in_specs=(), out_specs=P("x"), check_rep=False,
256+
)
257+
)()
258+
y = multihost_utils.process_allgather(y, tiled=True)
259+
ref = lax.broadcasted_iota(jnp.int32, (128, 128), 1)
260+
np.testing.assert_array_equal(y, np.concat([ref, ref], axis=0))
261+
232262

233263
if __name__ == '__main__':
234264
# This test doesn't work with the platform allocator, so we override it

tests/pallas/mosaic_gpu_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2603,6 +2603,7 @@ def test_missing_primitive_lowerings_are_tracked(self):
26032603
mgpu_primitives.semaphore_signal_parallel_p,
26042604
mgpu_primitives.try_cluster_cancel_p,
26052605
mgpu_primitives.query_cluster_cancel_p,
2606+
mgpu_primitives.multimem_store_p,
26062607
lax.slice_p,
26072608
lax.iota_p,
26082609
pallas_core.core_map_p,

0 commit comments

Comments
 (0)