diff --git a/docs/assets/images/float8/fp8_autograd.png b/docs/assets/images/float8/fp8_autograd.png new file mode 100644 index 000000000..e3becba55 Binary files /dev/null and b/docs/assets/images/float8/fp8_autograd.png differ diff --git a/docs/assets/images/float8/fp8_granularity.png b/docs/assets/images/float8/fp8_granularity.png new file mode 100644 index 000000000..9b5aa821e Binary files /dev/null and b/docs/assets/images/float8/fp8_granularity.png differ diff --git a/docs/assets/images/float8/fp8_overall.png b/docs/assets/images/float8/fp8_overall.png new file mode 100644 index 000000000..796d5a80f Binary files /dev/null and b/docs/assets/images/float8/fp8_overall.png differ diff --git a/docs/zh_cn/pretrain_sft/advanced_tutorial/float8.md b/docs/zh_cn/pretrain_sft/advanced_tutorial/float8.md new file mode 100644 index 000000000..7c8151a15 --- /dev/null +++ b/docs/zh_cn/pretrain_sft/advanced_tutorial/float8.md @@ -0,0 +1,171 @@ +# FP8 训练 + +Hopper 架构的 GPU 引入了新的数据类型 FP8(8-bit floating point),可显著提升矩阵乘法的计算效率。下面将介绍如何在 XTuner 中使用 FP8 进行训练。 + +## 为什么选择 FP8 + +1. 降低通信量、提升通信速度:XTuner V1 基于 PyTorch FSDP 开发。相较 BF16,使用 FP8 通信可显著缓解 FSDP 通信量大的固有瓶颈。 +2. 提升矩阵乘计算效率。 +3. 节约显存:与 BF16 训练相比,FP8 训练中 Linear 和 Grouped Linear 层 PyTorch 计算图中保存的是 FP8 Tensor 而非 BF16 Tensor。可大幅降低计算图的显存开销。 +4. 精度具有保证:为了避免陷入“你别管我对不对,就问你快不快”的窘境,XTuner 采用了细粒度的 FP8 量化模式,在保证训练精度的前提下优化了训练速度。 + +## BenchMark + +并行配置 | 训练配置 | SeqLen | GlobalBatchSize | GPUNum | TimePerIter (s) | Tokens/GPU/Second +-- | -- | -- | -- | -- | -- | -- +tp1, ep1, pp1 | BF16 | 65536 | 256 | 256 | 32.77 | 2000 +tp1, ep1, pp1 | FP8 | 65536 | 256 | 256 | 26.75 | 2450 + +[profile data](https://drive.google.com/file/d/1TW-DbsUCckKJS36-5YHJo73L1Nvlpv6h/view?usp=sharing) + +## 如何使用 XTuner FP8 训练 + +### 环境准备 + +首先检查 GPU 是否为 Hopper 及以上架构: + +```python +import torch + +print(torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)) +``` + +安装 `AdaptiveGEMM` 库: + +```{code-block} shell +:caption: 安装 AdaptiveGEMM + +pip install git+https://github.com/InternLM/AdaptiveGEMM.git@main +``` + +### 使用 XTuner 的 Linear 和 Grouped Linear 模块 + +```python +import torch +from xtuner.v1.float8 import TileWiseFloat8Linear, TileWiseFloat8GroupedLinear + +# (bs, seq, dim) +x = torch.randn(1, 32768, 1024, device='cuda', dtype=torch.bfloat16, requires_grad=True) +linear = TileWiseFloat8Linear(in_features=1024, out_features=2048, bias=False, device='cuda', dtype=torch.bfloat16) +out = linear(x) +out.mean().backward() + +x = torch.randn(1, 32768, 1024, device='cuda', dtype=torch.bfloat16) +grouped_linear = TileWiseFloat8GroupedLinear(in_features=1024, out_features=2048, num_routed_experts=4, moe_bias=False).to(dtype=torch.bfloat16, device='cuda') +tokens_per_expert = torch.tensor([1000, 4000, 6000, 32768 - 11000], device='cuda') +out = grouped_linear(x, tokens_per_expert) +out.mean().backward() +``` + +```{tip} +:class: margin + +1. 单测 `TileWiseFloat8Linear` 与 `TileWiseFloat8GroupedLinear` 难以体现端到端的理想速度,因为对权重的量化较耗时。需结合 FSDP 才能达到最佳训练效率(可用 FP8 通信,且每个 rank 仅量化自身切片参数,使权重量化开销可忽略)。用法见下一小节。 + +2. 首次执行 fwd + bwd 速度较慢是正常现象,再次执行速度就会恢复正常。 +``` + +### 使用 XTuner FP8 训练 + +第一步,参考 [选择模型](model-cfg) 一节构建 model_cfg 实例,并配置 float8_cfg: + +```{code-block} python +:caption: 构建模型配置 + +from xtuner.v1.model import Qwen3Dense8BConfig +from xtuner.v1.float8.config import Float8Config, ScalingGranularity + +float8_cfg = Float8Config( + scaling_granularity_gemm=ScalingGranularity.TILEWISE, + scaling_granularity_grouped_gemm=ScalingGranularity.TILEWISE, +) + +model_cfg = Qwen3Dense8BConfig(float8_cfg=float8_cfg) +``` + +第二步,参考 [使用 Trainer 进行大模型微调](trainer-sft) 一节后续内容构建 `trainer`。 + +第三步,启动训练,完整代码如下: + +````{toggle} +```diff +from xtuner.v1.model import Qwen3Dense8BConfig +from xtuner.v1.config import LRConfig, AdamWConfig +from xtuner.v1.train import Trainer ++ from xtuner.v1.float8.config import Float8Config, ScalingGranularity + ++ float8_cfg = Float8Config( ++ scaling_granularity_gemm=ScalingGranularity.TILEWISE, ++ scaling_granularity_grouped_gemm=ScalingGranularity.TILEWISE, ++ ) + +- model_cfg = Qwen3Dense8BConfig() ++ model_cfg = Qwen3Dense8BConfig(float8_cfg=float8_cfg) +dataset_cfg = [] +optim_cfg = AdamWConfig(lr=6e-05) +lr_cfg = LRConfig(lr_type="cosine", lr_min=1e-6) + +load_from = "<模型路径>" # 如果是微调模式,必须指定,否则会重头训练 +tokenizer = "" + +trainer = Trainer( + model_cfg=model_cfg, + tokenizer_path=tokenizer, + load_from=load_from, + optim_cfg=optim_cfg, + dataset_cfg=dataset_cfg, + lr_cfg=lr_cfg, +) +trainer.fit() +``` +```` + +写完上述 python 脚本后,命名为 `toy_train.py`,我们就能通过 `torchrun` 启动分布式训练了: + +```{code-block} bash +:caption: 启动训练 + +torchrun --nproc_per_node=8 toy_train.py +``` + +恭喜你,已经自己实现了一个 XTuner 的 FP8 训练入口!你可以在这个脚本里尽情地发挥,定制化自己的训练参数。 + +## XTuner FP8 训练策略 + +### FP8 量化 + +XTuner 采用对称量化: + +```python +s = absmax(x) / q_max +q = clip(x / s, q_min, q_max) +``` + +XTuner 支持以下三种量化粒度:Tensor-Wise, Block-Wise 和 Tile-Wise,如下图所示。相同颜色的元素共享同一个量化参数。在实际使用中,block_size 和 tile_size 一般会设置为 128。 + +![fp8_granularity](../../../assets/images/float8/fp8_granularity.png) + +XTuner 采用了 "just-in-time scaling" 的量化方法,该策略根据输入 Tensor 实时计算出对应的缩放因子 (scales) 。 + +### FP8 算子 + +我们基于 [DeepGemm](https://github.com/deepseek-ai/DeepGEMM/tree/3b3783d06cd4d06ac4ba048633e604151d1ee535) 扩展了以下两项与 Grouped GEMM 相关的能力(感谢 DeepSeek 团队对开源社区的贡献): + +1. 支持 Group Size M 非 128x 的情况以满足实际训练需求,细节见我们的论文 [TMA-Adaptive FP8 Grouped GEMM](https://arxiv.org/abs/2508.16584) 。 +2. 支持 Grouped Linear 的 Backward 算子 Group K GEMM。 + +需要额外说的是,为确保性能符合预期,Group K GEMM 算子要求 Group Size K 为 128 的倍数,这对我们的 AutoGrad 涉及提出了更高的要求,详情请见下一小节。 + +### FP8 混合精度训练 + +XTuner FP8 参考了 DeepSeek V3 中的 FP8 训练策略,如下图所示。对于主要的计算密集型算子(例如 GEMM 和 Grouped GEMM),我们采用了 FP8 来加速计算。算子接受 FP8 的输入并得到 BF16 的输出。下图中三个 Linear Module 涉及到的 GEMM 计算均使用 FP8 计算,我们将其命名为 Fprop (Forward Pass), Dgrad (Activation Backward +Pass) 和 Wgrad (Weight Backward Pass)。与 BF16 相比,FP8 让 GEMM 的理论耗时减半。同时, PyTorch 计算图中只需保存 FP8 Tensor 即可完成 Backward 计算,进而节约了计算图的显存开销。 + +![fp8_overall](../../../assets/images/float8/fp8_overall.png) + +进一步地,XTuner 细化了 FP8 Linear 和 Grouped Linear 的 AutoGrad 计算逻辑。这里我们以较为复杂的 Grouped Linear 为例展开介绍。如下图所示,在 Forward 和 Backward dx 计算中,我们对激活值采用了 Tile-Wise 的量化策略,对模型权重采用了 Block-Wise 的量化策略。而在 Backward dw 计算中,为了追求性能优势,我们对 Grad Output 采用了 Tile-Wise 的量化策略,而对 Forward 的输入 X 采用了 Block-Wise 的量化策略。 + +图中有一个需要特殊说明的地方是,在 Backward dw 的计算中,我们对 Forward 时的输入 X 进行了 Transpose + Block-Wise FP8 Quantize + Pad to 128x + Transpose 的操作,这是因为,为了达到理想的计算效率,FP8 GEMM 算子和 Grouped GEMM 算子要求 lhs 矩阵的 layout 是 Row-Major 的,而 rhs 矩阵则是 Column-Major。同时,如上一小节所述,Group K GEMM 算子要求 Group Size K 可以被 128 整除,我们把 Transpose + Block-Wise FP8 Quantize + Pad to 128x 融合成了一个算子以提高计算效率。 + +![fp8_overall](../../../assets/images/float8/fp8_autograd.png) + diff --git a/docs/zh_cn/pretrain_sft/advanced_tutorial/index.rst b/docs/zh_cn/pretrain_sft/advanced_tutorial/index.rst index 20c738f7e..edd150ff2 100644 --- a/docs/zh_cn/pretrain_sft/advanced_tutorial/index.rst +++ b/docs/zh_cn/pretrain_sft/advanced_tutorial/index.rst @@ -8,4 +8,5 @@ model.md dataset.md loss.md + float8.md profile.md diff --git a/docs/zh_cn/pretrain_sft/tutorial/llm_trainer.md b/docs/zh_cn/pretrain_sft/tutorial/llm_trainer.md index 5eae75487..3a3f8175e 100644 --- a/docs/zh_cn/pretrain_sft/tutorial/llm_trainer.md +++ b/docs/zh_cn/pretrain_sft/tutorial/llm_trainer.md @@ -1,8 +1,10 @@ +(trainer-sft)= # 使用 Trainer 进行大模型微调 在之前的[教程](../../get_started/sft.md)中我们通过命令行,用最简单的方式启动了一次微调训练,而在这快速启动的背后,则是 XTuner 的核心组件 `Trainer` 在发挥作用。这一节我们将初识 Trainer,用更加细力度的方式控制训练的各个环节。 +(model-cfg)= ## 选择模型: Trainer 通过配置文件的方式来构建模型,我们以 XTuner 内置支持的 `Qwen3 8B` 为例,来快速获取一个模型配置实例 diff --git a/tests/module/dispatcher/_test_deepep.py b/tests/module/dispatcher/test_deepep.py similarity index 60% rename from tests/module/dispatcher/_test_deepep.py rename to tests/module/dispatcher/test_deepep.py index f6da78838..23c49ea49 100644 --- a/tests/module/dispatcher/_test_deepep.py +++ b/tests/module/dispatcher/test_deepep.py @@ -19,8 +19,14 @@ def mock_experts(hidden_states: torch.Tensor, tokens_per_exprts: torch.Tensor): class TestMoETorchAll2AllDispatcher(DistributedTestBase): - @parametrize.parametrize("dtype,device", [(torch.bfloat16, "cuda")]) - def test_dispatch_and_combine(self, dtype, device): + @parametrize.parametrize( + "dtype,device,async_op", + [ + (torch.bfloat16, "cuda", False), + (torch.bfloat16, "cuda", True), + ] + ) + def test_dispatch_and_combine(self, dtype, device, async_op): self.create_pg(device) num_experts = 16 @@ -52,41 +58,63 @@ def test_dispatch_and_combine(self, dtype, device): dispatcher=all2all_dispatcher, hidden_states=hidden_states, topk_ids=topk_idx, - topk_weights=topk_weights + topk_weights=topk_weights, + async_op=async_op, ) self.assertTrue(torch.allclose(noep_results, all2all_results, atol=1e-6, rtol=1e-4)) def _dispatcher_call( - self, - dispatcher: GenericDispatcher, - hidden_states: torch.Tensor, - topk_ids: torch.Tensor, - topk_weights: torch.Tensor + self, + dispatcher: GenericDispatcher, + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + async_op: bool=False ): pre_dispatched = dispatcher.dispatch_preprocess( hidden_states=hidden_states, topk_ids=topk_ids, - topk_weights=topk_weights, + async_op=async_op, ) dispatched = dispatcher.dispatch( pre_dispatched=pre_dispatched, + topk_weights=topk_weights, decoding=False, + async_op=async_op, + ) + post_dispatched = dispatcher.dispatch_postprocess( + pre_dispatched=pre_dispatched, + dispatched=dispatched, + async_op=async_op, ) experts_results = mock_experts( - hidden_states=dispatched["hidden_states"], - tokens_per_exprts=dispatched["tokens_per_experts"], + hidden_states=post_dispatched["hidden_states"], + tokens_per_exprts=post_dispatched["tokens_per_expert"], ) - combined = dispatcher.combine(hidden_states=experts_results, + pre_combined = dispatcher.combine_preprocess( + hidden_states=experts_results, pre_dispatched=pre_dispatched, dispatched=dispatched, - decoding=False, + post_dispatched=post_dispatched, + async_op=async_op, + ) + combined = dispatcher.combine( + pre_dispatched=pre_dispatched, + dispatched=dispatched, + post_dispatched=post_dispatched, + pre_combined=pre_combined, + async_op=async_op, ) - return dispatcher.combine_post_process( + post_combined = dispatcher.combine_postprocess( pre_dispatched=pre_dispatched, - dispatch_result=dispatched, - combine_result=combined, + dispatched=dispatched, + post_dispatched=post_dispatched, + pre_combined=pre_combined, + combined=combined, + async_op=async_op, ) + return post_combined["hidden_states"] @property def world_size(self) -> int: diff --git a/xtuner/v1/float8/float8_gmm_tile_wise.py b/xtuner/v1/float8/float8_gmm_tile_wise.py index ad93661fe..25c71b617 100644 --- a/xtuner/v1/float8/float8_gmm_tile_wise.py +++ b/xtuner/v1/float8/float8_gmm_tile_wise.py @@ -291,9 +291,12 @@ def forward(self, input: torch.Tensor, tokens_per_expert, decoding: bool = False weight_fp8 = view_weight.apply(weight_fp8, self.ori_local_shape) else: weight = weight.view(*self.ori_local_shape) - weight_fp8 = weight_to_per_block_float8_dynamic.apply(weight, torch.float8_e4m3fn, group_size=128) + weight_fp8 = weight_to_per_block_float8_dynamic.apply(weight, torch.float8_e4m3fn, 128) + orig_shape = input.shape + input = input.view(-1, input.shape[-1]) out = fp8_gmm_weight_per_block_act_per_tile.apply(input, weight_fp8, tokens_per_expert) + out = out.view(*orig_shape[:-1], -1) return out @property diff --git a/xtuner/v1/float8/float8_linear_tile_wise.py b/xtuner/v1/float8/float8_linear_tile_wise.py index da3979828..9cad1c9f3 100644 --- a/xtuner/v1/float8/float8_linear_tile_wise.py +++ b/xtuner/v1/float8/float8_linear_tile_wise.py @@ -229,7 +229,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: weight_fp8 = slice_weight.apply(weight, self.ori_shape) if self.is_padded else weight else: weight = weight.view(*self.ori_shape) - weight_fp8 = weight_to_per_block_float8_dynamic.apply(weight, torch.float8_e4m3fn, group_size=128) + weight_fp8 = weight_to_per_block_float8_dynamic.apply(weight, torch.float8_e4m3fn, 128) out = fp8_matmul_weight_per_block_act_per_tile.apply(input, weight_fp8) diff --git a/xtuner/v1/float8/fsdp_utils.py b/xtuner/v1/float8/fsdp_utils.py index f0b8f893b..f65a4efde 100644 --- a/xtuner/v1/float8/fsdp_utils.py +++ b/xtuner/v1/float8/fsdp_utils.py @@ -22,7 +22,6 @@ DEVICE_MODULE = get_torch_device_module() -@maybe_compile(fullgraph=True) def tensor_to_per_block_fp8_devided_64_scales( tensor: "WeightWithDynamicTilewiseFloat8CastTensor", reduce_mesh_devided_64: Optional[DeviceMesh] = None, @@ -224,7 +223,6 @@ def cast_to_per_block_fp8_with_scales( return tensor_bits_fp8 -@maybe_compile(fullgraph=True) def cast_to_per_block_fp8_devided_64_with_scales( tensor: torch.Tensor, scales: torch.Tensor, diff --git a/xtuner/v1/loss/ce_loss.py b/xtuner/v1/loss/ce_loss.py index 8b8ffde8e..1d9ad8bf7 100644 --- a/xtuner/v1/loss/ce_loss.py +++ b/xtuner/v1/loss/ce_loss.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Annotated, Literal, cast +from typing import Annotated, Any, Literal, cast import torch import torch.distributed as dist @@ -27,12 +27,17 @@ class CELossConfig(BaseLossConfig): loss_reduction (str): The reduction mode for the loss. Options are "token", "sample", and "square". """ + mode: Annotated[Literal["eager", "chunk", "liger"], Parameter(help="loss calculation mode")] = "eager" # type: ignore loss_reduction: Annotated[Literal["token", "sample", "square"], Parameter(help="loss reduction mode")] = "token" @property def loss_ctx_cls(self) -> type["CELossContext"]: return CELossContext + def model_post_init(self, __context: Any) -> None: + if self.mode == "liger": + assert self.loss_reduction == "token", "Currently, cannot use liger kernel with sample or square reduction" + class CELossKwargs(BaseLossKwargs): """Keyword arguments for cross-entropy loss computation. @@ -76,6 +81,18 @@ class CELossContext(BaseLossContext[CELossContextInputItem]): loss_cfg: CELossConfig loss_kwargs: CELossKwargs + def __init__(self, loss_cfg: CELossConfig, loss_kwargs: CELossKwargs): + super().__init__(loss_cfg, loss_kwargs) + + if loss_cfg.mode == "liger": + from liger_kernel.transformers.fused_linear_cross_entropy import ( + LigerFusedLinearCrossEntropyLoss, + ) + + self.liger_loss_fct = LigerFusedLinearCrossEntropyLoss(reduction="sum") + else: + self.liger_loss_fct = None + @classmethod def build_batches_loss_kwargs( cls, @@ -170,3 +187,27 @@ def loss_fn( loss = (loss * loss_weight).sum() return loss, logits + + def chunk_mode( + self, + hidden_states: torch.Tensor, + head_weight: torch.Tensor, + head_bias: torch.Tensor | None, + loss_kwargs: CELossKwargs, + ): + if self.loss_cfg.mode == "chunk": + return super().chunk_mode(hidden_states, head_weight, head_bias, loss_kwargs) + else: + assert self.liger_loss_fct is not None, "liger_loss_fct must be initialized in liger mode" + shifted_labels = loss_kwargs.shifted_labels # (bs, seq_len) + loss_weight = loss_kwargs.loss_weight # (bs, seq_len) + + bs, seq, dim = hidden_states.shape + hidden_states = hidden_states.reshape(bs * seq, dim) + shifted_labels = shifted_labels.flatten() + # liger kernel dont support reduction=="none" + loss = self.liger_loss_fct(head_weight, hidden_states, shifted_labels) + mask = loss_weight != 0 + w = loss_weight.sum() / mask.sum() + loss = loss * w + return loss, None diff --git a/xtuner/v1/model/base.py b/xtuner/v1/model/base.py index 24acab5dc..5986bc501 100644 --- a/xtuner/v1/model/base.py +++ b/xtuner/v1/model/base.py @@ -1080,7 +1080,7 @@ def _fsdp_foreach_allgather( def _maybe_compile_layers(self): if self.fsdp_config is not None: if self.fsdp_config.torch_compile: - torch._dynamo.config.cache_size_limit = 128 + torch._dynamo.config.cache_size_limit = 256 if self.fsdp_config.compile_targets is not None: maybe_compile.clear_compile_targets() for target in self.fsdp_config.compile_targets: diff --git a/xtuner/v1/model/moe/moe.py b/xtuner/v1/model/moe/moe.py index 550983196..e130897b5 100644 --- a/xtuner/v1/model/moe/moe.py +++ b/xtuner/v1/model/moe/moe.py @@ -305,6 +305,7 @@ def _micro_batch_forward( cat_position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None cat_hidden_states: torch.Tensor | None = None + moe_forawrd = False for idx, decoder_layer in self.layers.items(): layer_idx = int(idx) @@ -322,14 +323,37 @@ def _micro_batch_forward( seq_ctx=cat_seq_ctx, ) else: - if cat_hidden_states is not None: - hidden_states_list = list(cat_hidden_states.chunk(len(seq_ctx_list), dim=1)) + if cat_hidden_states is not None and not moe_forawrd: + # TODO: `i.clone()` here is weird. However, the current Implementation of + # `async_save_on_cpu` is not friendly with `chunk` op (maybe caused by shared storage? not sure), + # resulting in nan grad norm. So we have to clone the chunked tensors here to make sure each + # hidden state has its own storage. This workaround may introduce extra memory and time cost, and + # should be optimized in the future. + hidden_states_list = [i.clone() for i in cat_hidden_states.chunk(len(seq_ctx_list), dim=1)] + moe_forawrd = True - layer_results = decoder_layer( - *hidden_states_list, - position_embeddings=position_embeddings_list, - seq_ctx=seq_ctx_list, - ) + if int(os.getenv("XTUNER_ACTIVATION_OFFLOAD", "0")) == 1: + offload_stream = decoder_layer._get_fsdp_state()._comm_ctx.all_gather_copy_in_stream + with async_save_on_cpu( + h2d_stream=offload_stream, # type: ignore + d2h_stream=offload_stream, # type: ignore + block_idx=layer_idx - self.config.first_k_dense_replace, + depth=len(self.layers) - self.config.first_k_dense_replace, + custom_check_fn=lambda x: x.data_ptr() + in [hidden_states.data_ptr() for hidden_states in hidden_states_list], + prefetch=True, + ): + layer_results = decoder_layer( + *hidden_states_list, + position_embeddings=position_embeddings_list, + seq_ctx=seq_ctx_list, + ) + else: + layer_results = decoder_layer( + *hidden_states_list, + position_embeddings=position_embeddings_list, + seq_ctx=seq_ctx_list, + ) hidden_states = layer_results[: len(hidden_states_list)] router_logits = layer_results[len(hidden_states_list) :] @@ -782,7 +806,7 @@ def traverse(module): def _maybe_compile_layers(self): if self.fsdp_config is not None: if self.fsdp_config.torch_compile: - torch._dynamo.config.cache_size_limit = 128 + torch._dynamo.config.cache_size_limit = 256 if self.fsdp_config.compile_targets is None: if self.ep_mesh.size() > 1: # all_to_all_single_autograd in TorchAll2AllDispatcher.dispatch can not be compiled even if the fullgraph=False diff --git a/xtuner/v1/module/decoder_layer/dense_decoder_layer.py b/xtuner/v1/module/decoder_layer/dense_decoder_layer.py index eb8d269cf..d978c49e1 100644 --- a/xtuner/v1/module/decoder_layer/dense_decoder_layer.py +++ b/xtuner/v1/module/decoder_layer/dense_decoder_layer.py @@ -9,8 +9,9 @@ from xtuner.v1.module import MHAConfig, MLAConfig, RMSNorm from xtuner.v1.ops.act_fn import get_act_fn from xtuner.v1.utils import ForwardState +from xtuner.v1.utils.compile import maybe_compile -from ..linear.linear import _Linear +from ..linear.linear import build_linear class DenseMLP(nn.Module): @@ -21,11 +22,12 @@ def __init__( intermediate_size: int, bias: bool = False, hidden_act: str, + float8_cfg: Float8Config | None = None, ): super().__init__() - self.gate_proj = _Linear(hidden_size, intermediate_size, bias=bias) - self.up_proj = _Linear(hidden_size, intermediate_size, bias=bias) - self.down_proj = _Linear(intermediate_size, hidden_size, bias=bias) + self.gate_proj = build_linear(hidden_size, intermediate_size, bias=bias, float8_cfg=float8_cfg) + self.up_proj = build_linear(hidden_size, intermediate_size, bias=bias, float8_cfg=float8_cfg) + self.down_proj = build_linear(intermediate_size, hidden_size, bias=bias, float8_cfg=float8_cfg) self.act_fn = get_act_fn(hidden_act) def forward(self, x): @@ -64,10 +66,12 @@ def __init__( intermediate_size=intermediate_size, bias=mlp_bias, hidden_act=hidden_act, + float8_cfg=float8_cfg, ) self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + @maybe_compile(fullgraph=True) def forward( self, hidden_states: torch.Tensor, diff --git a/xtuner/v1/module/dispatcher/deepep.py b/xtuner/v1/module/dispatcher/deepep.py index ffb97ce80..9acb8f996 100644 --- a/xtuner/v1/module/dispatcher/deepep.py +++ b/xtuner/v1/module/dispatcher/deepep.py @@ -1,56 +1,220 @@ -# type: ignore from typing import Literal, TypeAlias, cast import torch +import torch.distributed as dist +from deep_ep import EventOverlap from mmengine.utils import is_installed -from typing_extensions import Required, overload, override - -from xtuner.v1.ops import buffer_capture, deep_ep_combine, deep_ep_dispatch, get_low_latency_buffer -from xtuner.v1.ops.moe_permute import permute, unpermute +from typing_extensions import override + +from xtuner.v1.ops import permute, unpermute +from xtuner.v1.ops.comm.deepep_op import ( + buffer_capture, + combine_backward, + combine_forward, + dispatch_backward, + dispatch_forward, +) +from xtuner.v1.utils import copy_method_signature, get_device, get_logger +from . import XTUNER_DISPATCHER_DEBUG from .base import ( - DecodingCombineResult, - DecodingDispatchResult, + CombineResult, + DispatchResult, GenericDispatcher, - HiddenStates, + PostCombineResult, + PostDispatchResult, + PreCombineResult, PreDispatchResult, - PrefillingCombineResult, - PrefillingDispatchResult, ) +if get_device() == "npu": + from torch_npu.contrib import transfer_to_npu # noqa + + +DEVICE = get_device() +logger = get_logger() + + +DeepEPHandle = tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + + # DeepEP handle include 6 tensor: # (rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, is_token_in_rank, send_head) -DeepEPHandle = tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] +class DeepEPPreDispatchResult(PreDispatchResult): + backward_previous_event: EventOverlap | None + forward_finished_event: EventOverlap | None + -DeepEPPreDispatchResult: TypeAlias = PreDispatchResult +class DeepEPDispatchResult(DispatchResult): + handle: DeepEPHandle + topk_ids: torch.Tensor + num_recv_tokens_per_expert_list: list[int] + forward_finished_event: EventOverlap | None -# TODO: Broken inheritance, we should fix it later -class DeepEPPrefillingDispatchResult(PrefillingDispatchResult): - handle: Required[DeepEPHandle] # type: ignore - row_id_map: torch.Tensor +class DeepEPPostDispatchResult(PostDispatchResult): + row_ids_map: torch.Tensor -# TODO: Broken inheritance, we should fix it later -class DeepEPDecodingDispatchResult(DecodingDispatchResult): - handle: DeepEPHandle # type: ignore +class DeepEPPreCombineResult(PreCombineResult): + backward_previous_event: EventOverlap | None + forward_finished_event: EventOverlap | None -DeepEPPrefillingCombineResult: TypeAlias = PrefillingCombineResult -DeepEPDecodingCombineResult: TypeAlias = DecodingCombineResult +class DeepEPCombineResult(CombineResult): + forward_finished_event: EventOverlap | None + backward_previous_event: EventOverlap | None + + +DeepEPPostCombineResult = PostCombineResult + + +HiddenStates: TypeAlias = torch.Tensor + + +class DeepEPDispatch(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + num_experts: int, + group: dist.ProcessGroup, + forward_previous_event: EventOverlap | None = None, + backward_finished_event: EventOverlap | None = None, + ) -> tuple[ + torch.Tensor | tuple[torch.Tensor, torch.Tensor], + torch.Tensor, + torch.Tensor, + list, + tuple, + EventOverlap, + ]: + ( + recv_x, + recv_topk_idx, + recv_topk_weights, + num_recv_tokens_per_expert_list, + handle, + event, + ) = dispatch_forward(x, topk_idx, topk_weights, num_experts, group, forward_previous_event) + # save deep comm handle + ctx.save_for_backward(*handle) + ctx.group = group + ctx.num_experts = num_experts + ctx.backward_finished_event = backward_finished_event + return ( + recv_x, + recv_topk_idx, + recv_topk_weights, + num_recv_tokens_per_expert_list, + handle, + event, + ) + + @staticmethod + def backward( # type: ignore[invalid-override] + ctx, + grad_recv_x: torch.Tensor, + grad_recv_topk_idx: torch.Tensor, + grad_recv_topk_weights: torch.Tensor, + *args, + ) -> tuple[torch.Tensor, None, torch.Tensor | None, None, None, None, None, None, None]: + # load saved comm handle + handle = ctx.saved_tensors + combined_grad_x, combined_grad_recv_topk_weights, event = dispatch_backward( + grad_recv_x, grad_recv_topk_weights, ctx.num_experts, handle, ctx.group, buffer_capture() + ) + if ctx.backward_finished_event is not None: + ctx.backward_finished_event.event = event.event + return ( + combined_grad_x, + None, + combined_grad_recv_topk_weights, + None, + None, + None, + None, + None, + None, + ) + + +_async_dispatch = copy_method_signature(DeepEPDispatch.forward)(DeepEPDispatch.apply) + + +class DeepEPCombine(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: torch.Tensor, + num_experts: int, + handle: DeepEPHandle, + group: dist.ProcessGroup, + forward_previous_event: EventOverlap | None = None, + backward_previous_event: EventOverlap | None = None, + backward_finished_event: EventOverlap | None = None, + ) -> tuple[torch.Tensor, EventOverlap]: + combined_x, event = combine_forward(x, num_experts, handle, group, forward_previous_event) + # save deep comm handle + ctx.save_for_backward(*handle) + ctx.group = group + ctx.num_experts = num_experts + ctx.backward_finished_event = backward_finished_event + ctx.backward_previous_event = backward_previous_event + return combined_x, event + + @staticmethod + def backward( # type: ignore[invalid-override] + ctx, grad_combined_x: torch.Tensor, *args + ) -> tuple[torch.Tensor | tuple[torch.Tensor, torch.Tensor], None, None, None, None, None, None]: + # load saved comm handle + handle = ctx.saved_tensors + grad_x, event = combine_backward( + grad_combined_x, ctx.num_experts, handle, ctx.group, ctx.backward_previous_event + ) + ctx.backward_finished_event.event = event.event + return grad_x, None, None, None, None, None, None + + +_async_combine = copy_method_signature(DeepEPCombine.forward)(DeepEPCombine.apply) + + +def get_backward_pre_hook(backward_previous_event: EventOverlap, name: str | None = None, debug: bool = False): + def _backward_pre_hook(*_): + if debug: + logger.info(f"[{name}] backward pre hook") + if backward_previous_event is not None: + backward_previous_event.current_stream_wait() + + return _backward_pre_hook + + +def get_backward_hook(backward_finished_event: EventOverlap, name: str | None = None, debug: bool = False): + def _backward_hook(*_): + if debug: + logger.info(f"[{name}] backward hook") + if backward_finished_event is not None: + event = buffer_capture() + backward_finished_event.event = event.event + + return _backward_hook class DeepEPDispatcher( GenericDispatcher[ DeepEPPreDispatchResult, - DeepEPPrefillingDispatchResult, - DeepEPDecodingDispatchResult, - DeepEPPrefillingCombineResult, - DeepEPDecodingCombineResult, + DeepEPDispatchResult, + DeepEPPostDispatchResult, + DeepEPPreCombineResult, + DeepEPCombineResult, + DeepEPPostCombineResult, ] ): - _process_group: torch.distributed.ProcessGroup + _comm_stream = None + _process_group: dist.ProcessGroup def __init__( self, @@ -68,41 +232,10 @@ def __init__( training_dtype=training_dtype, generate_dtype=generate_dtype, ) - - @overload - def dispatch( - self, - *, - pre_dispatched: DeepEPPreDispatchResult, - decoding: Literal[True], - ) -> DeepEPDecodingDispatchResult: ... - - @overload - def dispatch( - self, - *, - pre_dispatched: DeepEPPreDispatchResult, - decoding: Literal[False], - ) -> DeepEPPrefillingDispatchResult: ... - - @override - def dispatch( - self, - *, - pre_dispatched: DeepEPPreDispatchResult, - decoding: bool = False, - ) -> DeepEPPrefillingDispatchResult | DeepEPDecodingDispatchResult: - if not decoding: - return self._dispatch_prefilling( - hidden_states=pre_dispatched["hidden_states"], - topk_weights=pre_dispatched["topk_weights"], - topk_ids=pre_dispatched["topk_ids"], - ) - else: - return self._dispatch_decoding( - hidden_states=pre_dispatched["hidden_states"], - topk_ids=pre_dispatched["topk_ids"], - ) + assert self._process_group is not None, ( + "Process group must be provided for `DeepEPDispatcher`. " + "If you are training a MoE model, it means that `expert parallel` is not enabled in the config." + ) @override def dispatch_preprocess( @@ -110,115 +243,39 @@ def dispatch_preprocess( *, hidden_states: torch.Tensor, topk_ids: torch.Tensor, - topk_weights: torch.Tensor, + async_op: bool = False, ) -> DeepEPPreDispatchResult: + if async_op: + backward_previous_event = EventOverlap(None) + forward_finished_event = buffer_capture() + if hidden_states.grad_fn is not None: + hidden_states.grad_fn.register_prehook( + get_backward_pre_hook( + backward_previous_event=backward_previous_event, + name="TorchAll2AllDispatcher.dispatch_preprocess", + debug=XTUNER_DISPATCHER_DEBUG, + ) + ) + else: + forward_finished_event = None + backward_previous_event = None + return DeepEPPreDispatchResult( hidden_states=hidden_states, - topk_ids=topk_ids, - topk_weights=topk_weights, + topk_ids=topk_ids.to(torch.int64), + backward_previous_event=backward_previous_event, + forward_finished_event=forward_finished_event, ) - @overload - def combine( - self, - *, - hidden_states: torch.Tensor, - pre_dispatched: PreDispatchResult, - dispatch_result: DeepEPPrefillingDispatchResult, - decoding: Literal[False], - ) -> DeepEPPrefillingCombineResult: ... - - @overload - def combine( - self, - *, - hidden_states: torch.Tensor, - pre_dispatched: PreDispatchResult, - dispatch_result: DeepEPDecodingDispatchResult, - decoding: Literal[True], - ) -> DeepEPDecodingCombineResult: ... - - @override - def combine( - self, - *, - hidden_states: torch.Tensor, - pre_dispatched: PreDispatchResult, - dispatch_result: DeepEPPrefillingDispatchResult | DeepEPDecodingDispatchResult, - decoding: bool = False, - ) -> DeepEPPrefillingCombineResult | DeepEPDecodingCombineResult: - if not decoding: - return self._combine_prefilling( - hidden_states=hidden_states, - dispatched_result=cast(DeepEPPrefillingDispatchResult, dispatch_result), - ) - else: - return self._combine_decoding( - hidden_states=hidden_states, - pre_dispatched=cast(DeepEPPreDispatchResult, dispatch_result), - dispatched_result=cast(DeepEPDecodingDispatchResult, dispatch_result), - ) - @override - def combine_post_process( + def dispatch( self, *, pre_dispatched: DeepEPPreDispatchResult, - dispatch_result: DeepEPPrefillingDispatchResult | DeepEPDecodingDispatchResult, - combine_result: DeepEPPrefillingCombineResult | DeepEPDecodingCombineResult, - ) -> HiddenStates: - return combine_result["hidden_states"] - - def _dispatch_decoding( - self, - hidden_states: torch.Tensor, - topk_ids: torch.Tensor, - ) -> DeepEPDecodingDispatchResult: - hidden_size = hidden_states.shape[-1] - x = hidden_states.view(-1, hidden_states.size()[-1]) - _buffer = get_low_latency_buffer(self._process_group, hidden=hidden_size, num_experts=self._n_routed_experts) - - # Do MoE dispatch, compatible with CUDA graph (but you may restore some buffer status once you replay) - recv_x, tokens_per_expert, handle, _, _ = _buffer.low_latency_dispatch( - x, - topk_ids, - x.size(0), - self._n_routed_experts, - async_finish=False, - use_fp8=self._training_dtype == "fp8", - return_recv_hook=False, - ) - - # NOTES: the actual tensor will not be received only if you call `hook()`, - # it is useful for double-batch overlapping, but **without any SM occupation** - # If you don't want to overlap, please set `return_recv_hook=False` - # Later, you can use our GEMM library to do the computation with this specific format - if self._training_dtype == "fp8": - assert isinstance(recv_x, tuple), "When using FP8, `recv_x` should be a tuple." - hidden_states, fp_8_scale = recv_x - return DeepEPDecodingDispatchResult( - hidden_states=hidden_states, - tokens_per_experts=tokens_per_expert, - fp8_scale=fp_8_scale, - handle=handle, - ) - else: - raise NotImplementedError("DeepEP decoding dispatch only supports FP8 for now.") - - def _dispatch_prefilling( - self, - hidden_states: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - ) -> DeepEPPrefillingDispatchResult: - hidden_dim = hidden_states.shape[-1] - hidden_states = hidden_states.view(-1, hidden_dim) - hidden_states = hidden_states.view(-1, hidden_dim) - - topk_ids = topk_ids.to(torch.int64) - - # TODO: Maybe we should decouple the sync and async interface of deepep - previous_event = buffer_capture() + async_op: bool = False, + decoding: bool = False, + ) -> DeepEPDispatchResult: ( dispatched_hidden_states, dispatched_topk_idx, @@ -226,113 +283,177 @@ def _dispatch_prefilling( num_recv_tokens_per_expert_list, dispatch_handle, event, - ) = deep_ep_dispatch( - x=hidden_states, - topk_idx=topk_ids, - topk_weights=topk_weights, - num_routed_experts=self._n_routed_experts, - group=self._process_group, - previous_event=previous_event, + ) = _async_dispatch( + pre_dispatched["hidden_states"], + pre_dispatched["topk_ids"], + topk_weights, + self._n_routed_experts, + self._process_group, + pre_dispatched["forward_finished_event"], + pre_dispatched["backward_previous_event"], ) - event.current_stream_wait() - permuted_hidden_states, row_id_map = self._training_permute_dispatched( - dispatched_hidden_states=dispatched_hidden_states, - dispatched_topk_ids=dispatched_topk_idx, - num_recv_tokens_per_expert_list=num_recv_tokens_per_expert_list, - ) - tokens_per_experts = torch.tensor( - num_recv_tokens_per_expert_list, - dtype=torch.long, - device=topk_weights.device, - ) - return DeepEPPrefillingDispatchResult( - hidden_states=permuted_hidden_states, - tokens_per_experts=tokens_per_experts, - handle=dispatch_handle, + if not async_op: + event.current_stream_wait() + forward_finished_event = None + else: + forward_finished_event = event + + ret = DeepEPDispatchResult( + hidden_states=cast(HiddenStates, dispatched_hidden_states), topk_weights=dispatched_topk_weights, - row_id_map=row_id_map, + topk_ids=dispatched_topk_idx, + handle=dispatch_handle, + num_recv_tokens_per_expert_list=num_recv_tokens_per_expert_list, + forward_finished_event=forward_finished_event, ) + return ret - def _training_permute_dispatched( + @override + def dispatch_postprocess( self, *, - dispatched_hidden_states: torch.Tensor, - dispatched_topk_ids: torch.Tensor, - num_recv_tokens_per_expert_list, - ): - num_out_tokens = sum(num_recv_tokens_per_expert_list) - recv_topk_idx_numel = dispatched_topk_ids.numel() + pre_dispatched: DeepEPPreDispatchResult, + dispatched: DeepEPDispatchResult, + async_op: bool = False, + decoding: bool = False, + ) -> DeepEPPostDispatchResult: + if async_op: + assert dispatched["forward_finished_event"] is not None, "Please use `async_op=True` for dispatch!" + dispatched["forward_finished_event"].current_stream_wait() + + num_recv_tokens_per_expert_list = dispatched["num_recv_tokens_per_expert_list"] + num_out_tokens = sum(dispatched["num_recv_tokens_per_expert_list"]) + recv_topk_idx_numel = dispatched["topk_ids"].numel() num_neg_one_idx = recv_topk_idx_numel - num_out_tokens - permuted_hidden_states, row_id_map = permute( - dispatched_hidden_states, - dispatched_topk_ids.int(), + permuted_hidden_states, row_ids_map = permute( + dispatched["hidden_states"], + dispatched["topk_ids"].int(), num_out_tokens=num_out_tokens, num_negative_one_in_indices=num_neg_one_idx, ) - return permuted_hidden_states, row_id_map + tokens_per_expert = torch.tensor( + num_recv_tokens_per_expert_list, + dtype=torch.long, + device=dispatched["topk_weights"].device, + ) + + if decoding: + raise NotImplementedError + else: + return DeepEPPostDispatchResult( + hidden_states=permuted_hidden_states, + row_ids_map=row_ids_map, + tokens_per_expert=tokens_per_expert, + ) - def _combine_prefilling( + @override + def combine_preprocess( self, + *, hidden_states: torch.Tensor, - dispatched_result: DeepEPPrefillingDispatchResult, - ): - unpermuted_hidden_states = self._training_unpermute_activation( - hidden_states=hidden_states, - row_id_map=dispatched_result["row_id_map"], - topk_weights=dispatched_result["topk_weights"], + pre_dispatched: DeepEPPreDispatchResult, + dispatched: DeepEPDispatchResult, + post_dispatched: DeepEPPostDispatchResult, + async_op: bool = False, + decoding: bool = False, + ) -> DeepEPPreCombineResult: + hidden_states = unpermute( + hidden_states, + post_dispatched["row_ids_map"], + probs=dispatched["topk_weights"], ) - # TODO: Maybe we should decouple the sync and async interface of deepep - event = buffer_capture() - combined_hidden_states, event = deep_ep_combine( - x=unpermuted_hidden_states, - num_experts=self._n_routed_experts, - deepep_comm_handle=dispatched_result["handle"], - group=self._process_group, - ) - event.current_stream_wait() - # For event management, please refer to the docs of the `EventOverlap` class - return DeepEPPrefillingCombineResult( - hidden_states=combined_hidden_states, - ) + if async_op: + backward_previous_event = EventOverlap(None) + forward_finished_event = buffer_capture() + if hidden_states.grad_fn is not None: + hidden_states.grad_fn.register_prehook( + get_backward_pre_hook( + backward_previous_event=backward_previous_event, + name="TorchAll2AllDispatcher.combine_preprocess", + debug=XTUNER_DISPATCHER_DEBUG, + ) + ) + else: + backward_previous_event = None + forward_finished_event = None + + if decoding: + raise NotImplementedError + else: + return DeepEPPreCombineResult( + hidden_states=hidden_states, + forward_finished_event=forward_finished_event, + backward_previous_event=backward_previous_event, + ) - def _combine_decoding( + @override + def combine( self, - hidden_states: torch.Tensor, + *, pre_dispatched: DeepEPPreDispatchResult, - dispatched_result: DeepEPDecodingDispatchResult, - ): - hidden_size = hidden_states.shape[-1] - _buffer = get_low_latency_buffer( - self._process_group, - hidden=hidden_size, - num_experts=self._n_routed_experts, - ) + dispatched: DeepEPDispatchResult, + post_dispatched: DeepEPPostDispatchResult, + pre_combined: DeepEPPreCombineResult, + async_op: bool = False, + decoding: bool = False, + ) -> CombineResult: + if async_op: + backward_previous_event = EventOverlap(None) + assert pre_combined["forward_finished_event"] is not None, "Please use `async_op=True` for combine!" + pre_combined["forward_finished_event"].current_stream_wait() + else: + backward_previous_event = None - # Do MoE combine, compatible with CUDA graph (but you may restore some buffer status once you replay) - combined_x, _, _ = _buffer.low_latency_combine( - x=hidden_states, - topk_idx=pre_dispatched["topk_ids"], - topk_weights=pre_dispatched["topk_weights"], - handle=dispatched_result["handle"], - async_finish=False, - return_recv_hook=False, + combined_hidden_states, event = _async_combine( + pre_combined["hidden_states"], + self._n_routed_experts, + dispatched["handle"], + self._process_group, + pre_combined["forward_finished_event"], + backward_previous_event, + pre_combined["backward_previous_event"], ) + if not async_op: + event.current_stream_wait() - return DeepEPDecodingCombineResult(hidden_states=combined_x) + if not decoding: + return DeepEPCombineResult( + hidden_states=combined_hidden_states, + forward_finished_event=event, + backward_previous_event=backward_previous_event, + ) + else: + raise NotImplementedError - def _training_unpermute_activation( + @override + def combine_postprocess( self, - hidden_states: torch.Tensor, - row_id_map: torch.Tensor, - topk_weights: torch.Tensor | None = None, - ): - # assert self.ep_mesh.size() > 1 - activation = unpermute( - input_act=hidden_states, - row_id_map=row_id_map, - probs=topk_weights, - ) - return activation + *, + pre_dispatched: DeepEPPreDispatchResult, + dispatched: DeepEPDispatchResult, + post_dispatched: DeepEPPostDispatchResult, + pre_combined: DeepEPPreCombineResult, + combined: DeepEPCombineResult, + async_op: bool = False, + ) -> PostCombineResult: + hidden_states = combined["hidden_states"] + forward_previous_event = combined["forward_finished_event"] + + hidden_states = hidden_states.view_as(hidden_states) + + if hidden_states.grad_fn is not None: + hidden_states.grad_fn.register_hook( + get_backward_hook( + backward_finished_event=combined["backward_previous_event"], + name="DeeEPDispatcher.combine_postprocess", + debug=XTUNER_DISPATCHER_DEBUG, + ) + ) + + if async_op: + assert forward_previous_event is not None, "Please use `async_op=True` for combine!" + forward_previous_event.current_stream_wait() + return PostCombineResult(hidden_states=hidden_states) diff --git a/xtuner/v1/module/router/noaux_router.py b/xtuner/v1/module/router/noaux_router.py index f915d516e..6aea89f2c 100644 --- a/xtuner/v1/module/router/noaux_router.py +++ b/xtuner/v1/module/router/noaux_router.py @@ -59,10 +59,6 @@ def __init__( ) def forward(self, logits) -> RouterResults: - if os.getenv("XTUNER_ROUTER_DEBUG") == "true": - noise = torch.randn_like(logits) * 50 - logits = logits + noise - if self.scoring_func == "sigmoid": scores = logits.sigmoid() else: @@ -71,6 +67,10 @@ def forward(self, logits) -> RouterResults: scores_for_choice = scores + self.e_score_correction_bias.unsqueeze(0) + if os.getenv("XTUNER_ROUTER_DEBUG") == "true": + noise = torch.randn_like(scores) * 50 + scores_for_choice = scores + noise + # select top-k experts # (only applicable when ep_size >= 64. when ep_size=32 (4 nodes), there is no need to employ this strategy) _, topk_idx = torch.topk(scores_for_choice, k=self.top_k, dim=-1) diff --git a/xtuner/v1/ops/comm/deepep_op.py b/xtuner/v1/ops/comm/deepep_op.py index b2e0ebdf6..575d1ea26 100644 --- a/xtuner/v1/ops/comm/deepep_op.py +++ b/xtuner/v1/ops/comm/deepep_op.py @@ -20,7 +20,8 @@ _low_latency_buffer: Optional[Buffer] = None # Set the number of SMs to use # NOTES: this is a static variable -Buffer.set_num_sms(24) +# Buffer.set_num_sms(24) +Buffer.set_num_sms(20) # You may call this function at the framework initialization diff --git a/xtuner/v1/ops/flash_attn/gpu.py b/xtuner/v1/ops/flash_attn/gpu.py index 6a1b3fc06..ecddab00b 100644 --- a/xtuner/v1/ops/flash_attn/gpu.py +++ b/xtuner/v1/ops/flash_attn/gpu.py @@ -72,6 +72,7 @@ def _flash_attn_varlen_forward_v3_fake( softcap: float = 0.0, # 0.0 means deactivated ) -> tuple[torch.Tensor, torch.Tensor]: total_q, num_heads, _ = q.shape + q = q.contiguous() out = torch.empty_like(q) softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device, layout=q.layout) return out, softmax_lse @@ -190,6 +191,7 @@ def forward( window_size[1], softcap, ) + # torch.distributed.breakpoint() ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k) ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_k = max_seqlen_k diff --git a/xtuner/v1/ops/rms_norm.py b/xtuner/v1/ops/rms_norm/__init__.py similarity index 62% rename from xtuner/v1/ops/rms_norm.py rename to xtuner/v1/ops/rms_norm/__init__.py index d65707480..9dfd4b65e 100644 --- a/xtuner/v1/ops/rms_norm.py +++ b/xtuner/v1/ops/rms_norm/__init__.py @@ -1,10 +1,8 @@ -from typing import Protocol +from functools import partial import torch - -class RMSNormProtocol(Protocol): - def __call__(self, x: torch.Tensor, weight: torch.Tensor, epsilon: float) -> torch.Tensor: ... +from .protocol import RMSNormProtocol def native_rms_norm(x: torch.Tensor, weight: torch.Tensor, epsilon: float) -> torch.Tensor: @@ -19,7 +17,13 @@ def npu_rms_norm(x: torch.Tensor, weight: torch.Tensor, epsilon: float) -> torch return torch_npu.npu_rms_norm(x, weight, epsilon=epsilon)[0] -def get_rms_norm() -> RMSNormProtocol: +def gpu_rms_norm(x: torch.Tensor, weight: torch.Tensor, epsilon: float) -> torch.Tensor: + from .gpu import rms_norm_fn + + return rms_norm_fn(x, weight, bias=None, eps=epsilon) + + +def get_rms_norm_fn() -> RMSNormProtocol: from xtuner.v1.utils import get_device device = get_device() @@ -28,7 +32,7 @@ def get_rms_norm() -> RMSNormProtocol: elif device == "npu": return npu_rms_norm else: - return native_rms_norm + return gpu_rms_norm -rms_norm = get_rms_norm() +rms_norm = get_rms_norm_fn() diff --git a/xtuner/v1/ops/rms_norm/gpu.py b/xtuner/v1/ops/rms_norm/gpu.py new file mode 100644 index 000000000..7aa523275 --- /dev/null +++ b/xtuner/v1/ops/rms_norm/gpu.py @@ -0,0 +1,1126 @@ +# Copied from https://github.com/Dao-AILab/flash-attention/blob/fd2fc9d85c8e54e5c20436465bca709bc1a6c5a1/flash_attn/ops/triton/layer_norm.py +# To reduce version constraints on the flash_attn library, we copied an updated +# rms_norm_fn operator from the flash_attn library that supports torch.compile, +# as the lower version's implementation does not. +# Copyright (c) 2024, Tri Dao. +# Implement dropout + residual + layer_norm / rms_norm. + +# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html +# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. +# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling. +# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. + +import math +from typing import Callable, Iterable, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +from torch import Tensor +from torch._library.triton import set_wrap_triton_enabled +from torch.library import CustomOpDef, custom_op + + +def triton_op( + name: str, + fn: Optional[Callable] = None, + /, + *, + mutates_args: Union[str, Iterable[str]], + schema: Optional[str] = None, + # If allow_decomposition=True, this matches torch.library.triton_op behavior. If set to False, + # then it behaves like torch.library.custom_op instead, which doesn't decompose the operator + # and so inductor can't trace inside. + allow_decomposition=True, +) -> Callable: + def dec(fn: Callable[..., object]) -> CustomOpDef: + def backend_fn(*args, **kwargs): # type: ignore[no-untyped-def] + # Optimization: we're passing regular Tensors into the triton kernel, so + # no need to go through HOP dispatch + with set_wrap_triton_enabled(False): + return fn(*args, **kwargs) + + result = custom_op( + name, + backend_fn, + mutates_args=mutates_args, + # This is the only difference with the PyTorch implementation + schema=schema, + ) + from torch._subclasses.functional_tensor import FunctionalTensorMode + + # We require that the user pass us a function that is make_fx traceable, + # so we can just register it as the Fake/meta kernel. + result.register_fake(fn) + + if allow_decomposition: + # We decompose the operator when FunctionalTensorMode is active. + # The goal is to decompose the operator in AOTDispatcher. + # - With torch.compile, this means that the backend (usually Inductor) + # can see a call to the triton kernel(s) and so it can directly optimize + # them by inlining them into the lowering process. + def functional_decomp( # type: ignore[no-untyped-def] + mode, op, types, args, kwargs + ): + # from torch.export._trace import custom_triton_ops_decomposition_disabled + + # if custom_triton_ops_decomposition_disabled(): + # return mode.__torch_dispatch__(op, types, args, kwargs) + # else: + with mode: + return fn(*args, **kwargs) + + result.register_torch_dispatch(FunctionalTensorMode, functional_decomp) + + return result + + if fn is None: + return dec + else: + return dec(fn) + + +def maybe_contiguous_lastdim(x): + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + + +def maybe_contiguous(x): + return x.contiguous() if x is not None else None + + +def triton_autotune_configs(): + # Return configs with a valid warp count for the current device + configs = [] # noqa: F841 + # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024 + max_threads_per_block = 1024 + # Default to warp size 32 if not defined by device + warp_size = getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32) + # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit + return [ + triton.Config({}, num_warps=warp_count) + for warp_count in [1, 2, 4, 8, 16, 32] + if warp_count * warp_size <= max_threads_per_block + ] + # return [triton.Config({}, num_warps=8)] + + +def layer_norm_ref( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + zero_centered_weight=False, + dropout_mask=None, + dropout_mask1=None, + upcast=False, +): + dtype = x.dtype + if upcast: + x = x.float() + weight = weight.float() + bias = bias.float() if bias is not None else None + residual = residual.float() if residual is not None else residual + x1 = x1.float() if x1 is not None else None + weight1 = weight1.float() if weight1 is not None else None + bias1 = bias1.float() if bias1 is not None else None + if zero_centered_weight: + weight = weight + 1.0 + if weight1 is not None: + weight1 = weight1 + 1.0 + if x1 is not None: + assert rowscale is None, "rowscale is not supported with parallel LayerNorm" + if rowscale is not None: + x = x * rowscale[..., None] + if dropout_p > 0.0: + if dropout_mask is not None: + x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) + else: + x = F.dropout(x, p=dropout_p) + if x1 is not None: + if dropout_mask1 is not None: + x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p) + else: + x1 = F.dropout(x1, p=dropout_p) + if x1 is not None: + x = x + x1 + if residual is not None: + x = (x + residual).to(x.dtype) + out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(dtype) + if weight1 is None: + return out if not prenorm else (out, x) + else: + out1 = F.layer_norm(x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps).to(dtype) + return (out, out1) if not prenorm else (out, out1, x) + + +def rms_norm_ref( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + zero_centered_weight=False, + dropout_mask=None, + dropout_mask1=None, + upcast=False, +): + dtype = x.dtype + if upcast: + x = x.float() + weight = weight.float() + bias = bias.float() if bias is not None else None + residual = residual.float() if residual is not None else residual + x1 = x1.float() if x1 is not None else None + weight1 = weight1.float() if weight1 is not None else None + bias1 = bias1.float() if bias1 is not None else None + if zero_centered_weight: + weight = weight + 1.0 + if weight1 is not None: + weight1 = weight1 + 1.0 + if x1 is not None: + assert rowscale is None, "rowscale is not supported with parallel LayerNorm" + if rowscale is not None: + x = x * rowscale[..., None] + if dropout_p > 0.0: + if dropout_mask is not None: + x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) + else: + x = F.dropout(x, p=dropout_p) + if x1 is not None: + if dropout_mask1 is not None: + x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p) + else: + x1 = F.dropout(x1, p=dropout_p) + if x1 is not None: + x = x + x1 + if residual is not None: + x = (x + residual).to(x.dtype) + rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) + out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype) + if weight1 is None: + return out if not prenorm else (out, x) + else: + out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to(dtype) + return (out, out1) if not prenorm else (out, out1, x) + + +@triton.autotune( + configs=triton_autotune_configs(), + key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS", "HAS_X1", "HAS_W1", "HAS_B1"], +) +# torch compile doesn't like triton.heuristics, so we set these manually when calling the kernel +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) +# @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None}) +# @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None}) +# @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None}) +@triton.jit +def _layer_norm_fwd_1pass_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + RESIDUAL, # pointer to the residual + X1, + W1, + B1, + Y1, + RESIDUAL_OUT, # pointer to the residual + ROWSCALE, + SEEDS, # Dropout seeds for each row + DROPOUT_MASK, + DROPOUT_MASK1, + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_res_row, + stride_res_out_row, + stride_x1_row, + stride_y1_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + dropout_p, # Dropout probability + zero_centered_weight, # If true, add 1.0 to the weight + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + STORE_RESIDUAL_OUT: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_DROPOUT: tl.constexpr, + STORE_DROPOUT_MASK: tl.constexpr, + HAS_ROWSCALE: tl.constexpr, + HAS_X1: tl.constexpr, + HAS_W1: tl.constexpr, + HAS_B1: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_y_row + if HAS_RESIDUAL: + RESIDUAL += row * stride_res_row + if STORE_RESIDUAL_OUT: + RESIDUAL_OUT += row * stride_res_out_row + if HAS_X1: + X1 += row * stride_x1_row + if HAS_W1: + Y1 += row * stride_y1_row + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + row).to(tl.float32) + x *= rowscale + if HAS_DROPOUT: + # Compute dropout mask + # 7 rounds is good enough, and reduces register pressure + keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0) + if STORE_DROPOUT_MASK: + tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N) + if HAS_X1: + x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + M + row).to(tl.float32) + x1 *= rowscale + if HAS_DROPOUT: + # Compute dropout mask + # 7 rounds is good enough, and reduces register pressure + keep_mask = tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0) + if STORE_DROPOUT_MASK: + tl.store(DROPOUT_MASK1 + row * N + cols, keep_mask, mask=cols < N) + x += x1 + if HAS_RESIDUAL: + residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) + x += residual + if STORE_RESIDUAL_OUT: + tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w += 1.0 + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + y = x_hat * w + b if HAS_BIAS else x_hat * w + # Write output + tl.store(Y + cols, y, mask=mask) + if HAS_W1: + w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w1 += 1.0 + if HAS_B1: + b1 = tl.load(B1 + cols, mask=mask).to(tl.float32) + y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1 + tl.store(Y1 + cols, y1, mask=mask) + + +def _layer_norm_fwd( + x: Tensor, + weight: Tensor, + bias: Tensor, + eps: float, + residual: Optional[Tensor] = None, + x1: Optional[Tensor] = None, + weight1: Optional[Tensor] = None, + bias1: Optional[Tensor] = None, + dropout_p: float = 0.0, + rowscale: Optional[Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + residual_dtype: Optional[torch.dtype] = None, + zero_centered_weight: bool = False, + is_rms_norm: bool = False, + return_dropout_mask: bool = False, + out: Optional[Tensor] = None, + residual_out: Optional[Tensor] = None, +) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + # Need to wrap to handle the case where residual_out is a alias of x, which makes torch.library + # and torch.compile unhappy. Also allocate memory for out and residual_out if they are None + # so that _layer_norm_fwd_impl doesn't have to return them. + if out is None: + out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) + if residual is not None: + residual_dtype = residual.dtype + if residual_out is None and ( + residual is not None + or (residual_dtype is not None and residual_dtype != x.dtype) + or dropout_p > 0.0 + or rowscale is not None + or x1 is not None + ): + residual_out = torch.empty_like(x, dtype=residual_dtype if residual_dtype is not None else x.dtype) + else: + residual_out = None + y1, mean, rstd, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd_impl( + x, + weight, + bias, + eps, + out, + residual=residual, + x1=x1, + weight1=weight1, + bias1=bias1, + dropout_p=dropout_p, + rowscale=rowscale, + zero_centered_weight=zero_centered_weight, + is_rms_norm=is_rms_norm, + return_dropout_mask=return_dropout_mask, + residual_out=residual_out, + ) + # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0 + if residual_out is None: + residual_out = x + return out, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 + + +# [2025-04-28] torch.library.triton_op ignores the schema argument, but here we need the schema +# since we're returning a tuple of tensors +@triton_op( + "flash_attn::layer_norm_fwd_impl", + mutates_args={"out", "residual_out"}, + schema="(Tensor x, Tensor weight, Tensor bias, float eps, Tensor(a!) out, Tensor? residual, Tensor? x1, Tensor? weight1, Tensor? bias1, float dropout_p, Tensor? rowscale, bool zero_centered_weight, bool is_rms_norm, bool return_dropout_mask, Tensor(a!)? residual_out) -> (Tensor y1, Tensor mean, Tensor rstd, Tensor seeds, Tensor dropout_mask, Tensor dropout_mask1)", +) +def _layer_norm_fwd_impl( + x: Tensor, + weight: Tensor, + bias: Tensor, + eps: float, + out: Tensor, + residual: Optional[Tensor] = None, + x1: Optional[Tensor] = None, + weight1: Optional[Tensor] = None, + bias1: Optional[Tensor] = None, + dropout_p: float = 0.0, + rowscale: Optional[Tensor] = None, + zero_centered_weight: bool = False, + is_rms_norm: bool = False, + return_dropout_mask: bool = False, + residual_out: Optional[Tensor] = None, +) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + M, N = x.shape + assert x.stride(-1) == 1 + if residual is not None: + assert residual.stride(-1) == 1 + assert residual.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + if x1 is not None: + assert x1.shape == x.shape + assert rowscale is None + assert x1.stride(-1) == 1 + if weight1 is not None: + assert weight1.shape == (N,) + assert weight1.stride(-1) == 1 + if bias1 is not None: + assert bias1.shape == (N,) + assert bias1.stride(-1) == 1 + if rowscale is not None: + assert rowscale.is_contiguous() + assert rowscale.shape == (M,) + assert out.shape == x.shape + assert out.stride(-1) == 1 + if residual_out is not None: + assert residual_out.shape == x.shape + assert residual_out.stride(-1) == 1 + if weight1 is not None: + y1 = torch.empty_like(out) + assert y1.stride(-1) == 1 + else: + y1 = None + mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None + rstd = torch.empty((M,), dtype=torch.float32, device=x.device) + if dropout_p > 0.0: + seeds = torch.randint(2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64) + else: + seeds = None + if return_dropout_mask and dropout_p > 0.0: + dropout_mask = torch.empty(M, N, device=x.device, dtype=torch.bool) + if x1 is not None: + dropout_mask1 = torch.empty(M, N, device=x.device, dtype=torch.bool) + else: + dropout_mask1 = None + else: + dropout_mask, dropout_mask1 = None, None + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + with torch.cuda.device(x.device.index): + torch.library.wrap_triton(_layer_norm_fwd_1pass_kernel)[(M,)]( + x, + out, + weight, + bias, + residual, + x1, + weight1, + bias1, + y1, + residual_out, + rowscale, + seeds, + dropout_mask, + dropout_mask1, + mean, + rstd, + x.stride(0), + out.stride(0), + residual.stride(0) if residual is not None else 0, + residual_out.stride(0) if residual_out is not None else 0, + x1.stride(0) if x1 is not None else 0, + y1.stride(0) if y1 is not None else 0, + M, + N, + eps, + dropout_p, + # Passing bool make torch inductor very unhappy since it then tries to compare to int_max + int(zero_centered_weight), + is_rms_norm, + BLOCK_N, + residual is not None, + residual_out is not None, + bias is not None, + dropout_p > 0.0, + dropout_mask is not None, + rowscale is not None, + HAS_X1=x1 is not None, + HAS_W1=weight1 is not None, + HAS_B1=bias1 is not None, + ) + return y1, mean, rstd, seeds, dropout_mask, dropout_mask1 # type: ignore + + +@triton.autotune( + configs=triton_autotune_configs(), + key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"], +) +# torch compile doesn't like triton.heuristics, so we set these manually when calling the kernel +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) +# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) +# @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None}) +# @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None}) +# @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None}) +# @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None}) +# @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) +@triton.jit +def _layer_norm_bwd_kernel( + X, # pointer to the input + W, # pointer to the weights + B, # pointer to the biases + Y, # pointer to the output to be recomputed + DY, # pointer to the output gradient + DX, # pointer to the input gradient + DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + DRESIDUAL, + W1, + DY1, + DX1, + DW1, + DB1, + DRESIDUAL_IN, + ROWSCALE, + SEEDS, + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_dy_row, + stride_dx_row, + stride_dres_row, + stride_dy1_row, + stride_dx1_row, + stride_dres_in_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + dropout_p, + zero_centered_weight, + rows_per_program, + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_DRESIDUAL: tl.constexpr, + STORE_DRESIDUAL: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_DROPOUT: tl.constexpr, + HAS_ROWSCALE: tl.constexpr, + HAS_DY1: tl.constexpr, + HAS_DX1: tl.constexpr, + HAS_B1: tl.constexpr, + RECOMPUTE_OUTPUT: tl.constexpr, +): + # Map the program id to the elements of X, DX, and DY it should compute. + row_block_id = tl.program_id(0) + row_start = row_block_id * rows_per_program + # Do not early exit if row_start >= M, because we need to write DW and DB + cols = tl.arange(0, BLOCK_N) + mask = cols < N + X += row_start * stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += row_start * stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += row_start * stride_dres_in_row + DY += row_start * stride_dy_row + DX += row_start * stride_dx_row + if HAS_DY1: + DY1 += row_start * stride_dy1_row + if HAS_DX1: + DX1 += row_start * stride_dx1_row + if RECOMPUTE_OUTPUT: + Y += row_start * stride_y_row + w = tl.load(W + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w += 1.0 + if RECOMPUTE_OUTPUT and HAS_BIAS: + b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) + if HAS_DY1: + w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w1 += 1.0 + dw = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_BIAS: + db = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_DY1: + dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_B1: + db1 = tl.zeros((BLOCK_N,), dtype=tl.float32) + row_end = min((row_block_id + 1) * rows_per_program, M) + for row in range(row_start, row_end): + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + if HAS_DY1: + dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32) + if not IS_RMS_NORM: + mean = tl.load(Mean + row) + rstd = tl.load(Rstd + row) + # Compute dx + xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + xhat = tl.where(mask, xhat, 0.0) + if RECOMPUTE_OUTPUT: + y = xhat * w + b if HAS_BIAS else xhat * w + tl.store(Y + cols, y, mask=mask) + wdy = w * dy + dw += dy * xhat + if HAS_BIAS: + db += dy + if HAS_DY1: + wdy += w1 * dy1 + dw1 += dy1 * xhat + if HAS_B1: + db1 += dy1 + if not IS_RMS_NORM: + c1 = tl.sum(xhat * wdy, axis=0) / N + c2 = tl.sum(wdy, axis=0) / N + dx = (wdy - (xhat * c1 + c2)) * rstd + else: + c1 = tl.sum(xhat * wdy, axis=0) / N + dx = (wdy - xhat * c1) * rstd + if HAS_DRESIDUAL: + dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32) + dx += dres + # Write dx + if STORE_DRESIDUAL: + tl.store(DRESIDUAL_IN + cols, dx, mask=mask) + if HAS_DX1: + if HAS_DROPOUT: + keep_mask = tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) + else: + dx1 = dx + tl.store(DX1 + cols, dx1, mask=mask) + if HAS_DROPOUT: + keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + row).to(tl.float32) + dx *= rowscale + tl.store(DX + cols, dx, mask=mask) + + X += stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += stride_dres_in_row + if RECOMPUTE_OUTPUT: + Y += stride_y_row + DY += stride_dy_row + DX += stride_dx_row + if HAS_DY1: + DY1 += stride_dy1_row + if HAS_DX1: + DX1 += stride_dx1_row + tl.store(DW + row_block_id * N + cols, dw, mask=mask) + if HAS_BIAS: + tl.store(DB + row_block_id * N + cols, db, mask=mask) + if HAS_DY1: + tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask) + if HAS_B1: + tl.store(DB1 + row_block_id * N + cols, db1, mask=mask) + + +def _layer_norm_bwd( + dy: Tensor, + x: Tensor, + weight: Tensor, + bias: Tensor, + eps: float, + mean: Tensor, + rstd: Tensor, + dresidual: Optional[Tensor] = None, + dy1: Optional[Tensor] = None, + weight1: Optional[Tensor] = None, + bias1: Optional[Tensor] = None, + seeds: Optional[Tensor] = None, + dropout_p: float = 0.0, + rowscale: Optional[Tensor] = None, + has_residual: bool = False, + has_x1: bool = False, + zero_centered_weight: bool = False, + is_rms_norm: bool = False, + x_dtype: Optional[torch.dtype] = None, + recompute_output: bool = False, +) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + # Need to wrap to handle the case where dresidual_in or dx1 are aliases of x, + # which makes torch.library unhappy + dx, dw, db, dresidual_in, dx1, dw1, db1, y = _layer_norm_bwd_impl( + dy, + x, + weight, + bias, + eps, + mean, + rstd, + dresidual, + dy1, + weight1, + bias1, + seeds, + dropout_p, + rowscale, + has_residual, + has_x1, + zero_centered_weight, + is_rms_norm, + x_dtype=x_dtype, + recompute_output=recompute_output, + ) + # Don't need to compute dresidual_in separately in this case + if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None: + dresidual_in = dx + if has_x1 and dropout_p == 0.0: + dx1 = dx + return dx, dw, db, dresidual_in, dx1, dw1, db1, y + + +@triton_op( + "flash_attn::layer_norm_bwd_impl", + mutates_args={}, + schema="(Tensor dy, Tensor x, Tensor weight, Tensor bias, float eps, Tensor mean, Tensor rstd, Tensor? dresidual, Tensor? dy1, Tensor? weight1, Tensor? bias1, Tensor? seeds, float dropout_p, Tensor? rowscale, bool has_residual, bool has_x1, bool zero_centered_weight, bool is_rms_norm, ScalarType? x_dtype, bool recompute_output) -> (Tensor dx, Tensor dw, Tensor db, Tensor dresidual_in, Tensor dx1, Tensor dw1, Tensor db1, Tensor y)", + allow_decomposition=False, # Don't let torch.compile trace inside +) +def _layer_norm_bwd_impl( + dy: Tensor, + x: Tensor, + weight: Tensor, + bias: Tensor, + eps: float, + mean: Tensor, + rstd: Tensor, + dresidual: Optional[Tensor] = None, + dy1: Optional[Tensor] = None, + weight1: Optional[Tensor] = None, + bias1: Optional[Tensor] = None, + seeds: Optional[Tensor] = None, + dropout_p: float = 0.0, + rowscale: Optional[Tensor] = None, + has_residual: bool = False, + has_x1: bool = False, + zero_centered_weight: bool = False, + is_rms_norm: bool = False, + x_dtype: Optional[torch.dtype] = None, + recompute_output: bool = False, +) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + M, N = x.shape + assert x.stride(-1) == 1 + dy = maybe_contiguous_lastdim(dy) + assert dy.stride(-1) == 1 + assert dy.shape == (M, N) + if dresidual is not None: + dresidual = maybe_contiguous_lastdim(dresidual) + assert dresidual.stride(-1) == 1 + assert dresidual.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + if dy1 is not None: + dy1 = maybe_contiguous_lastdim(dy1) + assert weight1 is not None + assert dy1.shape == dy.shape + assert dy1.stride(-1) == 1 + if weight1 is not None: + assert weight1.shape == (N,) + assert weight1.stride(-1) == 1 + if bias1 is not None: + assert bias1.shape == (N,) + assert bias1.stride(-1) == 1 + if seeds is not None: + assert seeds.is_contiguous() + assert seeds.shape == (M if not has_x1 else M * 2,) + if rowscale is not None: + assert rowscale.is_contiguous() + assert rowscale.shape == (M,) + # allocate output + dx = torch.empty_like(x) if x_dtype is None else torch.empty(M, N, dtype=x_dtype, device=x.device) + dresidual_in = ( + torch.empty_like(x) + if has_residual and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1) + else None + ) + dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None + y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None + if recompute_output: + assert weight1 is None, "recompute_output is not supported with parallel LayerNorm" + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # Increasing the multiple (e.g. 8) will allow more thread blocks to be launched and hide the + # latency of the gmem reads/writes, but will increase the time of summing up dw / db. + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count * 8 + _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) + _db = torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) if bias is not None else None + _dw1 = torch.empty_like(_dw) if weight1 is not None else None + _db1 = torch.empty_like(_db) if bias1 is not None else None + rows_per_program = math.ceil(M / sm_count) + grid = (sm_count,) + with torch.cuda.device(x.device.index): + torch.library.wrap_triton(_layer_norm_bwd_kernel)[grid]( + x, + weight, + bias, + y, + dy, + dx, + _dw, + _db, + dresidual, + weight1, + dy1, + dx1, + _dw1, + _db1, + dresidual_in, + rowscale, + seeds, + mean, + rstd, + x.stride(0), + 0 if not recompute_output else y.stride(0), # type: ignore + dy.stride(0), + dx.stride(0), + dresidual.stride(0) if dresidual is not None else 0, + dy1.stride(0) if dy1 is not None else 0, + dx1.stride(0) if dx1 is not None else 0, + dresidual_in.stride(0) if dresidual_in is not None else 0, + M, + N, + eps, + dropout_p, + # Passing bool make torch inductor very unhappy since it then tries to compare to int_max + int(zero_centered_weight), + rows_per_program, + is_rms_norm, + BLOCK_N, + dresidual is not None, + dresidual_in is not None, + bias is not None, + dropout_p > 0.0, + HAS_ROWSCALE=rowscale is not None, + HAS_DY1=dy1 is not None, + HAS_DX1=dx1 is not None, + HAS_B1=bias1 is not None, + RECOMPUTE_OUTPUT=y is not None, + ) + dw = _dw.sum(0).to(weight.dtype) + db = _db.sum(0).to(bias.dtype) if bias is not None else None + dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None # type: ignore + db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None # type: ignore + # dresidual_in and dx1 could be None, the wrapper will handle assigning them from dx + return dx, dw, db, dresidual_in, dx1, dw1, db1, y # type: ignore + + +class LayerNormFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + zero_centered_weight=False, + is_rms_norm=False, + return_dropout_mask=False, + out_dtype=None, + out=None, + residual_out=None, + ): + x_shape_og = x.shape + # reshape input data into 2D tensor + x = maybe_contiguous_lastdim(x.reshape(-1, x.shape[-1])) + if residual is not None: + assert residual.shape == x_shape_og + residual = maybe_contiguous_lastdim(residual.reshape(-1, residual.shape[-1])) + if x1 is not None: + assert x1.shape == x_shape_og + assert rowscale is None, "rowscale is not supported with parallel LayerNorm" + x1 = maybe_contiguous_lastdim(x1.reshape(-1, x1.shape[-1])) + weight = weight.contiguous() + bias = maybe_contiguous(bias) + weight1 = maybe_contiguous(weight1) + bias1 = maybe_contiguous(bias1) + if rowscale is not None: + rowscale = rowscale.reshape(-1).contiguous() + residual_dtype = residual.dtype if residual is not None else (torch.float32 if residual_in_fp32 else None) + if out is not None: + out = out.reshape(-1, out.shape[-1]) + if residual_out is not None: + residual_out = residual_out.reshape(-1, residual_out.shape[-1]) + y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd( + x, + weight, + bias, + eps, + residual, + x1, + weight1, + bias1, + dropout_p=dropout_p, + rowscale=rowscale, + out_dtype=out_dtype, + residual_dtype=residual_dtype, + zero_centered_weight=zero_centered_weight, + is_rms_norm=is_rms_norm, + return_dropout_mask=return_dropout_mask, + out=out, + residual_out=residual_out, + ) + ctx.save_for_backward(residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.dropout_p = dropout_p + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.has_x1 = x1 is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + ctx.zero_centered_weight = zero_centered_weight + y = y.reshape(x_shape_og) + y1 = y1.reshape(x_shape_og) if y1 is not None else None + residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None + dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None + dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None + if not return_dropout_mask: + if weight1 is None: + return y if not prenorm else (y, residual_out) + else: + return (y, y1) if not prenorm else (y, y1, residual_out) + else: + if weight1 is None: + return ( + (y, dropout_mask, dropout_mask1) if not prenorm else (y, residual_out, dropout_mask, dropout_mask1) + ) + else: + return ( + (y, y1, dropout_mask, dropout_mask1) + if not prenorm + else (y, y1, residual_out, dropout_mask, dropout_mask1) + ) + + @staticmethod + def backward(ctx, dy, *args): + x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors + dy = dy.reshape(-1, dy.shape[-1]) + if weight1 is not None: + dy1, args = args[0], args[1:] + dy1 = dy1.reshape(-1, dy1.shape[-1]) + assert dy1.shape == x.shape + else: + dy1 = None + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + assert dresidual.shape == x.shape + else: + dresidual = None + dx, dw, db, dresidual_in, dx1, dw1, db1, _ = _layer_norm_bwd( + dy, + x, + weight, + bias, + ctx.eps, + mean, + rstd, + dresidual, + dy1, + weight1, + bias1, + seeds, + ctx.dropout_p, + rowscale, + ctx.has_residual, + ctx.has_x1, + ctx.zero_centered_weight, + ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + recompute_output=False, + ) + return ( + dx.reshape(ctx.x_shape_og), + dw, + db, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + dx1.reshape(ctx.x_shape_og) if dx1 is not None else None, + dw1, + db1, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +def layer_norm_fn( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + zero_centered_weight=False, + is_rms_norm=False, + return_dropout_mask=False, + out_dtype=None, + out=None, + residual_out=None, +): + return LayerNormFn.apply( + x, + weight, + bias, + residual, + x1, + weight1, + bias1, + eps, + dropout_p, + rowscale, + prenorm, + residual_in_fp32, + zero_centered_weight, + is_rms_norm, + return_dropout_mask, + out_dtype, + out, + residual_out, + ) + + +def rms_norm_fn( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + zero_centered_weight=False, + return_dropout_mask=False, + out_dtype=None, + out=None, + residual_out=None, +): + return LayerNormFn.apply( + x, + weight, + bias, + residual, + x1, + weight1, + bias1, + eps, + dropout_p, + rowscale, + prenorm, + residual_in_fp32, + zero_centered_weight, + True, + return_dropout_mask, + out_dtype, + out, + residual_out, + ) diff --git a/xtuner/v1/ops/rms_norm/protocol.py b/xtuner/v1/ops/rms_norm/protocol.py new file mode 100644 index 000000000..1ce7b5ea3 --- /dev/null +++ b/xtuner/v1/ops/rms_norm/protocol.py @@ -0,0 +1,7 @@ +from typing import Protocol + +import torch + + +class RMSNormProtocol(Protocol): + def __call__(self, x: torch.Tensor, weight: torch.Tensor, epsilon: float) -> torch.Tensor: ... diff --git a/xtuner/v1/train/trainer.py b/xtuner/v1/train/trainer.py index d8f7deda4..41668218d 100644 --- a/xtuner/v1/train/trainer.py +++ b/xtuner/v1/train/trainer.py @@ -1,4 +1,5 @@ import contextlib +import gc import json import os import pickle @@ -144,7 +145,7 @@ class TrainerConfig(BaseModel): hf_interval: int | None = None hf_max_keep: int | None = None exp_tracker: Literal["tensorboard", "jsonl"] = "jsonl" - profile_step: int | None = None + profile_step: list[int] | int | None = None profile_time: bool = True profile_memory: bool = False intra_layer_micro_batch: int = 1 @@ -191,7 +192,7 @@ class Trainer: checkpoint_maxkeep (int | None): Maximum number of checkpoints to keep. hf_interval (int | None): Interval for saving Huggingface format checkpoints. hf_max_keep (int | None): Maximum number of Huggingface checkpoints to keep. - profile_step (int | None): Step to perform profiling. + profile_step (list[int] | int | None): Step to perform profiling. profile_time (bool): Whether to profile training time. profile_memory (bool): Whether to profile memory usage. intra_layer_micro_batch (int): Intra-layer micro batch size. @@ -240,7 +241,7 @@ def __init__( hf_interval: int | None = None, hf_max_keep: int | None = None, exp_tracker: Literal["tensorboard", "jsonl"] = "jsonl", - profile_step: int | None = None, + profile_step: list[int] | int | None = None, profile_time: bool = True, profile_memory: bool = False, intra_layer_micro_batch: int = 1, @@ -262,6 +263,8 @@ def __init__( if skip_checkpoint_validation: patch_default_save_plan() + if isinstance(profile_step, int): + profile_step = [profile_step] self._profile_step = profile_step self._profile_time = profile_time self._profile_memory = profile_memory @@ -502,6 +505,9 @@ def fit(self): time_before_get_data = time.time() + if self.cur_step % 50 == 0: + gc.collect() + @property def world_size(self) -> int: """Get the total number of processes in the distributed training group. @@ -956,14 +962,14 @@ def _init_xtuner_meta(self, work_dir: Path, auto_resume: bool) -> XTunerMeta: @contextmanager def _maybe_profiling(self): """Check if profiling is enabled and perform profiling if necessary.""" - if self._profile_step is not None and self._cur_step == self._profile_step: + if self._profile_step is not None and self._cur_step in self._profile_step: with contextlib.ExitStack() as stack: if self._profile_time: - time_dir = self.work_dir / self._PROFILE_TIME_PATH / f"step-{self._cur_step}" + time_dir = self.exp_dir / self._PROFILE_TIME_PATH / f"step-{self._cur_step}" stack.enter_context(profiling_time(time_dir)) if self._profile_memory: - memory_dir = self.work_dir / self._PROFILE_MEMORY_PATH / f"step-{self._cur_step}" + memory_dir = self.exp_dir / self._PROFILE_MEMORY_PATH / f"step-{self._cur_step}" stack.enter_context(profiling_memory(memory_dir)) yield else: @@ -1198,6 +1204,7 @@ def _resume_dataloader(self, dataloader_path: Path): self._dataloader.load_state_dict(dataloader_state) def _setup_env(self): + gc.disable() os.environ["TOKENIZERS_PARALLELISM"] = "true" log_str = "\n============XTuner Training Environment============\n" diff --git a/xtuner/v1/utils/activation_offload.py b/xtuner/v1/utils/activation_offload.py index 86acc50bd..d5a5055df 100644 --- a/xtuner/v1/utils/activation_offload.py +++ b/xtuner/v1/utils/activation_offload.py @@ -18,7 +18,7 @@ def base_check_fn(tensor): if isinstance(tensor._base, torch.nn.parameter.Parameter) or isinstance(tensor, torch.nn.parameter.Parameter): return False - if tensor.storage().size() <= 0: + if tensor.untyped_storage().size() <= 0: return False return True @@ -99,24 +99,33 @@ def wait_d2h_finished(self, stream, flag): self.stat = "host" # resize storage_size and host to device - def launch_h2d(self, h2d_stream, flag, working_stream): + def launch_h2d(self, h2d_stream: torch.cuda.Stream, flag, working_stream): if self.stat != "host": + # waiting for `non-blocking` storage copy of `prefetch_launch_h2d` + working_stream.wait_event(self.h2d_event) + # working_stream.wait_stream(h2d_stream) return - backward_event = torch.cuda.Event() - backward_event.record() if flag: self.tensor.storage().resize_(self.storage_size) with torch.no_grad(): - with torch.cuda.stream(h2d_stream): - h2d_stream.wait_event(backward_event) - if self.is_slice_tensor: - self.tensor.copy_(self.tensor_cpu, non_blocking=True) - else: - self.tensor.storage().copy_(self.tensor_cpu.storage(), non_blocking=True) - self.h2d_event.record() - self.stat = "device" - - working_stream.wait_stream(h2d_stream) + # if self.tensor_cpu.isnan().any(): + # raise RuntimeError("d2h error") + if self.is_slice_tensor: + self.tensor.copy_(self.tensor_cpu) + else: + self.tensor.storage().copy_(self.tensor_cpu.storage()) + self.stat = "device" + + def load(self): + self.tensor.storage().resize_(self.storage_size) + with torch.no_grad(): + # if self.tensor_cpu.isnan().any(): + # raise RuntimeError("d2h error") + if self.is_slice_tensor: + self.tensor.copy_(self.tensor_cpu) + else: + self.tensor.storage().copy_(self.tensor_cpu.storage()) + self.stat = "device" # resize storage_size and host to device def prefetch_launch_h2d(self, h2dstream, flag): @@ -128,14 +137,16 @@ def prefetch_launch_h2d(self, h2dstream, flag): self.tensor.storage().resize_(self.storage_size) with torch.no_grad(): with torch.cuda.stream(h2dstream): + # if self.tensor_cpu.isnan().any(): + # raise RuntimeError("d2h error") h2dstream.wait_event(backward_event) if self.is_slice_tensor: self.tensor.copy_(self.tensor_cpu, non_blocking=True) else: self.tensor.storage().copy_(self.tensor_cpu.storage(), non_blocking=True) self.h2d_event.record() - self.stat = "device" self.tensor.record_stream(h2dstream) + self.stat = "device" # synchronize h2d def wait_h2d_finished(self): @@ -183,6 +194,7 @@ def __init__(self, check=False): self.check = check self.device_item = [] self.getcnt = GetCnt() + self.may_npu_tensors = {} def get_cnt(self, block_idx): return self.getcnt.get_cnt(block_idx) @@ -209,11 +221,19 @@ def put(self, key, act, event=None): def put_npu_tensor(self, act): self.device_item.append(act) - def del_npu_tensor(self, prefile_key, d2h_stream): + def del_npu_tensor(self, profile_key, d2h_stream): for key in self.items.keys(): - if key.startswith(prefile_key): + if key.startswith(profile_key): self.items[key].act.wait_d2h_finished(d2h_stream, True) + def del_may_npu_tensor(self, profile_keys, h2d_stream): + may_npu_tensor_keys = list(self.may_npu_tensors.keys()) + for key in may_npu_tensor_keys: + if key.startswith(profile_keys): + with torch.cuda.stream(h2d_stream): + h2d_stream.wait_event(self.may_npu_tensors[key].act.h2d_event) + del self.may_npu_tensors[key] + def get(self, key): self.assert_exist(key) item = self.items[key] @@ -224,7 +244,9 @@ def get(self, key): item.ref_cnt -= 1 if item.ref_cnt == 0: - self.clear(key) + # self.clear(key) + self.assert_exist(key) + self.may_npu_tensors.update({key: self.items.pop(key)}) return act def prefetch_get(self, block_idx, tensor_idx, h2d_stream, d2h_stream): @@ -232,9 +254,11 @@ def prefetch_get(self, block_idx, tensor_idx, h2d_stream, d2h_stream): for prefetch_key in prefetch_keys: if self.exist(prefetch_key): prefetch_swap_tensor = self.get(prefetch_key) - d2h_stream.wait_stream(h2d_stream) + h2d_stream.wait_stream(d2h_stream) prefetch_swap_tensor.prefetch_launch_h2d(h2d_stream, True) - prefetch_swap_tensor.tensor.record_stream(h2d_stream) + # prefetch_swap_tensor.tensor.record_stream(h2d_stream) + else: + torch.distributed.breakpoint() def empty(self): return len(self.items) == 0 @@ -262,7 +286,15 @@ def has_event(self, key): class async_save_on_cpu(saved_tensors_hooks): - def __init__(self, h2d_stream, d2h_stream, block_idx, depth, custom_check_fn=None, prefetch=True) -> None: + def __init__( + self, + h2d_stream: torch.cuda.Stream, + d2h_stream: torch.cuda.Stream, + block_idx: int, + depth: int, + custom_check_fn=None, + prefetch=True, + ) -> None: def _pack_to_cpu(tensor): if not base_check_fn(tensor): return tensor @@ -277,7 +309,7 @@ def _pack_to_cpu(tensor): swap_tensor = SwapTensor(tensor, key) - if block_idx < depth - 1: + if block_idx <= depth - 1: working_stream = torch.cuda.current_stream() d2h_stream.wait_stream(working_stream) swap_tensor.launch_d2h(d2h_stream) @@ -290,15 +322,24 @@ def _unpack_from_cpu(swap_tensor) -> torch.Tensor: return swap_tensor working_stream = torch.cuda.current_stream() - working_stream.wait_stream(h2d_stream) # make sure all d2h copy is done before into backward + # working_stream.wait_stream(d2h_stream) # make sure all d2h copy is done before into backward h2d_stream.wait_stream(working_stream) + block_idx, tensor_idx = swap_tensor.key.split("_") + + OffloadManager().del_may_npu_tensor(f"{int(block_idx) + 1}_", h2d_stream) swap_tensor.launch_h2d(h2d_stream, True, working_stream) + # if block_idx in ["0", "2", "3"]: + # if block_idx in ["0"]: + # torch.cuda.synchronize() - if prefetch: - block_idx, tensor_idx = swap_tensor.key.split("_") + if prefetch and block_idx != 0: OffloadManager().prefetch_get(int(block_idx), int(tensor_idx), h2d_stream, d2h_stream) + + # if block_idx in ["0"] and tensor_idx == "1": + # swap_tensor.load() + # torch.cuda.synchronize() return swap_tensor.tensor super().__init__(_pack_to_cpu, _unpack_from_cpu) diff --git a/xtuner/v1/utils/debug.py b/xtuner/v1/utils/debug.py new file mode 100644 index 000000000..ad2c712ce --- /dev/null +++ b/xtuner/v1/utils/debug.py @@ -0,0 +1,31 @@ +import torch + + +FOUND_NAN = False + + +def register_grad_hook(tensor: torch.Tensor, message): + if (grad_fn := tensor.grad_fn) is not None: + message = f"{tensor.grad_fn}: {message}" + grad_fn.register_hook(get_grad_hook(message)) + + +def get_grad_hook(message: str): + def hook(g_in: tuple, g_out: tuple): + global FOUND_NAN + # torch.distributed.breakpoint() + # if torch.distributed.get_rank() == 0: + if FOUND_NAN: + return + + for idx, i in enumerate(g_in): + if isinstance(i, torch.Tensor) and i.isnan().any().item(): + FOUND_NAN = True + print(f"{message} index {idx} of g_in has nan") + + for idx, o in enumerate(g_out): + if isinstance(o, torch.Tensor) and o.isnan().any().item(): + FOUND_NAN = True + print(f"{message} index {idx} of g_out has nan") + + return hook