Skip to content

Commit 6ea8d57

Browse files
committed
debug:sgl backend; use torchrun
Signed-off-by: h-guo18 <[email protected]>
1 parent 6f8fa51 commit 6ea8d57

File tree

1 file changed

+29
-32
lines changed

1 file changed

+29
-32
lines changed

examples/speculative_decoding/train.py

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -14,39 +14,41 @@
1414
# limitations under the License.
1515

1616
import argparse
17-
import os
1817

1918
import torch
2019
import torch.distributed as dist
21-
import torch.multiprocessing as mp
2220
from eagle_utils import DataCollatorWithPadding, make_eagle_supervised_data_module
2321
from trainer.distill_trainer import EagleSGLTrainer, EagleTPTrainer
2422
from transformers import AutoTokenizer
2523

2624
torch.manual_seed(0)
2725

2826

29-
def _setup_distributed(rank, args, backend="nccl"):
30-
"""Initialize distributed environment"""
31-
os.environ["MASTER_ADDR"] = "localhost"
32-
os.environ["MASTER_PORT"] = args.master_port
33-
os.environ["LOCAL_RANK"] = str(rank)
34-
# Initialize process group
35-
dist.init_process_group(backend, rank=rank, world_size=args.world_size)
27+
def _check_args(args):
28+
"""Sanity check for arguments."""
29+
# TODO: (hg)
30+
31+
32+
def _setup_pgroups(args):
33+
"""Initialize student/teacher pgroups and set devices."""
34+
rank = dist.get_rank()
35+
args.teacher_ranks = list(range(len(args.teacher_devices)))
36+
args.student_ranks = list(
37+
range(len(args.teacher_devices), len(args.teacher_devices) + len(args.student_devices))
38+
)
3639
if rank in args.teacher_ranks:
3740
torch.cuda.set_device(args.teacher_devices[rank])
3841
else:
3942
torch.cuda.set_device(args.student_devices[rank - len(args.teacher_ranks)])
4043
print(
41-
f"Starting process rank={rank}, device={torch.cuda.current_device()}, world_size={args.world_size}"
44+
f"Starting process rank={rank}, device={torch.cuda.current_device()}, world_size={dist.get_world_size()}"
4245
)
4346
args.teacher_pgroup = dist.new_group(ranks=args.teacher_ranks)
4447
args.student_pgroup = dist.new_group(ranks=args.student_ranks)
4548

4649

47-
def train(rank, args):
48-
_setup_distributed(rank, args)
49-
50+
def train(args):
51+
"""Entrance for training."""
5052
tokenizer = AutoTokenizer.from_pretrained(
5153
args.model_path, model_max_length=args.training_seq_len
5254
)
@@ -55,19 +57,24 @@ def train(rank, args):
5557
args.offline_data_path = None
5658
data_module = make_eagle_supervised_data_module(tokenizer, args)
5759

60+
# Ensure different ranks load the same data
61+
g = torch.Generator()
62+
g.manual_seed(0)
63+
5864
train_dataloader = torch.utils.data.DataLoader(
5965
data_module["train_dataset"],
6066
batch_size=args.batch_size,
6167
shuffle=True,
6268
num_workers=0,
6369
collate_fn=DataCollatorWithPadding(max_length=args.training_seq_len),
6470
drop_last=True,
71+
generator=g,
6572
)
6673
trainer_cls = {
6774
"sglang": EagleSGLTrainer,
6875
"hf": EagleTPTrainer,
6976
}[args.teacher_backend]
70-
trainer = trainer_cls(rank, args, tokenizer, train_dataloader)
77+
trainer = trainer_cls(dist.get_rank(), args, tokenizer, train_dataloader)
7178
trainer.train()
7279
trainer.save(args.out_path)
7380

@@ -76,7 +83,7 @@ def main():
7683
parser = argparse.ArgumentParser(description="Multi-GPU distributed two-stage forward example")
7784

7885
# Training args
79-
parser.add_argument("--model_path", type=str, default="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
86+
parser.add_argument("--model_path", type=str, required=True, help="Target model path.")
8087
parser.add_argument("--data_path", type=str, required=True, help="Training dataset.")
8188
parser.add_argument("--training_seq_len", type=str, default=1024)
8289
parser.add_argument("--eagle_config_path", type=str, default="eagle_config.json")
@@ -103,26 +110,16 @@ def main():
103110
parser.add_argument(
104111
"--total_steps", type=int, default=60000, help="Total number of steps for debugging."
105112
)
113+
parser.add_argument("--master_addr", type=str, default="localhost")
106114
parser.add_argument("--master_port", type=str, default="12357")
107115

108116
args = parser.parse_args()
109-
# TODO: add sanity check for args
110-
111-
def set_ranks(args):
112-
args.world_size = len(args.teacher_devices) + len(args.student_devices)
113-
args.teacher_ranks = list(range(len(args.teacher_devices)))
114-
args.student_ranks = list(
115-
range(len(args.teacher_devices), len(args.teacher_devices) + len(args.student_devices))
116-
)
117-
118-
set_ranks(args)
119-
# Launch multiple processes
120-
mp.spawn(
121-
train,
122-
args=(args,),
123-
nprocs=args.world_size,
124-
join=True,
125-
)
117+
118+
dist.init_process_group("nccl")
119+
120+
_check_args(args)
121+
_setup_pgroups(args)
122+
train(args)
126123

127124

128125
if __name__ == "__main__":

0 commit comments

Comments
 (0)