-
Couldn't load subscription status.
- Fork 100
Add Draft LoRA scripts #138
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
27b5213
6d59c18
ab206b8
d3ca10a
f9d5201
383cfb8
e968da1
9f34883
8c86059
6b9201f
d77ad6d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| { | ||
| "alpha_pattern": {}, | ||
| "auto_mapping": null, | ||
| "base_model_name_or_path": "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B", | ||
| "bias": "none", | ||
| "fan_in_fan_out": false, | ||
| "inference_mode": false, | ||
| "init_lora_weights": true, | ||
| "layer_replication": null, | ||
| "layers_pattern": null, | ||
| "layers_to_transform": null, | ||
| "loftq_config": {}, | ||
| "lora_alpha": 128, | ||
| "lora_dropout": 0.1, | ||
| "megatron_config": null, | ||
| "megatron_core": "megatron.core", | ||
| "modules_to_save": null, | ||
| "peft_type": "LORA", | ||
| "qalora_group_size": 16, | ||
| "r": 64, | ||
| "rank_pattern": {}, | ||
| "revision": null, | ||
| "target_modules": [ | ||
| "gate_proj", | ||
| "o_proj", | ||
| "q_proj", | ||
| "v_proj", | ||
| "k_proj" | ||
| ], | ||
| "task_type": "CAUSAL_LM", | ||
| "trainable_token_indices": null, | ||
| "use_dora": false, | ||
| "use_qalora": false, | ||
| "use_rslora": false | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| import os | ||
|
|
||
| from huggingface_hub import snapshot_download | ||
|
|
||
|
|
||
| def download_model(model_id, local_dir): | ||
| print(f"downloading model: {model_id}") | ||
| print(f"will save to: {local_dir}") | ||
|
|
||
| try: | ||
| snapshot_download( | ||
| repo_id=model_id, | ||
| local_dir=local_dir, | ||
| local_dir_use_symlinks=False, | ||
| ) | ||
| print("download success!") | ||
| except Exception as e: | ||
| print(f"error: {e}") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| model_identifier = "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B" | ||
| save_directory = f"./{model_identifier.replace('/', '_')}" | ||
|
|
||
| if not os.path.exists(save_directory): | ||
| os.makedirs(save_directory) | ||
|
|
||
| download_model(model_id=model_identifier, local_dir=save_directory) | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,33 @@ | ||||||
| #!/bin/bash | ||||||
|
|
||||||
| SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) | ||||||
| ROOT_DIR=$(dirname $SCRIPT_DIR) | ||||||
|
|
||||||
| NUM_GPUS=${1:-8} | ||||||
| TARGET_LORA_PATH=${2:-/sgl-workspace/llama-duo_llama3.1-8b-summarize-gpt4o-128k} | ||||||
| DRAFT_LORA_CONFIG=${3:-$ROOT_DIR/configs/draft_lora_trainable_config.json} | ||||||
| BASE_DRAFT_MODEL_PATH=${4:-/sgl-workspace/jamesliu1_sglang-EAGLE3-Llama-3.1-Instruct-8B} | ||||||
|
|
||||||
| torchrun \ | ||||||
| --standalone \ | ||||||
| --nproc_per_node $NUM_GPUS \ | ||||||
| $ROOT_DIR/scripts/train_eagle3_lora_online.py \ | ||||||
| --target-model-path meta-llama/Llama-3.1-8B-Instruct \ | ||||||
| --draft-model-config /sgl-workspace/jamesliu1_sglang-EAGLE3-Llama-3.1-Instruct-8B/config.json \ | ||||||
| --base-draft-model-path $BASE_DRAFT_MODEL_PATH \ | ||||||
| --train-data-path $ROOT_DIR/cache/dataset/synth_summarize.jsonl \ | ||||||
| --output-dir $ROOT_DIR/outputs/llama3-8b-eagle3-lora-fixed \ | ||||||
| --use-lora \ | ||||||
| --lora-config $DRAFT_LORA_CONFIG \ | ||||||
| --target-lora-path $TARGET_LORA_PATH \ | ||||||
| --num-epochs 1 \ | ||||||
| --batch-size 1 \ | ||||||
| --learning-rate 1e-4 \ | ||||||
| --max-length 2048 \ | ||||||
| --chat-template llama3 \ | ||||||
| --cache-dir $ROOT_DIR/cache \ | ||||||
| --skip-vocab-mapping \ | ||||||
| --wandb \ | ||||||
| --wandb-key "f3b46a484034ca1fe99fc5ae4d19402c94da12c1" \ | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A hardcoded
Suggested change
|
||||||
| --wandb-project "specforge-training" \ | ||||||
| --wandb-name "llama3-8b-lora-online-fixed-run-1" | ||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,32 @@ | ||||||
| #!/bin/bash | ||||||
|
|
||||||
| SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) | ||||||
| ROOT_DIR=$(dirname $SCRIPT_DIR) | ||||||
|
|
||||||
| NUM_GPUS=${1:-8} | ||||||
| TARGET_LORA_PATH=${2:-/sgl-workspace/llama-duo_llama3.1-8b-summarize-gpt4o-128k} | ||||||
| DRAFT_LORA_CONFIG=${3:-$ROOT_DIR/configs/draft_lora_trainable_config.json} | ||||||
| BASE_DRAFT_MODEL_PATH=${4:-/sgl-workspace/jamesliu1_sglang-EAGLE3-Llama-3.1-Instruct-8B} | ||||||
|
|
||||||
| torchrun \ | ||||||
| --standalone \ | ||||||
| --nproc_per_node $NUM_GPUS \ | ||||||
| $ROOT_DIR/scripts/train_eagle3_lora_online.py \ | ||||||
| --target-model-path meta-llama/Llama-3.1-8B-Instruct \ | ||||||
| --draft-model-config /sgl-workspace/jamesliu1_sglang-EAGLE3-Llama-3.1-Instruct-8B/config.json \ | ||||||
| --base-draft-model-path $BASE_DRAFT_MODEL_PATH \ | ||||||
| --train-data-path $ROOT_DIR/cache/dataset/synth_summarize.jsonl \ | ||||||
| --output-dir $ROOT_DIR/outputs/llama3-8b-eagle3-lora-fixed \ | ||||||
| --use-lora \ | ||||||
| --target-lora-path $TARGET_LORA_PATH \ | ||||||
| --num-epochs 1 \ | ||||||
| --batch-size 1 \ | ||||||
| --learning-rate 1e-4 \ | ||||||
| --max-length 2048 \ | ||||||
| --chat-template llama3 \ | ||||||
| --cache-dir $ROOT_DIR/cache \ | ||||||
| --skip-vocab-mapping \ | ||||||
| --wandb \ | ||||||
| --wandb-key "f3b46a484034ca1fe99fc5ae4d19402c94da12c1" \ | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A hardcoded
Suggested change
|
||||||
| --wandb-project "specforge-training" \ | ||||||
| --wandb-name "llama3-8b-lora-online-fixed-run-1" | ||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -25,3 +25,8 @@ torchrun \ | |||||
| # --mlflow-tracking-uri http://mlflow.grid1.ard.grid.linkedin.com:31812 \ | ||||||
| # --eval-data-split 0.01 \ | ||||||
| --attention-backend flex_attention | ||||||
| --cache-dir $ROOT_DIR/cache \ | ||||||
| --wandb \ | ||||||
| --wandb-key "f3b46a484034ca1fe99fc5ae4d19402c94da12c1" \ | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A hardcoded
Suggested change
|
||||||
| --wandb-project "specforge-training" \ | ||||||
| --wandb-name "llama3-8b-online-fixed-run-1" | ||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,109 @@ | ||
| #!/usr/bin/env bash | ||
| set -euo pipefail | ||
|
|
||
| usage() { | ||
| cat <<'EOF' | ||
| Usage: | ||
| HF_TOKEN=xxx ./push_to_hf.sh --repo-id <namespace/name> --path <abs_path> [--branch main] [--private true|false] [--repo-type model|dataset|space] [--commit "msg"] [--force true|false] | ||
|
|
||
| Examples: | ||
| HF_TOKEN=hf_xxx ./push_to_hf.sh \ | ||
| --repo-id yourname/llama3-8b-eagle3-lora-fixed \ | ||
| --path /sgl-workspace/SpecForge/outputs/llama3-8b-eagle3-lora-fixed/epoch_0/draft_lora \ | ||
| --private true | ||
| EOF | ||
| } | ||
|
|
||
| # defaults | ||
| BRANCH="main" | ||
| PRIVATE="false" | ||
| REPO_TYPE="model" | ||
| COMMIT_MSG="" | ||
| FORCE="false" | ||
|
|
||
| # parse args | ||
| REPO_ID="" | ||
| SRC_PATH="" | ||
| while [[ $# -gt 0 ]]; do | ||
| case "$1" in | ||
| --repo-id) REPO_ID="${2:-}"; shift 2;; | ||
| --path) SRC_PATH="${2:-}"; shift 2;; | ||
| --branch) BRANCH="${2:-}"; shift 2;; | ||
| --private) PRIVATE="${2:-}"; shift 2;; | ||
| --repo-type) REPO_TYPE="${2:-}"; shift 2;; | ||
| --commit) COMMIT_MSG="${2:-}"; shift 2;; | ||
| --force) FORCE="${2:-}"; shift 2;; | ||
| -h|--help) usage; exit 0;; | ||
| *) echo "Unknown arg: $1"; usage; exit 1;; | ||
| esac | ||
| done | ||
|
|
||
| # validate | ||
| : "${HF_TOKEN:?Set HF_TOKEN in env}" | ||
| : "${REPO_ID:?--repo-id is required}" | ||
| : "${SRC_PATH:?--path is required}" | ||
|
|
||
| if [[ ! -d "$SRC_PATH" ]]; then | ||
| echo "Path not found: $SRC_PATH" >&2 | ||
| exit 1 | ||
| fi | ||
|
|
||
| if ! command -v git >/dev/null 2>&1; then | ||
| echo "git not found. Please install git." >&2 | ||
| exit 1 | ||
| fi | ||
| if ! command -v git-lfs >/dev/null 2>&1; then | ||
| echo "git-lfs not found. Please install git-lfs." >&2 | ||
| exit 1 | ||
| fi | ||
| if ! command -v hf >/dev/null 2>&1 && ! command -v huggingface-cli >/dev/null 2>&1; then | ||
| echo "huggingface CLI not found; attempting to install huggingface_hub..." | ||
| python3 -m pip install --user -U huggingface_hub >/dev/null | ||
| fi | ||
|
|
||
| # create repo if missing (ignore error if exists) | ||
| if command -v hf >/dev/null 2>&1; then | ||
| CREATE_FLAGS=(--repo-type "$REPO_TYPE" --token "$HF_TOKEN") | ||
| [[ "$PRIVATE" == "true" ]] && CREATE_FLAGS+=(--private) | ||
| hf repo create "$REPO_ID" "${CREATE_FLAGS[@]}" -y 2>/dev/null || true | ||
| else | ||
| CREATE_FLAGS=(--type "$REPO_TYPE" -y --token "$HF_TOKEN") | ||
| [[ "$PRIVATE" == "true" ]] && CREATE_FLAGS+=(--private) | ||
| huggingface-cli repo create "$REPO_ID" "${CREATE_FLAGS[@]}" 2>/dev/null || true | ||
| fi | ||
|
|
||
| # commit msg | ||
| if [[ -z "$COMMIT_MSG" ]]; then | ||
| COMMIT_MSG="Upload from script on $(date -u +'%Y-%m-%dT%H:%M:%SZ')" | ||
| fi | ||
|
|
||
| # work in a temp dir to avoid touching source | ||
| WORKDIR="$(mktemp -d)" | ||
| trap 'rm -rf "$WORKDIR"' EXIT | ||
|
|
||
| # copy content excluding .git | ||
| tar -C "$SRC_PATH" --exclude='.git' -cf - . | tar -C "$WORKDIR" -xf - | ||
|
|
||
| cd "$WORKDIR" | ||
| git init -q | ||
| # Some git-lfs versions do not support -q; keep output quiet via redirection | ||
| git lfs install --skip-repo >/dev/null 2>&1 || true | ||
|
|
||
| # sensible LFS defaults for model assets | ||
| git lfs track "*.safetensors" "*.bin" "*.pt" "*.ckpt" "*.h5" "*.gguf" "*.onnx" "*.tflite" "*.tar" "*.zip" 2>/dev/null || true | ||
| # track large tokenizer assets to satisfy HF pre-receive hooks (>10 MiB) | ||
| git lfs track "tokenizer.json" "tokenizer.model" "spiece.model" "sentencepiece.bpe.model" "*.spm" 2>/dev/null || true | ||
| echo ".gitattributes" >> .gitignore || true | ||
|
|
||
| git add -A | ||
| git commit -m "$COMMIT_MSG" -q | ||
|
|
||
| REMOTE="https://oauth2:${HF_TOKEN}@huggingface.co/${REPO_ID}" | ||
| git branch -M "$BRANCH" | ||
| git remote add origin "$REMOTE" | ||
|
|
||
| PUSH_FLAGS=() | ||
| [[ "$FORCE" == "true" ]] && PUSH_FLAGS+=("--force") | ||
| git push "${PUSH_FLAGS[@]}" -u origin "$BRANCH" | ||
|
|
||
| echo "Pushed $SRC_PATH to https://huggingface.co/${REPO_ID} (branch: $BRANCH)" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,3 +11,4 @@ psutil | |
| numpy | ||
| accelerate | ||
| pydantic | ||
| peft | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,7 +33,7 @@ def parse_args(): | |
| parser.add_argument( | ||
| "--dataset", | ||
| type=str, | ||
| choices=["ultrachat", "sharegpt", "opc"], | ||
| choices=["ultrachat", "sharegpt", "opc", "synth_summarize", "opc"], | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| help="The demo dataset to quickly run the training for speculative decoding", | ||
| ) | ||
| parser.add_argument( | ||
|
|
@@ -110,6 +110,43 @@ def load_dataset_from_path(data_path: Path): | |
| import hashlib | ||
|
|
||
|
|
||
| def process_opc_sft_stage1(row) -> Dict: | ||
| row_id = hashlib.md5((row["instruction"] + row["output"]).encode()).hexdigest() | ||
| return { | ||
| "id": row_id, | ||
| "conversations": [ | ||
| {"role": "user", "content": row["instruction"]}, | ||
| {"role": "assistant", "content": row["output"]}, | ||
| ], | ||
| } | ||
|
|
||
|
|
||
| def process_synth_summarize_row(row) -> Dict: | ||
| """Process a row from the synth_summarize dataset. | ||
|
|
||
| The function expects a row with the following schema: | ||
| "messages": [ | ||
| { | ||
| "role": "user" | "assistant", | ||
| "content": str | ||
| } | ||
| ], | ||
| "prompt_id": str | ||
| """ | ||
| conversations = row["messages"] | ||
| formatted_conversations = [] | ||
| for message in conversations: | ||
| role = message["role"] | ||
| content = message["content"] | ||
| assert role in ["user", "assistant"] | ||
| formatted_conversations.append({"role": role, "content": content}) | ||
| row = {"id": row["prompt_id"], "conversations": formatted_conversations} | ||
| return row, 0 | ||
|
|
||
|
|
||
| import hashlib | ||
|
|
||
|
|
||
| def process_opc_sft_stage1(row) -> Dict: | ||
| row_id = hashlib.md5((row["instruction"] + row["output"]).encode()).hexdigest() | ||
| return { | ||
|
|
@@ -139,9 +176,23 @@ def main(): | |
| "OpenCoder-LLM/opc-sft-stage1", "largescale_diverse_instruct" | ||
| )["train"] | ||
| proc_fn = process_opc_sft_stage1 | ||
| elif args.dataset == "synth_summarize": | ||
| if args.data_path is None: | ||
| ds = load_dataset("llama-duo/synth_summarize_dataset_dedup")[ | ||
| "train_sft_claude3sonnet" | ||
| ] | ||
| else: | ||
| print("Loading dataset from custom data path: ", args.data_path) | ||
| ds = load_dataset_from_path(Path(args.data_path)) | ||
| proc_fn = process_synth_summarize_row | ||
| elif args.dataset == "opc": | ||
| ds = load_dataset( | ||
| "OpenCoder-LLM/opc-sft-stage1", "largescale_diverse_instruct" | ||
| )["train"] | ||
| proc_fn = process_opc_sft_stage1 | ||
|
Comment on lines
+188
to
+192
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| else: | ||
| raise ValueError( | ||
| "This script only supports ultrachat_200k and sharegpt datasets for demo purpose, if you wish to use other datasets, please modify this script." | ||
| "This script only supports ultrachat_200k, sharegpt, opc, and synth_summarize datasets for demo purpose, if you wish to use other datasets, please modify this script." | ||
| ) | ||
|
|
||
| if args.output_path is None: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error handling in
download_modelcould be improved. When an exception occurs, the script prints to standard output and exits with status 0, which can make it difficult for calling scripts to detect failures. It's better practice to print errors tostderrand exit with a non-zero status code. Additionally, the success message contains a non-ASCII exclamation mark, which should be replaced for consistency and to avoid potential encoding issues. This suggestion also adds the requiredsysimport to the top of the file.