Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ci/scripts/test_sft_trainer_dsv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,8 @@ def main():
profile_step=20,
profile_memory=True,
strict_load=False,
intra_layer_micro_batch=2,
packed_samples_per_forward=2,
domino_forward=True,
)
trainer.fit()
if dist.get_rank() == 0:
Expand Down
12 changes: 6 additions & 6 deletions ci/scripts/test_sft_trainer_intralayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
from xtuner.v1.model.moe.moe import BalancingLossConfig, ZLossConfig
from xtuner.v1.datasets import DatasetConfig, DataloaderConfig
from xtuner.v1.datasets import FTDPTokenizeFnConfig
from xtuner.v1.loss import CELossContext
from xtuner.v1.model.moe.qwen3 import Qwen3MoE30BA3Config
from xtuner.v1.train.trainer import Trainer
from xtuner.v1.utils.compile import maybe_compile
from xtuner.v1.loss import CELossConfig
import argparse


Expand Down Expand Up @@ -250,24 +249,25 @@ def main():
]

dataloader_config = DataloaderConfig(
pack_max_length=8192,
pack_max_length=16384,
)
work_dir = f"{args.work_dir}-{name}"
loss_ctx = CELossContext(loss_class="liger_cross_entropy")
loss_cfg = CELossConfig(mode="chunk", chunk_size=1024, ignore_idx=-100)
trainer = Trainer(
load_from=QWEN3_MOE_PATH,
model_cfg=moe_cfg,
optim_cfg=optim_cfg,
fsdp_cfg=fsdp_cfg,
dataset_cfg=dataset_config,
dataloader_cfg=dataloader_config,
loss_ctx=loss_ctx,
loss_cfg=loss_cfg,
lr_cfg=lr_cfg,
tokenizer_path=QWEN3_MOE_PATH,
global_batch_size=32,
total_epoch=1,
work_dir=work_dir,
intra_layer_micro_batch=2,
packed_samples_per_forward=2,
domino_forward=True,
seed=0,
debug=False,
profile_memory=False,
Expand Down
3 changes: 3 additions & 0 deletions xtuner/v1/data_proto/sequence_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ def split(self, sequence_parallel_mesh: DeviceMesh | None = None) -> Self:

@classmethod
def pack(cls, sequence_context_list: list["SequenceContext"]):
if len(sequence_context_list) == 1:
return sequence_context_list[0]

packed_input_ids: list[torch.Tensor] = []
cu_seq_lens_q: list[torch.IntTensor] = []
cu_seq_lens_k: list[torch.IntTensor] = []
Expand Down
1 change: 0 additions & 1 deletion xtuner/v1/datasets/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,6 @@ def build(
micro_batch_size: int,
seed: int,
shuffle: bool = True,
total_step: int | None = None,
) -> Dataloader:
if self.dataset_config_list is None:
raise ValueError("dataset_config_list is required.")
Expand Down
33 changes: 18 additions & 15 deletions xtuner/v1/engine/train_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from xtuner.v1.config import FSDPConfig, OptimConfig
from xtuner.v1.data_proto.sequence_context import SequenceContext
from xtuner.v1.float8.float8_handler import Float8Handler
from xtuner.v1.model.base import BaseModel, ModelItem, TransformerConfig
from xtuner.v1.model.base import BaseLossContext, BaseModel, ModelItem, TransformerConfig
from xtuner.v1.module.router import NoAuxRouterConfig
from xtuner.v1.utils import get_device, get_logger, get_torch_device_module

Expand Down Expand Up @@ -142,14 +142,16 @@ def __init__(
model_cfg: TransformerConfig,
optim_cfg: OptimConfig,
fsdp_cfg: FSDPConfig,
intra_layer_micro_batch: int = 1,
packed_samples_per_forward: int = 1,
domino_forward: bool = False,
) -> None:
self.model_cfg = model_cfg
self.optim_cfg = optim_cfg
self.fsdp_cfg = fsdp_cfg
self.model = self.build_model()
self.optimizer = self.build_optimizer(optim_cfg)
self.intra_layer_micro_batch = intra_layer_micro_batch
self.packed_samples_per_forward = packed_samples_per_forward
self.domino_forward = domino_forward
self._count = 0

def build_model(self) -> BaseModel:
Expand Down Expand Up @@ -203,8 +205,7 @@ def forward_only(self, seq_ctx: SequenceContext):
return output

def grad_accumulation_steps(self, data_batches_len: int):
intra_layer_micro_batch = self.intra_layer_micro_batch
return data_batches_len // intra_layer_micro_batch
return data_batches_len // self.packed_samples_per_forward

def train_step(self, data_batches: list[ModelItem]):
"""Perform a training step with the given data batches and mesh.
Expand All @@ -217,9 +218,9 @@ def train_step(self, data_batches: list[ModelItem]):

loss_log = {}
other_log = {}
intra_layer_micro_batch = self.intra_layer_micro_batch
assert len(data_batches) % intra_layer_micro_batch == 0, (
f"data_batches length {len(data_batches)} is not divisible by intra_layer_micro_batch {intra_layer_micro_batch}"
packed_samples_per_forward = self.packed_samples_per_forward
assert len(data_batches) % self.packed_samples_per_forward == 0, (
f"data_batches length {len(data_batches)} is not divisible by num_packed_sample {packed_samples_per_forward}"
)
iters_per_step = self.grad_accumulation_steps(len(data_batches))

Expand All @@ -245,21 +246,23 @@ def train_step(self, data_batches: list[ModelItem]):
logger.info(f"grad_accumulation_steps: {iters_per_step}")
self._count += 1

for i in range(0, len(data_batches), intra_layer_micro_batch):
data_batch = data_batches[i : i + intra_layer_micro_batch]
seq_ctx_list = []
loss_ctx_list = []
for i in range(0, len(data_batches), packed_samples_per_forward):
data_batch = data_batches[i : i + packed_samples_per_forward]
seq_ctx_list: list[SequenceContext] = []
loss_ctx_list: list[BaseLossContext] = []
for data in data_batch:
seq_ctx = data["seq_ctx"]
loss_ctx = data["loss_ctx"]
seq_ctx_list.append(seq_ctx)
loss_ctx_list.append(loss_ctx)
step_consumed_tokens += seq_ctx.mask.sum()

if self.intra_layer_micro_batch == 1:
output = self.model(seq_ctx=seq_ctx_list[0], loss_ctx=loss_ctx_list[0])
if not self.domino_forward:
cat_seq_ctx = seq_ctx_list[0].__class__.pack(seq_ctx_list)
cat_loss_ctx = loss_ctx_list[0].__class__.pack(loss_ctx_list)
output = self.model(seq_ctx=cat_seq_ctx, loss_ctx=cat_loss_ctx)
else:
# For intra_layer_micro_batch > 1, we need to handle the data batches differently.
# For packed_samples_per_forward > 1, we need to handle the data batches differently.
# Here we assume that the model can handle a list of seq_ctx and loss_ctx.
output = self.model(
seq_ctx=seq_ctx_list,
Expand Down
38 changes: 25 additions & 13 deletions xtuner/v1/engine/vision_compose_train_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing_extensions import Self

from transformers.configuration_utils import PretrainedConfig
from xtuner.v1.config import FSDPConfig
from xtuner.v1.config import FSDPConfig, OptimConfig
from xtuner.v1.data_proto import SequenceContext
from xtuner.v1.float8.float8_handler import Float8Handler
from xtuner.v1.loss import BaseLossContext
Expand Down Expand Up @@ -70,10 +70,20 @@ class VisionComposeTrainEngine(TrainEngine):
def __init__(
self,
model_cfg: VisionComposeConfigProtocol,
*args,
**kwargs,
optim_cfg: OptimConfig,
fsdp_cfg: FSDPConfig,
packed_samples_per_forward: int = 1,
domino_forward: bool = False,
) -> None:
super().__init__(model_cfg, *args, **kwargs) # type: ignore
if domino_forward:
raise NotImplementedError
super().__init__(
model_cfg=model_cfg, # type: ignore
optim_cfg=optim_cfg,
fsdp_cfg=fsdp_cfg,
packed_samples_per_forward=packed_samples_per_forward,
domino_forward=domino_forward,
)

def build_model(self) -> VisionComposeModelProtocol: # type: ignore
with torch.device("meta"):
Expand Down Expand Up @@ -142,9 +152,9 @@ def train_step(self, data_batches: List[ModelItem]):

loss_log = {}
other_log = {}
intra_layer_micro_batch = self.intra_layer_micro_batch
assert len(data_batches) % intra_layer_micro_batch == 0, (
f"data_batches length {len(data_batches)} is not divisible by intra_layer_micro_batch {intra_layer_micro_batch}"
packed_samples_per_forward = self.packed_samples_per_forward
assert len(data_batches) % packed_samples_per_forward == 0, (
f"data_batches length {len(data_batches)} is not divisible by num_packed_sample {packed_samples_per_forward}"
)
iters_per_step = self.grad_accumulation_steps(len(data_batches))

Expand All @@ -165,19 +175,21 @@ def train_step(self, data_batches: List[ModelItem]):
step_z_loss: torch.Tensor | None = None
step_consumed_tokens = torch.tensor(0.0, device=DEVICE)

for i in range(0, len(data_batches), intra_layer_micro_batch):
data_batch = data_batches[i : i + intra_layer_micro_batch]
seq_ctx_list = []
loss_ctx_list = []
for i in range(0, len(data_batches), packed_samples_per_forward):
data_batch = data_batches[i : i + packed_samples_per_forward]
seq_ctx_list: list[SequenceContext] = []
loss_ctx_list: list[BaseLossContext] = []
for data in data_batch:
seq_ctx = data["seq_ctx"]
loss_ctx = data["loss_ctx"]
seq_ctx_list.append(seq_ctx)
loss_ctx_list.append(loss_ctx)
step_consumed_tokens += seq_ctx.mask.sum()

# todo: support intra_layer_micro_batch
output = self.model(seq_ctx=seq_ctx_list[0], loss_ctx=loss_ctx_list[0])
# todo: support packed_samples_per_forward
cat_seq_ctx = seq_ctx_list[0].__class__.pack(seq_ctx_list)
cat_loss_ctx = loss_ctx_list[0].__class__.pack(loss_ctx_list)
output = self.model(seq_ctx=cat_seq_ctx, loss_ctx=cat_loss_ctx)
# llm loss has been global averaged
llm_loss = output["loss"]
step_llm_loss += llm_loss.detach().clone()
Expand Down
6 changes: 6 additions & 0 deletions xtuner/v1/loss/base_loss_ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pydantic import BaseModel, ConfigDict
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.nn.functional import all_reduce
from typing_extensions import Self

from .chunk_loss import ChunkLoss

Expand Down Expand Up @@ -151,3 +152,8 @@ def forward(
if dist.is_initialized():
loss = all_reduce(loss, op=dist.ReduceOp.SUM, group=dist.group.WORLD)
return loss, logits

@classmethod
def pack(cls, loss_ctx_list: list[Self]) -> Self:
# TODO: Imp pack for all loss_ctx
raise NotImplementedError
11 changes: 11 additions & 0 deletions xtuner/v1/loss/ce_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,14 @@ def loss_fn(
loss = (loss * loss_weight).sum()

return loss, logits

@classmethod
def pack(cls, loss_ctx_list: list[Self]) -> Self: # type: ignore
if len(loss_ctx_list) == 0:
return loss_ctx_list
loss_cfg = loss_ctx_list[0].loss_cfg
loss_kwargs = [i.loss_kwargs for i in loss_ctx_list]
shifted_labels = torch.cat([i.shifted_labels for i in loss_kwargs], dim=-1)
shifted_loss_weights = torch.cat([i.loss_weight for i in loss_kwargs], dim=-1)
cat_loss_kwargs = CELossKwargs(shifted_labels=shifted_labels, loss_weight=shifted_loss_weights)
return cls(loss_cfg=loss_cfg, loss_kwargs=cat_loss_kwargs)
18 changes: 9 additions & 9 deletions xtuner/v1/model/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,25 +157,25 @@ def _select_non_pad_router_logits(
) -> torch.Tensor:
assert len(router_logits_list) > 0, "router_logits_list should not be empty"
if isinstance(router_logits_list[0], torch.Tensor):
router_logits_list = [cast(list[torch.Tensor], router_logits_list)] # intra_layer_micro_batch is 1
router_logits_list = [cast(list[torch.Tensor], router_logits_list)] # packed_samples_per_forward is 1
attn_mask_list = [cast(torch.Tensor, attn_mask_list)]
# router_logits_list [intra_layer_micro_batch, num_layers][seq, num_experts]
# attn_mask_list [intra_layer_micro_batch, ][1, seq]
intra_layer_micro_batch = len(router_logits_list)
# router_logits_list [packed_samples_per_forward, num_layers][seq, num_experts]
# attn_mask_list [packed_samples_per_forward, ][1, seq]
packed_samples_per_forward = len(router_logits_list)
num_layers = len(router_logits_list[0])

router_logits_list_new = [] # [num_layers, intra_layer_micro_batch] -> [num_layers * intra_layer_micro_batch]
router_logits_list_new = [] # [num_layers, packed_samples_per_forward] -> [num_layers * packed_samples_per_forward]
for layer_idx in range(num_layers):
for micro_batch_idx in range(intra_layer_micro_batch):
for micro_batch_idx in range(packed_samples_per_forward):
router_logits_list_new.append(router_logits_list[micro_batch_idx][layer_idx])

router_logits = torch.stack(
router_logits_list_new, dim=0
) # [num_layers * intra_layer_micro_batch, seq, num_experts]
) # [num_layers * packed_samples_per_forward, seq, num_experts]
router_logits = router_logits.view(
num_layers, -1, router_logits.shape[-1]
) # [num_layers, intra_layer_micro_batch * seq, num_experts]
attn_mask = torch.stack(attn_mask_list, dim=0) # type: ignore # [intra_layer_micro_batch, 1, seq]
) # [num_layers, packed_samples_per_forward * seq, num_experts]
attn_mask = torch.stack(attn_mask_list, dim=0) # type: ignore # [packed_samples_per_forward, 1, seq]
attn_mask = attn_mask.flatten()
router_logits = router_logits[:, attn_mask].contiguous().float() # [num_layers, non_pad_seq, num_experts]
return router_logits
Expand Down
4 changes: 2 additions & 2 deletions xtuner/v1/module/decoder_layer/moe_decoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def _micro_batch_forward(
assert all(hidden_states.shape == origin_shape for hidden_states in hidden_states_list), (
"All hidden states should have the same shape"
)
intra_layer_micro_batch = len(hidden_states_list)
packed_samples_per_forward = len(hidden_states_list)
residual_list: list[torch.Tensor] = []
router_results_list: list[RouterResults] = []

Expand Down Expand Up @@ -466,7 +466,7 @@ def _micro_batch_forward(
combined_list.append(combined)

hidden_states_out_list: list[torch.Tensor] = []
for i in range(intra_layer_micro_batch):
for i in range(packed_samples_per_forward):
post_combined = self.dispatcher.combine_postprocess(
pre_dispatched=pre_dispatched_list[i],
dispatched=dispatched_list[i],
Expand Down
Loading
Loading