From ad6b6e0130089e8104b8f32836d1938d6a641666 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Fri, 26 Sep 2025 03:13:58 +0800 Subject: [PATCH] [Refactor] Rename intra_layer_micro_batch to packed_samples_per_forward and add domino_forward flag - Rename intra_layer_micro_batch parameter to packed_samples_per_forward across all files - Add domino_forward boolean flag for controlling forward mode - Update test scripts to use new parameter names - Add packing support for sequence contexts and loss contexts - Update CELossContext to CELossConfig in test scripts --- ci/scripts/test_sft_trainer_dsv3.py | 3 +- ci/scripts/test_sft_trainer_intralayer.py | 12 ++-- xtuner/v1/data_proto/sequence_context.py | 3 + xtuner/v1/datasets/config.py | 1 - xtuner/v1/engine/train_engine.py | 33 +++++---- .../v1/engine/vision_compose_train_engine.py | 38 ++++++---- xtuner/v1/loss/base_loss_ctx.py | 6 ++ xtuner/v1/loss/ce_loss.py | 11 +++ xtuner/v1/model/moe/moe.py | 18 ++--- .../module/decoder_layer/moe_decoder_layer.py | 4 +- xtuner/v1/train/trainer.py | 69 ++++++++++++------- 11 files changed, 128 insertions(+), 70 deletions(-) diff --git a/ci/scripts/test_sft_trainer_dsv3.py b/ci/scripts/test_sft_trainer_dsv3.py index bb2c2f5ed..a8661fa78 100644 --- a/ci/scripts/test_sft_trainer_dsv3.py +++ b/ci/scripts/test_sft_trainer_dsv3.py @@ -283,7 +283,8 @@ def main(): profile_step=20, profile_memory=True, strict_load=False, - intra_layer_micro_batch=2, + packed_samples_per_forward=2, + domino_forward=True, ) trainer.fit() if dist.get_rank() == 0: diff --git a/ci/scripts/test_sft_trainer_intralayer.py b/ci/scripts/test_sft_trainer_intralayer.py index 0c97e8c5e..60fc5c78d 100644 --- a/ci/scripts/test_sft_trainer_intralayer.py +++ b/ci/scripts/test_sft_trainer_intralayer.py @@ -15,10 +15,9 @@ from xtuner.v1.model.moe.moe import BalancingLossConfig, ZLossConfig from xtuner.v1.datasets import DatasetConfig, DataloaderConfig from xtuner.v1.datasets import FTDPTokenizeFnConfig -from xtuner.v1.loss import CELossContext from xtuner.v1.model.moe.qwen3 import Qwen3MoE30BA3Config from xtuner.v1.train.trainer import Trainer -from xtuner.v1.utils.compile import maybe_compile +from xtuner.v1.loss import CELossConfig import argparse @@ -250,10 +249,10 @@ def main(): ] dataloader_config = DataloaderConfig( - pack_max_length=8192, + pack_max_length=16384, ) work_dir = f"{args.work_dir}-{name}" - loss_ctx = CELossContext(loss_class="liger_cross_entropy") + loss_cfg = CELossConfig(mode="chunk", chunk_size=1024, ignore_idx=-100) trainer = Trainer( load_from=QWEN3_MOE_PATH, model_cfg=moe_cfg, @@ -261,13 +260,14 @@ def main(): fsdp_cfg=fsdp_cfg, dataset_cfg=dataset_config, dataloader_cfg=dataloader_config, - loss_ctx=loss_ctx, + loss_cfg=loss_cfg, lr_cfg=lr_cfg, tokenizer_path=QWEN3_MOE_PATH, global_batch_size=32, total_epoch=1, work_dir=work_dir, - intra_layer_micro_batch=2, + packed_samples_per_forward=2, + domino_forward=True, seed=0, debug=False, profile_memory=False, diff --git a/xtuner/v1/data_proto/sequence_context.py b/xtuner/v1/data_proto/sequence_context.py index d7593781d..d817560be 100644 --- a/xtuner/v1/data_proto/sequence_context.py +++ b/xtuner/v1/data_proto/sequence_context.py @@ -137,6 +137,9 @@ def split(self, sequence_parallel_mesh: DeviceMesh | None = None) -> Self: @classmethod def pack(cls, sequence_context_list: list["SequenceContext"]): + if len(sequence_context_list) == 1: + return sequence_context_list[0] + packed_input_ids: list[torch.Tensor] = [] cu_seq_lens_q: list[torch.IntTensor] = [] cu_seq_lens_k: list[torch.IntTensor] = [] diff --git a/xtuner/v1/datasets/config.py b/xtuner/v1/datasets/config.py index 05b95bd13..51036ef85 100644 --- a/xtuner/v1/datasets/config.py +++ b/xtuner/v1/datasets/config.py @@ -298,7 +298,6 @@ def build( micro_batch_size: int, seed: int, shuffle: bool = True, - total_step: int | None = None, ) -> Dataloader: if self.dataset_config_list is None: raise ValueError("dataset_config_list is required.") diff --git a/xtuner/v1/engine/train_engine.py b/xtuner/v1/engine/train_engine.py index 49540aa75..9b2d61b76 100644 --- a/xtuner/v1/engine/train_engine.py +++ b/xtuner/v1/engine/train_engine.py @@ -29,7 +29,7 @@ from xtuner.v1.config import FSDPConfig, OptimConfig from xtuner.v1.data_proto.sequence_context import SequenceContext from xtuner.v1.float8.float8_handler import Float8Handler -from xtuner.v1.model.base import BaseModel, ModelItem, TransformerConfig +from xtuner.v1.model.base import BaseLossContext, BaseModel, ModelItem, TransformerConfig from xtuner.v1.module.router import NoAuxRouterConfig from xtuner.v1.utils import get_device, get_logger, get_torch_device_module @@ -142,14 +142,16 @@ def __init__( model_cfg: TransformerConfig, optim_cfg: OptimConfig, fsdp_cfg: FSDPConfig, - intra_layer_micro_batch: int = 1, + packed_samples_per_forward: int = 1, + domino_forward: bool = False, ) -> None: self.model_cfg = model_cfg self.optim_cfg = optim_cfg self.fsdp_cfg = fsdp_cfg self.model = self.build_model() self.optimizer = self.build_optimizer(optim_cfg) - self.intra_layer_micro_batch = intra_layer_micro_batch + self.packed_samples_per_forward = packed_samples_per_forward + self.domino_forward = domino_forward self._count = 0 def build_model(self) -> BaseModel: @@ -203,8 +205,7 @@ def forward_only(self, seq_ctx: SequenceContext): return output def grad_accumulation_steps(self, data_batches_len: int): - intra_layer_micro_batch = self.intra_layer_micro_batch - return data_batches_len // intra_layer_micro_batch + return data_batches_len // self.packed_samples_per_forward def train_step(self, data_batches: list[ModelItem]): """Perform a training step with the given data batches and mesh. @@ -217,9 +218,9 @@ def train_step(self, data_batches: list[ModelItem]): loss_log = {} other_log = {} - intra_layer_micro_batch = self.intra_layer_micro_batch - assert len(data_batches) % intra_layer_micro_batch == 0, ( - f"data_batches length {len(data_batches)} is not divisible by intra_layer_micro_batch {intra_layer_micro_batch}" + packed_samples_per_forward = self.packed_samples_per_forward + assert len(data_batches) % self.packed_samples_per_forward == 0, ( + f"data_batches length {len(data_batches)} is not divisible by num_packed_sample {packed_samples_per_forward}" ) iters_per_step = self.grad_accumulation_steps(len(data_batches)) @@ -245,10 +246,10 @@ def train_step(self, data_batches: list[ModelItem]): logger.info(f"grad_accumulation_steps: {iters_per_step}") self._count += 1 - for i in range(0, len(data_batches), intra_layer_micro_batch): - data_batch = data_batches[i : i + intra_layer_micro_batch] - seq_ctx_list = [] - loss_ctx_list = [] + for i in range(0, len(data_batches), packed_samples_per_forward): + data_batch = data_batches[i : i + packed_samples_per_forward] + seq_ctx_list: list[SequenceContext] = [] + loss_ctx_list: list[BaseLossContext] = [] for data in data_batch: seq_ctx = data["seq_ctx"] loss_ctx = data["loss_ctx"] @@ -256,10 +257,12 @@ def train_step(self, data_batches: list[ModelItem]): loss_ctx_list.append(loss_ctx) step_consumed_tokens += seq_ctx.mask.sum() - if self.intra_layer_micro_batch == 1: - output = self.model(seq_ctx=seq_ctx_list[0], loss_ctx=loss_ctx_list[0]) + if not self.domino_forward: + cat_seq_ctx = seq_ctx_list[0].__class__.pack(seq_ctx_list) + cat_loss_ctx = loss_ctx_list[0].__class__.pack(loss_ctx_list) + output = self.model(seq_ctx=cat_seq_ctx, loss_ctx=cat_loss_ctx) else: - # For intra_layer_micro_batch > 1, we need to handle the data batches differently. + # For packed_samples_per_forward > 1, we need to handle the data batches differently. # Here we assume that the model can handle a list of seq_ctx and loss_ctx. output = self.model( seq_ctx=seq_ctx_list, diff --git a/xtuner/v1/engine/vision_compose_train_engine.py b/xtuner/v1/engine/vision_compose_train_engine.py index 1548735dc..1ecce5d1b 100644 --- a/xtuner/v1/engine/vision_compose_train_engine.py +++ b/xtuner/v1/engine/vision_compose_train_engine.py @@ -8,7 +8,7 @@ from typing_extensions import Self from transformers.configuration_utils import PretrainedConfig -from xtuner.v1.config import FSDPConfig +from xtuner.v1.config import FSDPConfig, OptimConfig from xtuner.v1.data_proto import SequenceContext from xtuner.v1.float8.float8_handler import Float8Handler from xtuner.v1.loss import BaseLossContext @@ -70,10 +70,20 @@ class VisionComposeTrainEngine(TrainEngine): def __init__( self, model_cfg: VisionComposeConfigProtocol, - *args, - **kwargs, + optim_cfg: OptimConfig, + fsdp_cfg: FSDPConfig, + packed_samples_per_forward: int = 1, + domino_forward: bool = False, ) -> None: - super().__init__(model_cfg, *args, **kwargs) # type: ignore + if domino_forward: + raise NotImplementedError + super().__init__( + model_cfg=model_cfg, # type: ignore + optim_cfg=optim_cfg, + fsdp_cfg=fsdp_cfg, + packed_samples_per_forward=packed_samples_per_forward, + domino_forward=domino_forward, + ) def build_model(self) -> VisionComposeModelProtocol: # type: ignore with torch.device("meta"): @@ -142,9 +152,9 @@ def train_step(self, data_batches: List[ModelItem]): loss_log = {} other_log = {} - intra_layer_micro_batch = self.intra_layer_micro_batch - assert len(data_batches) % intra_layer_micro_batch == 0, ( - f"data_batches length {len(data_batches)} is not divisible by intra_layer_micro_batch {intra_layer_micro_batch}" + packed_samples_per_forward = self.packed_samples_per_forward + assert len(data_batches) % packed_samples_per_forward == 0, ( + f"data_batches length {len(data_batches)} is not divisible by num_packed_sample {packed_samples_per_forward}" ) iters_per_step = self.grad_accumulation_steps(len(data_batches)) @@ -165,10 +175,10 @@ def train_step(self, data_batches: List[ModelItem]): step_z_loss: torch.Tensor | None = None step_consumed_tokens = torch.tensor(0.0, device=DEVICE) - for i in range(0, len(data_batches), intra_layer_micro_batch): - data_batch = data_batches[i : i + intra_layer_micro_batch] - seq_ctx_list = [] - loss_ctx_list = [] + for i in range(0, len(data_batches), packed_samples_per_forward): + data_batch = data_batches[i : i + packed_samples_per_forward] + seq_ctx_list: list[SequenceContext] = [] + loss_ctx_list: list[BaseLossContext] = [] for data in data_batch: seq_ctx = data["seq_ctx"] loss_ctx = data["loss_ctx"] @@ -176,8 +186,10 @@ def train_step(self, data_batches: List[ModelItem]): loss_ctx_list.append(loss_ctx) step_consumed_tokens += seq_ctx.mask.sum() - # todo: support intra_layer_micro_batch - output = self.model(seq_ctx=seq_ctx_list[0], loss_ctx=loss_ctx_list[0]) + # todo: support packed_samples_per_forward + cat_seq_ctx = seq_ctx_list[0].__class__.pack(seq_ctx_list) + cat_loss_ctx = loss_ctx_list[0].__class__.pack(loss_ctx_list) + output = self.model(seq_ctx=cat_seq_ctx, loss_ctx=cat_loss_ctx) # llm loss has been global averaged llm_loss = output["loss"] step_llm_loss += llm_loss.detach().clone() diff --git a/xtuner/v1/loss/base_loss_ctx.py b/xtuner/v1/loss/base_loss_ctx.py index d9bd8ddbf..59335c043 100644 --- a/xtuner/v1/loss/base_loss_ctx.py +++ b/xtuner/v1/loss/base_loss_ctx.py @@ -9,6 +9,7 @@ from pydantic import BaseModel, ConfigDict from torch.distributed.device_mesh import DeviceMesh from torch.distributed.nn.functional import all_reduce +from typing_extensions import Self from .chunk_loss import ChunkLoss @@ -151,3 +152,8 @@ def forward( if dist.is_initialized(): loss = all_reduce(loss, op=dist.ReduceOp.SUM, group=dist.group.WORLD) return loss, logits + + @classmethod + def pack(cls, loss_ctx_list: list[Self]) -> Self: + # TODO: Imp pack for all loss_ctx + raise NotImplementedError diff --git a/xtuner/v1/loss/ce_loss.py b/xtuner/v1/loss/ce_loss.py index 8b8ffde8e..dba32d1c4 100644 --- a/xtuner/v1/loss/ce_loss.py +++ b/xtuner/v1/loss/ce_loss.py @@ -170,3 +170,14 @@ def loss_fn( loss = (loss * loss_weight).sum() return loss, logits + + @classmethod + def pack(cls, loss_ctx_list: list[Self]) -> Self: # type: ignore + if len(loss_ctx_list) == 0: + return loss_ctx_list + loss_cfg = loss_ctx_list[0].loss_cfg + loss_kwargs = [i.loss_kwargs for i in loss_ctx_list] + shifted_labels = torch.cat([i.shifted_labels for i in loss_kwargs], dim=-1) + shifted_loss_weights = torch.cat([i.loss_weight for i in loss_kwargs], dim=-1) + cat_loss_kwargs = CELossKwargs(shifted_labels=shifted_labels, loss_weight=shifted_loss_weights) + return cls(loss_cfg=loss_cfg, loss_kwargs=cat_loss_kwargs) diff --git a/xtuner/v1/model/moe/moe.py b/xtuner/v1/model/moe/moe.py index 550983196..0679b561f 100644 --- a/xtuner/v1/model/moe/moe.py +++ b/xtuner/v1/model/moe/moe.py @@ -157,25 +157,25 @@ def _select_non_pad_router_logits( ) -> torch.Tensor: assert len(router_logits_list) > 0, "router_logits_list should not be empty" if isinstance(router_logits_list[0], torch.Tensor): - router_logits_list = [cast(list[torch.Tensor], router_logits_list)] # intra_layer_micro_batch is 1 + router_logits_list = [cast(list[torch.Tensor], router_logits_list)] # packed_samples_per_forward is 1 attn_mask_list = [cast(torch.Tensor, attn_mask_list)] - # router_logits_list [intra_layer_micro_batch, num_layers][seq, num_experts] - # attn_mask_list [intra_layer_micro_batch, ][1, seq] - intra_layer_micro_batch = len(router_logits_list) + # router_logits_list [packed_samples_per_forward, num_layers][seq, num_experts] + # attn_mask_list [packed_samples_per_forward, ][1, seq] + packed_samples_per_forward = len(router_logits_list) num_layers = len(router_logits_list[0]) - router_logits_list_new = [] # [num_layers, intra_layer_micro_batch] -> [num_layers * intra_layer_micro_batch] + router_logits_list_new = [] # [num_layers, packed_samples_per_forward] -> [num_layers * packed_samples_per_forward] for layer_idx in range(num_layers): - for micro_batch_idx in range(intra_layer_micro_batch): + for micro_batch_idx in range(packed_samples_per_forward): router_logits_list_new.append(router_logits_list[micro_batch_idx][layer_idx]) router_logits = torch.stack( router_logits_list_new, dim=0 - ) # [num_layers * intra_layer_micro_batch, seq, num_experts] + ) # [num_layers * packed_samples_per_forward, seq, num_experts] router_logits = router_logits.view( num_layers, -1, router_logits.shape[-1] - ) # [num_layers, intra_layer_micro_batch * seq, num_experts] - attn_mask = torch.stack(attn_mask_list, dim=0) # type: ignore # [intra_layer_micro_batch, 1, seq] + ) # [num_layers, packed_samples_per_forward * seq, num_experts] + attn_mask = torch.stack(attn_mask_list, dim=0) # type: ignore # [packed_samples_per_forward, 1, seq] attn_mask = attn_mask.flatten() router_logits = router_logits[:, attn_mask].contiguous().float() # [num_layers, non_pad_seq, num_experts] return router_logits diff --git a/xtuner/v1/module/decoder_layer/moe_decoder_layer.py b/xtuner/v1/module/decoder_layer/moe_decoder_layer.py index c78365703..37288e491 100644 --- a/xtuner/v1/module/decoder_layer/moe_decoder_layer.py +++ b/xtuner/v1/module/decoder_layer/moe_decoder_layer.py @@ -375,7 +375,7 @@ def _micro_batch_forward( assert all(hidden_states.shape == origin_shape for hidden_states in hidden_states_list), ( "All hidden states should have the same shape" ) - intra_layer_micro_batch = len(hidden_states_list) + packed_samples_per_forward = len(hidden_states_list) residual_list: list[torch.Tensor] = [] router_results_list: list[RouterResults] = [] @@ -466,7 +466,7 @@ def _micro_batch_forward( combined_list.append(combined) hidden_states_out_list: list[torch.Tensor] = [] - for i in range(intra_layer_micro_batch): + for i in range(packed_samples_per_forward): post_combined = self.dispatcher.combine_postprocess( pre_dispatched=pre_dispatched_list[i], dispatched=dispatched_list[i], diff --git a/xtuner/v1/train/trainer.py b/xtuner/v1/train/trainer.py index 07243bcdb..b4a075e3c 100644 --- a/xtuner/v1/train/trainer.py +++ b/xtuner/v1/train/trainer.py @@ -145,7 +145,8 @@ class TrainerConfig(BaseModel): profile_step: int | None = None profile_time: bool = True profile_memory: bool = False - intra_layer_micro_batch: int = 1 + packed_samples_per_forward: int = 1 + domino_forward: bool = False seed: int = 42 dist_backend: str | None = None debug: bool = False @@ -192,7 +193,6 @@ class Trainer: profile_step (int | None): Step to perform profiling. profile_time (bool): Whether to profile training time. profile_memory (bool): Whether to profile memory usage. - intra_layer_micro_batch (int): Intra-layer micro batch size. seed (int): Random seed for reproducibility. debug (bool): Whether to enable debug mode. backend (str): Backend for distributed training. @@ -240,7 +240,8 @@ def __init__( profile_step: int | None = None, profile_time: bool = True, profile_memory: bool = False, - intra_layer_micro_batch: int = 1, + packed_samples_per_forward: int = 1, + domino_forward: bool = False, seed: int = 42, debug: bool = False, backend: str | None = None, @@ -253,6 +254,9 @@ def __init__( self._cur_epoch = 1 self._cur_step = 0 + self._packed_samples_per_forward = packed_samples_per_forward + self._domino_forward = domino_forward + self._trainer_cfg = trainer_cfg self._micro_batch_size: int | None = None @@ -323,20 +327,9 @@ def __init__( self._global_batch_size = global_batch_size self._resolve_config_conflicts(self.tokenizer, model_cfg, dataloader_cfg) - - if dataset_cfg is not None: # TODO: Removed in version 1.1.0 - # For backward compatibility, reserve the dataset_cfg interface, remove it later - if dataloader_cfg.dataset_config_list is not None: - logger.warning("Outside dataset_cfg will override inner dataset_config_list") - dataloader_cfg.dataset_config_list = dataset_cfg - - self._dataloader = dataloader_cfg.build( - tokenizer=self.tokenizer, - dp_mesh=self.data_mesh["dp"], - global_batch_size=self.global_batch_size, - micro_batch_size=self.micro_batch_size, - seed=seed, - total_step=total_step, + self._dataloader = self._build_dataloader( + dataset_cfg, + dataloader_cfg, ) # streaming dataloader may override `total_step`, so we may move this check after `build_dataloader` later. @@ -355,7 +348,8 @@ def __init__( fsdp_config=fsdp_cfg, resume_cfg=resume_cfg, strict=strict_load, - intra_layer_micro_batch=intra_layer_micro_batch, + packed_samples_per_forward=self._packed_samples_per_forward, + domino_forward=self._domino_forward, ) self._lr_scheduler = self.build_lr_scheduler(lr_cfg) @@ -410,7 +404,8 @@ def from_config(cls, config: TrainerConfig) -> Self: profile_step=config.profile_step, profile_time=config.profile_time, profile_memory=config.profile_memory, - intra_layer_micro_batch=config.intra_layer_micro_batch, + packed_samples_per_forward=config.packed_samples_per_forward, + domino_forward=config.domino_forward, seed=config.seed, backend=config.dist_backend, debug=config.debug, @@ -620,7 +615,8 @@ def build_engine( optim_config: OptimConfig, fsdp_config: FSDPConfig, resume_cfg: ResumeConfig, - intra_layer_micro_batch: int = 1, + packed_samples_per_forward: int = 1, + domino_forward: bool = False, strict: bool = True, ): """Build the training engine for the transformer model. @@ -631,7 +627,6 @@ def build_engine( optim_config (OptimConfig): Optimizer configuration. fsdp_config (FSDPConfig): FSDP configuration for distributed training. resume_cfg (ResumeConfig | None): Resume configuration for continuing training. - intra_layer_micro_batch (int): Intra-layer micro batch size for gradient accumulation. strict (bool): Whether to strictly load model weights. Returns: @@ -642,14 +637,15 @@ def build_engine( optim_cfg=optim_config, fsdp_cfg=fsdp_config, model_cfg=model_config, - intra_layer_micro_batch=intra_layer_micro_batch, + packed_samples_per_forward=packed_samples_per_forward, ) else: engine = TrainEngine( # type: ignore optim_cfg=optim_config, fsdp_cfg=fsdp_config, model_cfg=model_config, - intra_layer_micro_batch=intra_layer_micro_batch, + packed_samples_per_forward=packed_samples_per_forward, + domino_forward=domino_forward, ) if model_path is not None and resume_cfg.resume_from is None: engine.from_hf(hf_path=model_path, strict=strict) @@ -1209,3 +1205,30 @@ def _setup_env(self): log_str += f"{k}: {v}\n" log_str += "==================================================" logger.info(log_str) + + def _build_dataloader( + self, + dataset_cfg: DatasetConfigList | None, + dataloader_cfg: DataloaderConfig, + ): + if self._domino_forward: + if not self._packed_samples_per_forward > 1: + raise ValueError( + "`domino_forward` is only valid for `packed_samples_per_forward` > 1g " + f"but got {self._packed_samples_per_forward}" + ) + + if dataset_cfg is not None: # TODO: Removed in version 1.1.0 + # For backward compatibility, reserve the dataset_cfg interface, remove it later + if dataloader_cfg.dataset_config_list is not None: + logger.warning("Outside dataset_cfg will override inner dataset_config_list") + dataloader_cfg.dataset_config_list = dataset_cfg + + dataloader = dataloader_cfg.build( + tokenizer=self.tokenizer, + dp_mesh=self.data_mesh["dp"], + global_batch_size=self.global_batch_size, + micro_batch_size=self.micro_batch_size, + seed=self._seed, + ) + return dataloader