Remove sharding rules for q_lora and kv_lora from base.yml #2369
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
An issue was identified while running MaxText/Tunix SFT with Deepseek-V3 on v6e-256. The issue is that two axes within the MLA layer are being assigned the exact same sharding rule by the
logical_axis_rules
defined here.Conflict 1: The
embed
axis and theq_lora
axis are both using an identical sharding specification:['fsdp', 'sequence', 'context', 'expert'])
.Conflict 2: The
embed
axis and thekv_lora
axis are also using the same sharding specification :['fsdp', 'sequence', 'context', 'expert'])
.During the optimizer sharding in Tunix here,
jax.lax.with_sharding_constraint()
looks at the sharding rules defined for each axis and then uses those rules to determine how to shard the optimizer.The error points to this line in the JAX source:
https://github.com/jax-ml/jax/blob/1a91543e92778bb659939cc3bdc3d4b7978191b6/jax/_src/named_sharding.py#L473
When JAX encounters the same sharding specification being passed for two different axes (embed and q_lora), it sees this as an internal inconsistency, and throws
DuplicateSpecError
.This PR resolves that issue by removing the sharding rules for
q_lora
andkv_lora
.FIXES: b/444495481
Tests
Tested Deepseek-V3 on v6e-256: https://cloudlogging.app.goo.gl/FYJACjyjTdAF3uQE7
Checklist
Before submitting this PR, please make sure (put X in square brackets):