Open
Conversation
31717c0 to
e6509b7
Compare
1. All2All when compute embeddings. 2. Precompute multi_head_embedding.
e6509b7 to
c7d57c3
Compare
There was a problem hiding this comment.
Pull request overview
This PR improves DeepSeek Engram training performance by adding embedding-parallel communication options (all-to-all / all-reduce / offload), enabling precomputation/overlap of multi-head embedding work, and wiring new Engram-parallel arguments through initialization and example configs.
Changes:
- Add Engram embedding parallel method/size arguments and pass them into distributed initialization.
- Introduce an embedding-parallel
EngramMemoryimplementation and integrate it into Engram’s multi-head embedding + sharded checkpointing. - Add embedding precompute/caching (and attempt overlap across layers) plus new DeepSeek Engram Hydra configs.
Reviewed changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated 12 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/functional_tests/train/deepseek/gold_values/tp2_pp2_ep2_engram.json | Reformat/update stored gold loss JSON. |
| flagscale/train/megatron/training/initialize.py | Plumb engram_embedding_parallel_size into model-parallel initialization. |
| flagscale/train/megatron/training/arguments_fs.py | Add Engram embedding parallel CLI args + validation/warnings. |
| flagscale/train/megatron/train_engram.py | Switch to parallel_state and add an extra TP token broadcast path. |
| flagscale/models/megatron/engram/multi_head_embedding.py | Add EngramMemory and route MultiHeadEmbedding through it; add sharded state dict. |
| flagscale/models/megatron/engram/engram_transformer_layer.py | Add embedding precompute hook and enable sharded state dict. |
| flagscale/models/megatron/engram/engram_model.py | Implement build_schedule_plan for Engram model. |
| flagscale/models/megatron/engram/engram_config.py | Add Engram embedding parallel config fields. |
| flagscale/models/megatron/engram/engram.py | Add embedding precompute/cache + sharded state dict plumbing; rename embedding member to memory. |
| examples/deepseek_v3/conf/train_engram.yaml | New top-level Hydra entry for Engram training run. |
| examples/deepseek_v3/conf/train/engram.yaml | New DeepSeek Engram training preset including embedding-parallel settings. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+59
to
+65
| if not (parallel_state.get_pipeline_model_parallel_world_size == 1 or parallel_state.is_pipeline_first_stage()): | ||
| if parallel_state.get_tensor_model_parallel_rank() == 0: | ||
| torch.distributed.broadcast(batch["tokens"], src=parallel_state.get_tensor_model_parallel_src_rank(), group=parallel_state.get_tensor_model_parallel_group()) | ||
| else: | ||
| tokens = torch.empty_like(batch["labels"]) | ||
| torch.distributed.broadcast(tokens, src=parallel_state.get_tensor_model_parallel_src_rank(), group=parallel_state.get_tensor_model_parallel_group()) | ||
| batch["tokens"] = tokens |
| from megatron.core.utils import get_attr_wrapped_model, StragglerDetector | ||
| from megatron.core.tokenizers.text.utils.build_tokenizer import build_tokenizer | ||
| from megatron.core import mpu | ||
| from megatron.core import parallel_state |
Comment on lines
+78
to
+81
| (self.vocab_start_index, self.vocab_end_index) = ( | ||
| VocabUtility.vocab_range_from_global_vocab_size( | ||
| self.num_embeddings, get_pg_rank(self.embedding_parallel_group), get_pg_size(self.embedding_parallel_group) | ||
| ) |
Comment on lines
+102
to
+103
| rank=get_pg_rank(self.tp_group), | ||
| world_size=get_pg_size(self.tp_group), |
| input_ids = input_ids.view(-1) | ||
| routing_map = input_ids // self.num_embeddings_per_partition | ||
| # [num_partitions], number of tokens assigned to each partition from the current rank's input. | ||
| num_tokens_per_partition = torch.histc(routing_map, bins=self.embedding_parallel_size, min=0, max=self.embedding_parallel_size) |
Comment on lines
+375
to
+376
| warnings.warn(f"[rank0]: We do not recomend using allreduce for engram embedding, this is deprecated and will be removed in later version.", DeprecationWarning) | ||
| if self.args.engram_embedding_parallel_size is not None: |
| if parallel_state.get_tensor_model_parallel_rank() == 0: | ||
| torch.distributed.broadcast(batch["tokens"], src=parallel_state.get_tensor_model_parallel_src_rank(), group=parallel_state.get_tensor_model_parallel_group()) | ||
| else: | ||
| tokens = torch.empty_like(batch["labels"]) |
| self, | ||
| prefix: str = '', | ||
| sharded_offsets: Tuple[Tuple[int, int, int]] = (), | ||
| metadata: Optional[dict] = None,** kwargs, |
Comment on lines
+390
to
+391
| assert not self.args.use_megatron_fsdp, "Megatron FSDP does not be supported yet, looking forward to later version." | ||
| assert not self.args.init_model_with_meta_device, "Init_model_with_meta_device does not be supported yet, looking forward to later version." |
| backend: torchrun | ||
| nnodes: 3 | ||
| nproc_per_node: 8 | ||
| hostfile: hostfile # Select an available hostfile. Like ip_1 slosts=8\nip_2 slost=8... |
…l and offload to be enabled simultaneously.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
PR Category
[Train]
PR Types
[Improvements]
PR Description