From 8e7c5123d95a26066e1656ac052e9722c40c17c2 Mon Sep 17 00:00:00 2001 From: wjp666666 <1969554248@qq.com> Date: Mon, 8 Sep 2025 22:03:08 +0800 Subject: [PATCH 1/4] support deepseek-v2-lite online train and support yarn rope --- configs/deepseek-v2-lite-eagle3.json | 2 +- .../run_deepseek_v2_lite_eagle3_online.sh | 21 ++++ specforge/modeling/draft/llama3_eagle.py | 115 ++++++++++++++++++ 3 files changed, 137 insertions(+), 1 deletion(-) create mode 100644 examples/run_deepseek_v2_lite_eagle3_online.sh diff --git a/configs/deepseek-v2-lite-eagle3.json b/configs/deepseek-v2-lite-eagle3.json index 9ddad46d..da12c0fb 100644 --- a/configs/deepseek-v2-lite-eagle3.json +++ b/configs/deepseek-v2-lite-eagle3.json @@ -25,7 +25,7 @@ "mscale": 0.707, "mscale_all_dim": 0.707, "original_max_position_embeddings": 4096, - "type": "yarn" + "rope_type": "yarn" }, "rope_theta": 10000, "sliding_window": null, diff --git a/examples/run_deepseek_v2_lite_eagle3_online.sh b/examples/run_deepseek_v2_lite_eagle3_online.sh new file mode 100644 index 00000000..4f5a0512 --- /dev/null +++ b/examples/run_deepseek_v2_lite_eagle3_online.sh @@ -0,0 +1,21 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +# train eagle3 for deepseek-v2-lite +NUM_GPUS=${1:-8} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3_online.py \ + --target-model-path DeepSeek-V2-Lite \ + --draft-model-config $ROOT_DIR/configs/deepseek-v2-lite-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/sharegpt.jsonl \ + --output-dir $ROOT_DIR/outputs/deepseek-v2-lite-eagle3 \ + --num-epochs 10 \ + --batch-size 1 \ + --tp-size 1 \ + --learning-rate 1e-4 \ + --max-length 2048 \ + --chat-template deepseek \ + --cache-dir $ROOT_DIR/cache \ No newline at end of file diff --git a/specforge/modeling/draft/llama3_eagle.py b/specforge/modeling/draft/llama3_eagle.py index 22e36a94..f31d9ad1 100644 --- a/specforge/modeling/draft/llama3_eagle.py +++ b/specforge/modeling/draft/llama3_eagle.py @@ -336,6 +336,110 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) +# Inverse dim formula to find dim based on number of rotations +def yarn_find_correction_dim( + num_rotations, dim, base=10000, max_position_embeddings=2048 +): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + +# Find dim range bounds based on rotations +def yarn_find_correction_range( + low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 +): + low = math.floor( + yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +def yarn_linear_ramp_mask(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +class LlamaYarnRotaryEmbedding(LlamaRotaryEmbedding): + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + original_max_position_embeddings=4096, + beta_fast=32, + beta_slow=1, + mscale=1, + mscale_all_dim=0, + ): + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = mscale + self.mscale_all_dim = mscale_all_dim + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + dim = self.dim + + freq_extra = 1.0 / ( + self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + freq_inter = 1.0 / ( + self.scaling_factor + * self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to( + device=device, dtype=torch.float32 + ) + inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(seq_len, device=device, dtype=torch.float32) + + freqs = torch.outer(t, inv_freq) + + _mscale = float( + yarn_get_mscale(self.scaling_factor, self.mscale) + / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + ) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + "cos_cached", (emb.cos() * _mscale)[None, None, :, :].to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", (emb.sin() * _mscale)[None, None, :, :].to(dtype), persistent=False + ) + class LlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -397,6 +501,17 @@ def _init_rope(self): self.rotary_emb = LlamaMutiRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings ) + elif scaling_type == "yarn": + self.rotary_emb = LlamaYarnRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + original_max_position_embeddings=self.config.rope_scaling["original_max_position_embeddings"], + scaling_factor=self.config.rope_scaling["factor"], + beta_fast=self.config.rope_scaling["beta_fast"], + beta_slow=self.config.rope_scaling["beta_slow"], + mscale=self.config.rope_scaling["mscale"], + mscale_all_dim=self.config.rope_scaling["mscale_all_dim"], + ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") From d79d56d275000de2431353b59e5c0b50ca399b2d Mon Sep 17 00:00:00 2001 From: jiapingW <56055330+jiapingW@users.noreply.github.com> Date: Mon, 8 Sep 2025 22:15:10 +0800 Subject: [PATCH 2/4] Update run_deepseek_v2_lite_eagle3_online.sh --- examples/run_deepseek_v2_lite_eagle3_online.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/run_deepseek_v2_lite_eagle3_online.sh b/examples/run_deepseek_v2_lite_eagle3_online.sh index 4f5a0512..449dd10f 100644 --- a/examples/run_deepseek_v2_lite_eagle3_online.sh +++ b/examples/run_deepseek_v2_lite_eagle3_online.sh @@ -18,4 +18,4 @@ torchrun \ --learning-rate 1e-4 \ --max-length 2048 \ --chat-template deepseek \ - --cache-dir $ROOT_DIR/cache \ No newline at end of file + --cache-dir $ROOT_DIR/cache \ From cd9d19dc3dc779a0978cff78baabbc612255408d Mon Sep 17 00:00:00 2001 From: jiapingW <56055330+jiapingW@users.noreply.github.com> Date: Mon, 8 Sep 2025 22:16:26 +0800 Subject: [PATCH 3/4] Update llama3_eagle.py --- specforge/modeling/draft/llama3_eagle.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/specforge/modeling/draft/llama3_eagle.py b/specforge/modeling/draft/llama3_eagle.py index f31d9ad1..52d72796 100644 --- a/specforge/modeling/draft/llama3_eagle.py +++ b/specforge/modeling/draft/llama3_eagle.py @@ -364,11 +364,10 @@ def yarn_get_mscale(scale=1, mscale=1): return 0.1 * mscale * math.log(scale) + 1.0 -def yarn_linear_ramp_mask(min, max, dim): - if min == max: - max += 0.001 # Prevent singularity - - linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) +def yarn_linear_ramp_mask(min_val, max_val, dim): + if min_val == max_val: + max_val += 0.001 # Prevent singularity + linear_func = (torch.arange(dim, dtype=torch.float32) - min_val) / (max_val - min_val) ramp_func = torch.clamp(linear_func, 0, 1) return ramp_func From 112a31453e79e0d699d5078cbb0b779d9abdbba9 Mon Sep 17 00:00:00 2001 From: wjp666666 <1969554248@qq.com> Date: Mon, 15 Sep 2025 20:55:29 +0800 Subject: [PATCH 4/4] fix fmt --- specforge/modeling/draft/llama3_eagle.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/specforge/modeling/draft/llama3_eagle.py b/specforge/modeling/draft/llama3_eagle.py index 52d72796..3b900a09 100644 --- a/specforge/modeling/draft/llama3_eagle.py +++ b/specforge/modeling/draft/llama3_eagle.py @@ -336,6 +336,7 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + # Inverse dim formula to find dim based on number of rotations def yarn_find_correction_dim( num_rotations, dim, base=10000, max_position_embeddings=2048 @@ -367,7 +368,9 @@ def yarn_get_mscale(scale=1, mscale=1): def yarn_linear_ramp_mask(min_val, max_val, dim): if min_val == max_val: max_val += 0.001 # Prevent singularity - linear_func = (torch.arange(dim, dtype=torch.float32) - min_val) / (max_val - min_val) + linear_func = (torch.arange(dim, dtype=torch.float32) - min_val) / ( + max_val - min_val + ) ramp_func = torch.clamp(linear_func, 0, 1) return ramp_func @@ -433,10 +436,14 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer( - "cos_cached", (emb.cos() * _mscale)[None, None, :, :].to(dtype), persistent=False + "cos_cached", + (emb.cos() * _mscale)[None, None, :, :].to(dtype), + persistent=False, ) self.register_buffer( - "sin_cached", (emb.sin() * _mscale)[None, None, :, :].to(dtype), persistent=False + "sin_cached", + (emb.sin() * _mscale)[None, None, :, :].to(dtype), + persistent=False, ) @@ -504,7 +511,9 @@ def _init_rope(self): self.rotary_emb = LlamaYarnRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, - original_max_position_embeddings=self.config.rope_scaling["original_max_position_embeddings"], + original_max_position_embeddings=self.config.rope_scaling[ + "original_max_position_embeddings" + ], scaling_factor=self.config.rope_scaling["factor"], beta_fast=self.config.rope_scaling["beta_fast"], beta_slow=self.config.rope_scaling["beta_slow"],