-
Notifications
You must be signed in to change notification settings - Fork 372
[elian] Update new train method, use TrainConfig mode to change train… #1083
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
2Elian
wants to merge
2
commits into
InternLM:main
Choose a base branch
from
2Elian:elian-xtuner-dev-trainConfig
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .train_config import TrainConfig |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import torch.distributed as dist | ||
from cyclopts import App | ||
|
||
from xtuner.v1.train.trainer import Trainer | ||
from train_config import TrainConfig | ||
from xtuner.v1.model import Qwen3Dense4BConfig | ||
from xtuner.v1.config import LRConfig, AdamWConfig | ||
|
||
app = App( | ||
name="entrypoint of sft & pretrain", | ||
help="Elian-XTuner's entry point for fine-tuning and training, launched using configuration files or arguments.", | ||
) | ||
|
||
|
||
@app.default() | ||
def main(): | ||
cfg = TrainConfig() | ||
print(cfg) | ||
model_cfg = Qwen3Dense4BConfig(max_position_embeddings = cfg.max_position_embeddings) | ||
optim_cfg = AdamWConfig(lr=cfg.lr) | ||
lr_cfg = LRConfig(lr_type=cfg.lr_type, lr_min=cfg.lr_min) | ||
trainer = Trainer( | ||
**cfg.to_trainer_kwargs(model_cfg, optim_cfg, lr_cfg) | ||
) | ||
trainer.fit() | ||
|
||
if dist.is_initialized(): | ||
dist.destroy_process_group() | ||
|
||
|
||
if __name__ == "__main__": | ||
app(exit_on_error=False) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
#!/bin/bash | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个适合放在 tutorial 或者 examples 里,不适合放在库的核心代码里。以及这里其实有比较多的,贴近个人风格的写法,他更适合作为文档中的最佳实践的推荐写法。 此外有一些环境变量和步骤应该也不是必须的,例如 conda activate,CUDA_HOME,LD_LIBRARY_PATH 这些,不同环境下也容易失效 |
||
# ========================================== | ||
# Training Script for Multi-GPU Training | ||
# ========================================== | ||
# step1: cd /data1/nuist_llm/TrainLLM/SFT-elian/xtuner | ||
# step2: YOUR_ENV_PATH='/home/202312150002/anaconda3/envs/llm/lib/python3.10/site-packages' | ||
# step3: cp -r ./xtuner $YOUR_ENV_PATH | ||
# step4: bash ./elian/train/qwen3/run.sh | ||
|
||
# conda | ||
export PATH="/home/202312150002/anaconda3/bin:$PATH" | ||
source /home/202312150002/anaconda3/etc/profile.d/conda.sh | ||
conda activate xtuner | ||
TRAIN_PATH=/data1/nuist_llm/TrainLLM/SFT-elian/xtuner/elian/train/qwen3 | ||
cd $TRAIN_PATH || exit 1 | ||
|
||
# cuda | ||
export PATH="/usr/local/cuda-12.4/bin:$PATH" | ||
export LD_LIBRARY_PATH="/usr/local/cuda-12.4/lib64:$LD_LIBRARY_PATH" | ||
export CUDA_HOME="/usr/local/cuda-12.4" | ||
echo "CUDA version used: $($CUDA_HOME/bin/nvcc --version | grep 'release' | awk '{print $6}')" | ||
|
||
# node | ||
NUM_NODES=1 | ||
GPU_LIST="0,1,2,3" | ||
NUM_GPUS_PER_NODE=$(echo $GPU_LIST | awk -F',' '{print NF}') | ||
export NPROC_PER_NODE=4 | ||
export NNODES=1 | ||
export NODE_RANK=0 | ||
export PORT=10171 | ||
export NODE_0_ADDR=172.16.107.15 | ||
export NCCL_DEBUG=INFO | ||
export NCCL_IB_DISABLE=1 | ||
export OMP_NUM_THREADS=1 | ||
export NCCL_SOCKET_IFNAME=lo,eth0 | ||
|
||
# train params | ||
TRAIN_SCRIPT="$TRAIN_PATH/main.py" | ||
export CUDA_VISIBLE_DEVICES=$GPU_LIST | ||
echo "Elian-Xtuner-V0.2.0 (used GPUs: ${CUDA_VISIBLE_DEVICES})" | ||
export XTUNER_DETERMINISTIC=true # torch.use_deterministic_algorithms | ||
|
||
torchrun --nproc_per_node=$NUM_GPUS_PER_NODE \ | ||
--nnodes=$NUM_NODES \ | ||
--node_rank=$NODE_RANK \ | ||
--master_addr=$NODE_0_ADDR \ | ||
--master_port=$PORT \ | ||
$TRAIN_SCRIPT |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from xtuner.v1.loss.ce_loss import CELossConfig, CELossContextInputItem, CELossContext | ||
import time | ||
|
||
|
||
hidden_states = torch.randn(32768, 4096, device="cuda", dtype=torch.bfloat16, requires_grad=True) | ||
lm_head = nn.Linear(4096, 151936, bias=False).to(device="cuda", dtype=torch.bfloat16) | ||
torch.cuda.reset_peak_memory_stats() | ||
t1 = time.time() | ||
logits = lm_head(hidden_states) | ||
shifted_labels = torch.randint(0, 151936, (32768, ), device="cuda") | ||
loss = F.cross_entropy(logits, shifted_labels) | ||
loss.backward() | ||
max_memory = torch.cuda.max_memory_allocated() | ||
reserved_memory = torch.cuda.max_memory_reserved() | ||
print(f"Eager mode Loss: {loss.item()}") | ||
print(f"Eager mode hidden_states grad norm: {hidden_states.grad.norm().item()}") | ||
print(f"Eager mode lm_head weight grad norm: {lm_head.weight.grad.norm().item()}") | ||
print(f"Eager mode Max memory allocated: {max_memory / 1024**3:.2f} GB") | ||
print(f"Eager mode Max memory reserved: {reserved_memory / 1024**3:.2f} GB") | ||
print(f"Eager mode Time taken: {time.time() - t1:.2f} seconds") | ||
|
||
del logits | ||
torch.cuda.empty_cache() | ||
torch.cuda.reset_peak_memory_stats() | ||
|
||
shifted_labels = shifted_labels.unsqueeze(0) | ||
hidden_states = hidden_states.unsqueeze(0) | ||
hidden_states = hidden_states.clone().detach().requires_grad_(True) | ||
lm_head.weight.grad = None | ||
t1 = time.time() | ||
loss_ctx_input_list = [CELossContextInputItem(shifted_labels=shifted_labels)] | ||
loss_cfg = CELossConfig(mode='chunk', chunk_size=1024, loss_reduction="token") | ||
batches_loss_kwargs = CELossContext.build_batches_loss_kwargs(loss_ctx_input_list, loss_cfg) | ||
loss_ctx = CELossContext(loss_cfg, batches_loss_kwargs[0]) | ||
loss, _ = loss_ctx.forward(hidden_states, lm_head.weight) | ||
loss.backward() | ||
max_memory = torch.cuda.max_memory_allocated() | ||
reserved_memory = torch.cuda.max_memory_reserved() | ||
print(f"Chunk mode Loss: {loss.item()}") | ||
print(f"Chunk mode hidden_states grad norm: {hidden_states.grad.norm().item()}") | ||
print(f"Chunk mode lm_head weight grad norm: {lm_head.weight.grad.norm().item()}") | ||
print(f"Chunk mode Max memory allocated: {max_memory / 1024**3:.2f} GB") | ||
print(f"Chunk mode Max memory reserved: {reserved_memory / 1024**3:.2f} GB") | ||
print(f"Chunk mode Time taken: {time.time() - t1:.2f} seconds") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
from dataclasses import dataclass, field, asdict | ||
from pathlib import Path | ||
from typing import Optional | ||
|
||
from xtuner.v1.train.trainer import ResumeConfig | ||
from xtuner.v1.config import FSDPConfig | ||
from xtuner.v1.loss.ce_loss import CELossConfig | ||
from xtuner.v1.datasets.config import DatasetCombine, DatasetConfig, DataloaderConfig | ||
from xtuner.v1.datasets.sft_tokenize_fn.openai import OpenaiTokenizeFunctionConfig | ||
|
||
@dataclass | ||
class TrainConfig: | ||
# base path | ||
model_path: str = "/data1/nuist_llm/TrainLLM/ModelCkpt/qwen3-4b/instruct-base" | ||
work_dir: str = "/data1/nuist_llm/TrainLLM/SFT-elian/xtuner/elian/save/model-01" | ||
log_dir: str = "/data1/nuist_llm/TrainLLM/SFT-elian/xtuner/elian/save/model-01" | ||
|
||
# data params | ||
dataset_cfg: list = field(default_factory=lambda: [ | ||
"/data1/nuist_llm/TrainLLM/datasets/SFT/math/category/code/Nemotron-Post-Training-V2-code-coldStart.jsonl", | ||
"/data1/nuist_llm/TrainLLM/datasets/SFT/math/category/code/Nemotron-Post-Training-V2-code.jsonl", | ||
"/data1/nuist_llm/TrainLLM/datasets/SFT/math/category/math/Nemotron-Post-Training-V2-math.jsonl", | ||
"/data1/nuist_llm/TrainLLM/datasets/SFT/math/category/other/Nemotron-Post-Training-V2-math-coldStart.jsonl" | ||
] + [f"/data1/nuist_llm/TrainLLM/datasets/SFT/math/category/other/chat-0000{i}-of-00012.jsonl" for i in range(10)] + [ | ||
"/data1/nuist_llm/TrainLLM/datasets/SFT/math/category/other/chat-00010-of-00012.jsonl","/data1/nuist_llm/TrainLLM/datasets/SFT/math/category/other/chat-00011-of-00012.jsonl" | ||
]) | ||
cache_dir: str = "/data1/nuist_llm/cacheTem/elianXtuner" | ||
class_name: str = "JsonlDataset" # TODO @elian: new parquest | ||
sample_ratio: float = 1.0 | ||
cache_tag: str = "elian-xtuner" | ||
message_template: str = "qwen3" | ||
max_token_size: int = 4096 | ||
max_position_embeddings: int = 4096 | ||
collator: str = "sft_llm_collator" # ["sft_llm_collator", "sft_vllm_collator", "fake_collator"] | ||
pack_level: str = "soft" # ["soft", "none"] # soft is True, none is False for Pack | ||
pack_max_length: int = 8024 # max_position_size | ||
pack_workers: int = 8 | ||
num_workers: int = 8 | ||
|
||
|
||
# train params | ||
global_batch_size: int = 1 | ||
total_epoch: int = 1 | ||
|
||
# fsdp params | ||
sp_size: int = 2 | ||
tp_size: int = 2 | ||
ep_size: int = 1 | ||
recompute_ratio:float = 1.0 | ||
cpu_offload: bool = False | ||
|
||
# loss params | ||
mode: str = "chunk" | ||
chunk_size: int = 1024 | ||
loss_reduction: str = "token" # ["token", "sample", "square"] | ||
|
||
# resume params | ||
resume_from: Optional[str] = None | ||
auto_resume: bool = False | ||
load_optimizer: bool = True | ||
load_dataset: bool = True | ||
load_scheduler: bool = True | ||
strict_load: bool = False | ||
|
||
# save checkpoint step | ||
hf_interval: Optional[int] = 2000 | ||
hf_max_keep: Optional[int] = 1 | ||
checkpoint_interval: Optional[int] = 1000 | ||
checkpoint_maxkeep: Optional[int] = 2 | ||
|
||
# profiling | ||
profile_step: Optional[int] = 1 | ||
profile_time: bool = True | ||
profile_memory: bool = True | ||
intra_layer_micro_batch: int = 1 | ||
|
||
# other | ||
seed: int = 42 | ||
debug: bool = False | ||
backend: str = "nccl" | ||
exp_tracker: str = "tensorboard" | ||
|
||
# optim | ||
lr: float = 6e-5 | ||
weight_decay: float = 0.001 | ||
betas: tuple = (0.9, 0.95) | ||
max_grad_norm: float = 1.0 | ||
lr_type: str = "cosine" # ["cosine", "linear", "constant"] | ||
warmup_ratio: float = 0.03 | ||
lr_min: float = 1e-6 | ||
|
||
def build_resume_cfg(self) -> Optional[ResumeConfig]: | ||
if self.resume_from or self.auto_resume: | ||
return ResumeConfig( | ||
resume_from=self.resume_from, | ||
auto_resume=self.auto_resume, | ||
load_optimizer=self.load_optimizer, | ||
load_dataset=self.load_dataset, | ||
load_scheduler=self.load_scheduler, | ||
) | ||
return None | ||
|
||
def build_fsdp_cfg(self) -> Optional[FSDPConfig]: | ||
if self.tp_size > 1 or self.ep_size > 1: | ||
return FSDPConfig( | ||
tp_size = self.tp_size, | ||
sp_size = self.sp_size, | ||
ep_size = self.ep_size, | ||
cpu_offload = self.cpu_offload | ||
) | ||
else: | ||
return None | ||
|
||
def build_loss_cfg(self) -> Optional[CELossConfig]: | ||
if self.mode!="eager" or self.loss_reduction!="token": | ||
return CELossConfig( | ||
mode = self.mode, | ||
chunk_size = self.chunk_size, | ||
loss_reduction = self.loss_reduction | ||
) | ||
else: | ||
return None | ||
|
||
def build_datasets_cfg(self) -> list[DatasetCombine]: | ||
all_datasets = [] | ||
for data_file in self.dataset_cfg: | ||
data_path = Path(data_file) | ||
name = data_path.stem | ||
tokenize_fn_cfg = OpenaiTokenizeFunctionConfig( | ||
chat_template=self.message_template, | ||
max_length=self.max_token_size | ||
) | ||
all_datasets.append( | ||
{ | ||
"dataset":DatasetConfig( | ||
anno_path=data_file, | ||
cache_dir=self.cache_dir, | ||
name=name, | ||
cache_tag=self.cache_tag, | ||
class_name=self.class_name, | ||
sample_ratio=self.sample_ratio | ||
), | ||
"tokenize_fn":tokenize_fn_cfg | ||
} | ||
) | ||
return all_datasets | ||
|
||
def build_dataloader(self) -> DataloaderConfig: | ||
return DataloaderConfig( | ||
collator = self.collator, | ||
pack_level = self.pack_level, | ||
pack_max_length = self.pack_max_length, | ||
pack_workers = self.pack_workers, | ||
num_workers = self.num_workers | ||
) | ||
|
||
def to_trainer_kwargs(self, model_cfg, optim_cfg, lr_cfg): | ||
return dict( | ||
model_cfg=model_cfg, | ||
tokenizer_path=self.model_path, | ||
load_from=self.model_path, | ||
optim_cfg=optim_cfg, | ||
lr_cfg=lr_cfg, | ||
global_batch_size = self.global_batch_size, | ||
work_dir = self.work_dir, | ||
log_dir = self.log_dir, | ||
sp_size = self.sp_size, | ||
total_epoch = self.total_epoch, | ||
checkpoint_interval = self.checkpoint_interval, | ||
checkpoint_maxkeep = self.checkpoint_maxkeep, | ||
hf_interval = self.hf_interval, | ||
hf_max_keep = self.hf_max_keep, | ||
exp_tracker = self.exp_tracker, | ||
profile_step = self.profile_step, | ||
profile_time = self.profile_time, | ||
profile_memory = self.profile_memory, | ||
intra_layer_micro_batch = self.intra_layer_micro_batch, | ||
seed = self.seed, | ||
debug = self.debug, | ||
backend = self.backend, | ||
resume_cfg=self.build_resume_cfg(), | ||
fsdp_cfg=self.build_fsdp_cfg(), | ||
loss_cfg=self.build_loss_cfg(), | ||
dataset_cfg=self.build_datasets_cfg(), | ||
dataloader_cfg=self.build_dataloader() | ||
) | ||
|
||
def __str__(self): | ||
cfg_dict = asdict(self) | ||
max_key_len = max(len(k) for k in cfg_dict.keys()) | ||
lines = [] | ||
for k, v in cfg_dict.items(): | ||
lines.append(f"{k:<{max_key_len}} : {v}") | ||
return "\n".join(lines) | ||
|
||
if __name__ == "__main__": | ||
cfg = TrainConfig() | ||
print(cfg) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from xtuner.v1.loss.ce_loss import CELossConfig, CELossContextInputItem, CELossContext | ||
import time | ||
|
||
|
||
hidden_states = torch.randn(32768, 4096, device="cuda", dtype=torch.bfloat16, requires_grad=True) | ||
lm_head = nn.Linear(4096, 151936, bias=False).to(device="cuda", dtype=torch.bfloat16) | ||
torch.cuda.reset_peak_memory_stats() | ||
t1 = time.time() | ||
logits = lm_head(hidden_states) | ||
shifted_labels = torch.randint(0, 151936, (32768, ), device="cuda") | ||
loss = F.cross_entropy(logits, shifted_labels) | ||
loss.backward() | ||
max_memory = torch.cuda.max_memory_allocated() | ||
reserved_memory = torch.cuda.max_memory_reserved() | ||
print(f"Eager mode Loss: {loss.item()}") | ||
print(f"Eager mode hidden_states grad norm: {hidden_states.grad.norm().item()}") | ||
print(f"Eager mode lm_head weight grad norm: {lm_head.weight.grad.norm().item()}") | ||
print(f"Eager mode Max memory allocated: {max_memory / 1024**3:.2f} GB") | ||
print(f"Eager mode Max memory reserved: {reserved_memory / 1024**3:.2f} GB") | ||
print(f"Eager mode Time taken: {time.time() - t1:.2f} seconds") | ||
|
||
del logits | ||
torch.cuda.empty_cache() | ||
torch.cuda.reset_peak_memory_stats() | ||
|
||
shifted_labels = shifted_labels.unsqueeze(0) | ||
hidden_states = hidden_states.unsqueeze(0) | ||
hidden_states = hidden_states.clone().detach().requires_grad_(True) | ||
lm_head.weight.grad = None | ||
t1 = time.time() | ||
loss_ctx_input_list = [CELossContextInputItem(shifted_labels=shifted_labels)] | ||
loss_cfg = CELossConfig(mode='chunk', chunk_size=1024, loss_reduction="token") | ||
batches_loss_kwargs = CELossContext.build_batches_loss_kwargs(loss_ctx_input_list, loss_cfg) | ||
loss_ctx = CELossContext(loss_cfg, batches_loss_kwargs[0]) | ||
loss, _ = loss_ctx.forward(hidden_states, lm_head.weight) | ||
loss.backward() | ||
max_memory = torch.cuda.max_memory_allocated() | ||
reserved_memory = torch.cuda.max_memory_reserved() | ||
print(f"Chunk mode Loss: {loss.item()}") | ||
print(f"Chunk mode hidden_states grad norm: {hidden_states.grad.norm().item()}") | ||
print(f"Chunk mode lm_head weight grad norm: {lm_head.weight.grad.norm().item()}") | ||
print(f"Chunk mode Max memory allocated: {max_memory / 1024**3:.2f} GB") | ||
print(f"Chunk mode Max memory reserved: {reserved_memory / 1024**3:.2f} GB") | ||
print(f"Chunk mode Time taken: {time.time() - t1:.2f} seconds") |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xtuner/v1/train/cli/sft.py
已经是一个训练入口了,这个 main 入口的差别是?