Skip to content

Commit 974323d

Browse files
Merge pull request #2367 from AI-Hypercomputer:rbierneni-fix-0.7.2
PiperOrigin-RevId: 809235377
2 parents 2b25adf + 9733a99 commit 974323d

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

tests/attention_test.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
import numpy as np
2828

29-
from jax.sharding import Mesh
29+
from jax.sharding import Mesh, NamedSharding, PartitionSpec
3030
import jax
3131
import jax.numpy as jnp
3232

@@ -1383,6 +1383,21 @@ def _forward_with_context_expert_parallelism(cfg_cp, mesh_cp, attention_cp, lnx,
13831383
decoder_positions = reordered_batch["inputs_position"]
13841384
# apply attention with sharding
13851385
with mesh_cp, nn_partitioning.axis_rules(cfg_cp.logical_axis_rules):
1386+
lnx_spec = nn_partitioning.logical_to_mesh_axes(
1387+
('activation_batch_no_exp', 'activation_length_no_exp', 'activation_embed'),
1388+
nn_partitioning.get_axis_rules()
1389+
)
1390+
pos_spec = nn_partitioning.logical_to_mesh_axes(
1391+
('activation_batch_no_exp', 'activation_length_no_exp'),
1392+
nn_partitioning.get_axis_rules()
1393+
)
1394+
lnx_sharding = NamedSharding(mesh_cp, lnx_spec)
1395+
pos_sharding = NamedSharding(mesh_cp, pos_spec)
1396+
1397+
lnx = jax.device_put(lnx, lnx_sharding)
1398+
decoder_segment_ids = jax.device_put(decoder_segment_ids, pos_sharding)
1399+
decoder_positions = jax.device_put(decoder_positions, pos_sharding)
1400+
13861401
attention_cp_output = attention_cp(
13871402
lnx,
13881403
lnx,

0 commit comments

Comments
 (0)