Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/assets/images/float8/fp8_autograd.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/assets/images/float8/fp8_granularity.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/assets/images/float8/fp8_overall.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
171 changes: 171 additions & 0 deletions docs/zh_cn/pretrain_sft/advanced_tutorial/float8.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# FP8 训练

Hopper 架构的 GPU 引入了新的数据类型 FP8(8-bit floating point),可显著提升矩阵乘法的计算效率。下面将介绍如何在 XTuner 中使用 FP8 进行训练。

## 为什么选择 FP8

1. 降低通信量、提升通信速度:XTuner V1 基于 PyTorch FSDP 开发。相较 BF16,使用 FP8 通信可显著缓解 FSDP 通信量大的固有瓶颈。
2. 提升矩阵乘计算效率。
3. 节约显存:与 BF16 训练相比,FP8 训练中 Linear 和 Grouped Linear 层 PyTorch 计算图中保存的是 FP8 Tensor 而非 BF16 Tensor。可大幅降低计算图的显存开销。
4. 精度具有保证:为了避免陷入“你别管我对不对,就问你快不快”的窘境,XTuner 采用了细粒度的 FP8 量化模式,在保证训练精度的前提下优化了训练速度。

## BenchMark

并行配置 | 训练配置 | SeqLen | GlobalBatchSize | GPUNum | TimePerIter (s) | Tokens/GPU/Second
-- | -- | -- | -- | -- | -- | --
tp1, ep1, pp1 | BF16 | 65536 | 256 | 256 | 32.77 | 2000
tp1, ep1, pp1 | FP8 | 65536 | 256 | 256 | 26.75 | 2450

[profile data](https://drive.google.com/file/d/1TW-DbsUCckKJS36-5YHJo73L1Nvlpv6h/view?usp=sharing)

## 如何使用 XTuner FP8 训练

### 环境准备

首先检查 GPU 是否为 Hopper 及以上架构:

```python
import torch

print(torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9))
```

安装 `AdaptiveGEMM` 库:

```{code-block} shell
:caption: 安装 AdaptiveGEMM

pip install git+https://github.com/InternLM/AdaptiveGEMM.git@main
```

### 使用 XTuner 的 Linear 和 Grouped Linear 模块

```python
import torch
from xtuner.v1.float8 import TileWiseFloat8Linear, TileWiseFloat8GroupedLinear

# (bs, seq, dim)
x = torch.randn(1, 32768, 1024, device='cuda', dtype=torch.bfloat16, requires_grad=True)
linear = TileWiseFloat8Linear(in_features=1024, out_features=2048, bias=False, device='cuda', dtype=torch.bfloat16)
out = linear(x)
out.mean().backward()

x = torch.randn(1, 32768, 1024, device='cuda', dtype=torch.bfloat16)
grouped_linear = TileWiseFloat8GroupedLinear(in_features=1024, out_features=2048, num_routed_experts=4, moe_bias=False).to(dtype=torch.bfloat16, device='cuda')
tokens_per_expert = torch.tensor([1000, 4000, 6000, 32768 - 11000], device='cuda')
out = grouped_linear(x, tokens_per_expert)
out.mean().backward()
```

```{tip}
:class: margin

1. 单测 `TileWiseFloat8Linear` 与 `TileWiseFloat8GroupedLinear` 难以体现端到端的理想速度,因为对权重的量化较耗时。需结合 FSDP 才能达到最佳训练效率(可用 FP8 通信,且每个 rank 仅量化自身切片参数,使权重量化开销可忽略)。用法见下一小节。

2. 首次执行 fwd + bwd 速度较慢是正常现象,再次执行速度就会恢复正常。
```

### 使用 XTuner FP8 训练

第一步,参考 [选择模型](model-cfg) 一节构建 model_cfg 实例,并配置 float8_cfg:

```{code-block} python
:caption: 构建模型配置

from xtuner.v1.model import Qwen3Dense8BConfig
from xtuner.v1.float8.config import Float8Config, ScalingGranularity

float8_cfg = Float8Config(
scaling_granularity_gemm=ScalingGranularity.TILEWISE,
scaling_granularity_grouped_gemm=ScalingGranularity.TILEWISE,
)

model_cfg = Qwen3Dense8BConfig(float8_cfg=float8_cfg)
```

第二步,参考 [使用 Trainer 进行大模型微调](trainer-sft) 一节后续内容构建 `trainer`。

第三步,启动训练,完整代码如下:

````{toggle}
```diff
from xtuner.v1.model import Qwen3Dense8BConfig
from xtuner.v1.config import LRConfig, AdamWConfig
from xtuner.v1.train import Trainer
+ from xtuner.v1.float8.config import Float8Config, ScalingGranularity

+ float8_cfg = Float8Config(
+ scaling_granularity_gemm=ScalingGranularity.TILEWISE,
+ scaling_granularity_grouped_gemm=ScalingGranularity.TILEWISE,
+ )

- model_cfg = Qwen3Dense8BConfig()
+ model_cfg = Qwen3Dense8BConfig(float8_cfg=float8_cfg)
dataset_cfg = []
optim_cfg = AdamWConfig(lr=6e-05)
lr_cfg = LRConfig(lr_type="cosine", lr_min=1e-6)

load_from = "<模型路径>" # 如果是微调模式,必须指定,否则会重头训练
tokenizer = "<tokenizer 路径,通常和模型路径一致>"

trainer = Trainer(
model_cfg=model_cfg,
tokenizer_path=tokenizer,
load_from=load_from,
optim_cfg=optim_cfg,
dataset_cfg=dataset_cfg,
lr_cfg=lr_cfg,
)
trainer.fit()
```
````

写完上述 python 脚本后,命名为 `toy_train.py`,我们就能通过 `torchrun` 启动分布式训练了:

```{code-block} bash
:caption: 启动训练

torchrun --nproc_per_node=8 toy_train.py
```

恭喜你,已经自己实现了一个 XTuner 的 FP8 训练入口!你可以在这个脚本里尽情地发挥,定制化自己的训练参数。

## XTuner FP8 训练策略

### FP8 量化

XTuner 采用对称量化:

```python
s = absmax(x) / q_max
q = clip(x / s, q_min, q_max)
```

XTuner 支持以下三种量化粒度:Tensor-Wise, Block-Wise 和 Tile-Wise,如下图所示。相同颜色的元素共享同一个量化参数。在实际使用中,block_size 和 tile_size 一般会设置为 128。

![fp8_granularity](../../../assets/images/float8/fp8_granularity.png)

XTuner 采用了 "just-in-time scaling" 的量化方法,该策略根据输入 Tensor 实时计算出对应的缩放因子 (scales) 。

### FP8 算子

我们基于 [DeepGemm](https://github.com/deepseek-ai/DeepGEMM/tree/3b3783d06cd4d06ac4ba048633e604151d1ee535) 扩展了以下两项与 Grouped GEMM 相关的能力(感谢 DeepSeek 团队对开源社区的贡献):

1. 支持 Group Size M 非 128x 的情况以满足实际训练需求,细节见我们的论文 [TMA-Adaptive FP8 Grouped GEMM](https://arxiv.org/abs/2508.16584) 。
2. 支持 Grouped Linear 的 Backward 算子 Group K GEMM。

需要额外说的是,为确保性能符合预期,Group K GEMM 算子要求 Group Size K 为 128 的倍数,这对我们的 AutoGrad 涉及提出了更高的要求,详情请见下一小节。

### FP8 混合精度训练

XTuner FP8 参考了 DeepSeek V3 中的 FP8 训练策略,如下图所示。对于主要的计算密集型算子(例如 GEMM 和 Grouped GEMM),我们采用了 FP8 来加速计算。算子接受 FP8 的输入并得到 BF16 的输出。下图中三个 Linear Module 涉及到的 GEMM 计算均使用 FP8 计算,我们将其命名为 Fprop (Forward Pass), Dgrad (Activation Backward
Pass) 和 Wgrad (Weight Backward Pass)。与 BF16 相比,FP8 让 GEMM 的理论耗时减半。同时, PyTorch 计算图中只需保存 FP8 Tensor 即可完成 Backward 计算,进而节约了计算图的显存开销。

![fp8_overall](../../../assets/images/float8/fp8_overall.png)

进一步地,XTuner 细化了 FP8 Linear 和 Grouped Linear 的 AutoGrad 计算逻辑。这里我们以较为复杂的 Grouped Linear 为例展开介绍。如下图所示,在 Forward 和 Backward dx 计算中,我们对激活值采用了 Tile-Wise 的量化策略,对模型权重采用了 Block-Wise 的量化策略。而在 Backward dw 计算中,为了追求性能优势,我们对 Grad Output 采用了 Tile-Wise 的量化策略,而对 Forward 的输入 X 采用了 Block-Wise 的量化策略。

图中有一个需要特殊说明的地方是,在 Backward dw 的计算中,我们对 Forward 时的输入 X 进行了 Transpose + Block-Wise FP8 Quantize + Pad to 128x + Transpose 的操作,这是因为,为了达到理想的计算效率,FP8 GEMM 算子和 Grouped GEMM 算子要求 lhs 矩阵的 layout 是 Row-Major 的,而 rhs 矩阵则是 Column-Major。同时,如上一小节所述,Group K GEMM 算子要求 Group Size K 可以被 128 整除,我们把 Transpose + Block-Wise FP8 Quantize + Pad to 128x 融合成了一个算子以提高计算效率。

![fp8_overall](../../../assets/images/float8/fp8_autograd.png)

1 change: 1 addition & 0 deletions docs/zh_cn/pretrain_sft/advanced_tutorial/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
model.md
dataset.md
loss.md
float8.md
profile.md
2 changes: 2 additions & 0 deletions docs/zh_cn/pretrain_sft/tutorial/llm_trainer.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
(trainer-sft)=
# 使用 Trainer 进行大模型微调

在之前的[教程](../../get_started/sft.md)中我们通过命令行,用最简单的方式启动了一次微调训练,而在这快速启动的背后,则是 XTuner 的核心组件 `Trainer` 在发挥作用。这一节我们将初识 Trainer,用更加细力度的方式控制训练的各个环节。


(model-cfg)=
## 选择模型:

Trainer 通过配置文件的方式来构建模型,我们以 XTuner 内置支持的 `Qwen3 8B` 为例,来快速获取一个模型配置实例
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,14 @@ def mock_experts(hidden_states: torch.Tensor, tokens_per_exprts: torch.Tensor):


class TestMoETorchAll2AllDispatcher(DistributedTestBase):
@parametrize.parametrize("dtype,device", [(torch.bfloat16, "cuda")])
def test_dispatch_and_combine(self, dtype, device):
@parametrize.parametrize(
"dtype,device,async_op",
[
(torch.bfloat16, "cuda", False),
(torch.bfloat16, "cuda", True),
]
)
def test_dispatch_and_combine(self, dtype, device, async_op):
self.create_pg(device)
num_experts = 16

Expand Down Expand Up @@ -52,41 +58,63 @@ def test_dispatch_and_combine(self, dtype, device):
dispatcher=all2all_dispatcher,
hidden_states=hidden_states,
topk_ids=topk_idx,
topk_weights=topk_weights
topk_weights=topk_weights,
async_op=async_op,
)

self.assertTrue(torch.allclose(noep_results, all2all_results, atol=1e-6, rtol=1e-4))

def _dispatcher_call(
self,
dispatcher: GenericDispatcher,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor
self,
dispatcher: GenericDispatcher,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
async_op: bool=False
):
pre_dispatched = dispatcher.dispatch_preprocess(
hidden_states=hidden_states,
topk_ids=topk_ids,
topk_weights=topk_weights,
async_op=async_op,
)
dispatched = dispatcher.dispatch(
pre_dispatched=pre_dispatched,
topk_weights=topk_weights,
decoding=False,
async_op=async_op,
)
post_dispatched = dispatcher.dispatch_postprocess(
pre_dispatched=pre_dispatched,
dispatched=dispatched,
async_op=async_op,
)
experts_results = mock_experts(
hidden_states=dispatched["hidden_states"],
tokens_per_exprts=dispatched["tokens_per_experts"],
hidden_states=post_dispatched["hidden_states"],
tokens_per_exprts=post_dispatched["tokens_per_expert"],
)
combined = dispatcher.combine(hidden_states=experts_results,
pre_combined = dispatcher.combine_preprocess(
hidden_states=experts_results,
pre_dispatched=pre_dispatched,
dispatched=dispatched,
decoding=False,
post_dispatched=post_dispatched,
async_op=async_op,
)
combined = dispatcher.combine(
pre_dispatched=pre_dispatched,
dispatched=dispatched,
post_dispatched=post_dispatched,
pre_combined=pre_combined,
async_op=async_op,
)
return dispatcher.combine_post_process(
post_combined = dispatcher.combine_postprocess(
pre_dispatched=pre_dispatched,
dispatch_result=dispatched,
combine_result=combined,
dispatched=dispatched,
post_dispatched=post_dispatched,
pre_combined=pre_combined,
combined=combined,
async_op=async_op,
)
return post_combined["hidden_states"]

@property
def world_size(self) -> int:
Expand Down
5 changes: 4 additions & 1 deletion xtuner/v1/float8/float8_gmm_tile_wise.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,12 @@ def forward(self, input: torch.Tensor, tokens_per_expert, decoding: bool = False
weight_fp8 = view_weight.apply(weight_fp8, self.ori_local_shape)
else:
weight = weight.view(*self.ori_local_shape)
weight_fp8 = weight_to_per_block_float8_dynamic.apply(weight, torch.float8_e4m3fn, group_size=128)
weight_fp8 = weight_to_per_block_float8_dynamic.apply(weight, torch.float8_e4m3fn, 128)

orig_shape = input.shape
input = input.view(-1, input.shape[-1])
out = fp8_gmm_weight_per_block_act_per_tile.apply(input, weight_fp8, tokens_per_expert)
out = out.view(*orig_shape[:-1], -1)
return out

@property
Expand Down
2 changes: 1 addition & 1 deletion xtuner/v1/float8/float8_linear_tile_wise.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
weight_fp8 = slice_weight.apply(weight, self.ori_shape) if self.is_padded else weight
else:
weight = weight.view(*self.ori_shape)
weight_fp8 = weight_to_per_block_float8_dynamic.apply(weight, torch.float8_e4m3fn, group_size=128)
weight_fp8 = weight_to_per_block_float8_dynamic.apply(weight, torch.float8_e4m3fn, 128)

out = fp8_matmul_weight_per_block_act_per_tile.apply(input, weight_fp8)

Expand Down
2 changes: 0 additions & 2 deletions xtuner/v1/float8/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
DEVICE_MODULE = get_torch_device_module()


@maybe_compile(fullgraph=True)
def tensor_to_per_block_fp8_devided_64_scales(
tensor: "WeightWithDynamicTilewiseFloat8CastTensor",
reduce_mesh_devided_64: Optional[DeviceMesh] = None,
Expand Down Expand Up @@ -224,7 +223,6 @@ def cast_to_per_block_fp8_with_scales(
return tensor_bits_fp8


@maybe_compile(fullgraph=True)
def cast_to_per_block_fp8_devided_64_with_scales(
tensor: torch.Tensor,
scales: torch.Tensor,
Expand Down
43 changes: 42 additions & 1 deletion xtuner/v1/loss/ce_loss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Annotated, Literal, cast
from typing import Annotated, Any, Literal, cast

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -27,12 +27,17 @@ class CELossConfig(BaseLossConfig):
loss_reduction (str): The reduction mode for the loss. Options are "token", "sample", and "square".
"""

mode: Annotated[Literal["eager", "chunk", "liger"], Parameter(help="loss calculation mode")] = "eager" # type: ignore
loss_reduction: Annotated[Literal["token", "sample", "square"], Parameter(help="loss reduction mode")] = "token"

@property
def loss_ctx_cls(self) -> type["CELossContext"]:
return CELossContext

def model_post_init(self, __context: Any) -> None:
if self.mode == "liger":
assert self.loss_reduction == "token", "Currently, cannot use liger kernel with sample or square reduction"


class CELossKwargs(BaseLossKwargs):
"""Keyword arguments for cross-entropy loss computation.
Expand Down Expand Up @@ -76,6 +81,18 @@ class CELossContext(BaseLossContext[CELossContextInputItem]):
loss_cfg: CELossConfig
loss_kwargs: CELossKwargs

def __init__(self, loss_cfg: CELossConfig, loss_kwargs: CELossKwargs):
super().__init__(loss_cfg, loss_kwargs)

if loss_cfg.mode == "liger":
from liger_kernel.transformers.fused_linear_cross_entropy import (
LigerFusedLinearCrossEntropyLoss,
)

self.liger_loss_fct = LigerFusedLinearCrossEntropyLoss(reduction="sum")
else:
self.liger_loss_fct = None

@classmethod
def build_batches_loss_kwargs(
cls,
Expand Down Expand Up @@ -170,3 +187,27 @@ def loss_fn(
loss = (loss * loss_weight).sum()

return loss, logits

def chunk_mode(
self,
hidden_states: torch.Tensor,
head_weight: torch.Tensor,
head_bias: torch.Tensor | None,
loss_kwargs: CELossKwargs,
):
if self.loss_cfg.mode == "chunk":
return super().chunk_mode(hidden_states, head_weight, head_bias, loss_kwargs)
else:
assert self.liger_loss_fct is not None, "liger_loss_fct must be initialized in liger mode"
shifted_labels = loss_kwargs.shifted_labels # (bs, seq_len)
loss_weight = loss_kwargs.loss_weight # (bs, seq_len)

bs, seq, dim = hidden_states.shape
hidden_states = hidden_states.reshape(bs * seq, dim)
shifted_labels = shifted_labels.flatten()
# liger kernel dont support reduction=="none"
loss = self.liger_loss_fct(head_weight, hidden_states, shifted_labels)
mask = loss_weight != 0
w = loss_weight.sum() / mask.sum()
loss = loss * w
return loss, None
Loading