|
1 |
| -import gc |
2 |
| -import logging |
3 |
| -import time |
4 |
| -from functools import partial |
5 |
| -from typing import Dict, List, Optional, Union |
| 1 | +#!/usr/bin/env python |
| 2 | +# -*- encoding: utf-8 -*- |
6 | 3 |
|
7 |
| -import torch |
| 4 | +from internlm.checkpoint.checkpoint_manager import CheckpointManager |
8 | 5 | import torch.distributed as dist
|
9 | 6 | from torch.utils.data import DataLoader
|
10 |
| - |
11 |
| -from internlm.checkpoint.checkpoint_manager import CheckpointManager |
12 |
| -from internlm.core.context import ParallelMode |
13 |
| -from internlm.core.context import global_context as gpc |
| 7 | +from functools import partial |
| 8 | +from typing import Dict, List |
| 9 | +from internlm.core.context import ParallelMode, global_context as gpc |
14 | 10 | from internlm.core.parallel.comm import initialize_offload_manager
|
15 |
| -from internlm.core.trainer import ( |
16 |
| - Trainer, |
17 |
| - get_scheduler_hooks, |
18 |
| - load_new_batch, |
19 |
| - record_current_batch_training_metrics, |
| 11 | +from internlm.train.utils import get_scheduler_hooks, load_new_batch, record_current_batch_training_metrics |
| 12 | +from internlm.data import ( |
| 13 | + build_train_loader_with_data_type, |
| 14 | + build_valid_loader_with_data_type, |
20 | 15 | )
|
21 | 16 | from internlm.data.streaming.utils import streaming_simple_resume
|
22 |
| -from internlm.data.train_state import get_train_state |
23 | 17 | from internlm.eval import evaluate_on_val_dls
|
24 |
| -from internlm.initialize import initialize_trainer |
25 |
| -from internlm.initialize.initialize_model import ( |
26 |
| - initialize_model_and_parallel_communicator, |
27 |
| -) |
| 18 | +from internlm.initialize import initialize_launcher, initialize_trainer |
| 19 | +from internlm.initialize.initialize_model import initialize_model_and_parallel_communicator |
28 | 20 | from internlm.initialize.initialize_optimizer import initialize_optimizer
|
29 | 21 | from internlm.initialize.initialize_profiler import initialize_llm_profile
|
| 22 | +from internlm.launch.trainer_builder import logger |
| 23 | +from internlm.model.model_implementations.builder import create_model |
| 24 | +from internlm.model.model_implementations.registry import register_model_initializer |
30 | 25 | from internlm.model.model_ops.losses.ce_loss import InternLoss
|
31 | 26 | from internlm.model.model_ops.metrics import AccPerplex
|
32 |
| -from internlm.monitor import send_alert_message |
33 |
| -from internlm.utils.common import ( |
34 |
| - BatchSkipper, |
35 |
| - check_cuda_env, |
36 |
| - enable_pytorch_expandable_segments, |
37 |
| - get_current_device, |
38 |
| - get_megatron_flops, |
39 |
| - launch_time, |
40 |
| -) |
| 27 | +from internlm.monitor import internevo_monitor, send_alert_message |
| 28 | +from internlm.train.train_state import TrainState |
| 29 | +from internlm.train.trainer import Trainer |
| 30 | +from internlm.utils.common import BatchSkipper, check_cuda_env, enable_pytorch_expandable_segments, get_current_device, get_megatron_flops, launch_time, parse_args |
41 | 31 | from internlm.utils.gputest import empty_cache_and_diag
|
42 | 32 | from internlm.utils.logger import get_logger
|
43 | 33 | from internlm.utils.megatron_timers import megatron_timer as timer
|
|
46 | 36 | from internlm.utils.utils import DataType
|
47 | 37 | from internlm.utils.writer import Writer
|
48 | 38 |
|
49 |
| -# global llm logger |
50 |
| -logger = logging.getLogger(__file__) |
51 |
| - |
52 | 39 |
|
53 | 40 | class TrainerBuilder(Trainer):
|
54 | 41 | """
|
@@ -117,7 +104,7 @@ def __init__(
|
117 | 104 | initialize_offload_manager(gpc.config.get("selective_checkpoint_offload", False))
|
118 | 105 |
|
119 | 106 | # initialize train state
|
120 |
| - train_state = get_train_state(train_dl) |
| 107 | + train_state = TrainState(gpc.config, train_dl.batch_sampler) |
121 | 108 |
|
122 | 109 | # initialize optimizer
|
123 | 110 | optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model, isp_communicator)
|
@@ -385,3 +372,34 @@ def _update_profilers(self, batch_count: int, prof):
|
385 | 372 | self.memory_profiler.step()
|
386 | 373 | if batch_count % 2 == 0:
|
387 | 374 | prof.step()
|
| 375 | + |
| 376 | + |
| 377 | +@internevo_monitor(feishu_alert=True, clean_run=True) |
| 378 | +def main(args): |
| 379 | + # initialize model |
| 380 | + register_model_initializer() |
| 381 | + model = create_model() |
| 382 | + |
| 383 | + # initialize train dataloader |
| 384 | + train_dl, dataset_types = build_train_loader_with_data_type() |
| 385 | + |
| 386 | + # initialize validation dataloader |
| 387 | + val_dls = build_valid_loader_with_data_type() |
| 388 | + |
| 389 | + # build trainer |
| 390 | + merged_args = {**vars(args), "dataset_types": dataset_types} |
| 391 | + trainer = TrainerBuilder(model, train_dl, val_dls, **merged_args) |
| 392 | + |
| 393 | + # training |
| 394 | + trainer.fit() |
| 395 | + |
| 396 | + |
| 397 | +if __name__ == "__main__": |
| 398 | + args = parse_args() |
| 399 | + |
| 400 | + # Initialize distributed environment |
| 401 | + initialize_launcher(config=args.config, launcher=args.launcher, distributed_port=args.port, seed=args.seed) |
| 402 | + assert hasattr(gpc, "config") and gpc.config is not None |
| 403 | + |
| 404 | + # Run the main function with parsed arguments |
| 405 | + main(args) |
0 commit comments