Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ci/scripts/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def main():
optim_cfg = AdamWConfig(lr=6e-05)
lr_cfg = LRConfig(lr_type="cosine", lr_min=1e-6)
fsdp_cfg = FSDPConfig(
torch_compile=False, #get_device() == "cuda",
torch_compile=True, #get_device() == "cuda",
cpu_offload=False,
ep_size=moe_cfg.ep_size,
# hsdp_sharding_size=4,
Expand Down Expand Up @@ -260,7 +260,7 @@ def main():
loss_cfg=loss_cfg,
lr_cfg=lr_cfg,
tokenizer_path=QWEN3_MOE_PATH,
global_batch_size=16,
global_batch_size=8,
total_epoch=1,
work_dir=work_dir,
seed=0,
Expand Down
102 changes: 84 additions & 18 deletions xtuner/v1/data_proto/sequence_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class SequenceContext:
block_table: torch.Tensor | None = None
device: str | torch.device = "cpu" # TODO: 这个地方有点乱,到处是 device
position_ids: torch.LongTensor | None = None
cu_seq_lens_pad_len: int = 0 # 用于记录 cu_seq_lens pad 的长度,方便在 pad_cu_seq_lens 中恢复

# Intern-S1
image_flags: torch.LongTensor | None = None
Expand All @@ -57,6 +58,8 @@ def __post_init__(self):

self.position_ids = position_ids

self.pad_cu_seq_lens()

@classmethod
def from_input_ids(
cls,
Expand Down Expand Up @@ -98,15 +101,27 @@ def split(self, sequence_parallel_mesh: DeviceMesh | None = None) -> Self:
)
new_padding = pad_input_ids.numel() - self.input_ids.numel()
if new_padding > 0:
if self.num_padding > 0:
new_cu_seq_lens = self.cu_seq_lens_q.clone()
new_cu_seq_lens[-1] += new_padding
if self.cu_seq_lens_pad_len == 0:
if self.num_padding > 0:
new_cu_seq_lens = self.cu_seq_lens_q.clone()
new_cu_seq_lens[-1] += new_padding
else:
new_cu_seq_lens = torch.ones(
self.cu_seq_lens_q.numel() + 1, dtype=torch.int32, device=self.device
)
new_cu_seq_lens[: self.cu_seq_lens_q.numel()] = self.cu_seq_lens_q.clone()
new_cu_seq_lens[-1] = self.cu_seq_lens_q[-1] + new_padding
else:
new_cu_seq_lens = torch.ones(self.cu_seq_lens_q.numel() + 1, dtype=torch.int32, device=self.device)
new_cu_seq_lens[: self.cu_seq_lens_q.numel()] = self.cu_seq_lens_q.clone()
new_cu_seq_lens[-1] = self.cu_seq_lens_q[-1] + new_padding
new_cu_seq_lens = self.cu_seq_lens_q.clone()
if self.num_padding > 0:
new_cu_seq_lens[-(self.cu_seq_lens_pad_len + 1) :].add_(new_padding)
else:
new_cu_seq_lens[-self.cu_seq_lens_pad_len :].add_(new_padding)
# 有一个 cu_seq_lens 的元素从没有意义的 cu_seq_lens pad 变得有实际意义了(虽然对应的是 pad tokens)
self.cu_seq_lens_pad_len -= 1
else:
new_cu_seq_lens = self.cu_seq_lens_q.clone()

new_cu_seq_lens = cast(torch.IntTensor, new_cu_seq_lens)
new_max_length = cast(int, max(self.seq_lens_q.max().item(), new_padding))
num_non_padding = self.input_ids.shape[1] - self.num_padding
Expand Down Expand Up @@ -142,21 +157,26 @@ def pack(cls, sequence_context_list: list["SequenceContext"]):
num_padding = 0
device = []
inputs_embeds = []
for seq_ctx in sequence_context_list:
cu_seq_lens_is_padded = False
for i, seq_ctx in enumerate(sequence_context_list):
assert seq_ctx.sequence_parallel_mesh is None
# todo: support vlm model
assert seq_ctx.pixel_values is None
packed_input_ids.append(seq_ctx.input_ids)
cu_seq_lens_q.append(
seq_ctx.cu_seq_lens_q # type: ignore
if len(cu_seq_lens_q) == 0
else (seq_ctx.cu_seq_lens_q + cu_seq_lens_q[-1][-1])[1:]
)
cu_seq_lens_k.append(
seq_ctx.cu_seq_lens_k # type: ignore
if len(cu_seq_lens_k) == 0
else (seq_ctx.cu_seq_lens_k + cu_seq_lens_k[-1][-1])[1:]
)
if seq_ctx.cu_seq_lens_pad_len != 0:
new_cu_seq_lens_q = seq_ctx.cu_seq_lens_q.clone()
new_cu_seq_lens_k = seq_ctx.cu_seq_lens_k.clone()
new_cu_seq_lens_q = new_cu_seq_lens_q[: -seq_ctx.cu_seq_lens_pad_len]
new_cu_seq_lens_k = new_cu_seq_lens_k[: -seq_ctx.cu_seq_lens_pad_len]
cu_seq_lens_is_padded = True
else:
new_cu_seq_lens_q = seq_ctx.cu_seq_lens_q.clone()
new_cu_seq_lens_k = seq_ctx.cu_seq_lens_k.clone()
if i > 0:
new_cu_seq_lens_q = (new_cu_seq_lens_q + cu_seq_lens_q[-1][-1])[1:]
new_cu_seq_lens_k = (new_cu_seq_lens_k + cu_seq_lens_k[-1][-1])[1:]
cu_seq_lens_q.append(new_cu_seq_lens_q)
cu_seq_lens_k.append(new_cu_seq_lens_k)
max_length_q = max(max_length_q, seq_ctx.max_length_q)
max_length_k = max(max_length_k, seq_ctx.max_length_k)
num_padding += seq_ctx.num_padding
Expand All @@ -165,7 +185,7 @@ def pack(cls, sequence_context_list: list["SequenceContext"]):
inputs_embeds.append(seq_ctx.inputs_embeds)
assert len(set(device)) == 1, f"All sequence contexts must be on the same device. Got {set(device)}"

return cls(
out = cls(
input_ids=torch.cat(packed_input_ids, dim=1), # type: ignore
cu_seq_lens_q=torch.cat(cu_seq_lens_q, dim=0), # type: ignore
cu_seq_lens_k=torch.cat(cu_seq_lens_k, dim=0), # type: ignore
Expand All @@ -176,6 +196,11 @@ def pack(cls, sequence_context_list: list["SequenceContext"]):
inputs_embeds=torch.cat(inputs_embeds, dim=1) if inputs_embeds else None, # type: ignore
)

if cu_seq_lens_is_padded:
out = out.pad_cu_seq_lens()

return out

@property
def mask(self) -> torch.BoolTensor:
mask: torch.BoolTensor
Expand All @@ -189,14 +214,19 @@ def mask(self) -> torch.BoolTensor:

@property
def seq_lens_q(self) -> torch.LongTensor:
# 这里不能把 pad 的 cu_seq_lens slice 掉,否则又会把不同 shape 的 cu_seq_lens 暴露给 torch compile
return self.cu_seq_lens_q[1:] - self.cu_seq_lens_q[:-1] # type: ignore

@property
def seq_lens_k(self) -> torch.LongTensor:
# 这里不能把 pad 的 cu_seq_lens slice 掉,否则又会把不同 shape 的 cu_seq_lens 暴露给 torch compile
return self.cu_seq_lens_k[1:] - self.cu_seq_lens_k[:-1] # type: ignore

# TODO: 暂时没有用到,可能要删掉
def chunk(self, num_chunks: int) -> list[Self]:
# 暂时没用,就先不支持 cu_seq_lens pad 了
if self.pad_cu_seq_lens is not None and self.pad_cu_seq_lens != 0:
raise NotImplementedError
n = self.seq_lens_q.numel()
assert n // num_chunks
n_per_chunk = n // num_chunks
Expand Down Expand Up @@ -233,6 +263,42 @@ def set_sp_mesh(self, sp_mesh: DeviceMesh) -> Self:
self.sequence_parallel_mesh = sp_mesh
return self

def pad_cu_seq_lens(self) -> Self:
"""Pad the cumulative sequence lengths to the specified maximum length.

In large-scale training (1024 GPUs or more), varying data leads to different
cu_seq_lens shapes, causing frequent recompilations when using torch.compile
for optimization and significantly slowing down training.
To address this, we pad cu_seq_lens to a fixed shape (inferred from seq_len)
and slice out the padded content during attention calculation using torch.library.custom_op,
ensuring training behavior remains unaffected.

Args:
max_len: The target maximum length for padding.

Returns:
Self: The context with padded cumulative sequence lengths.
"""
current_len = self.cu_seq_lens_q.shape[0]
seq_len = self.input_ids.shape[1]
cu_seq_lens_max_len_estimation = seq_len // 64 + 1
if self.cu_seq_lens_pad_len != 0:
assert current_len == cu_seq_lens_max_len_estimation
# assert self.cu_seq_lens_pad_len == 0, "pad_cu_seq_lens should only be called once."
if current_len >= cu_seq_lens_max_len_estimation:
return self
pad_len = cu_seq_lens_max_len_estimation - current_len
self.cu_seq_lens_pad_len = pad_len
assert torch.equal(self.cu_seq_lens_q, self.cu_seq_lens_k), (
"cu_seq_lens_q and cu_seq_lens_k must be equal to pad."
)
pad_tensor = torch.full(
(pad_len,), self.cu_seq_lens_q[-1], dtype=self.cu_seq_lens_q.dtype, device=self.cu_seq_lens_q.device
)
self.cu_seq_lens_q = torch.cat([self.cu_seq_lens_q, pad_tensor], dim=0)
self.cu_seq_lens_k = torch.cat([self.cu_seq_lens_k, pad_tensor], dim=0)
return self

def to(self, device: torch.device | str):
"""Move all tensors in the context to the specified device.

Expand Down
5 changes: 5 additions & 0 deletions xtuner/v1/module/attention/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@ def fill_paged_kv_cache(
value_cache: torch.Tensor,
cu_seq_lens_q: torch.Tensor,
cu_seq_lens_k: torch.Tensor,
cu_seqlens_pad_len: int,
max_length_q: int,
max_length_k: int,
block_table: torch.Tensor,
) -> None:
if cu_seqlens_pad_len > 0:
cu_seq_lens_q = cu_seq_lens_q[:-cu_seqlens_pad_len]
cu_seq_lens_k = cu_seq_lens_k[:-cu_seqlens_pad_len]
bs = block_table.size(0)
from lmdeploy.pytorch.kernels import fill_kv_cache

Expand All @@ -40,6 +44,7 @@ def fill_paged_kv_cache_fake(
value_cache: torch.Tensor,
cu_seq_lens_q: torch.Tensor,
cu_seq_lens_k: torch.Tensor,
cu_seqlens_pad_len: int,
max_length_q: int,
max_length_k: int,
block_table: torch.Tensor,
Expand Down
5 changes: 5 additions & 0 deletions xtuner/v1/module/attention/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def prefilling(
past_key_values[self.layer_idx][1],
seq_ctx.cu_seq_lens_q,
seq_ctx.cu_seq_lens_k,
seq_ctx.cu_seq_lens_pad_len,
seq_ctx.max_length_q,
seq_ctx.max_length_k,
seq_ctx.block_table,
Expand All @@ -233,6 +234,7 @@ def prefilling(
value_states.transpose(1, 2).squeeze(0),
cu_seqlens_q=seq_ctx.cu_seq_lens_q,
cu_seqlens_k=seq_ctx.cu_seq_lens_k,
cu_seqlens_pad_len=seq_ctx.cu_seq_lens_pad_len,
max_seqlen_q=seq_ctx.max_length_q,
max_seqlen_k=seq_ctx.max_length_k,
dropout_p=self.dropout,
Expand All @@ -253,6 +255,8 @@ def decoding(
) -> torch.Tensor:
assert seq_ctx.block_table is not None
assert self.layer_idx is not None
if seq_ctx.cu_seq_lens_pad_len != 0:
raise NotImplementedError

input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
Expand Down Expand Up @@ -387,6 +391,7 @@ def forward(
value_states,
cu_seqlens_q=seq_ctx.cu_seq_lens_q,
cu_seqlens_k=seq_ctx.cu_seq_lens_k,
cu_seqlens_pad_len=seq_ctx.cu_seq_lens_pad_len,
max_seqlen_q=seq_ctx.max_length_q,
max_seqlen_k=seq_ctx.max_length_k,
window_size=self.window_size,
Expand Down
6 changes: 6 additions & 0 deletions xtuner/v1/module/attention/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def forward_training(
value_states.transpose(1, 2).squeeze(0),
cu_seqlens_q=attn_meta.cu_seq_lens_q,
cu_seqlens_k=attn_meta.cu_seq_lens_k,
cu_seqlens_pad_len=attn_meta.cu_seq_lens_pad_len,
max_seqlen_q=attn_meta.max_length_q,
max_seqlen_k=attn_meta.max_length_k,
dropout_p=self.dropout,
Expand All @@ -349,6 +350,8 @@ def prefilling(
seq_ctx: SequenceContext,
past_key_values: list[list[torch.Tensor]],
) -> torch.Tensor:
if seq_ctx.cu_seq_lens_pad_len != 0:
raise NotImplementedError
bsz, q_len, _ = hidden_states.size()

if self.q_lora_rank is None:
Expand Down Expand Up @@ -451,6 +454,8 @@ def decoding(
seq_ctx: SequenceContext,
past_key_values: list[list[torch.Tensor]],
) -> torch.Tensor:
if seq_ctx.cu_seq_lens_pad_len != 0:
raise NotImplementedError
bsz, q_len, _ = hidden_states.size()

if self.q_lora_rank is None:
Expand Down Expand Up @@ -606,6 +611,7 @@ def forward(
value_states.transpose(1, 2).squeeze(0),
cu_seqlens_q=seq_ctx.cu_seq_lens_q,
cu_seqlens_k=seq_ctx.cu_seq_lens_k,
cu_seqlens_pad_len=seq_ctx.cu_seq_lens_pad_len,
max_seqlen_q=seq_ctx.max_length_q,
max_seqlen_k=seq_ctx.max_length_k,
dropout_p=self.dropout,
Expand Down
5 changes: 4 additions & 1 deletion xtuner/v1/ops/attn_imp.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,10 @@ def flash_attention(q, k, v, window_size=(-1, -1), s_aux=None, **kwargs) -> torc
attention_output = flash_attn_varlen_func(q, k, v, **kwargs)
else:
cu_seqlens_q = kwargs["cu_seqlens_q"]
attention_output = flash_sink_attn_varlen_func(q, k, v, s_aux, cu_seqlens_q, window_size[0])
cu_seqlens_pad_len = kwargs["cu_seqlens_pad_len"]
attention_output = flash_sink_attn_varlen_func(
q, k, v, s_aux, cu_seqlens_q, cu_seqlens_pad_len, window_size[0]
)
return attention_output[None]


Expand Down
2 changes: 2 additions & 0 deletions xtuner/v1/ops/flash_attn/flash_sink_varlen_attn_gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ def forward(
v: torch.Tensor,
sink: torch.Tensor,
cu_seqlen: torch.Tensor,
cu_seqlens_pad_len: int,
window_size=None,
):
if window_size == -1:
Expand All @@ -492,6 +493,7 @@ def forward(
)

ctx.save_for_backward(q, k, v, o, lse)
ctx.cu_seqlens_pad_len = cu_seqlens_pad_len
ctx.sink = sink
ctx.window_size = window_size
ctx.cu_seqlen = cu_seqlen
Expand Down
Loading
Loading