Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/deepseek-v2-lite-eagle3.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"mscale": 0.707,
"mscale_all_dim": 0.707,
"original_max_position_embeddings": 4096,
"type": "yarn"
"rope_type": "yarn"
},
"rope_theta": 10000,
"sliding_window": null,
Expand Down
21 changes: 21 additions & 0 deletions examples/run_deepseek_v2_lite_eagle3_online.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
ROOT_DIR=$(dirname $SCRIPT_DIR)

# train eagle3 for deepseek-v2-lite
NUM_GPUS=${1:-8}

torchrun \
--standalone \
--nproc_per_node $NUM_GPUS \
$ROOT_DIR/scripts/train_eagle3_online.py \
--target-model-path DeepSeek-V2-Lite \
--draft-model-config $ROOT_DIR/configs/deepseek-v2-lite-eagle3.json \
--train-data-path $ROOT_DIR/cache/dataset/sharegpt.jsonl \
--output-dir $ROOT_DIR/outputs/deepseek-v2-lite-eagle3 \
--num-epochs 10 \
--batch-size 1 \
--tp-size 1 \
--learning-rate 1e-4 \
--max-length 2048 \
--chat-template deepseek \
--cache-dir $ROOT_DIR/cache \
114 changes: 114 additions & 0 deletions specforge/modeling/draft/llama3_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,109 @@ def forward(self, x, position_ids):

return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

# Inverse dim formula to find dim based on number of rotations
def yarn_find_correction_dim(
num_rotations, dim, base=10000, max_position_embeddings=2048
):
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
2 * math.log(base)
)


# Find dim range bounds based on rotations
def yarn_find_correction_range(
low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
):
low = math.floor(
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
)
high = math.ceil(
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
)
return max(low, 0), min(high, dim - 1) # Clamp values just in case


def yarn_get_mscale(scale=1, mscale=1):
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0


def yarn_linear_ramp_mask(min_val, max_val, dim):
if min_val == max_val:
max_val += 0.001 # Prevent singularity
linear_func = (torch.arange(dim, dtype=torch.float32) - min_val) / (max_val - min_val)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func


class LlamaYarnRotaryEmbedding(LlamaRotaryEmbedding):

def __init__(
self,
dim,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
original_max_position_embeddings=4096,
beta_fast=32,
beta_slow=1,
mscale=1,
mscale_all_dim=0,
):
self.scaling_factor = scaling_factor
self.original_max_position_embeddings = original_max_position_embeddings
self.beta_fast = beta_fast
self.beta_slow = beta_slow
self.mscale = mscale
self.mscale_all_dim = mscale_all_dim
super().__init__(dim, max_position_embeddings, base, device)

def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
dim = self.dim

freq_extra = 1.0 / (
self.base
** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
)
freq_inter = 1.0 / (
self.scaling_factor
* self.base
** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
)

low, high = yarn_find_correction_range(
self.beta_fast,
self.beta_slow,
dim,
self.base,
self.original_max_position_embeddings,
)
inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
device=device, dtype=torch.float32
)
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
self.register_buffer("inv_freq", inv_freq, persistent=False)

t = torch.arange(seq_len, device=device, dtype=torch.float32)

freqs = torch.outer(t, inv_freq)

_mscale = float(
yarn_get_mscale(self.scaling_factor, self.mscale)
/ yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
)

emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer(
"cos_cached", (emb.cos() * _mscale)[None, None, :, :].to(dtype), persistent=False
)
self.register_buffer(
"sin_cached", (emb.sin() * _mscale)[None, None, :, :].to(dtype), persistent=False
)


class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
Expand Down Expand Up @@ -397,6 +500,17 @@ def _init_rope(self):
self.rotary_emb = LlamaMutiRotaryEmbedding(
self.head_dim, max_position_embeddings=self.max_position_embeddings
)
elif scaling_type == "yarn":
self.rotary_emb = LlamaYarnRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
original_max_position_embeddings=self.config.rope_scaling["original_max_position_embeddings"],
scaling_factor=self.config.rope_scaling["factor"],
beta_fast=self.config.rope_scaling["beta_fast"],
beta_slow=self.config.rope_scaling["beta_slow"],
mscale=self.config.rope_scaling["mscale"],
mscale_all_dim=self.config.rope_scaling["mscale_all_dim"],
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

Expand Down
Loading