diff --git a/ci/scripts/test_sft_trainer.py b/ci/scripts/test_sft_trainer.py index 2e60523cf..bc0cd3b58 100644 --- a/ci/scripts/test_sft_trainer.py +++ b/ci/scripts/test_sft_trainer.py @@ -231,7 +231,7 @@ def main(): optim_cfg = AdamWConfig(lr=6e-05) lr_cfg = LRConfig(lr_type="cosine", lr_min=1e-6) fsdp_cfg = FSDPConfig( - torch_compile=False, #get_device() == "cuda", + torch_compile=True, #get_device() == "cuda", cpu_offload=False, ep_size=moe_cfg.ep_size, # hsdp_sharding_size=4, @@ -260,7 +260,7 @@ def main(): loss_cfg=loss_cfg, lr_cfg=lr_cfg, tokenizer_path=QWEN3_MOE_PATH, - global_batch_size=16, + global_batch_size=8, total_epoch=1, work_dir=work_dir, seed=0, diff --git a/xtuner/v1/data_proto/sequence_context.py b/xtuner/v1/data_proto/sequence_context.py index cbd7dbebe..2ec9a9409 100644 --- a/xtuner/v1/data_proto/sequence_context.py +++ b/xtuner/v1/data_proto/sequence_context.py @@ -35,6 +35,7 @@ class SequenceContext: block_table: torch.Tensor | None = None device: str | torch.device = "cpu" # TODO: 这个地方有点乱,到处是 device position_ids: torch.LongTensor | None = None + cu_seq_lens_pad_len: int = 0 # 用于记录 cu_seq_lens pad 的长度,方便在 pad_cu_seq_lens 中恢复 # Intern-S1 image_flags: torch.LongTensor | None = None @@ -57,6 +58,8 @@ def __post_init__(self): self.position_ids = position_ids + self.pad_cu_seq_lens() + @classmethod def from_input_ids( cls, @@ -98,15 +101,27 @@ def split(self, sequence_parallel_mesh: DeviceMesh | None = None) -> Self: ) new_padding = pad_input_ids.numel() - self.input_ids.numel() if new_padding > 0: - if self.num_padding > 0: - new_cu_seq_lens = self.cu_seq_lens_q.clone() - new_cu_seq_lens[-1] += new_padding + if self.cu_seq_lens_pad_len == 0: + if self.num_padding > 0: + new_cu_seq_lens = self.cu_seq_lens_q.clone() + new_cu_seq_lens[-1] += new_padding + else: + new_cu_seq_lens = torch.ones( + self.cu_seq_lens_q.numel() + 1, dtype=torch.int32, device=self.device + ) + new_cu_seq_lens[: self.cu_seq_lens_q.numel()] = self.cu_seq_lens_q.clone() + new_cu_seq_lens[-1] = self.cu_seq_lens_q[-1] + new_padding else: - new_cu_seq_lens = torch.ones(self.cu_seq_lens_q.numel() + 1, dtype=torch.int32, device=self.device) - new_cu_seq_lens[: self.cu_seq_lens_q.numel()] = self.cu_seq_lens_q.clone() - new_cu_seq_lens[-1] = self.cu_seq_lens_q[-1] + new_padding + new_cu_seq_lens = self.cu_seq_lens_q.clone() + if self.num_padding > 0: + new_cu_seq_lens[-(self.cu_seq_lens_pad_len + 1) :].add_(new_padding) + else: + new_cu_seq_lens[-self.cu_seq_lens_pad_len :].add_(new_padding) + # 有一个 cu_seq_lens 的元素从没有意义的 cu_seq_lens pad 变得有实际意义了(虽然对应的是 pad tokens) + self.cu_seq_lens_pad_len -= 1 else: new_cu_seq_lens = self.cu_seq_lens_q.clone() + new_cu_seq_lens = cast(torch.IntTensor, new_cu_seq_lens) new_max_length = cast(int, max(self.seq_lens_q.max().item(), new_padding)) num_non_padding = self.input_ids.shape[1] - self.num_padding @@ -142,21 +157,26 @@ def pack(cls, sequence_context_list: list["SequenceContext"]): num_padding = 0 device = [] inputs_embeds = [] - for seq_ctx in sequence_context_list: + cu_seq_lens_is_padded = False + for i, seq_ctx in enumerate(sequence_context_list): assert seq_ctx.sequence_parallel_mesh is None # todo: support vlm model assert seq_ctx.pixel_values is None packed_input_ids.append(seq_ctx.input_ids) - cu_seq_lens_q.append( - seq_ctx.cu_seq_lens_q # type: ignore - if len(cu_seq_lens_q) == 0 - else (seq_ctx.cu_seq_lens_q + cu_seq_lens_q[-1][-1])[1:] - ) - cu_seq_lens_k.append( - seq_ctx.cu_seq_lens_k # type: ignore - if len(cu_seq_lens_k) == 0 - else (seq_ctx.cu_seq_lens_k + cu_seq_lens_k[-1][-1])[1:] - ) + if seq_ctx.cu_seq_lens_pad_len != 0: + new_cu_seq_lens_q = seq_ctx.cu_seq_lens_q.clone() + new_cu_seq_lens_k = seq_ctx.cu_seq_lens_k.clone() + new_cu_seq_lens_q = new_cu_seq_lens_q[: -seq_ctx.cu_seq_lens_pad_len] + new_cu_seq_lens_k = new_cu_seq_lens_k[: -seq_ctx.cu_seq_lens_pad_len] + cu_seq_lens_is_padded = True + else: + new_cu_seq_lens_q = seq_ctx.cu_seq_lens_q.clone() + new_cu_seq_lens_k = seq_ctx.cu_seq_lens_k.clone() + if i > 0: + new_cu_seq_lens_q = (new_cu_seq_lens_q + cu_seq_lens_q[-1][-1])[1:] + new_cu_seq_lens_k = (new_cu_seq_lens_k + cu_seq_lens_k[-1][-1])[1:] + cu_seq_lens_q.append(new_cu_seq_lens_q) + cu_seq_lens_k.append(new_cu_seq_lens_k) max_length_q = max(max_length_q, seq_ctx.max_length_q) max_length_k = max(max_length_k, seq_ctx.max_length_k) num_padding += seq_ctx.num_padding @@ -165,7 +185,7 @@ def pack(cls, sequence_context_list: list["SequenceContext"]): inputs_embeds.append(seq_ctx.inputs_embeds) assert len(set(device)) == 1, f"All sequence contexts must be on the same device. Got {set(device)}" - return cls( + out = cls( input_ids=torch.cat(packed_input_ids, dim=1), # type: ignore cu_seq_lens_q=torch.cat(cu_seq_lens_q, dim=0), # type: ignore cu_seq_lens_k=torch.cat(cu_seq_lens_k, dim=0), # type: ignore @@ -176,6 +196,11 @@ def pack(cls, sequence_context_list: list["SequenceContext"]): inputs_embeds=torch.cat(inputs_embeds, dim=1) if inputs_embeds else None, # type: ignore ) + if cu_seq_lens_is_padded: + out = out.pad_cu_seq_lens() + + return out + @property def mask(self) -> torch.BoolTensor: mask: torch.BoolTensor @@ -189,14 +214,19 @@ def mask(self) -> torch.BoolTensor: @property def seq_lens_q(self) -> torch.LongTensor: + # 这里不能把 pad 的 cu_seq_lens slice 掉,否则又会把不同 shape 的 cu_seq_lens 暴露给 torch compile return self.cu_seq_lens_q[1:] - self.cu_seq_lens_q[:-1] # type: ignore @property def seq_lens_k(self) -> torch.LongTensor: + # 这里不能把 pad 的 cu_seq_lens slice 掉,否则又会把不同 shape 的 cu_seq_lens 暴露给 torch compile return self.cu_seq_lens_k[1:] - self.cu_seq_lens_k[:-1] # type: ignore # TODO: 暂时没有用到,可能要删掉 def chunk(self, num_chunks: int) -> list[Self]: + # 暂时没用,就先不支持 cu_seq_lens pad 了 + if self.pad_cu_seq_lens is not None and self.pad_cu_seq_lens != 0: + raise NotImplementedError n = self.seq_lens_q.numel() assert n // num_chunks n_per_chunk = n // num_chunks @@ -233,6 +263,42 @@ def set_sp_mesh(self, sp_mesh: DeviceMesh) -> Self: self.sequence_parallel_mesh = sp_mesh return self + def pad_cu_seq_lens(self) -> Self: + """Pad the cumulative sequence lengths to the specified maximum length. + + In large-scale training (1024 GPUs or more), varying data leads to different + cu_seq_lens shapes, causing frequent recompilations when using torch.compile + for optimization and significantly slowing down training. + To address this, we pad cu_seq_lens to a fixed shape (inferred from seq_len) + and slice out the padded content during attention calculation using torch.library.custom_op, + ensuring training behavior remains unaffected. + + Args: + max_len: The target maximum length for padding. + + Returns: + Self: The context with padded cumulative sequence lengths. + """ + current_len = self.cu_seq_lens_q.shape[0] + seq_len = self.input_ids.shape[1] + cu_seq_lens_max_len_estimation = seq_len // 64 + 1 + if self.cu_seq_lens_pad_len != 0: + assert current_len == cu_seq_lens_max_len_estimation + # assert self.cu_seq_lens_pad_len == 0, "pad_cu_seq_lens should only be called once." + if current_len >= cu_seq_lens_max_len_estimation: + return self + pad_len = cu_seq_lens_max_len_estimation - current_len + self.cu_seq_lens_pad_len = pad_len + assert torch.equal(self.cu_seq_lens_q, self.cu_seq_lens_k), ( + "cu_seq_lens_q and cu_seq_lens_k must be equal to pad." + ) + pad_tensor = torch.full( + (pad_len,), self.cu_seq_lens_q[-1], dtype=self.cu_seq_lens_q.dtype, device=self.cu_seq_lens_q.device + ) + self.cu_seq_lens_q = torch.cat([self.cu_seq_lens_q, pad_tensor], dim=0) + self.cu_seq_lens_k = torch.cat([self.cu_seq_lens_k, pad_tensor], dim=0) + return self + def to(self, device: torch.device | str): """Move all tensors in the context to the specified device. diff --git a/xtuner/v1/module/attention/kv_cache.py b/xtuner/v1/module/attention/kv_cache.py index 6ea65b8c0..3788e9125 100644 --- a/xtuner/v1/module/attention/kv_cache.py +++ b/xtuner/v1/module/attention/kv_cache.py @@ -10,10 +10,14 @@ def fill_paged_kv_cache( value_cache: torch.Tensor, cu_seq_lens_q: torch.Tensor, cu_seq_lens_k: torch.Tensor, + cu_seqlens_pad_len: int, max_length_q: int, max_length_k: int, block_table: torch.Tensor, ) -> None: + if cu_seqlens_pad_len > 0: + cu_seq_lens_q = cu_seq_lens_q[:-cu_seqlens_pad_len] + cu_seq_lens_k = cu_seq_lens_k[:-cu_seqlens_pad_len] bs = block_table.size(0) from lmdeploy.pytorch.kernels import fill_kv_cache @@ -40,6 +44,7 @@ def fill_paged_kv_cache_fake( value_cache: torch.Tensor, cu_seq_lens_q: torch.Tensor, cu_seq_lens_k: torch.Tensor, + cu_seqlens_pad_len: int, max_length_q: int, max_length_k: int, block_table: torch.Tensor, diff --git a/xtuner/v1/module/attention/mha.py b/xtuner/v1/module/attention/mha.py index fed5091ac..03ddde80e 100644 --- a/xtuner/v1/module/attention/mha.py +++ b/xtuner/v1/module/attention/mha.py @@ -216,6 +216,7 @@ def prefilling( past_key_values[self.layer_idx][1], seq_ctx.cu_seq_lens_q, seq_ctx.cu_seq_lens_k, + seq_ctx.cu_seq_lens_pad_len, seq_ctx.max_length_q, seq_ctx.max_length_k, seq_ctx.block_table, @@ -233,6 +234,7 @@ def prefilling( value_states.transpose(1, 2).squeeze(0), cu_seqlens_q=seq_ctx.cu_seq_lens_q, cu_seqlens_k=seq_ctx.cu_seq_lens_k, + cu_seqlens_pad_len=seq_ctx.cu_seq_lens_pad_len, max_seqlen_q=seq_ctx.max_length_q, max_seqlen_k=seq_ctx.max_length_k, dropout_p=self.dropout, @@ -253,6 +255,8 @@ def decoding( ) -> torch.Tensor: assert seq_ctx.block_table is not None assert self.layer_idx is not None + if seq_ctx.cu_seq_lens_pad_len != 0: + raise NotImplementedError input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) @@ -387,6 +391,7 @@ def forward( value_states, cu_seqlens_q=seq_ctx.cu_seq_lens_q, cu_seqlens_k=seq_ctx.cu_seq_lens_k, + cu_seqlens_pad_len=seq_ctx.cu_seq_lens_pad_len, max_seqlen_q=seq_ctx.max_length_q, max_seqlen_k=seq_ctx.max_length_k, window_size=self.window_size, diff --git a/xtuner/v1/module/attention/mla.py b/xtuner/v1/module/attention/mla.py index 192209bb3..d33353fe2 100644 --- a/xtuner/v1/module/attention/mla.py +++ b/xtuner/v1/module/attention/mla.py @@ -324,6 +324,7 @@ def forward_training( value_states.transpose(1, 2).squeeze(0), cu_seqlens_q=attn_meta.cu_seq_lens_q, cu_seqlens_k=attn_meta.cu_seq_lens_k, + cu_seqlens_pad_len=attn_meta.cu_seq_lens_pad_len, max_seqlen_q=attn_meta.max_length_q, max_seqlen_k=attn_meta.max_length_k, dropout_p=self.dropout, @@ -349,6 +350,8 @@ def prefilling( seq_ctx: SequenceContext, past_key_values: list[list[torch.Tensor]], ) -> torch.Tensor: + if seq_ctx.cu_seq_lens_pad_len != 0: + raise NotImplementedError bsz, q_len, _ = hidden_states.size() if self.q_lora_rank is None: @@ -451,6 +454,8 @@ def decoding( seq_ctx: SequenceContext, past_key_values: list[list[torch.Tensor]], ) -> torch.Tensor: + if seq_ctx.cu_seq_lens_pad_len != 0: + raise NotImplementedError bsz, q_len, _ = hidden_states.size() if self.q_lora_rank is None: @@ -606,6 +611,7 @@ def forward( value_states.transpose(1, 2).squeeze(0), cu_seqlens_q=seq_ctx.cu_seq_lens_q, cu_seqlens_k=seq_ctx.cu_seq_lens_k, + cu_seqlens_pad_len=seq_ctx.cu_seq_lens_pad_len, max_seqlen_q=seq_ctx.max_length_q, max_seqlen_k=seq_ctx.max_length_k, dropout_p=self.dropout, diff --git a/xtuner/v1/ops/attn_imp.py b/xtuner/v1/ops/attn_imp.py index 45e4364bc..c15da39ba 100644 --- a/xtuner/v1/ops/attn_imp.py +++ b/xtuner/v1/ops/attn_imp.py @@ -210,7 +210,10 @@ def flash_attention(q, k, v, window_size=(-1, -1), s_aux=None, **kwargs) -> torc attention_output = flash_attn_varlen_func(q, k, v, **kwargs) else: cu_seqlens_q = kwargs["cu_seqlens_q"] - attention_output = flash_sink_attn_varlen_func(q, k, v, s_aux, cu_seqlens_q, window_size[0]) + cu_seqlens_pad_len = kwargs["cu_seqlens_pad_len"] + attention_output = flash_sink_attn_varlen_func( + q, k, v, s_aux, cu_seqlens_q, cu_seqlens_pad_len, window_size[0] + ) return attention_output[None] diff --git a/xtuner/v1/ops/flash_attn/flash_sink_varlen_attn_gpt_oss.py b/xtuner/v1/ops/flash_attn/flash_sink_varlen_attn_gpt_oss.py index aa0614cd8..20165c214 100644 --- a/xtuner/v1/ops/flash_attn/flash_sink_varlen_attn_gpt_oss.py +++ b/xtuner/v1/ops/flash_attn/flash_sink_varlen_attn_gpt_oss.py @@ -480,6 +480,7 @@ def forward( v: torch.Tensor, sink: torch.Tensor, cu_seqlen: torch.Tensor, + cu_seqlens_pad_len: int, window_size=None, ): if window_size == -1: @@ -492,6 +493,7 @@ def forward( ) ctx.save_for_backward(q, k, v, o, lse) + ctx.cu_seqlens_pad_len = cu_seqlens_pad_len ctx.sink = sink ctx.window_size = window_size ctx.cu_seqlen = cu_seqlen diff --git a/xtuner/v1/ops/flash_attn/gpu.py b/xtuner/v1/ops/flash_attn/gpu.py index 6a1b3fc06..f73439cf7 100644 --- a/xtuner/v1/ops/flash_attn/gpu.py +++ b/xtuner/v1/ops/flash_attn/gpu.py @@ -9,6 +9,7 @@ def _flash_attn_varlen_forward_v3( v: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, + cu_seqlens_pad_len: int, max_seqlen_q: int, max_seqlen_k: int, softmax_scale: float, @@ -17,6 +18,9 @@ def _flash_attn_varlen_forward_v3( window_size_right: int = -1, softcap: float = 0.0, # 0.0 means deactivated ) -> tuple[torch.Tensor, torch.Tensor]: + if cu_seqlens_pad_len > 0: + cu_seqlens_q = cu_seqlens_q[:-cu_seqlens_pad_len] + cu_seqlens_k = cu_seqlens_k[:-cu_seqlens_pad_len] out, softmax_lse, *rest = flash_attn_3_cuda.fwd( q, k, @@ -63,6 +67,7 @@ def _flash_attn_varlen_forward_v3_fake( v: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, + cu_seqlens_pad_len: int, max_seqlen_q: int, max_seqlen_k: int, softmax_scale: float, @@ -89,6 +94,7 @@ def _flash_attn_varlen_backward_v3( softmax_lse: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, + cu_seqlens_pad_len: int, max_seqlen_q: int, max_seqlen_k: int, dq: torch.Tensor, @@ -101,6 +107,9 @@ def _flash_attn_varlen_backward_v3( softcap: float = 0.0, deterministic: bool = False, ) -> None: + if cu_seqlens_pad_len > 0: + cu_seqlens_q = cu_seqlens_q[:-cu_seqlens_pad_len] + cu_seqlens_k = cu_seqlens_k[:-cu_seqlens_pad_len] flash_attn_3_cuda.bwd( dout, q, @@ -137,6 +146,7 @@ def _flash_attn_varlen_backward_v3_fake( softmax_lse: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, + cu_seqlens_pad_len: int, max_seqlen_q: int, max_seqlen_k: int, dq: torch.Tensor, @@ -161,6 +171,7 @@ def forward( v, cu_seqlens_q, cu_seqlens_k, + cu_seqlens_pad_len, max_seqlen_q, max_seqlen_k, softmax_scale, @@ -182,6 +193,7 @@ def forward( v, cu_seqlens_q, cu_seqlens_k, + cu_seqlens_pad_len, max_seqlen_q, max_seqlen_k, softmax_scale, @@ -193,6 +205,7 @@ def forward( ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k) ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_k = max_seqlen_k + ctx.cu_seqlens_pad_len = cu_seqlens_pad_len ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size_left = window_size[0] @@ -218,6 +231,7 @@ def backward(ctx, dout, *args): softmax_lse, cu_seqlens_q, cu_seqlens_k, + ctx.cu_seqlens_pad_len, ctx.max_seqlen_q, ctx.max_seqlen_k, dq, @@ -233,7 +247,7 @@ def backward(ctx, dout, *args): dq = dq[..., : q.shape[-1]] # We could have padded the head dimension dk = dk[..., : k.shape[-1]] dv = dv[..., : v.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None def gpu_flash_varlen_attn_v3( @@ -242,6 +256,7 @@ def gpu_flash_varlen_attn_v3( v, cu_seqlens_q, cu_seqlens_k, + cu_seqlens_pad_len, max_seqlen_q, max_seqlen_k, dropout_p=0.0, @@ -263,6 +278,7 @@ def gpu_flash_varlen_attn_v3( v, cu_seqlens_q, cu_seqlens_k, + cu_seqlens_pad_len, max_seqlen_q, max_seqlen_k, softmax_scale,