-
Notifications
You must be signed in to change notification settings - Fork 417
Better sharding for dsv3 moe layer #2373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@richjames0 for thoughts =D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Qinwen! Could you also attach test results in the description?
|
Thanks! I meant to say the original links, so that we could directly to see profiles or anything. One example here, usually put original links for reviewers. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just nits
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
update refactor update Add optional config Explicitly shard input tensors across mesh devices Run on 0.7.2 candidate image Fix typo in image tag Revert to use latest tag update test for new jax version Remove sharding rules for q_lora and kv_lora from base.yml update with configs clean up update
b0cbe7f
to
859abdf
Compare
if raw_keys["fsdp_shard_on_exp"] and raw_keys["num_experts"] % raw_keys["ici_fsdp_parallelism"]!=0: | ||
raise ValueError("fsdp_shard_on_exp requires num_experts is divisiable by ici_fsdp_parallelism.") | ||
if raw_keys["fsdp_shard_on_exp"] and (using_tensor_parallelism(raw_keys) or useing_expert_parallelism(raw_keys)): | ||
raise ValueError("fsdp_shard_on_exp requires ici_expert_parallelism = 1 and ici_tensor_parallelism/ici_tensor_transpose_parallelism = 1.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: probably just say fsdp_shard_on_exp does not support EP and TP shardings
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @suexu1025
axis=-1, | ||
kernel_init=self.kernel_init, | ||
kernel_axes=("embed", "q_lora"), | ||
kernel_axes=("embed", "q_lora_up_proj"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part is helpful for all MLA models or only DS V3?
cc @richjames0 another case.
Description
land sharding strategy for moe layer.
--fsdp_shard_on_exp=true
to enable shard fsdp on num_expert dim[43s]
If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
FIXES: #123456
Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.
Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.
Tests
Please describe how you tested this change, and include any instructions and/or
commands to reproduce.
Checklist
Before submitting this PR, please make sure (put X in square brackets):