Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions examples/deepseek_v3/conf/train/engram.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# DeepSeek Engram 27B
system:
no_shared_fs: ${experiment.runner.no_shared_fs}
num_workers: 2
tensor_model_parallel_size: 8
expert_model_parallel_size: 8
expert_tensor_parallel_size: 1
context_parallel_size: 1
engram_embedding_parallel_size: 8
sequence_parallel: true
use_distributed_optimizer: true
overlap_grad_reduce: true
overlap_param_gather: true
precision:
bf16: true
attention_softmax_in_fp32: true
accumulate_allreduce_grads_in_fp32: true
logging:
log_interval: 1
tensorboard_log_interval: 1
wandb_project: ${experiment.exp_name}
wandb_exp_name: ${experiment.exp_name}
log_timers_to_tensorboard: true
log_validation_ppl_to_tensorboard: true
log_throughput: true
log_params_norm: true
log_num_zeros_in_grad: true
log_memory_to_tensorboard: true
checkpoint:
save_interval: ${experiment.save_steps}
load: ${experiment.load}
ckpt_format: ${experiment.ckpt_format}

model:
# nsys profile args =================
# profile: true
# profile_step_start: 5
# profile_step_end: 6
# profile_ranks: [0,7] # default [0]
# Note, need to run with nsys profile

# # torch profiler args =================
# profile: true
# use_pytorch_profiler: true
# profile_step_start: 5
# profile_step_end: 6
# profile_ranks: [0] # default [0]
# tensorboard_dir: /workspace/torch_profile
transformer_impl: transformer_engine
num_layers: 30
hidden_size: 2560
num_attention_heads: 32
num_query_groups: 32 # num_key_value_heads
seq_length: 4096
max_position_embeddings: 4096
norm_epsilon: 1e-6
use_rotary_position_embeddings: true
rotary_base: 1000000
swiglu: true
normalization: RMSNorm
qk_layernorm: true
init_method_std: 0.02
attention_dropout: 0.0
hidden_dropout: 0.0
position_embedding_type: rope
untie_embeddings_and_output_weights: true
no_position_embedding: true
no_rope_fusion: true
disable_bias_linear: true

# mla args ==================
multi_latent_attention: true
q_lora_rank: 768
kv_lora_rank: 512
qk_head_dim: 128
qk_pos_emb_head_dim: 64
v_head_dim: 128

# moe args ===================
ffn_hidden_size: 12288
moe_ffn_hidden_size: 1536
moe_grouped_gemm: true
moe_shared_expert_intermediate_size: 3072
num_experts: 56
moe_router_load_balancing_type: "seq_aux_loss"
moe_router_score_function: sigmoid
moe_router_enable_expert_bias: true
moe_router_bias_update_rate: 0.001
moe_aux_loss_coeff: 0.02
moe_layer_freq: "[0]+[1]*29"
# node limited routing
moe_router_num_groups: 1
moe_router_group_topk: 1
moe_router_topk: 6
moe_router_topk_scaling_factor: 2.446
moe_token_dispatcher_type: "alltoall"
# overlap_moe_expert_parallel_comm: true # Optional.

# mtp args ====================
# mtp_num_layers: 1
# mtp_loss_scaling_factor: 0.3

# engram args =================
use_engram: true
engram_tokenizer_name_or_path: /workspace/qwentokenizer
engram_vocab_size: [1131200, 1131200]
max_ngram_size: 3
n_embed_per_ngram: 1280
n_head_per_ngram: 8
engram_layer_ids: [2, 15]
engram_pad_id: 0
engram_seed: 0
engram_kernel_size: 4
engram_hc_mult: 1
engram_embedding_parallel_method: alltoall # alltoall, allreduce, offload
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The inline comment lists offload as a valid value for engram_embedding_parallel_method, but the CLI/config choices added in this PR are only alltoall and allreduce (offloading is controlled by a separate boolean flag). Update the comment to avoid misleading users, and optionally document engram_offload_embedding_optimizer_states here instead.

Suggested change
engram_embedding_parallel_method: alltoall # alltoall, allreduce, offload
engram_embedding_parallel_method: alltoall # alltoall, allreduce
# engram_offload_embedding_optimizer_states: true # Optional; controls optimizer-state offloading.

Copilot uses AI. Check for mistakes.

# training
seed: ${experiment.seed}
finetune: false
micro_batch_size: 2
global_batch_size: 2048
eval_iters: 0
train_iters: 20

optimizer:
clip_grad: 1.0
weight_decay: 0.1
adam_beta1: 0.9
adam_beta2: 0.95
lr_scheduler:
lr: 3.0e-3
min_lr: 3.0e-4
lr_warmup_fraction: 0.01
lr_decay_style: WSD
lr_wsd_decay_style: cosine
lr_wsd_decay_iters: 10

data:
reset_position_ids: True
reset_attention_mask: True
data_path: /workspace/data/enron_emails_demo_text_document_qwen
split: 1
no_mmap_bin_files: true
tokenizer:
legacy_tokenizer: true
tokenizer_type: QwenTokenizerFS
tokenizer_path: /workspace/qwentokenizer
vocab_size: 151851
make_vocab_size_divisible_by: 64
51 changes: 51 additions & 0 deletions examples/deepseek_v3/conf/train_engram.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
defaults:
- _self_
- train: engram

experiment:
exp_name: DeepSeek-Engram
seed: 42
save_steps: 100
load: null
exp_dir: outputs/${experiment.exp_name}
ckpt_format: torch
# ckpt_format: fsdp_dtensor # Just for Megatron FSDP.
task:
type: train
backend: megatron
entrypoint: flagscale/train/megatron/train_engram.py
runner:
# 单机
# per_node_task: false
# no_shared_fs: false
# rdzv_backend: static
# hostfile: null
# ssh_port: 10710
# 多机
per_node_task: false
no_shared_fs: false
backend: torchrun
nnodes: 3
nproc_per_node: 8
hostfile: hostfile # Select an available hostfile. Like ip_1 slosts=8\nip_2 slost=8...
master_port: 10720 # Select an available port.
ssh_port: 10710 # Select an available port.
master_addr: <master_ip>
rdzv_backend: static
cmds:
before_start: ulimit -n 1048576 && source /root/miniconda3/bin/activate /root/miniconda3/envs/flagscale-train
envs:
LOGLEVEL: "INFO"
CUDA_VISIBLE_DEVICES: "0,1,2,3,4,5,6,7"
CUDA_DEVICE_MAX_CONNECTIONS: 1
NCCL_IB_HCA: "IB interface" # Select correct IB interface.
NCCL_SOCKET_IFNAME: "IP interface" # Select correct interface.
NCCL_IB_DISABLE: 0
NCCL_DEBUG: "WARN"
NCCL_IB_GID_INDEX: 3

action: run

hydra:
run:
dir: ${experiment.exp_dir}/hydra
50 changes: 47 additions & 3 deletions flagscale/models/megatron/engram/engram.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from .short_conv import ShortConv


## Megatron
from megatron.core.transformer.utils import sharded_state_dict_default

class Engram(nn.Module):
def __init__(self, engram_cfg: EngramConfig, layer_id):
super().__init__()
Expand All @@ -34,13 +37,17 @@ def __init__(self, engram_cfg: EngramConfig, layer_id):
pad_id=engram_cfg.engram_pad_id,
seed=engram_cfg.engram_seed,
)
self.multi_head_embedding = MultiHeadEmbedding(
self.memory = MultiHeadEmbedding(
engram_cfg,
list_of_N=[
x for y in global_hash_mapping.vocab_size_across_layers[self.layer_id] for x in y
],
D=engram_cfg.n_embed_per_ngram // engram_cfg.n_head_per_ngram,
)
self.embedding_cache = None # Cache for pre-computed embeddings
self.embedding_stream = None # Stream for pre-computing embeddings
if torch.cuda.is_available():
self.embedding_stream = torch.cuda.Stream()
self.short_conv = ShortConv(
hidden_size=self.backbone_config.hidden_size,
kernel_size=engram_cfg.engram_kernel_size,
Expand Down Expand Up @@ -81,8 +88,14 @@ def forward(self, hidden_states, hash_input_ids):
# [B, L, N_GRAM * N_HEADS_PER_GRAM]
# fake hyper-connection
hidden_states = hidden_states.unsqueeze(2)

embeddings = self.multi_head_embedding(hash_input_ids).flatten(start_dim=-2)
if self.embedding_cache is not None:
embeddings, embedding_event = self.embedding_cache
if embedding_event is not None:
torch.cuda.current_stream().wait_event(embedding_event) # Ensure pre-computed embeddings are ready
self.embedding_cache = None # Clear cache after use
del embedding_event # Free the event
else:
embeddings = self.memory(hash_input_ids).flatten(start_dim=-2)
# [L/tp_size, B, N_GRAM * N_HEADS_PER_GRAM, N_EMBED_PER_GRAM // N_HEADS_PER_GRAM]
# [L/tp_size, B, N_GRAM * N_EMBED_PER_NGRAM]

Expand Down Expand Up @@ -120,3 +133,34 @@ def forward(self, hidden_states, hash_input_ids):
output = output.squeeze(2)

return output

def pre_compute_embedding(self, input_ids: torch.Tensor):
"""
Pre-compute the multi-head embedding for the given input IDs.
This can be called before the forward pass to warm up the embedding cache.
"""
assert input_ids is not None, "Input ids can not be None for EngramModel"
self.embedding_stream.synchronize() # Ensure previous computations on the stream are finished
with torch.cuda.stream(self.embedding_stream):
embedding_result = self.memory(input_ids).flatten(start_dim=-2)
embedding_event = torch.cuda.Event()
embedding_event.record(self.embedding_stream)
self.embedding_cache = (embedding_result, embedding_event)
Comment on lines +137 to +148
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pre_compute_embedding() unconditionally uses CUDA APIs (self.embedding_stream.synchronize(), torch.cuda.stream(...), torch.cuda.Event()) but self.embedding_stream is only created when torch.cuda.is_available(). If Engram is instantiated in a CPU-only run (or CUDA is disabled), this will raise an exception. Guard this method (and/or the call sites) so it either becomes a no-op on CPU or falls back to synchronous computation without CUDA streams/events.

Copilot uses AI. Check for mistakes.

Comment on lines +143 to +149
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pre_compute_embedding() unconditionally uses self.embedding_stream (synchronize() and torch.cuda.stream(...)), but self.embedding_stream is set only when torch.cuda.is_available(). If this method is called in a CPU-only environment (or before CUDA is initialized), it will raise an attribute error / CUDA error. Guard for self.embedding_stream is None (fallback to synchronous compute without streams/events) or assert CUDA availability with a clearer error message.

Suggested change
self.embedding_stream.synchronize() # Ensure previous computations on the stream are finished
with torch.cuda.stream(self.embedding_stream):
embedding_result = self.memory(input_ids).flatten(start_dim=-2)
embedding_event = torch.cuda.Event()
embedding_event.record(self.embedding_stream)
self.embedding_cache = (embedding_result, embedding_event)
embedding_stream = getattr(self, "embedding_stream", None)
if embedding_stream is not None and torch.cuda.is_available():
embedding_stream.synchronize() # Ensure previous computations on the stream are finished
with torch.cuda.stream(embedding_stream):
embedding_result = self.memory(input_ids).flatten(start_dim=-2)
embedding_event = torch.cuda.Event()
embedding_event.record(embedding_stream)
else:
embedding_result = self.memory(input_ids).flatten(start_dim=-2)
embedding_event = None
self.embedding_cache = (embedding_result, embedding_event)

Copilot uses AI. Check for mistakes.
def sharded_state_dict(
self, prefix: str = "", sharded_offsets: tuple = (), metadata: dict | None = None
):
sharded_dict = {}
memory_prefix = f"{prefix}memory."
sharded_dict.update(self.memory.sharded_state_dict(memory_prefix, sharded_offsets, metadata))
conv_prefix = f"{prefix}short_conv."
sharded_dict.update(sharded_state_dict_default(self.short_conv, conv_prefix, sharded_offsets, metadata))
value_proj_prefix = f"{prefix}value_proj."
sharded_dict.update(sharded_state_dict_default(self.value_proj, value_proj_prefix, sharded_offsets, metadata))
key_projs_prefix = f"{prefix}key_projs."
sharded_dict.update(sharded_state_dict_default(self.key_projs, key_projs_prefix, sharded_offsets, metadata))
norm1_prefix = f"{prefix}norm1."
sharded_dict.update(sharded_state_dict_default(self.norm1, norm1_prefix, sharded_offsets, metadata))
norm2_prefix = f"{prefix}norm2."
sharded_dict.update(sharded_state_dict_default(self.norm2, norm2_prefix, sharded_offsets, metadata))
return sharded_dict
3 changes: 3 additions & 0 deletions flagscale/models/megatron/engram/engram_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@ class EngramConfig(MLATransformerConfig):
engram_seed: int = 0
engram_kernel_size: int = 1
engram_hc_mult: int = 1
engram_embedding_parallel_size: int | None = 1
engram_embedding_parallel_method: str = "alltoall"
engram_offload_embedding_optimizer_states: bool = False
86 changes: 68 additions & 18 deletions flagscale/models/megatron/engram/engram_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# ruff: noqa: RUF013
## built-in
from typing import Optional

import torch
from torch import Tensor

Expand Down Expand Up @@ -27,26 +29,41 @@ def __init__(self, hash_mapping, input_ids, hash_stream=None):
self.input_ids = input_ids
self.hash_stream = hash_stream
self._result = None
self._computation_started = False

# torch.cuda.nvtx.range_push("LazyHashInputIds hash")
# Start async computation immediately if stream is available
self._is_async_pending = False
# Async
if self.hash_stream is not None:
# self.hash_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.hash_stream):
self._result = self.hash_mapping.hash(self.input_ids)
self._computation_started = True
# torch.cuda.nvtx.range_pop()
self._is_async_pending = True
# record result to use across stream
self._record_current_stream()

def __getitem__(self, key):
"""Access hash result, synchronizing if necessary."""
def _record_current_stream(self):
"""Helper to record current stream on all result tensors"""
if self._result is None:
if self.hash_stream is not None and self._computation_started:
# Wait for async computation to complete
torch.cuda.current_stream().wait_stream(self.hash_stream)
self._computation_started = False # Mark as synchronized
else:
# Compute synchronously if no stream or computation not started
self._result = self.hash_mapping.hash(self.input_ids)
return
current_stream = torch.cuda.current_stream()
if isinstance(self._result, dict):
for t in self._result.values():
if isinstance(t, torch.Tensor):
t.record_stream(current_stream)
elif isinstance(self._result, torch.Tensor):
self._result.record_stream(current_stream)

def __getitem__(self, key):
# Case 1: Async compute -> wait
if self._is_async_pending:
torch.cuda.current_stream().wait_stream(self.hash_stream)
self._is_async_pending = False # Async finish
self._record_current_stream()

# Case 2: Sync but no compute -> start compute
elif self._result is None:
self._result = self.hash_mapping.hash(self.input_ids)

# Case 3: Async or sync compute is finished.
# print(f"[rank{torch.distributed.get_rank()}]: LazyHashInputIds result = {self._result}")
return self._result[key]

def get(self, key, default=None):
Expand Down Expand Up @@ -171,7 +188,40 @@ def forward(
inference_context=inference_context,
)

def sharded_state_dict(
self, prefix: str = "", sharded_offsets: tuple = (), metadata: dict | None = None
def build_schedule_plan(
self,
input_ids: Tensor,
position_ids: Tensor,
attention_mask: Tensor,
decoder_input: Tensor = None,
labels: Tensor = None,
inference_context: BaseInferenceContext = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
runtime_gather_output: Optional[bool] = None,
inference_params: Optional[BaseInferenceContext] = None,
loss_mask: Optional[Tensor] = None,
):
raise NotImplementedError("Sharded state dict is not supported for EngramModel")
"""
Adaptation of overlap_moe_expert_parallel_comm.
"""
# Precompute the engram_hash_iput_ids, it will be used to create a TransformerChunkSchedulePlan.
engram_hash_input_ids = LazyHashInputIds(
hash_mapping=self.engram_hash,
input_ids=input_ids,
hash_stream=self._hash_stream,
)
if extra_block_kwargs is None:
extra_block_kwargs = {
"engram_hash_input_ids": engram_hash_input_ids,
}
Comment on lines +214 to +217
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

build_schedule_plan() only injects engram_hash_input_ids when extra_block_kwargs is None. If a caller passes a non-empty extra_block_kwargs, the schedule plan will be built without engram_hash_input_ids, but EngramTransformerBlock.forward() requires it. Merge into the provided dict (e.g., copy + setdefault) so engram_hash_input_ids is always present without clobbering user-supplied keys.

Suggested change
if extra_block_kwargs is None:
extra_block_kwargs = {
"engram_hash_input_ids": engram_hash_input_ids,
}
extra_block_kwargs = dict(extra_block_kwargs or {})
extra_block_kwargs.setdefault("engram_hash_input_ids", engram_hash_input_ids)

Copilot uses AI. Check for mistakes.
return super().build_schedule_plan(
input_ids,
position_ids,
attention_mask,
decoder_input,
labels=labels,
loss_mask=loss_mask,
extra_block_kwargs=extra_block_kwargs
)

Loading
Loading