-
Notifications
You must be signed in to change notification settings - Fork 148
Deepseek Engram optimization. #1147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
63a5dcb
df6f6bb
8f528db
9b26315
178d7fd
ad494c8
6ef717d
4bc0f4d
67513a6
66b3167
2a474bf
09fd1aa
480801d
e603ba7
c423043
790ccee
1f6ad03
3e3fe90
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,149 @@ | ||
| # DeepSeek Engram 27B | ||
| system: | ||
| no_shared_fs: ${experiment.runner.no_shared_fs} | ||
| num_workers: 2 | ||
| tensor_model_parallel_size: 8 | ||
| expert_model_parallel_size: 8 | ||
| expert_tensor_parallel_size: 1 | ||
| context_parallel_size: 1 | ||
| engram_embedding_parallel_size: 8 | ||
| sequence_parallel: true | ||
| use_distributed_optimizer: true | ||
| overlap_grad_reduce: true | ||
| overlap_param_gather: true | ||
| precision: | ||
| bf16: true | ||
| attention_softmax_in_fp32: true | ||
| accumulate_allreduce_grads_in_fp32: true | ||
| logging: | ||
| log_interval: 1 | ||
| tensorboard_log_interval: 1 | ||
| wandb_project: ${experiment.exp_name} | ||
| wandb_exp_name: ${experiment.exp_name} | ||
| log_timers_to_tensorboard: true | ||
| log_validation_ppl_to_tensorboard: true | ||
| log_throughput: true | ||
| log_params_norm: true | ||
| log_num_zeros_in_grad: true | ||
| log_memory_to_tensorboard: true | ||
| checkpoint: | ||
| save_interval: ${experiment.save_steps} | ||
| load: ${experiment.load} | ||
| ckpt_format: ${experiment.ckpt_format} | ||
|
|
||
| model: | ||
| # nsys profile args ================= | ||
| # profile: true | ||
| # profile_step_start: 5 | ||
| # profile_step_end: 6 | ||
| # profile_ranks: [0,7] # default [0] | ||
| # Note, need to run with nsys profile | ||
|
|
||
| # # torch profiler args ================= | ||
| # profile: true | ||
| # use_pytorch_profiler: true | ||
| # profile_step_start: 5 | ||
| # profile_step_end: 6 | ||
| # profile_ranks: [0] # default [0] | ||
| # tensorboard_dir: /workspace/torch_profile | ||
| transformer_impl: transformer_engine | ||
| num_layers: 30 | ||
| hidden_size: 2560 | ||
| num_attention_heads: 32 | ||
| num_query_groups: 32 # num_key_value_heads | ||
| seq_length: 4096 | ||
| max_position_embeddings: 4096 | ||
| norm_epsilon: 1e-6 | ||
| use_rotary_position_embeddings: true | ||
| rotary_base: 1000000 | ||
| swiglu: true | ||
| normalization: RMSNorm | ||
| qk_layernorm: true | ||
| init_method_std: 0.02 | ||
| attention_dropout: 0.0 | ||
| hidden_dropout: 0.0 | ||
| position_embedding_type: rope | ||
| untie_embeddings_and_output_weights: true | ||
| no_position_embedding: true | ||
| no_rope_fusion: true | ||
| disable_bias_linear: true | ||
|
|
||
| # mla args ================== | ||
| multi_latent_attention: true | ||
| q_lora_rank: 768 | ||
| kv_lora_rank: 512 | ||
| qk_head_dim: 128 | ||
| qk_pos_emb_head_dim: 64 | ||
| v_head_dim: 128 | ||
|
|
||
| # moe args =================== | ||
| ffn_hidden_size: 12288 | ||
| moe_ffn_hidden_size: 1536 | ||
| moe_grouped_gemm: true | ||
| moe_shared_expert_intermediate_size: 3072 | ||
| num_experts: 56 | ||
| moe_router_load_balancing_type: "seq_aux_loss" | ||
| moe_router_score_function: sigmoid | ||
| moe_router_enable_expert_bias: true | ||
| moe_router_bias_update_rate: 0.001 | ||
| moe_aux_loss_coeff: 0.02 | ||
| moe_layer_freq: "[0]+[1]*29" | ||
| # node limited routing | ||
| moe_router_num_groups: 1 | ||
| moe_router_group_topk: 1 | ||
| moe_router_topk: 6 | ||
| moe_router_topk_scaling_factor: 2.446 | ||
| moe_token_dispatcher_type: "alltoall" | ||
| # overlap_moe_expert_parallel_comm: true # Optional. | ||
|
|
||
| # mtp args ==================== | ||
| # mtp_num_layers: 1 | ||
| # mtp_loss_scaling_factor: 0.3 | ||
|
|
||
| # engram args ================= | ||
| use_engram: true | ||
| engram_tokenizer_name_or_path: /workspace/qwentokenizer | ||
| engram_vocab_size: [1131200, 1131200] | ||
| max_ngram_size: 3 | ||
| n_embed_per_ngram: 1280 | ||
| n_head_per_ngram: 8 | ||
| engram_layer_ids: [2, 15] | ||
| engram_pad_id: 0 | ||
| engram_seed: 0 | ||
| engram_kernel_size: 4 | ||
| engram_hc_mult: 1 | ||
| engram_embedding_parallel_method: alltoall # alltoall, allreduce, offload | ||
|
|
||
| # training | ||
| seed: ${experiment.seed} | ||
| finetune: false | ||
| micro_batch_size: 2 | ||
| global_batch_size: 2048 | ||
| eval_iters: 0 | ||
| train_iters: 20 | ||
|
|
||
| optimizer: | ||
| clip_grad: 1.0 | ||
| weight_decay: 0.1 | ||
| adam_beta1: 0.9 | ||
| adam_beta2: 0.95 | ||
| lr_scheduler: | ||
| lr: 3.0e-3 | ||
| min_lr: 3.0e-4 | ||
| lr_warmup_fraction: 0.01 | ||
| lr_decay_style: WSD | ||
| lr_wsd_decay_style: cosine | ||
| lr_wsd_decay_iters: 10 | ||
|
|
||
| data: | ||
| reset_position_ids: True | ||
| reset_attention_mask: True | ||
| data_path: /workspace/data/enron_emails_demo_text_document_qwen | ||
| split: 1 | ||
| no_mmap_bin_files: true | ||
| tokenizer: | ||
| legacy_tokenizer: true | ||
| tokenizer_type: QwenTokenizerFS | ||
| tokenizer_path: /workspace/qwentokenizer | ||
| vocab_size: 151851 | ||
| make_vocab_size_divisible_by: 64 | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| defaults: | ||
| - _self_ | ||
| - train: engram | ||
|
|
||
| experiment: | ||
| exp_name: DeepSeek-Engram | ||
| seed: 42 | ||
| save_steps: 100 | ||
| load: null | ||
| exp_dir: outputs/${experiment.exp_name} | ||
| ckpt_format: torch | ||
| # ckpt_format: fsdp_dtensor # Just for Megatron FSDP. | ||
| task: | ||
| type: train | ||
| backend: megatron | ||
| entrypoint: flagscale/train/megatron/train_engram.py | ||
| runner: | ||
| # 单机 | ||
| # per_node_task: false | ||
| # no_shared_fs: false | ||
| # rdzv_backend: static | ||
| # hostfile: null | ||
| # ssh_port: 10710 | ||
| # 多机 | ||
| per_node_task: false | ||
| no_shared_fs: false | ||
| backend: torchrun | ||
| nnodes: 3 | ||
| nproc_per_node: 8 | ||
| hostfile: hostfile # Select an available hostfile. Like ip_1 slosts=8\nip_2 slost=8... | ||
|
|
||
| master_port: 10720 # Select an available port. | ||
| ssh_port: 10710 # Select an available port. | ||
| master_addr: <master_ip> | ||
| rdzv_backend: static | ||
| cmds: | ||
| before_start: ulimit -n 1048576 && source /root/miniconda3/bin/activate /root/miniconda3/envs/flagscale-train | ||
| envs: | ||
| LOGLEVEL: "INFO" | ||
| CUDA_VISIBLE_DEVICES: "0,1,2,3,4,5,6,7" | ||
| CUDA_DEVICE_MAX_CONNECTIONS: 1 | ||
| NCCL_IB_HCA: "IB interface" # Select correct IB interface. | ||
| NCCL_SOCKET_IFNAME: "IP interface" # Select correct interface. | ||
| NCCL_IB_DISABLE: 0 | ||
| NCCL_DEBUG: "WARN" | ||
| NCCL_IB_GID_INDEX: 3 | ||
|
|
||
| action: run | ||
|
|
||
| hydra: | ||
| run: | ||
| dir: ${experiment.exp_dir}/hydra | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -14,6 +14,9 @@ | |||||||||||||||||||||||||||||||||||||||||
| from .short_conv import ShortConv | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| ## Megatron | ||||||||||||||||||||||||||||||||||||||||||
| from megatron.core.transformer.utils import sharded_state_dict_default | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| class Engram(nn.Module): | ||||||||||||||||||||||||||||||||||||||||||
| def __init__(self, engram_cfg: EngramConfig, layer_id): | ||||||||||||||||||||||||||||||||||||||||||
| super().__init__() | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -34,13 +37,17 @@ def __init__(self, engram_cfg: EngramConfig, layer_id): | |||||||||||||||||||||||||||||||||||||||||
| pad_id=engram_cfg.engram_pad_id, | ||||||||||||||||||||||||||||||||||||||||||
| seed=engram_cfg.engram_seed, | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
| self.multi_head_embedding = MultiHeadEmbedding( | ||||||||||||||||||||||||||||||||||||||||||
| self.memory = MultiHeadEmbedding( | ||||||||||||||||||||||||||||||||||||||||||
| engram_cfg, | ||||||||||||||||||||||||||||||||||||||||||
| list_of_N=[ | ||||||||||||||||||||||||||||||||||||||||||
| x for y in global_hash_mapping.vocab_size_across_layers[self.layer_id] for x in y | ||||||||||||||||||||||||||||||||||||||||||
| ], | ||||||||||||||||||||||||||||||||||||||||||
| D=engram_cfg.n_embed_per_ngram // engram_cfg.n_head_per_ngram, | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
| self.embedding_cache = None # Cache for pre-computed embeddings | ||||||||||||||||||||||||||||||||||||||||||
| self.embedding_stream = None # Stream for pre-computing embeddings | ||||||||||||||||||||||||||||||||||||||||||
| if torch.cuda.is_available(): | ||||||||||||||||||||||||||||||||||||||||||
| self.embedding_stream = torch.cuda.Stream() | ||||||||||||||||||||||||||||||||||||||||||
| self.short_conv = ShortConv( | ||||||||||||||||||||||||||||||||||||||||||
| hidden_size=self.backbone_config.hidden_size, | ||||||||||||||||||||||||||||||||||||||||||
| kernel_size=engram_cfg.engram_kernel_size, | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -81,8 +88,14 @@ def forward(self, hidden_states, hash_input_ids): | |||||||||||||||||||||||||||||||||||||||||
| # [B, L, N_GRAM * N_HEADS_PER_GRAM] | ||||||||||||||||||||||||||||||||||||||||||
| # fake hyper-connection | ||||||||||||||||||||||||||||||||||||||||||
| hidden_states = hidden_states.unsqueeze(2) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| embeddings = self.multi_head_embedding(hash_input_ids).flatten(start_dim=-2) | ||||||||||||||||||||||||||||||||||||||||||
| if self.embedding_cache is not None: | ||||||||||||||||||||||||||||||||||||||||||
| embeddings, embedding_event = self.embedding_cache | ||||||||||||||||||||||||||||||||||||||||||
| if embedding_event is not None: | ||||||||||||||||||||||||||||||||||||||||||
| torch.cuda.current_stream().wait_event(embedding_event) # Ensure pre-computed embeddings are ready | ||||||||||||||||||||||||||||||||||||||||||
| self.embedding_cache = None # Clear cache after use | ||||||||||||||||||||||||||||||||||||||||||
| del embedding_event # Free the event | ||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||
| embeddings = self.memory(hash_input_ids).flatten(start_dim=-2) | ||||||||||||||||||||||||||||||||||||||||||
| # [L/tp_size, B, N_GRAM * N_HEADS_PER_GRAM, N_EMBED_PER_GRAM // N_HEADS_PER_GRAM] | ||||||||||||||||||||||||||||||||||||||||||
| # [L/tp_size, B, N_GRAM * N_EMBED_PER_NGRAM] | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -120,3 +133,34 @@ def forward(self, hidden_states, hash_input_ids): | |||||||||||||||||||||||||||||||||||||||||
| output = output.squeeze(2) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| return output | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def pre_compute_embedding(self, input_ids: torch.Tensor): | ||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||
| Pre-compute the multi-head embedding for the given input IDs. | ||||||||||||||||||||||||||||||||||||||||||
| This can be called before the forward pass to warm up the embedding cache. | ||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||
| assert input_ids is not None, "Input ids can not be None for EngramModel" | ||||||||||||||||||||||||||||||||||||||||||
| self.embedding_stream.synchronize() # Ensure previous computations on the stream are finished | ||||||||||||||||||||||||||||||||||||||||||
| with torch.cuda.stream(self.embedding_stream): | ||||||||||||||||||||||||||||||||||||||||||
| embedding_result = self.memory(input_ids).flatten(start_dim=-2) | ||||||||||||||||||||||||||||||||||||||||||
| embedding_event = torch.cuda.Event() | ||||||||||||||||||||||||||||||||||||||||||
| embedding_event.record(self.embedding_stream) | ||||||||||||||||||||||||||||||||||||||||||
| self.embedding_cache = (embedding_result, embedding_event) | ||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+137
to
+148
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+143
to
+149
|
||||||||||||||||||||||||||||||||||||||||||
| self.embedding_stream.synchronize() # Ensure previous computations on the stream are finished | |
| with torch.cuda.stream(self.embedding_stream): | |
| embedding_result = self.memory(input_ids).flatten(start_dim=-2) | |
| embedding_event = torch.cuda.Event() | |
| embedding_event.record(self.embedding_stream) | |
| self.embedding_cache = (embedding_result, embedding_event) | |
| embedding_stream = getattr(self, "embedding_stream", None) | |
| if embedding_stream is not None and torch.cuda.is_available(): | |
| embedding_stream.synchronize() # Ensure previous computations on the stream are finished | |
| with torch.cuda.stream(embedding_stream): | |
| embedding_result = self.memory(input_ids).flatten(start_dim=-2) | |
| embedding_event = torch.cuda.Event() | |
| embedding_event.record(embedding_stream) | |
| else: | |
| embedding_result = self.memory(input_ids).flatten(start_dim=-2) | |
| embedding_event = None | |
| self.embedding_cache = (embedding_result, embedding_event) |
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,5 +1,7 @@ | ||||||||||||||
| # ruff: noqa: RUF013 | ||||||||||||||
| ## built-in | ||||||||||||||
| from typing import Optional | ||||||||||||||
|
|
||||||||||||||
| import torch | ||||||||||||||
| from torch import Tensor | ||||||||||||||
|
|
||||||||||||||
|
|
@@ -27,26 +29,41 @@ def __init__(self, hash_mapping, input_ids, hash_stream=None): | |||||||||||||
| self.input_ids = input_ids | ||||||||||||||
| self.hash_stream = hash_stream | ||||||||||||||
| self._result = None | ||||||||||||||
| self._computation_started = False | ||||||||||||||
|
|
||||||||||||||
| # torch.cuda.nvtx.range_push("LazyHashInputIds hash") | ||||||||||||||
| # Start async computation immediately if stream is available | ||||||||||||||
| self._is_async_pending = False | ||||||||||||||
| # Async | ||||||||||||||
| if self.hash_stream is not None: | ||||||||||||||
| # self.hash_stream.wait_stream(torch.cuda.current_stream()) | ||||||||||||||
| with torch.cuda.stream(self.hash_stream): | ||||||||||||||
| self._result = self.hash_mapping.hash(self.input_ids) | ||||||||||||||
| self._computation_started = True | ||||||||||||||
| # torch.cuda.nvtx.range_pop() | ||||||||||||||
| self._is_async_pending = True | ||||||||||||||
| # record result to use across stream | ||||||||||||||
| self._record_current_stream() | ||||||||||||||
|
|
||||||||||||||
| def __getitem__(self, key): | ||||||||||||||
| """Access hash result, synchronizing if necessary.""" | ||||||||||||||
| def _record_current_stream(self): | ||||||||||||||
| """Helper to record current stream on all result tensors""" | ||||||||||||||
| if self._result is None: | ||||||||||||||
| if self.hash_stream is not None and self._computation_started: | ||||||||||||||
| # Wait for async computation to complete | ||||||||||||||
| torch.cuda.current_stream().wait_stream(self.hash_stream) | ||||||||||||||
| self._computation_started = False # Mark as synchronized | ||||||||||||||
| else: | ||||||||||||||
| # Compute synchronously if no stream or computation not started | ||||||||||||||
| self._result = self.hash_mapping.hash(self.input_ids) | ||||||||||||||
| return | ||||||||||||||
| current_stream = torch.cuda.current_stream() | ||||||||||||||
| if isinstance(self._result, dict): | ||||||||||||||
| for t in self._result.values(): | ||||||||||||||
| if isinstance(t, torch.Tensor): | ||||||||||||||
| t.record_stream(current_stream) | ||||||||||||||
| elif isinstance(self._result, torch.Tensor): | ||||||||||||||
| self._result.record_stream(current_stream) | ||||||||||||||
|
|
||||||||||||||
| def __getitem__(self, key): | ||||||||||||||
| # Case 1: Async compute -> wait | ||||||||||||||
| if self._is_async_pending: | ||||||||||||||
| torch.cuda.current_stream().wait_stream(self.hash_stream) | ||||||||||||||
| self._is_async_pending = False # Async finish | ||||||||||||||
| self._record_current_stream() | ||||||||||||||
|
|
||||||||||||||
| # Case 2: Sync but no compute -> start compute | ||||||||||||||
| elif self._result is None: | ||||||||||||||
| self._result = self.hash_mapping.hash(self.input_ids) | ||||||||||||||
|
|
||||||||||||||
| # Case 3: Async or sync compute is finished. | ||||||||||||||
| # print(f"[rank{torch.distributed.get_rank()}]: LazyHashInputIds result = {self._result}") | ||||||||||||||
| return self._result[key] | ||||||||||||||
|
|
||||||||||||||
| def get(self, key, default=None): | ||||||||||||||
|
|
@@ -171,7 +188,40 @@ def forward( | |||||||||||||
| inference_context=inference_context, | ||||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
| def sharded_state_dict( | ||||||||||||||
| self, prefix: str = "", sharded_offsets: tuple = (), metadata: dict | None = None | ||||||||||||||
| def build_schedule_plan( | ||||||||||||||
| self, | ||||||||||||||
| input_ids: Tensor, | ||||||||||||||
| position_ids: Tensor, | ||||||||||||||
| attention_mask: Tensor, | ||||||||||||||
| decoder_input: Tensor = None, | ||||||||||||||
| labels: Tensor = None, | ||||||||||||||
| inference_context: BaseInferenceContext = None, | ||||||||||||||
| packed_seq_params: PackedSeqParams = None, | ||||||||||||||
| extra_block_kwargs: dict = None, | ||||||||||||||
| runtime_gather_output: Optional[bool] = None, | ||||||||||||||
| inference_params: Optional[BaseInferenceContext] = None, | ||||||||||||||
| loss_mask: Optional[Tensor] = None, | ||||||||||||||
| ): | ||||||||||||||
| raise NotImplementedError("Sharded state dict is not supported for EngramModel") | ||||||||||||||
| """ | ||||||||||||||
| Adaptation of overlap_moe_expert_parallel_comm. | ||||||||||||||
| """ | ||||||||||||||
| # Precompute the engram_hash_iput_ids, it will be used to create a TransformerChunkSchedulePlan. | ||||||||||||||
| engram_hash_input_ids = LazyHashInputIds( | ||||||||||||||
| hash_mapping=self.engram_hash, | ||||||||||||||
| input_ids=input_ids, | ||||||||||||||
| hash_stream=self._hash_stream, | ||||||||||||||
| ) | ||||||||||||||
| if extra_block_kwargs is None: | ||||||||||||||
| extra_block_kwargs = { | ||||||||||||||
| "engram_hash_input_ids": engram_hash_input_ids, | ||||||||||||||
| } | ||||||||||||||
|
Comment on lines
+214
to
+217
|
||||||||||||||
| if extra_block_kwargs is None: | |
| extra_block_kwargs = { | |
| "engram_hash_input_ids": engram_hash_input_ids, | |
| } | |
| extra_block_kwargs = dict(extra_block_kwargs or {}) | |
| extra_block_kwargs.setdefault("engram_hash_input_ids", engram_hash_input_ids) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The inline comment lists
offloadas a valid value forengram_embedding_parallel_method, but the CLI/config choices added in this PR are onlyalltoallandallreduce(offloading is controlled by a separate boolean flag). Update the comment to avoid misleading users, and optionally documentengram_offload_embedding_optimizer_stateshere instead.