Skip to content

Conversation

suexu1025
Copy link
Collaborator

@suexu1025 suexu1025 commented Sep 19, 2025

Description

land sharding strategy for moe layer.

  • Add --fsdp_shard_on_exp=true to enable shard fsdp on num_expert dim
  • dsv3 step time decrease from
  • before (fsdp on embed): [47s]
  • after (fsdp on num_expert)
    [43s]
  • no change for mixtral 8x7b model.

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
FIXES: #123456

Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.

Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

@suexu1025 suexu1025 changed the title Better Sharding for dsv3 moe layer Better sharding for dsv3 moe layer Sep 19, 2025
Copy link
Collaborator

@gobbleturk gobbleturk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@richjames0 for thoughts =D

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Qinwen! Could you also attach test results in the description?

@suexu1025
Copy link
Collaborator Author

Thanks Qinwen! Could you also attach test results in the description?
pls see the test results here
https://screenshot.googleplex.com/5xAhsHPUdHFS6zc

@RissyRan
Copy link
Collaborator

Thanks Qinwen! Could you also attach test results in the description?
pls see the test results here
https://screenshot.googleplex.com/5xAhsHPUdHFS6zc

Thanks! I meant to say the original links, so that we could directly to see profiles or anything. One example here, usually put original links for reviewers.

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just nits

Copy link
Collaborator

@richjames0 richjames0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

update

refactor

update

Add optional config

Explicitly shard input tensors across mesh devices

Run on 0.7.2 candidate image

Fix typo in image tag

Revert to use latest tag

update test for new jax version

Remove sharding rules for q_lora and kv_lora from base.yml

update with configs

clean up

update
if raw_keys["fsdp_shard_on_exp"] and raw_keys["num_experts"] % raw_keys["ici_fsdp_parallelism"]!=0:
raise ValueError("fsdp_shard_on_exp requires num_experts is divisiable by ici_fsdp_parallelism.")
if raw_keys["fsdp_shard_on_exp"] and (using_tensor_parallelism(raw_keys) or useing_expert_parallelism(raw_keys)):
raise ValueError("fsdp_shard_on_exp requires ici_expert_parallelism = 1 and ici_tensor_parallelism/ici_tensor_transpose_parallelism = 1.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: probably just say fsdp_shard_on_exp does not support EP and TP shardings?

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @suexu1025

axis=-1,
kernel_init=self.kernel_init,
kernel_axes=("embed", "q_lora"),
kernel_axes=("embed", "q_lora_up_proj"),
Copy link
Collaborator

@RissyRan RissyRan Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is helpful for all MLA models or only DS V3?

cc @richjames0 another case.

@copybara-service copybara-service bot merged commit c84c424 into main Sep 24, 2025
30 checks passed
@copybara-service copybara-service bot deleted the qinwen/update_sharding_moe branch September 24, 2025 17:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants