Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions examples/deepseek_v3/conf/train/engram.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# DeepSeek Engram 27B
system:
no_shared_fs: ${experiment.runner.no_shared_fs}
num_workers: 2
tensor_model_parallel_size: 8
expert_model_parallel_size: 8
expert_tensor_parallel_size: 1
context_parallel_size: 1
engram_embedding_parallel_size: 8
sequence_parallel: true
use_distributed_optimizer: true
overlap_grad_reduce: true
overlap_param_gather: true
precision:
bf16: true
attention_softmax_in_fp32: true
accumulate_allreduce_grads_in_fp32: true
logging:
log_interval: 1
tensorboard_log_interval: 1
wandb_project: ${experiment.exp_name}
wandb_exp_name: ${experiment.exp_name}
log_timers_to_tensorboard: true
log_validation_ppl_to_tensorboard: true
log_throughput: true
log_params_norm: true
log_num_zeros_in_grad: true
log_memory_to_tensorboard: true
checkpoint:
save_interval: ${experiment.save_steps}
load: ${experiment.load}
ckpt_format: ${experiment.ckpt_format}

model:
# nsys profile args =================
# profile: true
# profile_step_start: 5
# profile_step_end: 6
# profile_ranks: [0,7] # default [0]
# Note, need to run with nsys profile

# # torch profiler args =================
# profile: true
# use_pytorch_profiler: true
# profile_step_start: 5
# profile_step_end: 6
# profile_ranks: [0] # default [0]
# tensorboard_dir: /workspace/torch_profile
transformer_impl: transformer_engine
num_layers: 30
hidden_size: 2560
num_attention_heads: 32
num_query_groups: 32 # num_key_value_heads
seq_length: 4096
max_position_embeddings: 4096
norm_epsilon: 1e-6
use_rotary_position_embeddings: true
rotary_base: 1000000
swiglu: true
normalization: RMSNorm
qk_layernorm: true
init_method_std: 0.02
attention_dropout: 0.0
hidden_dropout: 0.0
position_embedding_type: rope
untie_embeddings_and_output_weights: true
no_position_embedding: true
no_rope_fusion: true
disable_bias_linear: true

# mla args ==================
multi_latent_attention: true
q_lora_rank: 768
kv_lora_rank: 512
qk_head_dim: 128
qk_pos_emb_head_dim: 64
v_head_dim: 128

# moe args ===================
ffn_hidden_size: 12288
moe_ffn_hidden_size: 1536
moe_grouped_gemm: true
moe_shared_expert_intermediate_size: 3072
num_experts: 56
moe_router_load_balancing_type: "seq_aux_loss"
moe_router_score_function: sigmoid
moe_router_enable_expert_bias: true
moe_router_bias_update_rate: 0.001
moe_aux_loss_coeff: 0.02
moe_layer_freq: "[0]+[1]*29"
# node limited routing
moe_router_num_groups: 1
moe_router_group_topk: 1
moe_router_topk: 6
moe_router_topk_scaling_factor: 2.446
moe_token_dispatcher_type: "alltoall"
# overlap_moe_expert_parallel_comm: true # Optional.

# mtp args ====================
# mtp_num_layers: 1
# mtp_loss_scaling_factor: 0.3

# engram args =================
use_engram: true
engram_tokenizer_name_or_path: /workspace/qwentokenizer
engram_vocab_size: [1131200, 1131200]
max_ngram_size: 3
n_embed_per_ngram: 1280
n_head_per_ngram: 8
engram_layer_ids: [2, 15]
engram_pad_id: 0
engram_seed: 0
engram_kernel_size: 4
engram_hc_mult: 1
engram_embedding_parallel_method: alltoall # alltoall, allreduce, offload

# training
seed: ${experiment.seed}
finetune: false
micro_batch_size: 2
global_batch_size: 2048
eval_iters: 0
train_iters: 20

optimizer:
clip_grad: 1.0
weight_decay: 0.1
adam_beta1: 0.9
adam_beta2: 0.95
lr_scheduler:
lr: 3.0e-3
min_lr: 3.0e-4
lr_warmup_fraction: 0.01
lr_decay_style: WSD
lr_wsd_decay_style: cosine
lr_wsd_decay_iters: 10

data:
reset_position_ids: True
reset_attention_mask: True
data_path: /workspace/data/enron_emails_demo_text_document_qwen
split: 1
no_mmap_bin_files: true
tokenizer:
legacy_tokenizer: true
tokenizer_type: QwenTokenizerFS
tokenizer_path: /workspace/qwentokenizer
vocab_size: 151851
make_vocab_size_divisible_by: 64
51 changes: 51 additions & 0 deletions examples/deepseek_v3/conf/train_engram.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
defaults:
- _self_
- train: engram

experiment:
exp_name: DeepSeek-Engram
seed: 42
save_steps: 100
load: null
exp_dir: outputs/${experiment.exp_name}
ckpt_format: torch
# ckpt_format: fsdp_dtensor # Just for Megatron FSDP.
task:
type: train
backend: megatron
entrypoint: flagscale/train/megatron/train_engram.py
runner:
# 单机
# per_node_task: false
# no_shared_fs: false
# rdzv_backend: static
# hostfile: null
# ssh_port: 10710
# 多机
per_node_task: false
no_shared_fs: false
backend: torchrun
nnodes: 3
nproc_per_node: 8
hostfile: hostfile # Select an available hostfile. Like ip_1 slosts=8\nip_2 slost=8...
master_port: 10720 # Select an available port.
ssh_port: 10710 # Select an available port.
master_addr: <master_ip>
rdzv_backend: static
cmds:
before_start: ulimit -n 1048576 && source /root/miniconda3/bin/activate /root/miniconda3/envs/flagscale-train
envs:
LOGLEVEL: "INFO"
CUDA_VISIBLE_DEVICES: "0,1,2,3,4,5,6,7"
CUDA_DEVICE_MAX_CONNECTIONS: 1
NCCL_IB_HCA: "IB interface" # Select correct IB interface.
NCCL_SOCKET_IFNAME: "IP interface" # Select correct interface.
NCCL_IB_DISABLE: 0
NCCL_DEBUG: "WARN"
NCCL_IB_GID_INDEX: 3

action: run

hydra:
run:
dir: ${experiment.exp_dir}/hydra
50 changes: 47 additions & 3 deletions flagscale/models/megatron/engram/engram.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from .short_conv import ShortConv


## Megatron
from megatron.core.transformer.utils import sharded_state_dict_default

class Engram(nn.Module):
def __init__(self, engram_cfg: EngramConfig, layer_id):
super().__init__()
Expand All @@ -34,13 +37,17 @@ def __init__(self, engram_cfg: EngramConfig, layer_id):
pad_id=engram_cfg.engram_pad_id,
seed=engram_cfg.engram_seed,
)
self.multi_head_embedding = MultiHeadEmbedding(
self.memory = MultiHeadEmbedding(
engram_cfg,
list_of_N=[
x for y in global_hash_mapping.vocab_size_across_layers[self.layer_id] for x in y
],
D=engram_cfg.n_embed_per_ngram // engram_cfg.n_head_per_ngram,
)
self.embedding_cache = None # Cache for pre-computed embeddings
self.embedding_stream = None # Stream for pre-computing embeddings
if torch.cuda.is_available():
self.embedding_stream = torch.cuda.Stream()
self.short_conv = ShortConv(
hidden_size=self.backbone_config.hidden_size,
kernel_size=engram_cfg.engram_kernel_size,
Expand Down Expand Up @@ -81,8 +88,14 @@ def forward(self, hidden_states, hash_input_ids):
# [B, L, N_GRAM * N_HEADS_PER_GRAM]
# fake hyper-connection
hidden_states = hidden_states.unsqueeze(2)

embeddings = self.multi_head_embedding(hash_input_ids).flatten(start_dim=-2)
if self.embedding_cache is not None:
embeddings, embedding_event = self.embedding_cache
if embedding_event is not None:
torch.cuda.current_stream().wait_event(embedding_event) # Ensure pre-computed embeddings are ready
self.embedding_cache = None # Clear cache after use
del embedding_event # Free the event
else:
embeddings = self.memory(hash_input_ids).flatten(start_dim=-2)
# [L/tp_size, B, N_GRAM * N_HEADS_PER_GRAM, N_EMBED_PER_GRAM // N_HEADS_PER_GRAM]
# [L/tp_size, B, N_GRAM * N_EMBED_PER_NGRAM]

Expand Down Expand Up @@ -120,3 +133,34 @@ def forward(self, hidden_states, hash_input_ids):
output = output.squeeze(2)

return output

def pre_compute_embedding(self, input_ids: torch.Tensor):
"""
Pre-compute the multi-head embedding for the given input IDs.
This can be called before the forward pass to warm up the embedding cache.
"""
assert input_ids is not None, "Input ids can not be None for EngramModel"
self.embedding_stream.synchronize() # Ensure previous computations on the stream are finished
with torch.cuda.stream(self.embedding_stream):
embedding_result = self.memory(input_ids).flatten(start_dim=-2)
embedding_event = torch.cuda.Event()
embedding_event.record(self.embedding_stream)
self.embedding_cache = (embedding_result, embedding_event)

def sharded_state_dict(
self, prefix: str = "", sharded_offsets: tuple = (), metadata: dict | None = None
):
sharded_dict = {}
memory_prefix = f"{prefix}memory."
sharded_dict.update(self.memory.sharded_state_dict(memory_prefix, sharded_offsets, metadata))
conv_prefix = f"{prefix}short_conv."
sharded_dict.update(sharded_state_dict_default(self.short_conv, conv_prefix, sharded_offsets, metadata))
value_proj_prefix = f"{prefix}value_proj."
sharded_dict.update(sharded_state_dict_default(self.value_proj, value_proj_prefix, sharded_offsets, metadata))
key_projs_prefix = f"{prefix}key_projs."
sharded_dict.update(sharded_state_dict_default(self.key_projs, key_projs_prefix, sharded_offsets, metadata))
norm1_prefix = f"{prefix}norm1."
sharded_dict.update(sharded_state_dict_default(self.norm1, norm1_prefix, sharded_offsets, metadata))
norm2_prefix = f"{prefix}norm2."
sharded_dict.update(sharded_state_dict_default(self.norm2, norm2_prefix, sharded_offsets, metadata))
return sharded_dict
2 changes: 2 additions & 0 deletions flagscale/models/megatron/engram/engram_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ class EngramConfig(MLATransformerConfig):
engram_seed: int = 0
engram_kernel_size: int = 1
engram_hc_mult: int = 1
engram_embedding_parallel_size: int | None = 1
engram_embedding_parallel_method: str = "alltoall"
32 changes: 29 additions & 3 deletions flagscale/models/megatron/engram/engram_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# ruff: noqa: RUF013
## built-in
from typing import Optional

import torch
from torch import Tensor

Expand Down Expand Up @@ -171,7 +173,31 @@ def forward(
inference_context=inference_context,
)

def sharded_state_dict(
self, prefix: str = "", sharded_offsets: tuple = (), metadata: dict | None = None
def build_schedule_plan(
self,
input_ids: Tensor,
position_ids: Tensor,
attention_mask: Tensor,
decoder_input: Tensor = None,
labels: Tensor = None,
inference_context: BaseInferenceContext = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
runtime_gather_output: Optional[bool] = None,
inference_params: Optional[BaseInferenceContext] = None,
loss_mask: Optional[Tensor] = None,
):
raise NotImplementedError("Sharded state dict is not supported for EngramModel")
self.engram_hash_input_ids = LazyHashInputIds(
hash_mapping=self.engram_hash,
input_ids=input_ids,
hash_stream=self._hash_stream,
)
return super().build_schedule_plan(
input_ids,
position_ids,
attention_mask,
decoder_input,
labels=labels,
loss_mask=loss_mask
)

22 changes: 16 additions & 6 deletions flagscale/models/megatron/engram/engram_transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,18 @@ def forward(
sequence_len_offset=sequence_len_offset,
inference_params=inference_params,
)

def pre_compute_embedding(self, hash_input_ids: Tensor):
self.engram.pre_compute_embedding(hash_input_ids)

def sharded_state_dict(
self, prefix: str = "", sharded_offsets: tuple = (), metadata: dict | None = None
):
raise NotImplementedError("Sharded state dict is not supported for EngramTransformerLayer")
sharded_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
engram_prefix = f"{prefix}engram."
engram_sharded = self.engram.sharded_state_dict(engram_prefix, sharded_offsets, metadata)
sharded_dict.update(engram_sharded)
return sharded_dict


class EngramTransformerBlock(TransformerBlock):
Expand Down Expand Up @@ -283,6 +290,14 @@ def forward(
# Build kwargs based on layer type
layer_kwargs = {}

# Pre-compute embeddings for the next EngramTransformerLayer if exists to overlap with current layer's computation
if l_no < len(self.layers) - 1:
next_layer = self.layers[l_no + 1]
if isinstance(next_layer, EngramTransformerLayer):
engram_hash_layer_id = next_layer.layer_number - 1
hash_input_ids = engram_hash_input_ids[engram_hash_layer_id]
next_layer.pre_compute_embedding(hash_input_ids)

# Only pass input_ids to EngramTransformerLayer
if isinstance(layer, EngramTransformerLayer):
layer_kwargs["input_ids"] = input_ids
Expand Down Expand Up @@ -333,8 +348,3 @@ def forward(
hidden_states = hidden_states.clone()

return hidden_states

def sharded_state_dict(
self, prefix: str = "", sharded_offsets: tuple = (), metadata: dict = None
):
raise NotImplementedError("Sharded state dict is not supported for EngramTransformerBlock")
Loading
Loading