@@ -3434,3 +3434,86 @@ def query_cluster_cancel(
3434
3434
grid_names = grid_names ,
3435
3435
transforms_tree = result_transforms_tree )
3436
3436
return tuple (result [:- 1 ]), result [- 1 ]
3437
+
3438
+
3439
+ multimem_store_p = jax_core .Primitive ("multimem_store" )
3440
+ multimem_store_p .multiple_results = True
3441
+
3442
+
3443
+ def multimem_store (source : jax .Array , ref : _Ref , collective_axes : Hashable | tuple [Hashable , ...]):
3444
+ """Stores the value to ref on all devices present in collective_axes.
3445
+
3446
+ The stores is done using the multimem instructions, meaning that the data is
3447
+ only transferred to the switch once, and broadcasted to all other devices
3448
+ there.
3449
+
3450
+ Args:
3451
+ source: The value to store.
3452
+ ref: The GMEM reference to store the value to.
3453
+ collective_axes: The JAX mesh axes indicating the devices to store to.
3454
+ """
3455
+ if isinstance (ref , pallas_core .TransformedRef ):
3456
+ transforms_leaves , transforms_tree = jax .tree .flatten (
3457
+ ref .transforms
3458
+ )
3459
+ ref = ref .ref
3460
+ else :
3461
+ transforms_leaves , transforms_tree = [], None
3462
+ multimem_store_p .bind (
3463
+ source ,
3464
+ ref ,
3465
+ * transforms_leaves ,
3466
+ collective_axes = collective_axes ,
3467
+ transforms_tree = transforms_tree ,
3468
+ )
3469
+
3470
+
3471
+ @multimem_store_p .def_effectful_abstract_eval
3472
+ def _multimem_store_abstract_eval (source , ref , * transforms_leaves , transforms_tree , ** _ ):
3473
+ _check_ref (ref , "ref" , gpu_core .GMEM )
3474
+ shape , dtype = ref .shape , ref .dtype
3475
+ if transforms_tree is not None :
3476
+ transforms = jax .tree .unflatten (transforms_tree , transforms_leaves )
3477
+ for t in transforms :
3478
+ shape = t .transform_shape (shape )
3479
+ dtype = t .transform_dtype (dtype )
3480
+ if source .dtype != dtype :
3481
+ raise ValueError (f"Value dtype { source .dtype } does not match ref dtype { dtype } " )
3482
+ if source .shape != shape :
3483
+ raise ValueError (f"Value shape { source .shape } does not match ref shape { shape } " )
3484
+ return [], {pallas_core .comms_effect }
3485
+
3486
+
3487
+ @lowering .register_lowering_rule (multimem_store_p , mgpu .LoweringSemantics .Lane )
3488
+ def _multimem_store_lowering_rule (
3489
+ ctx : lowering .LoweringRuleContext , value , local_ref , * transforms_leaves , transforms_tree , collective_axes ,
3490
+ ):
3491
+ if (mesh_info := ctx .module_ctx .mesh_info ) is None :
3492
+ raise ValueError (
3493
+ "JAX device mesh is required by multimem_store, but not defined."
3494
+ )
3495
+ if set (collective_axes ) != set (mesh_info .axis_names ):
3496
+ raise NotImplementedError (
3497
+ "Only collective_axes that include all JAX device mesh"
3498
+ f" ({ mesh_info .axis_names } ) axes are supported, but got"
3499
+ f" { collective_axes } "
3500
+ )
3501
+ if not isinstance (value , mgpu .FragmentedArray ):
3502
+ raise TypeError (f"Can only store arrays (got { value } )." )
3503
+ if transforms_tree is not None :
3504
+ transforms = tree_util .tree_unflatten (transforms_tree , transforms_leaves )
3505
+ local_ref , transforms = lowering ._handle_transforms (
3506
+ ctx , local_ref , transforms , allow_peer_refs = False
3507
+ )
3508
+ if transforms :
3509
+ raise NotImplementedError (
3510
+ f"Unhandled transforms for multimem_store: { transforms } "
3511
+ )
3512
+ multi_ref = ctx .launch_ctx .to_remote_multicast (local_ref )
3513
+ if not ctx .avals_in [0 ].shape :
3514
+ multi_ref .store (lowering ._ensure_ir_value (value , ctx .avals_out [0 ].dtype ), [])
3515
+ else :
3516
+ value .store_untiled (multi_ref , optimized = False )
3517
+ if ctx .module_ctx .auto_barriers :
3518
+ mgpu .warpgroup_barrier () # Make sure the writes have completed.
3519
+ return ()
0 commit comments