Skip to content

Deepseek Engram optimization.#1147

Open
LiJunscs wants to merge 16 commits intoflagos-ai:mainfrom
LiJunscs:deepseek_related
Open

Deepseek Engram optimization.#1147
LiJunscs wants to merge 16 commits intoflagos-ai:mainfrom
LiJunscs:deepseek_related

Conversation

@LiJunscs
Copy link
Copy Markdown
Collaborator

PR Category

[Train]

PR Types

[Improvements]

PR Description

  1. AlltoAll communication when compute multi_head_embedding.
  2. Precompute multi_head_embedding.
  3. Optional offloading embedding's optimizer states.

@CLAassistant
Copy link
Copy Markdown

CLAassistant commented Mar 12, 2026

CLA assistant check
All committers have signed the CLA.

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR improves DeepSeek Engram training performance by adding embedding-parallel communication options (all-to-all / all-reduce / offload), enabling precomputation/overlap of multi-head embedding work, and wiring new Engram-parallel arguments through initialization and example configs.

Changes:

  • Add Engram embedding parallel method/size arguments and pass them into distributed initialization.
  • Introduce an embedding-parallel EngramMemory implementation and integrate it into Engram’s multi-head embedding + sharded checkpointing.
  • Add embedding precompute/caching (and attempt overlap across layers) plus new DeepSeek Engram Hydra configs.

Reviewed changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated 12 comments.

Show a summary per file
File Description
tests/functional_tests/train/deepseek/gold_values/tp2_pp2_ep2_engram.json Reformat/update stored gold loss JSON.
flagscale/train/megatron/training/initialize.py Plumb engram_embedding_parallel_size into model-parallel initialization.
flagscale/train/megatron/training/arguments_fs.py Add Engram embedding parallel CLI args + validation/warnings.
flagscale/train/megatron/train_engram.py Switch to parallel_state and add an extra TP token broadcast path.
flagscale/models/megatron/engram/multi_head_embedding.py Add EngramMemory and route MultiHeadEmbedding through it; add sharded state dict.
flagscale/models/megatron/engram/engram_transformer_layer.py Add embedding precompute hook and enable sharded state dict.
flagscale/models/megatron/engram/engram_model.py Implement build_schedule_plan for Engram model.
flagscale/models/megatron/engram/engram_config.py Add Engram embedding parallel config fields.
flagscale/models/megatron/engram/engram.py Add embedding precompute/cache + sharded state dict plumbing; rename embedding member to memory.
examples/deepseek_v3/conf/train_engram.yaml New top-level Hydra entry for Engram training run.
examples/deepseek_v3/conf/train/engram.yaml New DeepSeek Engram training preset including embedding-parallel settings.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +59 to +65
if not (parallel_state.get_pipeline_model_parallel_world_size == 1 or parallel_state.is_pipeline_first_stage()):
if parallel_state.get_tensor_model_parallel_rank() == 0:
torch.distributed.broadcast(batch["tokens"], src=parallel_state.get_tensor_model_parallel_src_rank(), group=parallel_state.get_tensor_model_parallel_group())
else:
tokens = torch.empty_like(batch["labels"])
torch.distributed.broadcast(tokens, src=parallel_state.get_tensor_model_parallel_src_rank(), group=parallel_state.get_tensor_model_parallel_group())
batch["tokens"] = tokens
from megatron.core.utils import get_attr_wrapped_model, StragglerDetector
from megatron.core.tokenizers.text.utils.build_tokenizer import build_tokenizer
from megatron.core import mpu
from megatron.core import parallel_state
Comment on lines +78 to +81
(self.vocab_start_index, self.vocab_end_index) = (
VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, get_pg_rank(self.embedding_parallel_group), get_pg_size(self.embedding_parallel_group)
)
Comment on lines +102 to +103
rank=get_pg_rank(self.tp_group),
world_size=get_pg_size(self.tp_group),
input_ids = input_ids.view(-1)
routing_map = input_ids // self.num_embeddings_per_partition
# [num_partitions], number of tokens assigned to each partition from the current rank's input.
num_tokens_per_partition = torch.histc(routing_map, bins=self.embedding_parallel_size, min=0, max=self.embedding_parallel_size)
Comment on lines +375 to +376
warnings.warn(f"[rank0]: We do not recomend using allreduce for engram embedding, this is deprecated and will be removed in later version.", DeprecationWarning)
if self.args.engram_embedding_parallel_size is not None:
if parallel_state.get_tensor_model_parallel_rank() == 0:
torch.distributed.broadcast(batch["tokens"], src=parallel_state.get_tensor_model_parallel_src_rank(), group=parallel_state.get_tensor_model_parallel_group())
else:
tokens = torch.empty_like(batch["labels"])
self,
prefix: str = '',
sharded_offsets: Tuple[Tuple[int, int, int]] = (),
metadata: Optional[dict] = None,** kwargs,
Comment on lines +390 to +391
assert not self.args.use_megatron_fsdp, "Megatron FSDP does not be supported yet, looking forward to later version."
assert not self.args.init_model_with_meta_device, "Init_model_with_meta_device does not be supported yet, looking forward to later version."
backend: torchrun
nnodes: 3
nproc_per_node: 8
hostfile: hostfile # Select an available hostfile. Like ip_1 slosts=8\nip_2 slost=8...
lxd-cumt
lxd-cumt previously approved these changes Mar 16, 2026
Copy link
Copy Markdown
Collaborator

@lxd-cumt lxd-cumt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants