Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions examples/speculative_decoding/eagle_config.json
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
{
"rope_scaling": {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to delete this configuration?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"factor": 32.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3"
},
"initializer_range": 0.02,
"_attn_implementation": "sdpa"
}
17 changes: 12 additions & 5 deletions examples/speculative_decoding/eagle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,18 +518,25 @@ def on_log(self, args, state, control, **kwargs):
average_acc = np.mean(state.training_accs, axis=0)
if self.estimate_ar:
# Calculate mean training AR since last log
# NOTE: This is only a estimate of the real AR.
# NOTE: This is only an estimate of the real AR.
est_ar = 1
acc_cumprod = 1
for step_acc in average_acc:
est_ar += acc_cumprod * step_acc
for step_acc in average_acc[0]:
acc_cumprod *= step_acc
est_ar += acc_cumprod
# Parallel draft tokens only used after all eagle tokens
for draft_acc in average_acc[1:]:
acc_cumprod *= draft_acc[-1]
est_ar += acc_cumprod
print_rank_0(f"Step {state.global_step} Estimated Training AR: {est_ar:.4f}")

# log to wandb
if wandb and is_master():
for i, step_acc in enumerate(average_acc):
wandb.log({f"step_{i}_train_acc": step_acc}, step=state.global_step)
for i, draft_acc in enumerate(average_acc):
for j, step_acc in enumerate(draft_acc):
wandb.log(
{f"parallel_{i}_step_{j}_train_acc": step_acc}, step=state.global_step
)
if self.estimate_ar:
wandb.log({"estimated_training_ar": est_ar}, step=state.global_step)

Expand Down
8 changes: 6 additions & 2 deletions examples/speculative_decoding/launch_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ while [ $# -gt 0 ]; do
if [[ "$1" != *=* ]]; then shift; fi
VLM_IMG_DIR="${1#*=}"
;;
--estimate_ar*)
if [[ "$1" != *=* ]]; then shift; fi
ESTIMATE_AR="${1#*=}"
;;
--ar_validate_steps*)
if [[ "$1" != *=* ]]; then shift; fi
AR_VALIDATE_STEPS="${1#*=}"
Expand Down Expand Up @@ -120,8 +124,6 @@ LR=${LR:-"1e-4"}
TRAIN_BS=${TRAIN_BS:-4}
MEDUSA_NUM_HEADS=${MEDUSA_NUM_HEADS:-1}
MEDUSA_NUM_LAYERS=${MEDUSA_NUM_LAYERS:-1}
REDRAFTER_TOKENS=${REDRAFTER_TOKENS:-1}
REDRAFTER_NUM_LAYERS=${REDRAFTER_NUM_LAYERS:-1}
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:-"LlamaDecoderLayer"}
NUM_GPU=${NUM_GPU:-1}
TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-2048}
Expand All @@ -130,6 +132,7 @@ DISABLE_TQDM=${DISABLE_TQDM:-False}
VLM_PROCESSOR=${VLM_PROCESSOR:-}
VLM_IMG_DIR=${VLM_IMG_DIR:-}
AR_VALIDATE_STEPS=${AR_VALIDATE_STEPS:-1000}
ESTIMATE_AR=${ESTIMATE_AR:-False}

if [[ "$MODE" == "medusa" ]]; then
SPECULATIVE_ARGS="--medusa_num_heads $MEDUSA_NUM_HEADS --medusa_num_layers $MEDUSA_NUM_LAYERS"
Expand Down Expand Up @@ -192,6 +195,7 @@ CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
--tf32 True \
--data_path $DATA \
--disable_tqdm $DISABLE_TQDM \
--estimate_ar $ESTIMATE_AR \
--ar_validate_steps $AR_VALIDATE_STEPS \
$VLM_ARGS \
$OFFLINE_TRAINING_ARGS \
Expand Down
21 changes: 4 additions & 17 deletions examples/speculative_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ class TrainingArguments(transformers.TrainingArguments):
dataloader_drop_last: bool = field(default=True)
bf16: bool = field(default=True)
mode: Literal["eagle1", "eagle3", "medusa"] = "eagle3"
estimate_ar: bool = field(
default=False, metadata={"help": "Whether to estimate AR during training for logging."}
)
ar_validate_steps: int = field(default=1000, metadata={"help": "Steps between AR validation."})
disable_tqdm: bool = field(default=False, metadata={"help": "Disable tqdm progress bar."})
remove_unused_columns: bool = field(
Expand Down Expand Up @@ -193,22 +196,6 @@ def train():
custom_config = json.load(f)
config["eagle_architecture_config"].update(custom_config)

# Hidden size and vocab size must match base model
llm_config = (
model.config.llm_config if hasattr(model.config, "llm_config") else model.config
)
config["eagle_architecture_config"].update(
{
"hidden_size": llm_config.hidden_size,
"vocab_size": llm_config.vocab_size,
# we also overwrite max_pos_embedding for deployment compatibility
"max_position_embeddings": llm_config.max_position_embeddings,
"draft_vocab_size": custom_config["draft_vocab_size"]
if eagle_args.eagle_config and "draft_vocab_size" in custom_config
else llm_config.vocab_size,
}
)

mtsp.convert(model, [("eagle", config)])

# read draft vocab cache
Expand Down Expand Up @@ -238,7 +225,7 @@ def train():
model=model,
processing_class=tokenizer,
args=training_args,
callbacks=[EagleTrainingPlot(training_args.ar_validate_steps)],
callbacks=[EagleTrainingPlot(training_args.ar_validate_steps, training_args.estimate_ar)],
**data_module,
)

Expand Down
32 changes: 8 additions & 24 deletions modelopt/torch/export/unified_export_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,26 +212,14 @@ def __init__(
)

eagle_config = {
"use_input_layernorm_in_first_layer": mode_cfg["config"][
"eagle_architecture_config"
]["use_input_layernorm_in_first_layer"],
"use_last_layernorm": mode_cfg["config"]["eagle_architecture_config"][
"use_last_layernorm"
],
"use_mtp_layernorm": mode_cfg["config"]["eagle_architecture_config"][
"use_mtp_layernorm"
],
"use_aux_hidden_state": mode_cfg["config"]["eagle_architecture_config"][
"use_aux_hidden_state"
],
"use_input_layernorm_in_first_layer": model.eagle_config.use_input_layernorm_in_first_layer,
"use_last_layernorm": model.eagle_config.use_last_layernorm,
"use_mtp_layernorm": model.eagle_config.use_mtp_layernorm,
"use_aux_hidden_state": model.eagle_config.use_aux_hidden_state,
"eagle_aux_hidden_state_layer_ids": model.eagle_config.eagle_aux_hidden_state_layer_ids,
"next_layer_regular": True,
"parallel_draft_step": mode_cfg["config"]["eagle_architecture_config"][
"parallel_draft_step"
],
"parallel_draft_heads_num_layers": mode_cfg["config"][
"eagle_architecture_config"
]["parallel_draft_heads_num_layers"],
"parallel_draft_step": model.eagle_config.parallel_draft_step,
"parallel_draft_heads_num_layers": model.eagle_config.parallel_draft_heads_num_layers,
}

eagle_config_update = {
Expand All @@ -243,9 +231,7 @@ def __init__(
"max_position_embeddings": self._hf_text_config.max_position_embeddings,
"num_attention_heads": model.eagle_module.config.num_attention_heads,
"num_key_value_heads": model.eagle_module.config.num_query_groups,
"num_hidden_layers": mode_cfg["config"]["eagle_architecture_config"][
"num_hidden_layers"
],
"num_hidden_layers": model.eagle_config.num_hidden_layers,
"vocab_size": self._hf_text_config.vocab_size,
# Unset any special token ids given that the tokenizer can change here.
"bos_token_id": None,
Expand All @@ -254,9 +240,7 @@ def __init__(
"sep_token_id": None,
# The following attributes are EAGLE specific
"eagle_config": eagle_config,
"draft_vocab_size": mode_cfg["config"]["eagle_architecture_config"][
"draft_vocab_size"
],
"draft_vocab_size": model.eagle_config.draft_vocab_size,
}

self._hf_extra_config.update(eagle_config_update)
Expand Down
6 changes: 1 addition & 5 deletions modelopt/torch/speculative/eagle/default_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
default_eagle_config = {
"hidden_act": "silu",
"torch_dtype": "bfloat16",
"vocab_size": 128256,
"draft_vocab_size": 128256,
"max_position_embeddings": 8192,
"position_embedding_type": "rope",
"rope_scaling": {
"factor": 8.0,
Expand All @@ -31,7 +28,6 @@
},
"rope_theta": 500000.0,
"num_hidden_layers": 1,
"hidden_size": 4096,
"intermediate_size": 14336,
"num_attention_heads": 32,
"num_key_value_heads": 8,
Expand All @@ -47,6 +43,6 @@
"use_mtp_layernorm": False,
"parallel_draft_step": 1,
"parallel_draft_heads_num_layers": 1,
"has_lm_head": False,
"has_lm_head": True,
"head_dim": 128,
}
Loading