Skip to content

Commit 8f00ea5

Browse files
committed
squash: new trainer with HF and SGL backend
Signed-off-by: h-guo18 <[email protected]>
1 parent e3e399a commit 8f00ea5

File tree

6 files changed

+852
-4
lines changed

6 files changed

+852
-4
lines changed

examples/speculative_decoding/ar_validate.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,14 @@
2626
mto.enable_huggingface_checkpointing()
2727

2828

29-
def validate_ar(model, tokenizer, ds, steps=3, osl=20, num_samples=80, device=None):
29+
def validate_ar(
30+
model, tokenizer, ds, steps=3, osl=20, num_samples=80, device=None, disable_pbar=False
31+
):
3032
validator = HFARValidation(model, tokenizer)
3133
num_samples = min(num_samples, len(ds))
3234
ars = []
33-
for i in tqdm(range(num_samples), desc="Validating AR"):
35+
print("validating AR...")
36+
for i in tqdm(range(num_samples), disable=disable_pbar):
3437
prompt = ds[i]["prompt"][0]
3538
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
3639
# Apply chat template to the prompt, continuing with assistant response

examples/speculative_decoding/eagle_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ def compute_loss(self, *args, **kwargs):
498498
kwargs.pop("num_items_in_batch", None)
499499
loss, outputs = super().compute_loss(return_outputs=True, *args, **kwargs)
500500
if hasattr(outputs, "train_acc"):
501-
self.state.training_accs.append(outputs.train_acc)
501+
self.state.training_accs.append([acc.item() for acc in outputs.train_acc])
502502
return loss
503503

504504

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import argparse
17+
import os
18+
19+
import torch
20+
import torch.distributed as dist
21+
import torch.multiprocessing as mp
22+
from eagle_utils import DataCollatorWithPadding, make_eagle_supervised_data_module
23+
from trainer.distill_trainer import EagleSGLTrainer, EagleTPTrainer
24+
from transformers import AutoTokenizer
25+
26+
torch.manual_seed(0)
27+
28+
29+
def _setup_distributed(rank, args, backend="nccl"):
30+
"""Initialize distributed environment"""
31+
os.environ["MASTER_ADDR"] = "localhost"
32+
os.environ["MASTER_PORT"] = args.master_port
33+
os.environ["LOCAL_RANK"] = str(rank)
34+
# Initialize process group
35+
dist.init_process_group(backend, rank=rank, world_size=args.world_size)
36+
if rank in args.teacher_ranks:
37+
torch.cuda.set_device(args.teacher_devices[rank])
38+
else:
39+
torch.cuda.set_device(args.student_devices[rank - len(args.teacher_ranks)])
40+
print(
41+
f"Starting process rank={rank}, device={torch.cuda.current_device()}, world_size={args.world_size}"
42+
)
43+
args.teacher_pgroup = dist.new_group(ranks=args.teacher_ranks)
44+
args.student_pgroup = dist.new_group(ranks=args.student_ranks)
45+
46+
47+
def train(rank, args):
48+
_setup_distributed(rank, args)
49+
50+
tokenizer = AutoTokenizer.from_pretrained(
51+
args.model_path, model_max_length=args.training_seq_len
52+
)
53+
args.use_offline_training = False
54+
args.vlm_processor = None
55+
args.offline_data_path = None
56+
data_module = make_eagle_supervised_data_module(tokenizer, args)
57+
58+
train_dataloader = torch.utils.data.DataLoader(
59+
data_module["train_dataset"],
60+
batch_size=args.batch_size,
61+
shuffle=True,
62+
num_workers=0,
63+
collate_fn=DataCollatorWithPadding(max_length=args.training_seq_len),
64+
drop_last=True,
65+
)
66+
trainer_cls = {
67+
"sglang": EagleSGLTrainer,
68+
"hf": EagleTPTrainer,
69+
}[args.teacher_backend]
70+
trainer = trainer_cls(rank, args, tokenizer, train_dataloader)
71+
trainer.train()
72+
trainer.save(args.out_path)
73+
74+
75+
def main():
76+
parser = argparse.ArgumentParser(description="Multi-GPU distributed two-stage forward example")
77+
78+
# Training args
79+
parser.add_argument("--model_path", type=str, default="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
80+
parser.add_argument("--data_path", type=str, required=True, help="Training dataset.")
81+
parser.add_argument("--training_seq_len", type=str, default=1024)
82+
parser.add_argument("--eagle_config_path", type=str, default="eagle_config.json")
83+
parser.add_argument("--out_path", type=str, default="ckpts/fast-trained")
84+
parser.add_argument("--lr", type=float, default=1e-5)
85+
parser.add_argument("--epoch", type=int, default=1)
86+
parser.add_argument("--batch_size", type=int, default=8, help="Total bs across all ranks.")
87+
88+
# Trainer args
89+
parser.add_argument("--teacher_backend", type=str, choices=["sglang", "hf"], default="sglang")
90+
parser.add_argument(
91+
"--teacher_ep_size",
92+
type=int,
93+
default=1,
94+
help="Teacher EP size, only used for sglang backend.",
95+
)
96+
parser.add_argument("--teacher_devices", type=list, default=[0, 1, 2, 3])
97+
parser.add_argument("--student_devices", type=list, default=[4, 5, 6, 7])
98+
parser.add_argument(
99+
"--lazy_preprocess", type=bool, default=True, help="Whether to use lazy preprocessing."
100+
)
101+
parser.add_argument("--log_interval", type=int, default=50)
102+
parser.add_argument("--save_interval", type=int, default=20000)
103+
parser.add_argument(
104+
"--total_steps", type=int, default=60000, help="Total number of steps for debugging."
105+
)
106+
parser.add_argument("--master_port", type=str, default="12357")
107+
108+
args = parser.parse_args()
109+
# TODO: add sanity check for args
110+
111+
def set_ranks(args):
112+
args.world_size = len(args.teacher_devices) + len(args.student_devices)
113+
args.teacher_ranks = list(range(len(args.teacher_devices)))
114+
args.student_ranks = list(
115+
range(len(args.teacher_devices), len(args.teacher_devices) + len(args.student_devices))
116+
)
117+
118+
set_ranks(args)
119+
# Launch multiple processes
120+
mp.spawn(
121+
train,
122+
args=(args,),
123+
nprocs=args.world_size,
124+
join=True,
125+
)
126+
127+
128+
if __name__ == "__main__":
129+
main()

0 commit comments

Comments
 (0)