Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions configs/draft_lora_trainable_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
{
"alpha_pattern": {},
"auto_mapping": null,
"base_model_name_or_path": "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B",
"bias": "none",
"fan_in_fan_out": false,
"inference_mode": false,
"init_lora_weights": true,
"layer_replication": null,
"layers_pattern": null,
"layers_to_transform": null,
"loftq_config": {},
"lora_alpha": 128,
"lora_dropout": 0.1,
"megatron_config": null,
"megatron_core": "megatron.core",
"modules_to_save": null,
"peft_type": "LORA",
"qalora_group_size": 16,
"r": 64,
"rank_pattern": {},
"revision": null,
"target_modules": [
"gate_proj",
"o_proj",
"q_proj",
"v_proj",
"k_proj"
],
"task_type": "CAUSAL_LM",
"trainable_token_indices": null,
"use_dora": false,
"use_qalora": false,
"use_rslora": false
}
28 changes: 28 additions & 0 deletions download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import os

from huggingface_hub import snapshot_download


def download_model(model_id, local_dir):
print(f"downloading model: {model_id}")
print(f"will save to: {local_dir}")

try:
snapshot_download(
repo_id=model_id,
local_dir=local_dir,
local_dir_use_symlinks=False,
)
print("download success!")
except Exception as e:
print(f"error: {e}")


if __name__ == "__main__":
model_identifier = "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B"
save_directory = f"./{model_identifier.replace('/', '_')}"

if not os.path.exists(save_directory):
os.makedirs(save_directory)

download_model(model_id=model_identifier, local_dir=save_directory)
Comment on lines +1 to +28
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The error handling in download_model could be improved. When an exception occurs, the script prints to standard output and exits with status 0, which can make it difficult for calling scripts to detect failures. It's better practice to print errors to stderr and exit with a non-zero status code. Additionally, the success message contains a non-ASCII exclamation mark, which should be replaced for consistency and to avoid potential encoding issues. This suggestion also adds the required sys import to the top of the file.

import os
import sys

from huggingface_hub import snapshot_download


def def download_model(model_id, local_dir):
    print(f"downloading model: {model_id}")
    print(f"will save to: {local_dir}")

    try:
        snapshot_download(
            repo_id=model_id,
            local_dir=local_dir,
            local_dir_use_symlinks=False,
        )
        print("Download successful!")
    except Exception as e:
        print(f"Error: {e}", file=sys.stderr)
        sys.exit(1)


if __name__ == "__main__":
    model_identifier = "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B"
    save_directory = f"./{model_identifier.replace('/', '_')}"

    if not os.path.exists(save_directory):
        os.makedirs(save_directory)

    download_model(model_id=model_identifier, local_dir=save_directory)

33 changes: 33 additions & 0 deletions examples/run_llama3_eagle3_lora_online.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#!/bin/bash

SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
ROOT_DIR=$(dirname $SCRIPT_DIR)

NUM_GPUS=${1:-8}
TARGET_LORA_PATH=${2:-/sgl-workspace/llama-duo_llama3.1-8b-summarize-gpt4o-128k}
DRAFT_LORA_CONFIG=${3:-$ROOT_DIR/configs/draft_lora_trainable_config.json}
BASE_DRAFT_MODEL_PATH=${4:-/sgl-workspace/jamesliu1_sglang-EAGLE3-Llama-3.1-Instruct-8B}

torchrun \
--standalone \
--nproc_per_node $NUM_GPUS \
$ROOT_DIR/scripts/train_eagle3_lora_online.py \
--target-model-path meta-llama/Llama-3.1-8B-Instruct \
--draft-model-config /sgl-workspace/jamesliu1_sglang-EAGLE3-Llama-3.1-Instruct-8B/config.json \
--base-draft-model-path $BASE_DRAFT_MODEL_PATH \
--train-data-path $ROOT_DIR/cache/dataset/synth_summarize.jsonl \
--output-dir $ROOT_DIR/outputs/llama3-8b-eagle3-lora-fixed \
--use-lora \
--lora-config $DRAFT_LORA_CONFIG \
--target-lora-path $TARGET_LORA_PATH \
--num-epochs 1 \
--batch-size 1 \
--learning-rate 1e-4 \
--max-length 2048 \
--chat-template llama3 \
--cache-dir $ROOT_DIR/cache \
--skip-vocab-mapping \
--wandb \
--wandb-key "f3b46a484034ca1fe99fc5ae4d19402c94da12c1" \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

A hardcoded wandb-key has been found in the script. Committing secrets like API keys to version control is a significant security risk. This key should be removed and loaded from a secure source, such as an environment variable. You should also add a check at the beginning of the script to ensure the environment variable is set, for example: : "${WANDB_API_KEY:?WANDB_API_KEY is not set}"

Suggested change
--wandb-key "f3b46a484034ca1fe99fc5ae4d19402c94da12c1" \
--wandb-key "${WANDB_API_KEY}" \

--wandb-project "specforge-training" \
--wandb-name "llama3-8b-lora-online-fixed-run-1"
32 changes: 32 additions & 0 deletions examples/run_llama3_eagle3_lora_online_fixed.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/bin/bash

SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
ROOT_DIR=$(dirname $SCRIPT_DIR)

NUM_GPUS=${1:-8}
TARGET_LORA_PATH=${2:-/sgl-workspace/llama-duo_llama3.1-8b-summarize-gpt4o-128k}
DRAFT_LORA_CONFIG=${3:-$ROOT_DIR/configs/draft_lora_trainable_config.json}
BASE_DRAFT_MODEL_PATH=${4:-/sgl-workspace/jamesliu1_sglang-EAGLE3-Llama-3.1-Instruct-8B}

torchrun \
--standalone \
--nproc_per_node $NUM_GPUS \
$ROOT_DIR/scripts/train_eagle3_lora_online.py \
--target-model-path meta-llama/Llama-3.1-8B-Instruct \
--draft-model-config /sgl-workspace/jamesliu1_sglang-EAGLE3-Llama-3.1-Instruct-8B/config.json \
--base-draft-model-path $BASE_DRAFT_MODEL_PATH \
--train-data-path $ROOT_DIR/cache/dataset/synth_summarize.jsonl \
--output-dir $ROOT_DIR/outputs/llama3-8b-eagle3-lora-fixed \
--use-lora \
--target-lora-path $TARGET_LORA_PATH \
--num-epochs 1 \
--batch-size 1 \
--learning-rate 1e-4 \
--max-length 2048 \
--chat-template llama3 \
--cache-dir $ROOT_DIR/cache \
--skip-vocab-mapping \
--wandb \
--wandb-key "f3b46a484034ca1fe99fc5ae4d19402c94da12c1" \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

A hardcoded wandb-key has been found in the script. Committing secrets like API keys to version control is a significant security risk. This key should be removed and loaded from a secure source, such as an environment variable. You should also add a check at the beginning of the script to ensure the environment variable is set, for example: : "${WANDB_API_KEY:?WANDB_API_KEY is not set}"

Suggested change
--wandb-key "f3b46a484034ca1fe99fc5ae4d19402c94da12c1" \
--wandb-key "${WANDB_API_KEY}" \

--wandb-project "specforge-training" \
--wandb-name "llama3-8b-lora-online-fixed-run-1"
5 changes: 5 additions & 0 deletions examples/run_llama3_eagle3_online.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,8 @@ torchrun \
# --mlflow-tracking-uri http://mlflow.grid1.ard.grid.linkedin.com:31812 \
# --eval-data-split 0.01 \
--attention-backend flex_attention
--cache-dir $ROOT_DIR/cache \
--wandb \
--wandb-key "f3b46a484034ca1fe99fc5ae4d19402c94da12c1" \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

A hardcoded wandb-key has been found in the script. Committing secrets like API keys to version control is a significant security risk. This key should be removed and loaded from a secure source, such as an environment variable. You should also add a check at the beginning of the script to ensure the environment variable is set, for example: : "${WANDB_API_KEY:?WANDB_API_KEY is not set}"

Suggested change
--wandb-key "f3b46a484034ca1fe99fc5ae4d19402c94da12c1" \
--wandb-key "${WANDB_API_KEY}" \

--wandb-project "specforge-training" \
--wandb-name "llama3-8b-online-fixed-run-1"
109 changes: 109 additions & 0 deletions push_to_hf.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#!/usr/bin/env bash
set -euo pipefail

usage() {
cat <<'EOF'
Usage:
HF_TOKEN=xxx ./push_to_hf.sh --repo-id <namespace/name> --path <abs_path> [--branch main] [--private true|false] [--repo-type model|dataset|space] [--commit "msg"] [--force true|false]

Examples:
HF_TOKEN=hf_xxx ./push_to_hf.sh \
--repo-id yourname/llama3-8b-eagle3-lora-fixed \
--path /sgl-workspace/SpecForge/outputs/llama3-8b-eagle3-lora-fixed/epoch_0/draft_lora \
--private true
EOF
}

# defaults
BRANCH="main"
PRIVATE="false"
REPO_TYPE="model"
COMMIT_MSG=""
FORCE="false"

# parse args
REPO_ID=""
SRC_PATH=""
while [[ $# -gt 0 ]]; do
case "$1" in
--repo-id) REPO_ID="${2:-}"; shift 2;;
--path) SRC_PATH="${2:-}"; shift 2;;
--branch) BRANCH="${2:-}"; shift 2;;
--private) PRIVATE="${2:-}"; shift 2;;
--repo-type) REPO_TYPE="${2:-}"; shift 2;;
--commit) COMMIT_MSG="${2:-}"; shift 2;;
--force) FORCE="${2:-}"; shift 2;;
-h|--help) usage; exit 0;;
*) echo "Unknown arg: $1"; usage; exit 1;;
esac
done

# validate
: "${HF_TOKEN:?Set HF_TOKEN in env}"
: "${REPO_ID:?--repo-id is required}"
: "${SRC_PATH:?--path is required}"

if [[ ! -d "$SRC_PATH" ]]; then
echo "Path not found: $SRC_PATH" >&2
exit 1
fi

if ! command -v git >/dev/null 2>&1; then
echo "git not found. Please install git." >&2
exit 1
fi
if ! command -v git-lfs >/dev/null 2>&1; then
echo "git-lfs not found. Please install git-lfs." >&2
exit 1
fi
if ! command -v hf >/dev/null 2>&1 && ! command -v huggingface-cli >/dev/null 2>&1; then
echo "huggingface CLI not found; attempting to install huggingface_hub..."
python3 -m pip install --user -U huggingface_hub >/dev/null
fi

# create repo if missing (ignore error if exists)
if command -v hf >/dev/null 2>&1; then
CREATE_FLAGS=(--repo-type "$REPO_TYPE" --token "$HF_TOKEN")
[[ "$PRIVATE" == "true" ]] && CREATE_FLAGS+=(--private)
hf repo create "$REPO_ID" "${CREATE_FLAGS[@]}" -y 2>/dev/null || true
else
CREATE_FLAGS=(--type "$REPO_TYPE" -y --token "$HF_TOKEN")
[[ "$PRIVATE" == "true" ]] && CREATE_FLAGS+=(--private)
huggingface-cli repo create "$REPO_ID" "${CREATE_FLAGS[@]}" 2>/dev/null || true
fi

# commit msg
if [[ -z "$COMMIT_MSG" ]]; then
COMMIT_MSG="Upload from script on $(date -u +'%Y-%m-%dT%H:%M:%SZ')"
fi

# work in a temp dir to avoid touching source
WORKDIR="$(mktemp -d)"
trap 'rm -rf "$WORKDIR"' EXIT

# copy content excluding .git
tar -C "$SRC_PATH" --exclude='.git' -cf - . | tar -C "$WORKDIR" -xf -

cd "$WORKDIR"
git init -q
# Some git-lfs versions do not support -q; keep output quiet via redirection
git lfs install --skip-repo >/dev/null 2>&1 || true

# sensible LFS defaults for model assets
git lfs track "*.safetensors" "*.bin" "*.pt" "*.ckpt" "*.h5" "*.gguf" "*.onnx" "*.tflite" "*.tar" "*.zip" 2>/dev/null || true
# track large tokenizer assets to satisfy HF pre-receive hooks (>10 MiB)
git lfs track "tokenizer.json" "tokenizer.model" "spiece.model" "sentencepiece.bpe.model" "*.spm" 2>/dev/null || true
echo ".gitattributes" >> .gitignore || true

git add -A
git commit -m "$COMMIT_MSG" -q

REMOTE="https://oauth2:${HF_TOKEN}@huggingface.co/${REPO_ID}"
git branch -M "$BRANCH"
git remote add origin "$REMOTE"

PUSH_FLAGS=()
[[ "$FORCE" == "true" ]] && PUSH_FLAGS+=("--force")
git push "${PUSH_FLAGS[@]}" -u origin "$BRANCH"

echo "Pushed $SRC_PATH to https://huggingface.co/${REPO_ID} (branch: $BRANCH)"
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ psutil
numpy
accelerate
pydantic
peft
2 changes: 2 additions & 0 deletions scripts/build_eagle3_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import hashlib
import os
from pathlib import Path
from typing import Optional

import torch
import torch.nn as nn
from datasets import load_dataset
from transformers import AutoTokenizer

Expand Down
55 changes: 53 additions & 2 deletions scripts/prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def parse_args():
parser.add_argument(
"--dataset",
type=str,
choices=["ultrachat", "sharegpt", "opc"],
choices=["ultrachat", "sharegpt", "opc", "synth_summarize", "opc"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The choice "opc" is duplicated in the choices list for the --dataset argument. The duplicate should be removed for clarity and correctness.

Suggested change
choices=["ultrachat", "sharegpt", "opc", "synth_summarize", "opc"],
choices=["ultrachat", "sharegpt", "opc", "synth_summarize"],

help="The demo dataset to quickly run the training for speculative decoding",
)
parser.add_argument(
Expand Down Expand Up @@ -110,6 +110,43 @@ def load_dataset_from_path(data_path: Path):
import hashlib


def process_opc_sft_stage1(row) -> Dict:
row_id = hashlib.md5((row["instruction"] + row["output"]).encode()).hexdigest()
return {
"id": row_id,
"conversations": [
{"role": "user", "content": row["instruction"]},
{"role": "assistant", "content": row["output"]},
],
}


def process_synth_summarize_row(row) -> Dict:
"""Process a row from the synth_summarize dataset.

The function expects a row with the following schema:
"messages": [
{
"role": "user" | "assistant",
"content": str
}
],
"prompt_id": str
"""
conversations = row["messages"]
formatted_conversations = []
for message in conversations:
role = message["role"]
content = message["content"]
assert role in ["user", "assistant"]
formatted_conversations.append({"role": role, "content": content})
row = {"id": row["prompt_id"], "conversations": formatted_conversations}
return row, 0


import hashlib


def process_opc_sft_stage1(row) -> Dict:
row_id = hashlib.md5((row["instruction"] + row["output"]).encode()).hexdigest()
return {
Expand Down Expand Up @@ -139,9 +176,23 @@ def main():
"OpenCoder-LLM/opc-sft-stage1", "largescale_diverse_instruct"
)["train"]
proc_fn = process_opc_sft_stage1
elif args.dataset == "synth_summarize":
if args.data_path is None:
ds = load_dataset("llama-duo/synth_summarize_dataset_dedup")[
"train_sft_claude3sonnet"
]
else:
print("Loading dataset from custom data path: ", args.data_path)
ds = load_dataset_from_path(Path(args.data_path))
proc_fn = process_synth_summarize_row
elif args.dataset == "opc":
ds = load_dataset(
"OpenCoder-LLM/opc-sft-stage1", "largescale_diverse_instruct"
)["train"]
proc_fn = process_opc_sft_stage1
Comment on lines +188 to +192
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This elif block for args.dataset == "opc" is a duplicate of the one at lines 174-178. It should be removed to avoid redundant code and potential inconsistencies.

else:
raise ValueError(
"This script only supports ultrachat_200k and sharegpt datasets for demo purpose, if you wish to use other datasets, please modify this script."
"This script only supports ultrachat_200k, sharegpt, opc, and synth_summarize datasets for demo purpose, if you wish to use other datasets, please modify this script."
)

if args.output_path is None:
Expand Down
Loading
Loading