Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions elian/train/qwen3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .train_config import TrainConfig
32 changes: 32 additions & 0 deletions elian/train/qwen3/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch.distributed as dist
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xtuner/v1/train/cli/sft.py

已经是一个训练入口了,这个 main 入口的差别是?

from cyclopts import App

from xtuner.v1.train.trainer import Trainer
from train_config import TrainConfig
from xtuner.v1.model import Qwen3Dense4BConfig
from xtuner.v1.config import LRConfig, AdamWConfig

app = App(
name="entrypoint of sft & pretrain",
help="Elian-XTuner's entry point for fine-tuning and training, launched using configuration files or arguments.",
)


@app.default()
def main():
cfg = TrainConfig()
print(cfg)
model_cfg = Qwen3Dense4BConfig(max_position_embeddings = cfg.max_position_embeddings)
optim_cfg = AdamWConfig(lr=cfg.lr)
lr_cfg = LRConfig(lr_type=cfg.lr_type, lr_min=cfg.lr_min)
trainer = Trainer(
**cfg.to_trainer_kwargs(model_cfg, optim_cfg, lr_cfg)
)
trainer.fit()

if dist.is_initialized():
dist.destroy_process_group()


if __name__ == "__main__":
app(exit_on_error=False)
48 changes: 48 additions & 0 deletions elian/train/qwen3/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#!/bin/bash
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个适合放在 tutorial 或者 examples 里,不适合放在库的核心代码里。以及这里其实有比较多的,贴近个人风格的写法,他更适合作为文档中的最佳实践的推荐写法。

此外有一些环境变量和步骤应该也不是必须的,例如 conda activate,CUDA_HOME,LD_LIBRARY_PATH 这些,不同环境下也容易失效

# ==========================================
# Training Script for Multi-GPU Training
# ==========================================
# step1: cd /data1/nuist_llm/TrainLLM/SFT-elian/xtuner
# step2: YOUR_ENV_PATH='/home/202312150002/anaconda3/envs/llm/lib/python3.10/site-packages'
# step3: cp -r ./xtuner $YOUR_ENV_PATH
# step4: bash ./elian/train/qwen3/run.sh

# conda
export PATH="/home/202312150002/anaconda3/bin:$PATH"
source /home/202312150002/anaconda3/etc/profile.d/conda.sh
conda activate xtuner
TRAIN_PATH=/data1/nuist_llm/TrainLLM/SFT-elian/xtuner/elian/train/qwen3
cd $TRAIN_PATH || exit 1

# cuda
export PATH="/usr/local/cuda-12.4/bin:$PATH"
export LD_LIBRARY_PATH="/usr/local/cuda-12.4/lib64:$LD_LIBRARY_PATH"
export CUDA_HOME="/usr/local/cuda-12.4"
echo "CUDA version used: $($CUDA_HOME/bin/nvcc --version | grep 'release' | awk '{print $6}')"

# node
NUM_NODES=1
GPU_LIST="0,1,2,3"
NUM_GPUS_PER_NODE=$(echo $GPU_LIST | awk -F',' '{print NF}')
export NPROC_PER_NODE=4
export NNODES=1
export NODE_RANK=0
export PORT=10171
export NODE_0_ADDR=172.16.107.15
export NCCL_DEBUG=INFO
export NCCL_IB_DISABLE=1
export OMP_NUM_THREADS=1
export NCCL_SOCKET_IFNAME=lo,eth0

# train params
TRAIN_SCRIPT="$TRAIN_PATH/main.py"
export CUDA_VISIBLE_DEVICES=$GPU_LIST
echo "Elian-Xtuner-V0.2.0 (used GPUs: ${CUDA_VISIBLE_DEVICES})"
export XTUNER_DETERMINISTIC=true # torch.use_deterministic_algorithms

torchrun --nproc_per_node=$NUM_GPUS_PER_NODE \
--nnodes=$NUM_NODES \
--node_rank=$NODE_RANK \
--master_addr=$NODE_0_ADDR \
--master_port=$PORT \
$TRAIN_SCRIPT
47 changes: 47 additions & 0 deletions elian/train/qwen3/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from xtuner.v1.loss.ce_loss import CELossConfig, CELossContextInputItem, CELossContext
import time


hidden_states = torch.randn(32768, 4096, device="cuda", dtype=torch.bfloat16, requires_grad=True)
lm_head = nn.Linear(4096, 151936, bias=False).to(device="cuda", dtype=torch.bfloat16)
torch.cuda.reset_peak_memory_stats()
t1 = time.time()
logits = lm_head(hidden_states)
shifted_labels = torch.randint(0, 151936, (32768, ), device="cuda")
loss = F.cross_entropy(logits, shifted_labels)
loss.backward()
max_memory = torch.cuda.max_memory_allocated()
reserved_memory = torch.cuda.max_memory_reserved()
print(f"Eager mode Loss: {loss.item()}")
print(f"Eager mode hidden_states grad norm: {hidden_states.grad.norm().item()}")
print(f"Eager mode lm_head weight grad norm: {lm_head.weight.grad.norm().item()}")
print(f"Eager mode Max memory allocated: {max_memory / 1024**3:.2f} GB")
print(f"Eager mode Max memory reserved: {reserved_memory / 1024**3:.2f} GB")
print(f"Eager mode Time taken: {time.time() - t1:.2f} seconds")

del logits
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

shifted_labels = shifted_labels.unsqueeze(0)
hidden_states = hidden_states.unsqueeze(0)
hidden_states = hidden_states.clone().detach().requires_grad_(True)
lm_head.weight.grad = None
t1 = time.time()
loss_ctx_input_list = [CELossContextInputItem(shifted_labels=shifted_labels)]
loss_cfg = CELossConfig(mode='chunk', chunk_size=1024, loss_reduction="token")
batches_loss_kwargs = CELossContext.build_batches_loss_kwargs(loss_ctx_input_list, loss_cfg)
loss_ctx = CELossContext(loss_cfg, batches_loss_kwargs[0])
loss, _ = loss_ctx.forward(hidden_states, lm_head.weight)
loss.backward()
max_memory = torch.cuda.max_memory_allocated()
reserved_memory = torch.cuda.max_memory_reserved()
print(f"Chunk mode Loss: {loss.item()}")
print(f"Chunk mode hidden_states grad norm: {hidden_states.grad.norm().item()}")
print(f"Chunk mode lm_head weight grad norm: {lm_head.weight.grad.norm().item()}")
print(f"Chunk mode Max memory allocated: {max_memory / 1024**3:.2f} GB")
print(f"Chunk mode Max memory reserved: {reserved_memory / 1024**3:.2f} GB")
print(f"Chunk mode Time taken: {time.time() - t1:.2f} seconds")
198 changes: 198 additions & 0 deletions elian/train/qwen3/train_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
from dataclasses import dataclass, field, asdict
from pathlib import Path
from typing import Optional

from xtuner.v1.train.trainer import ResumeConfig
from xtuner.v1.config import FSDPConfig
from xtuner.v1.loss.ce_loss import CELossConfig
from xtuner.v1.datasets.config import DatasetCombine, DatasetConfig, DataloaderConfig
from xtuner.v1.datasets.sft_tokenize_fn.openai import OpenaiTokenizeFunctionConfig

@dataclass
class TrainConfig:
# base path
model_path: str = "/data1/nuist_llm/TrainLLM/ModelCkpt/qwen3-4b/instruct-base"
work_dir: str = "/data1/nuist_llm/TrainLLM/SFT-elian/xtuner/elian/save/model-01"
log_dir: str = "/data1/nuist_llm/TrainLLM/SFT-elian/xtuner/elian/save/model-01"

# data params
dataset_cfg: list = field(default_factory=lambda: [
"/data1/nuist_llm/TrainLLM/datasets/SFT/math/category/code/Nemotron-Post-Training-V2-code-coldStart.jsonl",
"/data1/nuist_llm/TrainLLM/datasets/SFT/math/category/code/Nemotron-Post-Training-V2-code.jsonl",
"/data1/nuist_llm/TrainLLM/datasets/SFT/math/category/math/Nemotron-Post-Training-V2-math.jsonl",
"/data1/nuist_llm/TrainLLM/datasets/SFT/math/category/other/Nemotron-Post-Training-V2-math-coldStart.jsonl"
] + [f"/data1/nuist_llm/TrainLLM/datasets/SFT/math/category/other/chat-0000{i}-of-00012.jsonl" for i in range(10)] + [
"/data1/nuist_llm/TrainLLM/datasets/SFT/math/category/other/chat-00010-of-00012.jsonl","/data1/nuist_llm/TrainLLM/datasets/SFT/math/category/other/chat-00011-of-00012.jsonl"
])
cache_dir: str = "/data1/nuist_llm/cacheTem/elianXtuner"
class_name: str = "JsonlDataset" # TODO @elian: new parquest
sample_ratio: float = 1.0
cache_tag: str = "elian-xtuner"
message_template: str = "qwen3"
max_token_size: int = 4096
max_position_embeddings: int = 4096
collator: str = "sft_llm_collator" # ["sft_llm_collator", "sft_vllm_collator", "fake_collator"]
pack_level: str = "soft" # ["soft", "none"] # soft is True, none is False for Pack
pack_max_length: int = 8024 # max_position_size
pack_workers: int = 8
num_workers: int = 8


# train params
global_batch_size: int = 1
total_epoch: int = 1

# fsdp params
sp_size: int = 2
tp_size: int = 2
ep_size: int = 1
recompute_ratio:float = 1.0
cpu_offload: bool = False

# loss params
mode: str = "chunk"
chunk_size: int = 1024
loss_reduction: str = "token" # ["token", "sample", "square"]

# resume params
resume_from: Optional[str] = None
auto_resume: bool = False
load_optimizer: bool = True
load_dataset: bool = True
load_scheduler: bool = True
strict_load: bool = False

# save checkpoint step
hf_interval: Optional[int] = 2000
hf_max_keep: Optional[int] = 1
checkpoint_interval: Optional[int] = 1000
checkpoint_maxkeep: Optional[int] = 2

# profiling
profile_step: Optional[int] = 1
profile_time: bool = True
profile_memory: bool = True
intra_layer_micro_batch: int = 1

# other
seed: int = 42
debug: bool = False
backend: str = "nccl"
exp_tracker: str = "tensorboard"

# optim
lr: float = 6e-5
weight_decay: float = 0.001
betas: tuple = (0.9, 0.95)
max_grad_norm: float = 1.0
lr_type: str = "cosine" # ["cosine", "linear", "constant"]
warmup_ratio: float = 0.03
lr_min: float = 1e-6

def build_resume_cfg(self) -> Optional[ResumeConfig]:
if self.resume_from or self.auto_resume:
return ResumeConfig(
resume_from=self.resume_from,
auto_resume=self.auto_resume,
load_optimizer=self.load_optimizer,
load_dataset=self.load_dataset,
load_scheduler=self.load_scheduler,
)
return None

def build_fsdp_cfg(self) -> Optional[FSDPConfig]:
if self.tp_size > 1 or self.ep_size > 1:
return FSDPConfig(
tp_size = self.tp_size,
sp_size = self.sp_size,
ep_size = self.ep_size,
cpu_offload = self.cpu_offload
)
else:
return None

def build_loss_cfg(self) -> Optional[CELossConfig]:
if self.mode!="eager" or self.loss_reduction!="token":
return CELossConfig(
mode = self.mode,
chunk_size = self.chunk_size,
loss_reduction = self.loss_reduction
)
else:
return None

def build_datasets_cfg(self) -> list[DatasetCombine]:
all_datasets = []
for data_file in self.dataset_cfg:
data_path = Path(data_file)
name = data_path.stem
tokenize_fn_cfg = OpenaiTokenizeFunctionConfig(
chat_template=self.message_template,
max_length=self.max_token_size
)
all_datasets.append(
{
"dataset":DatasetConfig(
anno_path=data_file,
cache_dir=self.cache_dir,
name=name,
cache_tag=self.cache_tag,
class_name=self.class_name,
sample_ratio=self.sample_ratio
),
"tokenize_fn":tokenize_fn_cfg
}
)
return all_datasets

def build_dataloader(self) -> DataloaderConfig:
return DataloaderConfig(
collator = self.collator,
pack_level = self.pack_level,
pack_max_length = self.pack_max_length,
pack_workers = self.pack_workers,
num_workers = self.num_workers
)

def to_trainer_kwargs(self, model_cfg, optim_cfg, lr_cfg):
return dict(
model_cfg=model_cfg,
tokenizer_path=self.model_path,
load_from=self.model_path,
optim_cfg=optim_cfg,
lr_cfg=lr_cfg,
global_batch_size = self.global_batch_size,
work_dir = self.work_dir,
log_dir = self.log_dir,
sp_size = self.sp_size,
total_epoch = self.total_epoch,
checkpoint_interval = self.checkpoint_interval,
checkpoint_maxkeep = self.checkpoint_maxkeep,
hf_interval = self.hf_interval,
hf_max_keep = self.hf_max_keep,
exp_tracker = self.exp_tracker,
profile_step = self.profile_step,
profile_time = self.profile_time,
profile_memory = self.profile_memory,
intra_layer_micro_batch = self.intra_layer_micro_batch,
seed = self.seed,
debug = self.debug,
backend = self.backend,
resume_cfg=self.build_resume_cfg(),
fsdp_cfg=self.build_fsdp_cfg(),
loss_cfg=self.build_loss_cfg(),
dataset_cfg=self.build_datasets_cfg(),
dataloader_cfg=self.build_dataloader()
)

def __str__(self):
cfg_dict = asdict(self)
max_key_len = max(len(k) for k in cfg_dict.keys())
lines = []
for k, v in cfg_dict.items():
lines.append(f"{k:<{max_key_len}} : {v}")
return "\n".join(lines)

if __name__ == "__main__":
cfg = TrainConfig()
print(cfg)
3 changes: 0 additions & 3 deletions requirements/runtime.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,12 @@ bitsandbytes==0.45.0
datasets<4.0.0
einops
loguru
mmengine@git+https://github.com/open-mmlab/mmengine.git@c1724c6
openpyxl
peft>=0.14.0
scikit-image
scipy
SentencePiece
tiktoken
torch>=2.6.0
torchvision
transformers==4.56.0
cyclopts
transformers_stream_generator
Expand Down
47 changes: 47 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from xtuner.v1.loss.ce_loss import CELossConfig, CELossContextInputItem, CELossContext
import time


hidden_states = torch.randn(32768, 4096, device="cuda", dtype=torch.bfloat16, requires_grad=True)
lm_head = nn.Linear(4096, 151936, bias=False).to(device="cuda", dtype=torch.bfloat16)
torch.cuda.reset_peak_memory_stats()
t1 = time.time()
logits = lm_head(hidden_states)
shifted_labels = torch.randint(0, 151936, (32768, ), device="cuda")
loss = F.cross_entropy(logits, shifted_labels)
loss.backward()
max_memory = torch.cuda.max_memory_allocated()
reserved_memory = torch.cuda.max_memory_reserved()
print(f"Eager mode Loss: {loss.item()}")
print(f"Eager mode hidden_states grad norm: {hidden_states.grad.norm().item()}")
print(f"Eager mode lm_head weight grad norm: {lm_head.weight.grad.norm().item()}")
print(f"Eager mode Max memory allocated: {max_memory / 1024**3:.2f} GB")
print(f"Eager mode Max memory reserved: {reserved_memory / 1024**3:.2f} GB")
print(f"Eager mode Time taken: {time.time() - t1:.2f} seconds")

del logits
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

shifted_labels = shifted_labels.unsqueeze(0)
hidden_states = hidden_states.unsqueeze(0)
hidden_states = hidden_states.clone().detach().requires_grad_(True)
lm_head.weight.grad = None
t1 = time.time()
loss_ctx_input_list = [CELossContextInputItem(shifted_labels=shifted_labels)]
loss_cfg = CELossConfig(mode='chunk', chunk_size=1024, loss_reduction="token")
batches_loss_kwargs = CELossContext.build_batches_loss_kwargs(loss_ctx_input_list, loss_cfg)
loss_ctx = CELossContext(loss_cfg, batches_loss_kwargs[0])
loss, _ = loss_ctx.forward(hidden_states, lm_head.weight)
loss.backward()
max_memory = torch.cuda.max_memory_allocated()
reserved_memory = torch.cuda.max_memory_reserved()
print(f"Chunk mode Loss: {loss.item()}")
print(f"Chunk mode hidden_states grad norm: {hidden_states.grad.norm().item()}")
print(f"Chunk mode lm_head weight grad norm: {lm_head.weight.grad.norm().item()}")
print(f"Chunk mode Max memory allocated: {max_memory / 1024**3:.2f} GB")
print(f"Chunk mode Max memory reserved: {reserved_memory / 1024**3:.2f} GB")
print(f"Chunk mode Time taken: {time.time() - t1:.2f} seconds")
Loading