From afc993a212dbcc04a8f6d0d54a9e0550f73fad50 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Thu, 4 Jun 2026 10:42:17 -0700 Subject: [PATCH 1/5] megacpp Signed-off-by: Zhongbo Zhu --- tests/pytorch/megacpp/test_grouped_mlp.py | 476 +++++++++++ transformer_engine/pytorch/csrc/extensions.h | 24 + .../pytorch/csrc/extensions/pybind.cpp | 19 + .../pytorch/csrc/megacpp/grouped_mlp.cpp | 797 ++++++++++++++++++ .../pytorch/ops/fused/__init__.py | 6 + .../ops/fused/backward_grouped_mlp_megacpp.py | 392 +++++++++ .../ops/fused/forward_grouped_mlp_megacpp.py | 382 +++++++++ 7 files changed, 2096 insertions(+) create mode 100644 tests/pytorch/megacpp/test_grouped_mlp.py create mode 100644 transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp create mode 100644 transformer_engine/pytorch/ops/fused/backward_grouped_mlp_megacpp.py create mode 100644 transformer_engine/pytorch/ops/fused/forward_grouped_mlp_megacpp.py diff --git a/tests/pytorch/megacpp/test_grouped_mlp.py b/tests/pytorch/megacpp/test_grouped_mlp.py new file mode 100644 index 0000000000..ddddcb7fc4 --- /dev/null +++ b/tests/pytorch/megacpp/test_grouped_mlp.py @@ -0,0 +1,476 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch + +import transformer_engine.pytorch as te +import transformer_engine.pytorch.ops as te_ops + + +_HIDDEN_SIZE = 512 +_FFN_HIDDEN_SIZE = 256 + + +def _megacpp_available() -> tuple[bool, str]: + if not torch.cuda.is_available(): + return False, "CUDA is required" + if not te.is_bf16_available(): + return False, "BF16 is required" + if torch.cuda.get_device_capability() < (10, 0): + return False, "megacpp grouped MLP uses SM100 grouped GEMM" + if not te_ops.fused.ForwardGroupedMLP_MegaCpp.is_supported(): + return False, "ForwardGroupedMLP_MegaCpp is not supported" + if not te_ops.fused.BackwardGroupedMLP_MegaCpp.is_supported(): + return False, "BackwardGroupedMLP_MegaCpp is not supported" + return True, "" + + +_AVAILABLE, _SKIP_REASON = _megacpp_available() +pytestmark = pytest.mark.skipif(not _AVAILABLE, reason=_SKIP_REASON) + + +def _make_grouped_mlp( + *, + num_groups: int, + hidden_size: int, + ffn_hidden_size: int, + activation_kind: str, + bias: bool, + delay_wgrad_compute: bool, + accumulate_into_main_grad: bool, + glu_interleave_size: int | None, + single_grouped_param: bool, +) -> te_ops.Sequential: + gated_activation = activation_kind in ("scaled_swiglu", "scaled_clamped_qgeglu") + fc1_out_features = 2 * ffn_hidden_size if gated_activation else ffn_hidden_size + fc1 = te_ops.GroupedLinear( + num_groups, + hidden_size, + fc1_out_features, + bias=bias, + device="cuda", + dtype=torch.bfloat16, + delay_wgrad_compute=delay_wgrad_compute, + accumulate_into_main_grad=accumulate_into_main_grad, + single_grouped_weight=single_grouped_param, + single_grouped_bias=single_grouped_param and bias, + ) + if activation_kind == "scaled_swiglu": + act = te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) + elif activation_kind == "scaled_clamped_qgeglu": + act = te_ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size) + elif activation_kind == "scaled_srelu": + act = te_ops.ScaledSReLU() + else: + raise ValueError(f"Unsupported test activation_kind={activation_kind}.") + fc2 = te_ops.GroupedLinear( + num_groups, + ffn_hidden_size, + hidden_size, + bias=bias, + device="cuda", + dtype=torch.bfloat16, + delay_wgrad_compute=delay_wgrad_compute, + accumulate_into_main_grad=accumulate_into_main_grad, + single_grouped_weight=single_grouped_param, + single_grouped_bias=single_grouped_param and bias, + ) + return te_ops.Sequential(fc1, act, fc2) + + +def _copy_grouped_mlp_params(dst: te_ops.Sequential, src: te_ops.Sequential) -> None: + with torch.no_grad(): + for dst_linear, src_linear in ((dst[0], src[0]), (dst[2], src[2])): + if dst_linear.single_grouped_weight: + dst_linear.weight.rowwise_data.copy_(src_linear.weight.rowwise_data) + if dst_linear.has_bias: + dst_linear.bias.rowwise_data.copy_(src_linear.bias.rowwise_data) + else: + for group_idx in range(dst_linear.num_groups): + getattr(dst_linear, f"weight{group_idx}").copy_( + getattr(src_linear, f"weight{group_idx}") + ) + if dst_linear.has_bias: + getattr(dst_linear, f"bias{group_idx}").copy_( + getattr(src_linear, f"bias{group_idx}") + ) + + +def _init_main_grads(module: te_ops.Sequential) -> None: + for linear in (module[0], module[2]): + if linear.single_grouped_weight: + linear.weight.main_grad = torch.zeros( + linear.num_groups, + linear.out_features, + linear.in_features, + device="cuda", + dtype=torch.bfloat16, + ) + else: + for group_idx in range(linear.num_groups): + weight = getattr(linear, f"weight{group_idx}") + weight.main_grad = torch.zeros_like(weight) + + +def _run_grouped_mlp( + module: te_ops.Sequential, + x: torch.Tensor, + split_sizes: torch.Tensor, + act_scales: torch.Tensor, + dy: torch.Tensor, + *, + delay_wgrad_compute: bool, +) -> torch.Tensor: + y = module(x, split_sizes, act_scales, split_sizes) + y.backward(dy) + if delay_wgrad_compute: + module[0].backward_dw() + module[2].backward_dw() + return y + + +def _assert_grouped_mlp_close( + test: te_ops.Sequential, + ref: te_ops.Sequential, + *, + accumulate_into_main_grad: bool, +) -> None: + for test_linear, ref_linear in ((test[0], ref[0]), (test[2], ref[2])): + if test_linear.single_grouped_weight: + if accumulate_into_main_grad: + torch.testing.assert_close( + test_linear.weight.main_grad, + ref_linear.weight.main_grad, + rtol=2e-2, + atol=2e-2, + ) + else: + torch.testing.assert_close( + test_linear.weight.grad, + ref_linear.weight.grad, + rtol=2e-2, + atol=2e-2, + ) + if test_linear.has_bias: + torch.testing.assert_close( + test_linear.bias.grad, + ref_linear.bias.grad, + rtol=2e-2, + atol=2e-2, + ) + continue + for group_idx in range(test_linear.num_groups): + if accumulate_into_main_grad: + torch.testing.assert_close( + getattr(test_linear, f"weight{group_idx}").main_grad, + getattr(ref_linear, f"weight{group_idx}").main_grad, + rtol=2e-2, + atol=2e-2, + ) + else: + torch.testing.assert_close( + getattr(test_linear, f"weight{group_idx}").grad, + getattr(ref_linear, f"weight{group_idx}").grad, + rtol=2e-2, + atol=2e-2, + ) + if test_linear.has_bias: + torch.testing.assert_close( + getattr(test_linear, f"bias{group_idx}").grad, + getattr(ref_linear, f"bias{group_idx}").grad, + rtol=2e-2, + atol=2e-2, + ) + + +def _assert_grouped_mlp_nonzero_expert_grads_close( + test: te_ops.Sequential, + ref: te_ops.Sequential, + split_sizes: list[int], +) -> None: + """Compare only non-empty experts; zero-token expert grads may be unwritten.""" + for test_linear, ref_linear in ((test[0], ref[0]), (test[2], ref[2])): + for group_idx, split_size in enumerate(split_sizes): + if split_size == 0: + continue + torch.testing.assert_close( + getattr(test_linear, f"weight{group_idx}").grad, + getattr(ref_linear, f"weight{group_idx}").grad, + rtol=2e-2, + atol=2e-2, + ) + if test_linear.has_bias: + torch.testing.assert_close( + getattr(test_linear, f"bias{group_idx}").grad, + getattr(ref_linear, f"bias{group_idx}").grad, + rtol=2e-2, + atol=2e-2, + ) + + +def _assert_valid_prefix_close( + test: torch.Tensor, + ref: torch.Tensor, + valid_tokens: int, +) -> None: + """Paged-stashed buffers only guarantee correctness in the valid token prefix.""" + if valid_tokens == 0: + return + torch.testing.assert_close(test[:valid_tokens], ref[:valid_tokens], rtol=2e-2, atol=2e-2) + + +def _make_split_tensor( + split_sizes: list[int], + *, + dtype: torch.dtype = torch.int64, + device: str = "cuda", +) -> torch.Tensor: + return torch.tensor(split_sizes, dtype=dtype, device=device) + + +def _run_megacpp_against_python( + *, + split_sizes_list: list[int], + physical_tokens: int, + split_dtype: torch.dtype, + split_device: str, + bias: bool = True, + glu_interleave_size: int | None = None, + activation_kind: str = "scaled_swiglu", + single_grouped_param: bool = False, + accumulate_into_main_grad: bool = False, + compare_zero_expert_grads: bool = True, + monkeypatch, +) -> None: + num_groups = len(split_sizes_list) + valid_tokens = sum(split_sizes_list) + assert physical_tokens >= valid_tokens + if single_grouped_param: + monkeypatch.setenv("NVTE_GROUPED_LINEAR_SINGLE_PARAM", "1") + split_sizes = _make_split_tensor(split_sizes_list, dtype=split_dtype, device=split_device) + ref = _make_grouped_mlp( + num_groups=num_groups, + hidden_size=_HIDDEN_SIZE, + ffn_hidden_size=_FFN_HIDDEN_SIZE, + activation_kind=activation_kind, + bias=bias, + delay_wgrad_compute=False, + accumulate_into_main_grad=accumulate_into_main_grad, + glu_interleave_size=glu_interleave_size, + single_grouped_param=single_grouped_param, + ) + test = _make_grouped_mlp( + num_groups=num_groups, + hidden_size=_HIDDEN_SIZE, + ffn_hidden_size=_FFN_HIDDEN_SIZE, + activation_kind=activation_kind, + bias=bias, + delay_wgrad_compute=False, + accumulate_into_main_grad=accumulate_into_main_grad, + glu_interleave_size=glu_interleave_size, + single_grouped_param=single_grouped_param, + ) + _copy_grouped_mlp_params(test, ref) + if accumulate_into_main_grad: + _init_main_grads(ref) + _init_main_grads(test) + + # Paged stashing passes a static physical buffer to the op while m_splits + # describe only the valid prefix. Rows after sum(m_splits) are garbage and + # must not affect outputs/gradients for the valid prefix. + x_ref = torch.randn( + physical_tokens, _HIDDEN_SIZE, device="cuda", dtype=torch.bfloat16 + ).requires_grad_() + x_test = x_ref.detach().clone().requires_grad_() + act_scales_ref = torch.rand( + physical_tokens, device="cuda", dtype=torch.bfloat16 + ).requires_grad_() + act_scales_test = act_scales_ref.detach().clone().requires_grad_() + dy = torch.randn(physical_tokens, _HIDDEN_SIZE, device="cuda", dtype=torch.bfloat16) + + monkeypatch.setenv("NVTE_MEGACPP_GROUPED_LINEAR", "0") + y_ref = _run_grouped_mlp( + ref, + x_ref, + split_sizes, + act_scales_ref, + dy, + delay_wgrad_compute=False, + ) + monkeypatch.setenv("NVTE_MEGACPP_GROUPED_LINEAR", "1") + y_test = _run_grouped_mlp( + test, + x_test, + split_sizes, + act_scales_test, + dy, + delay_wgrad_compute=False, + ) + + fuser = test._module_groups[0] + assert isinstance(fuser._forward_ops[0][0], te_ops.fused.ForwardGroupedMLP_MegaCpp) + assert isinstance(fuser._backward_ops[0][0], te_ops.fused.BackwardGroupedMLP_MegaCpp) + + _assert_valid_prefix_close(y_test, y_ref, valid_tokens) + _assert_valid_prefix_close(x_test.grad, x_ref.grad, valid_tokens) + _assert_valid_prefix_close( + act_scales_test.grad, + act_scales_ref.grad, + valid_tokens, + ) + if valid_tokens == physical_tokens and compare_zero_expert_grads: + _assert_grouped_mlp_close(test, ref, accumulate_into_main_grad=accumulate_into_main_grad) + elif valid_tokens > 0 and not single_grouped_param and not accumulate_into_main_grad: + _assert_grouped_mlp_nonzero_expert_grads_close(test, ref, split_sizes_list) + + +@pytest.mark.parametrize( + "single_grouped_param", + [False, True], + ids=["discrete_weight", "packed_weight"], +) +@pytest.mark.parametrize( + "accumulate_into_main_grad", + [False, True], + ids=["cpp_allocated_wgrad", "megatron_main_grad"], +) +def test_megacpp_grouped_mlp_wgrad_storage_matches_python( + single_grouped_param, + accumulate_into_main_grad, + monkeypatch, +): + torch.manual_seed(1234) + _run_megacpp_against_python( + split_sizes_list=[256, 256, 512], + physical_tokens=1024, + split_dtype=torch.int64, + split_device="cuda", + single_grouped_param=single_grouped_param, + accumulate_into_main_grad=accumulate_into_main_grad, + monkeypatch=monkeypatch, + ) + + +@pytest.mark.parametrize( + "split_dtype,split_device", + [ + pytest.param(torch.int64, "cuda", id="i64_cuda"), + pytest.param(torch.int32, "cuda", id="i32_cuda"), + pytest.param(torch.int64, "cpu", id="i64_cpu"), + ], +) +def test_megacpp_grouped_mlp_split_source_matches_python( + split_dtype, + split_device, + monkeypatch, +): + torch.manual_seed(1234) + _run_megacpp_against_python( + split_sizes_list=[256, 256, 512], + physical_tokens=1024, + split_dtype=split_dtype, + split_device=split_device, + monkeypatch=monkeypatch, + ) + + +@pytest.mark.parametrize( + "activation_kind", + ["scaled_swiglu", "scaled_srelu", "scaled_clamped_qgeglu"], + ids=["swiglu", "srelu", "clamped_qgeglu"], +) +@pytest.mark.parametrize( + "glu_interleave_size", + [None, 32], + ids=["no_interleave", "interleave_32"], +) +def test_megacpp_grouped_mlp_activation_matches_python( + activation_kind, + glu_interleave_size, + monkeypatch, +): + if activation_kind == "scaled_srelu" and glu_interleave_size is not None: + pytest.skip("ScaledSReLU is not a GLU activation.") + torch.manual_seed(1234) + _run_megacpp_against_python( + split_sizes_list=[256, 256, 512], + physical_tokens=1024, + split_dtype=torch.int64, + split_device="cuda", + activation_kind=activation_kind, + glu_interleave_size=glu_interleave_size, + monkeypatch=monkeypatch, + ) + + +@pytest.mark.parametrize("bias", [True, False], ids=["bias", "no_bias"]) +def test_megacpp_grouped_mlp_bias_matches_python(bias, monkeypatch): + torch.manual_seed(1234) + _run_megacpp_against_python( + split_sizes_list=[256, 256, 512], + physical_tokens=1024, + split_dtype=torch.int64, + split_device="cuda", + bias=bias, + monkeypatch=monkeypatch, + ) + + +@pytest.mark.parametrize( + "split_sizes_list,physical_tokens", + [ + pytest.param([256, 256, 256, 256], 1024, id="even"), + pytest.param([0, 256, 256, 512], 1024, id="zero_front"), + pytest.param([256, 0, 256, 512], 1024, id="zero_middle"), + pytest.param([256, 256, 512, 0], 1024, id="zero_end"), + pytest.param([256, 256], 1024, id="paged_stashing_even_with_garbage"), + pytest.param([0, 256, 256], 1024, id="paged_stashing_zero_front_with_garbage"), + pytest.param([256, 0, 256], 1024, id="paged_stashing_zero_middle_with_garbage"), + pytest.param([256, 256, 0], 1024, id="paged_stashing_zero_end_with_garbage"), + pytest.param([0, 0, 0, 0], 1024, id="paged_stashing_zero_tokens_all_nonempty_input"), + ], +) +def test_megacpp_grouped_mlp_split_edge_cases( + split_sizes_list, + physical_tokens, + monkeypatch, +): + torch.manual_seed(1234) + _run_megacpp_against_python( + split_sizes_list=split_sizes_list, + physical_tokens=physical_tokens, + split_dtype=torch.int64, + split_device="cuda", + compare_zero_expert_grads=False, + monkeypatch=monkeypatch, + ) + + +def test_megacpp_grouped_mlp_delay_wgrad_raises(monkeypatch): + torch.manual_seed(1234) + num_groups = 3 + split_sizes = torch.tensor([256, 256, 512], dtype=torch.int64, device="cuda") + total_tokens = int(split_sizes.sum().item()) + module = _make_grouped_mlp( + num_groups=num_groups, + hidden_size=_HIDDEN_SIZE, + ffn_hidden_size=_FFN_HIDDEN_SIZE, + activation_kind="scaled_swiglu", + bias=True, + delay_wgrad_compute=True, + accumulate_into_main_grad=False, + glu_interleave_size=None, + single_grouped_param=False, + ) + x = torch.randn(total_tokens, _HIDDEN_SIZE, device="cuda", dtype=torch.bfloat16).requires_grad_() + act_scales = torch.rand( + total_tokens, device="cuda", dtype=torch.bfloat16 + ).requires_grad_() + dy = torch.randn(total_tokens, _HIDDEN_SIZE, device="cuda", dtype=torch.bfloat16) + + monkeypatch.setenv("NVTE_MEGACPP_GROUPED_LINEAR", "1") + with pytest.raises(ValueError, match="delay_wgrad_compute"): + y = module(x, split_sizes, act_scales, split_sizes) + y.backward(dy) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 2b4f899e1d..a59e85456d 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -185,6 +185,30 @@ py::object te_general_grouped_gemm_for_discrete_out(py::handle A, bool transa, p at::Tensor workspace_cublas, bool use_split_accumulator, int math_sm_count); +/*************************************************************************************************** + * Mega C++ grouped MLP + **************************************************************************************************/ + +std::vector megacpp_grouped_mlp_forward( + const at::Tensor &input, const at::Tensor &split_sizes, py::handle fc1_weight, + py::handle fc1_bias, py::handle fc2_weight, py::handle fc2_bias, + const std::optional &act_scales, const std::string &activation, + int64_t glu_interleave_size, double activation_limit, double activation_alpha, + double activation_glu_linear_offset); + +py::tuple megacpp_grouped_mlp_backward( + const at::Tensor &grad_output, const at::Tensor &split_sizes, const at::Tensor &x_offsets, + const at::Tensor &fc1_offsets, const at::Tensor &fc2_offsets, + const at::Tensor &fc2_dy_offsets, const at::Tensor &base_offsets, + const at::Tensor &x, const at::Tensor &fc1_activation_input, const at::Tensor &fc2_x, + const std::optional &act_scales, py::handle fc1_weight, + py::handle fc2_weight, + py::handle fc1_wgrad_output, bool fc1_compute_wgrad, bool fc1_accumulate_wgrad, + py::handle fc2_wgrad_output, bool fc2_compute_wgrad, bool fc2_accumulate_wgrad, + const std::string &activation, int64_t glu_interleave_size, double activation_limit, + double activation_alpha, double activation_glu_linear_offset, bool act_scales_requires_grad, + bool input_requires_grad); + /*************************************************************************************************** * Transpose **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index d6089b1e01..78c9e280f3 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -357,6 +357,25 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("te_general_grouped_gemm_for_discrete_out", &transformer_engine::pytorch::te_general_grouped_gemm_for_discrete_out, "Grouped GEMM for discrete output list"); + m.def("megacpp_grouped_mlp_forward", + &transformer_engine::pytorch::megacpp_grouped_mlp_forward, + "Mega C++ grouped MLP forward", py::arg("input"), py::arg("split_sizes"), + py::arg("fc1_weight"), py::arg("fc1_bias"), py::arg("fc2_weight"), py::arg("fc2_bias"), + py::arg("act_scales"), py::arg("activation"), py::arg("glu_interleave_size"), + py::arg("activation_limit") = 0.0, py::arg("activation_alpha") = 0.0, + py::arg("activation_glu_linear_offset") = 0.0); + m.def("megacpp_grouped_mlp_backward", + &transformer_engine::pytorch::megacpp_grouped_mlp_backward, + "Mega C++ grouped MLP backward", py::arg("grad_output"), py::arg("split_sizes"), + py::arg("x_offsets"), py::arg("fc1_offsets"), py::arg("fc2_offsets"), + py::arg("fc2_dy_offsets"), py::arg("base_offsets"), py::arg("x"), + py::arg("fc1_activation_input"), py::arg("fc2_x"), py::arg("act_scales"), + py::arg("fc1_weight"), py::arg("fc2_weight"), py::arg("fc1_wgrad_output"), + py::arg("fc1_compute_wgrad"), py::arg("fc1_accumulate_wgrad"), py::arg("fc2_wgrad_output"), + py::arg("fc2_compute_wgrad"), py::arg("fc2_accumulate_wgrad"), py::arg("activation"), + py::arg("glu_interleave_size"), py::arg("activation_limit") = 0.0, + py::arg("activation_alpha") = 0.0, py::arg("activation_glu_linear_offset") = 0.0, + py::arg("act_scales_requires_grad") = true, py::arg("input_requires_grad") = true); m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O", py::arg("input"), py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard()); diff --git a/transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp b/transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp new file mode 100644 index 0000000000..2f9a642041 --- /dev/null +++ b/transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp @@ -0,0 +1,797 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include + +#include +#include +#include +#include + +#include +#include + +#include "../extensions.h" +#include "../pybind.h" +#include "common/util/cuda_runtime.h" +#include "common/util/system.h" +#include "transformer_engine/activation.h" +#include "transformer_engine/gemm.h" +#include "transformer_engine/transformer_engine.h" + +namespace py = pybind11; + +namespace transformer_engine::pytorch { +namespace { + +constexpr int64_t kGroupedGemmCublasWorkspaceSize = 32 * 1024 * 1024 + 1024; + +bool is_none(py::handle obj) { return obj.is_none(); } + +std::vector tensor_shape_1d(const at::Tensor &tensor) { + return {static_cast(tensor.numel())}; +} + +at::Tensor maybe_cast_dtype(const at::Tensor &tensor, at::ScalarType dtype) { + at::Tensor out = tensor; + if (out.scalar_type() != dtype) { + out = out.to(out.options().dtype(dtype)); + } + return out; +} + +void check_contiguous(const at::Tensor &tensor, const std::string &name) { + NVTE_CHECK(tensor.is_contiguous(), name, " must be contiguous."); +} + +size_t num_groups_from_prepared_split_sizes(const at::Tensor &split_sizes, + const c10::Device &device) { + NVTE_CHECK(split_sizes.dim() == 1, "split_sizes must be a 1D tensor."); + NVTE_CHECK(split_sizes.device() == device, "split_sizes must be on the current CUDA device."); + NVTE_CHECK(split_sizes.scalar_type() == at::kLong, + "split_sizes must be the int64 CUDA tensor returned by splits_to_offsets_multi."); + return static_cast(split_sizes.numel()); +} + +GroupedTensorWrapper make_grouped_tensor(at::Tensor data, const at::Tensor &prepared_split_sizes, + const at::Tensor &tensor_offsets, int64_t logical_last_dim) { + const auto num_groups = static_cast(prepared_split_sizes.numel()); + const auto total_tokens = static_cast(data.numel() / logical_last_dim); + auto grouped = GroupedTensorWrapper( + num_groups, std::vector{total_tokens, static_cast(logical_last_dim)}); + grouped.set_rowwise_data(data.data_ptr(), GetTransformerEngineDType(data.scalar_type()), + tensor_shape_1d(data)); + grouped.set_first_dims(prepared_split_sizes.data_ptr(), DType::kInt64, + std::vector{num_groups}); + grouped.set_tensor_offsets(tensor_offsets.data_ptr(), DType::kInt64, + std::vector{num_groups + 1}); + return grouped; +} + +GroupedTensorWrapper make_uniform_grouped_tensor(at::Tensor data, size_t num_groups, + int64_t first_dim, int64_t last_dim) { + auto grouped = GroupedTensorWrapper( + num_groups, + std::vector{num_groups * static_cast(first_dim), + static_cast(last_dim)}); + grouped.set_rowwise_data(data.data_ptr(), GetTransformerEngineDType(data.scalar_type()), + tensor_shape_1d(data)); + return grouped; +} + +struct GroupedWeightArg { + bool is_grouped = false; + at::Tensor packed; + std::vector discrete; + // Logical per-expert weight shape. For both supported layouts: + // - packed single grouped weight: packed has shape [G, rows, cols] + // - discrete weights: each tensor has shape [rows, cols] + // rows = out_features, cols = in_features. + int64_t rows = 0; + int64_t cols = 0; + + c10::Device device() const { + return is_grouped ? packed.device() : discrete[0].device(); + } +}; + +GroupedWeightArg weight_arg_from_py(py::handle arg, size_t num_groups, at::ScalarType dtype, + const std::string &name) { + GroupedWeightArg out; + if (py::isinstance(arg) || py::isinstance(arg)) { + auto seq = py::reinterpret_borrow(arg); + NVTE_CHECK(static_cast(seq.size()) == num_groups, name, " must have ", num_groups, + " tensors."); + out.discrete.reserve(num_groups); + for (size_t i = 0; i < num_groups; ++i) { + auto tensor = maybe_cast_dtype(seq[i].cast(), dtype); + check_contiguous(tensor, name); + NVTE_CHECK(tensor.dim() == 2, name, " tensors must be rank-2."); + if (i == 0) { + // Discrete case: each expert owns one [out_features, in_features] + // tensor. Cache the shared logical shape for later GEMM setup. + out.rows = tensor.size(0); + out.cols = tensor.size(1); + } else { + NVTE_CHECK(tensor.size(0) == out.rows && tensor.size(1) == out.cols, name, + " tensors must have a uniform shape."); + } + out.discrete.emplace_back(tensor); + } + return out; + } + + out.packed = maybe_cast_dtype(arg.cast(), dtype); + NVTE_CHECK(out.packed.dim() == 3, name, " must be a tensor with shape [num_groups, rows, cols]."); + NVTE_CHECK(static_cast(out.packed.size(0)) == num_groups, name, + " first dimension must be ", num_groups, "."); + check_contiguous(out.packed, name); + out.is_grouped = true; + // Packed case: a single [G, out_features, in_features] tensor stores all + // experts, so dimensions 1 and 2 are the same per-expert logical shape. + out.rows = out.packed.size(1); + out.cols = out.packed.size(2); + return out; +} + +at::Tensor packed_bias_from_arg(py::handle arg, size_t num_groups, at::ScalarType dtype, + int64_t out_features, const std::string &name) { + if (is_none(arg)) { + return at::Tensor(); + } + + auto packed = maybe_cast_dtype(arg.cast(), dtype); + NVTE_CHECK(packed.dim() == 2, name, " must be a tensor with shape [num_groups, features]."); + NVTE_CHECK(static_cast(packed.size(0)) == num_groups, name, " first dimension must be ", + num_groups, "."); + NVTE_CHECK(packed.size(1) == out_features, name, " second dimension must be ", out_features, "."); + check_contiguous(packed, name); + return packed; +} + +std::vector nvte_tensor_list_from_tensors(const std::vector &tensors, + std::vector *wrappers) { + wrappers->clear(); + wrappers->reserve(tensors.size()); + std::vector out; + out.reserve(tensors.size()); + for (const auto &tensor : tensors) { + wrappers->emplace_back(makeTransformerEngineTensor(tensor)); + out.emplace_back(wrappers->back().data()); + } + return out; +} + +int grouped_gemm_math_sm_count(const c10::Device &device) { + const int device_id = static_cast(device.index()); + const int sm_count = transformer_engine::cuda::sm_count(device_id); + return sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); +} + +struct GroupedGemmResources { + c10::Device device; + size_t num_groups; + at::Tensor alpha; + at::Tensor beta_zero; + at::Tensor beta_one; + at::Tensor setup; + at::Tensor cublas; + TensorWrapper te_alpha; + TensorWrapper te_beta_zero; + TensorWrapper te_beta_one; + TensorWrapper te_setup; + TensorWrapper te_cublas; + std::optional config; + + GroupedGemmResources(const c10::Device &device_, size_t num_groups_) + : device(device_), + num_groups(num_groups_), + alpha(at::ones({static_cast(num_groups_)}, at::device(device).dtype(at::kFloat))), + beta_zero( + at::zeros({static_cast(num_groups_)}, at::device(device).dtype(at::kFloat))), + beta_one(alpha), + setup(at::empty( + {static_cast(nvte_get_grouped_gemm_setup_workspace_size(num_groups_))}, + at::device(device).dtype(at::kByte))), + cublas(at::empty({kGroupedGemmCublasWorkspaceSize}, at::device(device).dtype(at::kByte))), + te_alpha(makeTransformerEngineTensor(alpha)), + te_beta_zero(makeTransformerEngineTensor(beta_zero)), + te_beta_one(makeTransformerEngineTensor(beta_one)), + te_setup(makeTransformerEngineTensor(setup.data_ptr(), + std::vector{static_cast(setup.numel())}, + DType::kByte)), + te_cublas(makeTransformerEngineTensor( + cublas.data_ptr(), std::vector{static_cast(cublas.numel())}, + DType::kByte)) { + // These scratch tensors are intentionally local to one megacpp call. They + // are safe after this CPU function returns because every current cuBLAS + // grouped GEMM below is enqueued on at::cuda::getCurrentCUDAStream(), so + // PyTorch's caching allocator observes same-stream allocation/release + // ordering. If a future backend uses auxiliary streams, this helper must + // either record those streams on the tensors or extend resource lifetime. + const int math_sm_count = grouped_gemm_math_sm_count(device); + if (math_sm_count > 0) { + config.emplace(); + config->set_sm_count(math_sm_count); + } + } + + NVTETensor beta(bool accumulate) { + return accumulate ? te_beta_one.data() : te_beta_zero.data(); + } + + NVTEGroupedMatmulConfig config_data() { + return config.has_value() ? static_cast(*config) : nullptr; + } +}; + +GroupedGemmResources make_grouped_mlp_backend_resources(const c10::Device &device, + size_t num_groups) { + // Keep the backend resource policy private to megacpp. Today this is cuBLAS + // grouped GEMM scratch; future backends can change this helper without + // changing the Python or pybind contract. + return GroupedGemmResources(device, num_groups); +} + +void grouped_gemm(GroupedTensorWrapper *A, bool transa, GroupedTensorWrapper *B, bool transb, + GroupedTensorWrapper *D, GroupedGemmResources *resources, bool accumulate) { + NVTE_SCOPED_GIL_RELEASE({ + nvte_grouped_gemm(A->data(), transa, B->data(), transb, D->data(), D->data(), + resources->te_alpha.data(), resources->beta(accumulate), + resources->te_setup.data(), resources->te_cublas.data(), + resources->config_data(), + at::cuda::getCurrentCUDAStream()); + }); +} + +std::vector output_tensor_list_from_arg(py::handle arg, size_t num_groups, + at::ScalarType dtype, + const std::string &name) { + std::vector out; + if (is_none(arg)) { + return out; + } + out.reserve(num_groups); + if (py::isinstance(arg) || py::isinstance(arg)) { + auto seq = py::reinterpret_borrow(arg); + NVTE_CHECK(static_cast(seq.size()) == num_groups, name, " must have ", num_groups, + " tensors."); + for (size_t i = 0; i < num_groups; ++i) { + auto tensor = seq[i].cast(); + NVTE_CHECK(tensor.is_cuda(), name, " tensors must be CUDA tensors."); + NVTE_CHECK(tensor.scalar_type() == dtype, name, " tensors must have the requested dtype."); + NVTE_CHECK(tensor.dim() == 2, name, " tensors must be rank-2 wgrad buffers."); + check_contiguous(tensor, name); + out.emplace_back(tensor); + } + return out; + } + + auto packed = arg.cast(); + NVTE_CHECK(packed.is_cuda(), name, " must be a CUDA tensor."); + NVTE_CHECK(packed.scalar_type() == dtype, name, " must have the requested dtype."); + NVTE_CHECK(packed.dim() == 3, name, " must have shape [num_groups, rows, cols]."); + NVTE_CHECK(static_cast(packed.size(0)) == num_groups, name, " first dimension must be ", + num_groups, "."); + check_contiguous(packed, name); + for (size_t i = 0; i < num_groups; ++i) { + out.emplace_back(packed.select(0, static_cast(i))); + } + return out; +} + +struct WgradOutput { + std::vector tensors; + at::Tensor packed; + bool is_grouped = false; + bool owns_storage = false; +}; + +WgradOutput wgrad_output_from_arg(py::handle arg, bool compute_wgrad, size_t num_groups, + at::ScalarType dtype, const c10::Device &device, int64_t rows, + int64_t cols, const std::string &name, + bool prefer_grouped_output) { + WgradOutput out; + if (!compute_wgrad) { + return out; + } + if (is_none(arg)) { + // Cases 1 and 2: no external wgrad buffer was provided, so C++ owns the + // allocation. Single grouped weight keeps this packed as [G, N, K]; + // discrete weights split the same packed allocation into per-expert views. + out.packed = at::empty({static_cast(num_groups), rows, cols}, + at::device(device).dtype(dtype)); + out.owns_storage = true; + out.is_grouped = prefer_grouped_output; + if (out.is_grouped) { + return out; + } + out.tensors.reserve(num_groups); + for (size_t i = 0; i < num_groups; ++i) { + out.tensors.emplace_back(out.packed.select(0, static_cast(i))); + } + return out; + } + if (!py::isinstance(arg) && !py::isinstance(arg)) { + // Case 3: single grouped weight with externally-owned storage, e.g. + // Megatron main_grad viewed as [G, N, K]. GEMM writes in-place and Python + // should not receive a newly allocated grad tensor from this helper. + out.packed = arg.cast(); + NVTE_CHECK(out.packed.is_cuda(), name, " must be a CUDA tensor."); + NVTE_CHECK(out.packed.scalar_type() == dtype, name, " must have the requested dtype."); + NVTE_CHECK(out.packed.dim() == 3, name, " must have shape [num_groups, rows, cols]."); + NVTE_CHECK(static_cast(out.packed.size(0)) == num_groups, name, + " first dimension must be ", num_groups, "."); + NVTE_CHECK(out.packed.size(1) == rows && out.packed.size(2) == cols, name, + " has an unexpected shape."); + check_contiguous(out.packed, name); + out.is_grouped = true; + return out; + } + // Case 4: discrete weights with externally-owned per-expert buffers, e.g. + // Megatron main_grad list. GEMM writes each tensor in-place and returns no + // allocated grad list to Python. + out.tensors = output_tensor_list_from_arg(arg, num_groups, dtype, name); + return out; +} + +void grouped_gemm_fwd_dgrad(GroupedWeightArg *weights, bool trans_weight, + GroupedTensorWrapper *input, bool trans_input, + GroupedTensorWrapper *output, GroupedGemmResources *resources) { + if (weights->is_grouped) { + // Single grouped weight case: weights are packed as [G, N, K]. Wrap the + // packed buffer as a uniform GroupedTensor and use the grouped-tensor GEMM. + auto grouped_weight = + make_uniform_grouped_tensor(weights->packed, input->num_tensors(), weights->rows, + weights->cols); + grouped_gemm(&grouped_weight, trans_weight, input, trans_input, output, resources, false); + } else { + // Discrete weight case: weights are a list of per-expert tensors. Use the + // discrete-input grouped GEMM variant. + std::vector weight_wrappers; + auto weight_nvte = nvte_tensor_list_from_tensors(weights->discrete, &weight_wrappers); + NVTE_SCOPED_GIL_RELEASE({ + nvte_grouped_gemm_with_discrete_inputA( + weight_nvte.data(), weights->discrete.size(), trans_weight, input->data(), trans_input, + output->data(), output->data(), resources->te_alpha.data(), resources->beta(false), + resources->te_setup.data(), resources->te_cublas.data(), resources->config_data(), + at::cuda::getCurrentCUDAStream()); + }); + } +} + +std::vector grouped_gemm_wgrad(GroupedTensorWrapper *x, GroupedTensorWrapper *dy, + py::handle output, bool compute_wgrad, bool accumulate, + GroupedGemmResources *resources, at::ScalarType dtype, + int64_t rows, int64_t cols, const std::string &name, + bool prefer_grouped_output) { + auto prepared = wgrad_output_from_arg(output, compute_wgrad, resources->num_groups, dtype, + resources->device, rows, cols, name, prefer_grouped_output); + NVTE_CHECK(!(prepared.owns_storage && accumulate), name, + " cannot accumulate into a newly allocated wgrad buffer."); + std::vector returned_wgrads; + + if (prepared.is_grouped) { + // Cases 1 and 3: single grouped weight layout. + // Case 1: C++ allocated packed [G, N, K] storage; return [packed]. + // Case 3: caller provided packed storage, e.g. main_grad; write in-place + // and return nothing because autograd receives dummy wgrad tensors. + auto grouped_output = + make_uniform_grouped_tensor(prepared.packed, resources->num_groups, rows, cols); + grouped_gemm(x, false, dy, true, &grouped_output, resources, accumulate); + if (prepared.owns_storage) { + returned_wgrads.emplace_back(prepared.packed); + } + } else if (!prepared.tensors.empty()) { + // Cases 2 and 4: discrete per-expert weight layout. + // Case 2: C++ allocated packed backing storage and split it into views; + // return those views in parameter order. + // Case 4: caller provided per-expert buffers, e.g. main_grad list; write + // in-place and return nothing because autograd receives dummy wgrads. + std::vector output_wrappers; + auto output_nvte = nvte_tensor_list_from_tensors(prepared.tensors, &output_wrappers); + NVTE_SCOPED_GIL_RELEASE({ + nvte_grouped_gemm_with_discrete_out( + x->data(), false, dy->data(), true, output_nvte.data(), resources->num_groups, + output_nvte.data(), resources->num_groups, resources->te_alpha.data(), + resources->beta(accumulate), resources->te_setup.data(), resources->te_cublas.data(), + resources->config_data(), at::cuda::getCurrentCUDAStream()); + }); + if (prepared.owns_storage) { + returned_wgrads = prepared.tensors; + } + } + return returned_wgrads; +} + +GroupedTensorWrapper make_grouped_bias(const at::Tensor &bias, size_t num_groups, + at::ScalarType dtype, int64_t out_features) { + NVTE_CHECK(bias.defined(), "Bias tensor must be defined."); + auto grouped = GroupedTensorWrapper( + num_groups, std::vector{num_groups, static_cast(out_features)}); + grouped.set_rowwise_data(bias.data_ptr(), GetTransformerEngineDType(dtype), tensor_shape_1d(bias)); + return grouped; +} + +void add_grouped_bias(GroupedTensorWrapper *output, const at::Tensor &bias, size_t num_groups, + at::ScalarType dtype, int64_t out_features, + std::optional bias_scale = std::nullopt) { + if (!bias.defined()) { + return; + } + auto grouped_bias = make_grouped_bias(bias, num_groups, dtype, out_features); + if (bias_scale.has_value()) { + auto scale = maybe_cast_dtype(*bias_scale, at::kFloat); + check_contiguous(scale, "bias_scale"); + scale = scale.view({-1}); + auto te_scale = makeTransformerEngineTensor(scale); + NVTE_SCOPED_GIL_RELEASE({ + nvte_grouped_scaled_bias_add(output->data(), grouped_bias.data(), te_scale.data(), + at::cuda::getCurrentCUDAStream()); + }); + } else { + NVTE_SCOPED_GIL_RELEASE({ + nvte_grouped_bias_add(output->data(), grouped_bias.data(), at::cuda::getCurrentCUDAStream()); + }); + } +} + +bool is_gated_activation(const std::string &activation) { + return activation == "swiglu" || activation == "clamped_swiglu" || activation == "geglu" || + activation == "reglu" || activation == "qgeglu" || activation == "sreglu"; +} + +at::Tensor maybe_deinterleave_glu(const at::Tensor &input, int64_t glu_interleave_size) { + if (glu_interleave_size <= 0) { + return input; + } + auto shape = input.sizes().vec(); + const int64_t last_dim = shape.back(); + NVTE_CHECK(last_dim % (2 * glu_interleave_size) == 0, + "GLU interleaving requires the last dimension to be divisible by 2*interleave."); + check_contiguous(input, "GLU input"); + // Explicit layout materialization: GLU interleave changes memory order. + return input.view({-1, last_dim / (2 * glu_interleave_size), 2, glu_interleave_size}) + .transpose(1, 2) + .contiguous() + .view(shape); +} + +at::Tensor maybe_reinterleave_glu_grad(const at::Tensor &input, int64_t glu_interleave_size) { + if (glu_interleave_size <= 0) { + return input; + } + auto shape = input.sizes().vec(); + const int64_t last_dim = shape.back(); + check_contiguous(input, "GLU grad input"); + // Explicit layout materialization: reverse GLU interleave changes memory order. + return input.view({-1, 2, last_dim / (2 * glu_interleave_size), glu_interleave_size}) + .transpose(1, 2) + .contiguous() + .view(shape); +} + +at::Tensor activation_forward_impl(const at::Tensor &input, const std::string &activation, + double activation_limit, double activation_alpha, + double activation_glu_linear_offset) { + const int64_t out_features = + is_gated_activation(activation) ? input.size(-1) / 2 : input.size(-1); + auto output = at::empty({input.size(0), out_features}, input.options()); + auto te_input = makeTransformerEngineTensor(input); + auto te_output = makeTransformerEngineTensor(output); + auto stream = at::cuda::getCurrentCUDAStream(); + NVTE_SCOPED_GIL_RELEASE({ + if (activation == "swiglu") { + nvte_swiglu(te_input.data(), te_output.data(), stream); + } else if (activation == "glu") { + nvte_glu(te_input.data(), te_output.data(), stream); + } else if (activation == "geglu") { + nvte_geglu(te_input.data(), te_output.data(), stream); + } else if (activation == "qgeglu") { + nvte_qgeglu(te_input.data(), te_output.data(), stream); + } else if (activation == "reglu") { + nvte_reglu(te_input.data(), te_output.data(), stream); + } else if (activation == "sreglu") { + nvte_sreglu(te_input.data(), te_output.data(), stream); + } else if (activation == "clamped_swiglu") { + nvte_clamped_swiglu_v2(te_input.data(), te_output.data(), static_cast(activation_limit), + static_cast(activation_alpha), + static_cast(activation_glu_linear_offset), stream); + } else if (activation == "srelu") { + nvte_srelu(te_input.data(), te_output.data(), stream); + } else if (activation == "gelu") { + nvte_gelu(te_input.data(), te_output.data(), stream); + } else if (activation == "qgelu") { + nvte_qgelu(te_input.data(), te_output.data(), stream); + } else if (activation == "relu") { + nvte_relu(te_input.data(), te_output.data(), stream); + } else if (activation == "silu") { + nvte_silu(te_input.data(), te_output.data(), stream); + } else { + NVTE_ERROR("Unsupported megacpp grouped MLP activation: ", activation); + } + }); + return output; +} + +at::Tensor activation_backward_impl(const at::Tensor &grad, const at::Tensor &input, + const std::string &activation, double activation_limit, + double activation_alpha, + double activation_glu_linear_offset) { + auto output = at::empty_like(input); + auto te_grad = makeTransformerEngineTensor(grad); + auto te_input = makeTransformerEngineTensor(input); + auto te_output = makeTransformerEngineTensor(output); + auto stream = at::cuda::getCurrentCUDAStream(); + NVTE_SCOPED_GIL_RELEASE({ + if (activation == "swiglu") { + nvte_dswiglu(te_grad.data(), te_input.data(), te_output.data(), stream); + } else if (activation == "glu") { + nvte_dglu(te_grad.data(), te_input.data(), te_output.data(), stream); + } else if (activation == "geglu") { + nvte_dgeglu(te_grad.data(), te_input.data(), te_output.data(), stream); + } else if (activation == "qgeglu") { + nvte_dqgeglu(te_grad.data(), te_input.data(), te_output.data(), stream); + } else if (activation == "reglu") { + nvte_dreglu(te_grad.data(), te_input.data(), te_output.data(), stream); + } else if (activation == "sreglu") { + nvte_dsreglu(te_grad.data(), te_input.data(), te_output.data(), stream); + } else if (activation == "clamped_swiglu") { + nvte_clamped_dswiglu_v2(te_grad.data(), te_input.data(), te_output.data(), + static_cast(activation_limit), + static_cast(activation_alpha), + static_cast(activation_glu_linear_offset), stream); + } else if (activation == "srelu") { + nvte_dsrelu(te_grad.data(), te_input.data(), te_output.data(), stream); + } else if (activation == "gelu") { + nvte_dgelu(te_grad.data(), te_input.data(), te_output.data(), stream); + } else if (activation == "qgelu") { + nvte_dqgelu(te_grad.data(), te_input.data(), te_output.data(), stream); + } else if (activation == "relu") { + nvte_drelu(te_grad.data(), te_input.data(), te_output.data(), stream); + } else if (activation == "silu") { + nvte_dsilu(te_grad.data(), te_input.data(), te_output.data(), stream); + } else { + NVTE_ERROR("Unsupported megacpp grouped MLP activation backward: ", activation); + } + }); + return output; +} + +at::Tensor grouped_mlp_activation_forward( + const at::Tensor &input, const std::optional &act_scales, + const std::string &activation, int64_t glu_interleave_size, double activation_limit, + double activation_alpha, double activation_glu_linear_offset, at::ScalarType dtype) { + auto activation_input = maybe_deinterleave_glu(input, glu_interleave_size); + auto activation_output = activation_forward_impl(activation_input, activation, activation_limit, + activation_alpha, activation_glu_linear_offset); + if (!act_scales.has_value()) { + return activation_output; + } + auto act_scales_for_fc2 = maybe_cast_dtype(*act_scales, dtype); + check_contiguous(act_scales_for_fc2, "act_scales"); + return activation_output * act_scales_for_fc2.view({-1, 1}); +} + +struct ActivationBackwardResult { + at::Tensor grad_input; + at::Tensor grad_act_scales; +}; + +ActivationBackwardResult grouped_mlp_activation_backward( + const at::Tensor &grad_output, const at::Tensor &input, + const std::optional &act_scales, const std::string &activation, + int64_t glu_interleave_size, double activation_limit, double activation_alpha, + double activation_glu_linear_offset, at::ScalarType dtype, bool act_scales_requires_grad) { + auto activation_input = maybe_deinterleave_glu(input, glu_interleave_size); + + at::Tensor grad_activation_output = grad_output; + at::Tensor grad_act_scales; + if (act_scales.has_value()) { + if (act_scales_requires_grad) { + // Scaled activations compute y = activation(x) * act_scales[:, None]. + // Recompute activation(x) for dact_scales to match the Python basic-op + // path without saving another [tokens, hidden] activation tensor. + auto activation_output = + activation_forward_impl(activation_input, activation, activation_limit, activation_alpha, + activation_glu_linear_offset); + grad_act_scales = (activation_output * grad_output).sum(-1); + } + auto act_scales_for_grad = maybe_cast_dtype(*act_scales, dtype); + check_contiguous(act_scales_for_grad, "act_scales"); + grad_activation_output = grad_output * act_scales_for_grad.view({-1, 1}); + } + + auto grad_activation_input = + activation_backward_impl(grad_activation_output, activation_input, activation, activation_limit, + activation_alpha, activation_glu_linear_offset); + return {maybe_reinterleave_glu_grad(grad_activation_input, glu_interleave_size), + grad_act_scales}; +} + +} // namespace + +std::vector megacpp_grouped_mlp_forward( + const at::Tensor &input, const at::Tensor &split_sizes, py::handle fc1_weight, + py::handle fc1_bias, py::handle fc2_weight, py::handle fc2_bias, + const std::optional &act_scales, const std::string &activation, + int64_t glu_interleave_size, double activation_limit, double activation_alpha, + double activation_glu_linear_offset) { + NVTE_CHECK(input.is_cuda(), "megacpp_grouped_mlp_forward requires CUDA input."); + at::cuda::CUDAGuard device_guard(input.device()); + + const auto num_groups = static_cast(split_sizes.numel()); + NVTE_CHECK(num_groups > 0, "megacpp grouped MLP requires at least one group."); + + const auto dtype = input.scalar_type(); + NVTE_CHECK(dtype == at::kBFloat16 || dtype == at::kHalf, + "megacpp grouped MLP currently supports BF16/FP16 only."); + + auto fc1_weights = weight_arg_from_py(fc1_weight, num_groups, dtype, "fc1_weight"); + auto fc2_weights = weight_arg_from_py(fc2_weight, num_groups, dtype, "fc2_weight"); + const int64_t in_features = fc1_weights.cols; + const int64_t fc1_out_features = fc1_weights.rows; + const int64_t fc2_out_features = fc2_weights.rows; + const int64_t fc2_in_features = fc2_weights.cols; + const int64_t activation_out_features = + is_gated_activation(activation) ? fc1_out_features / 2 : fc1_out_features; + NVTE_CHECK(activation_out_features == fc2_in_features, + "FC1 activation output dimension must match FC2 input dimension."); + auto fc1_bias_tensor = + packed_bias_from_arg(fc1_bias, num_groups, dtype, fc1_out_features, "fc1_bias"); + auto fc2_bias_tensor = + packed_bias_from_arg(fc2_bias, num_groups, dtype, fc2_out_features, "fc2_bias"); + + auto x = maybe_cast_dtype(input, dtype); + check_contiguous(x, "input"); + x = x.view({-1, in_features}); + auto [split_sizes_i64, split_offsets] = splits_to_offsets_multi( + split_sizes, x.device(), + std::vector{1, in_features, fc1_out_features, fc2_in_features, fc2_out_features}, + std::vector{true, true, true, true, true}, + std::vector{at::kLong, at::kLong, at::kLong, at::kLong, at::kLong}, + true); + // splits_to_offsets_multi returns the canonical int64 CUDA split sizes and + // offsets in the same order as the stride list above. The CuTe path also asks + // for int32 split_points, but cuBLAS grouped GEMM does not consume them. + NVTE_CHECK(split_offsets.size() == 5, "Expected five grouped split-offset tensors."); + auto base_offsets = split_offsets[0]; + auto x_offsets = split_offsets[1]; + auto fc1_offsets = split_offsets[2]; + auto fc2_offsets = split_offsets[3]; + auto output_offsets = split_offsets[4]; + const int64_t total_tokens = x.size(0); + auto gemm_resources = make_grouped_mlp_backend_resources(x.device(), num_groups); + + auto fc1_preact = at::empty({total_tokens, fc1_out_features}, x.options()); + auto grouped_x = make_grouped_tensor(x.view({-1}), split_sizes_i64, x_offsets, in_features); + auto grouped_fc1_preact = + make_grouped_tensor(fc1_preact.view({-1}), split_sizes_i64, fc1_offsets, fc1_out_features); + grouped_gemm_fwd_dgrad(&fc1_weights, true, &grouped_x, false, &grouped_fc1_preact, + &gemm_resources); + add_grouped_bias(&grouped_fc1_preact, fc1_bias_tensor, num_groups, dtype, fc1_out_features); + + auto fc2_x = + grouped_mlp_activation_forward(fc1_preact, act_scales, activation, glu_interleave_size, + activation_limit, activation_alpha, + activation_glu_linear_offset, dtype); + + std::vector out_shape = input.sizes().vec(); + out_shape.back() = fc2_out_features; + auto output = at::empty(out_shape, x.options()); + auto output_2d = output.view({-1, fc2_out_features}); + auto grouped_fc2_x = + make_grouped_tensor(fc2_x.view({-1}), split_sizes_i64, fc2_offsets, fc2_in_features); + auto grouped_output = + make_grouped_tensor(output_2d.view({-1}), split_sizes_i64, output_offsets, fc2_out_features); + grouped_gemm_fwd_dgrad(&fc2_weights, true, &grouped_fc2_x, false, &grouped_output, + &gemm_resources); + add_grouped_bias(&grouped_output, fc2_bias_tensor, num_groups, dtype, fc2_out_features); + + return {output, x, split_sizes_i64, base_offsets, x_offsets, fc1_offsets, fc2_offsets, + output_offsets, fc1_preact, fc2_x}; +} + +py::tuple megacpp_grouped_mlp_backward( + const at::Tensor &grad_output, const at::Tensor &split_sizes, const at::Tensor &x_offsets, + const at::Tensor &fc1_offsets, const at::Tensor &fc2_offsets, + const at::Tensor &fc2_dy_offsets, const at::Tensor &base_offsets, + const at::Tensor &x, const at::Tensor &fc1_activation_input, const at::Tensor &fc2_x, + const std::optional &act_scales, py::handle fc1_weight, + py::handle fc2_weight, + py::handle fc1_wgrad_output, bool fc1_compute_wgrad, bool fc1_accumulate_wgrad, + py::handle fc2_wgrad_output, bool fc2_compute_wgrad, bool fc2_accumulate_wgrad, + const std::string &activation, int64_t glu_interleave_size, double activation_limit, + double activation_alpha, double activation_glu_linear_offset, bool act_scales_requires_grad, + bool input_requires_grad) { + (void)base_offsets; + NVTE_CHECK(grad_output.is_cuda(), "megacpp_grouped_mlp_backward requires CUDA grad_output."); + at::cuda::CUDAGuard device_guard(grad_output.device()); + + const auto num_groups = num_groups_from_prepared_split_sizes(split_sizes, grad_output.device()); + const auto dtype = grad_output.scalar_type(); + auto fc1_weights = weight_arg_from_py(fc1_weight, num_groups, dtype, "fc1_weight"); + auto fc2_weights = weight_arg_from_py(fc2_weight, num_groups, dtype, "fc2_weight"); + + const int64_t in_features = fc1_weights.cols; + const int64_t fc1_out_features = fc1_weights.rows; + const int64_t fc2_out_features = fc2_weights.rows; + const int64_t fc2_in_features = fc2_weights.cols; + + auto dy = maybe_cast_dtype(grad_output, dtype); + check_contiguous(dy, "grad_output"); + dy = dy.view({-1, fc2_out_features}); + const int64_t total_tokens = dy.size(0); + auto gemm_resources = make_grouped_mlp_backend_resources(grad_output.device(), num_groups); + + auto grouped_dy = + make_grouped_tensor(dy.view({-1}), split_sizes, fc2_dy_offsets, fc2_out_features); + std::vector fc2_wgrads; + if (fc2_compute_wgrad) { + auto fc2_x_for_wgrad = maybe_cast_dtype(fc2_x, dtype); + check_contiguous(fc2_x_for_wgrad, "fc2_x"); + fc2_x_for_wgrad = fc2_x_for_wgrad.view({-1, fc2_in_features}); + auto grouped_fc2_x_for_wgrad = + make_grouped_tensor(fc2_x_for_wgrad.view({-1}), split_sizes, fc2_offsets, fc2_in_features); + fc2_wgrads = + grouped_gemm_wgrad(&grouped_fc2_x_for_wgrad, &grouped_dy, fc2_wgrad_output, + fc2_compute_wgrad, fc2_accumulate_wgrad, &gemm_resources, dtype, + fc2_out_features, fc2_in_features, "fc2_wgrad_output", + fc2_weights.is_grouped); + } + + auto fc2_dx = at::empty({total_tokens, fc2_in_features}, dy.options()); + auto grouped_fc2_dx = + make_grouped_tensor(fc2_dx.view({-1}), split_sizes, fc2_offsets, fc2_in_features); + grouped_gemm_fwd_dgrad(&fc2_weights, false, &grouped_dy, false, &grouped_fc2_dx, + &gemm_resources); + + auto activation_grads = grouped_mlp_activation_backward( + fc2_dx, fc1_activation_input, act_scales, activation, glu_interleave_size, + activation_limit, activation_alpha, activation_glu_linear_offset, dtype, + act_scales_requires_grad); + auto fc1_dy = activation_grads.grad_input; + auto grad_act_scales = activation_grads.grad_act_scales; + auto grouped_fc1_dy = + make_grouped_tensor(fc1_dy.view({-1}), split_sizes, fc1_offsets, fc1_out_features); + + std::vector fc1_wgrads; + if (fc1_compute_wgrad) { + auto x_for_wgrad = maybe_cast_dtype(x, dtype); + check_contiguous(x_for_wgrad, "x"); + x_for_wgrad = x_for_wgrad.view({-1, in_features}); + auto grouped_x_for_wgrad = + make_grouped_tensor(x_for_wgrad.view({-1}), split_sizes, x_offsets, in_features); + fc1_wgrads = + grouped_gemm_wgrad(&grouped_x_for_wgrad, &grouped_fc1_dy, fc1_wgrad_output, + fc1_compute_wgrad, fc1_accumulate_wgrad, &gemm_resources, dtype, + fc1_out_features, in_features, "fc1_wgrad_output", + fc1_weights.is_grouped); + } + + at::Tensor grad_input; + if (input_requires_grad) { + std::vector grad_input_shape = grad_output.sizes().vec(); + grad_input_shape.back() = in_features; + grad_input = at::empty(grad_input_shape, dy.options()); + auto grad_input_2d = grad_input.view({-1, in_features}); + auto grouped_grad_input = make_grouped_tensor(grad_input_2d.view({-1}), split_sizes, + x_offsets, in_features); + grouped_gemm_fwd_dgrad(&fc1_weights, false, &grouped_fc1_dy, false, &grouped_grad_input, + &gemm_resources); + } else { + grad_input = at::empty({0}, dy.options()); + } + + auto empty_return = at::empty({0}, dy.options()); + if (!grad_act_scales.defined()) { + grad_act_scales = empty_return; + } + return py::make_tuple(grad_input, fc1_dy, grad_act_scales, fc1_wgrads, fc2_wgrads); +} + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/ops/fused/__init__.py b/transformer_engine/pytorch/ops/fused/__init__.py index 78f9d880ba..fd09162ade 100644 --- a/transformer_engine/pytorch/ops/fused/__init__.py +++ b/transformer_engine/pytorch/ops/fused/__init__.py @@ -39,3 +39,9 @@ BackwardGroupedMLP_CuTeGEMMDGLU, BackwardGroupedMLP_CuTeGEMMDUnary, ) +from .forward_grouped_mlp_megacpp import ( # pylint: disable=wrong-import-position + ForwardGroupedMLP_MegaCpp, +) +from .backward_grouped_mlp_megacpp import ( # pylint: disable=wrong-import-position + BackwardGroupedMLP_MegaCpp, +) diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp_megacpp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp_megacpp.py new file mode 100644 index 0000000000..ebaf30d075 --- /dev/null +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp_megacpp.py @@ -0,0 +1,392 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Mega C++ grouped MLP backward fuser.""" + +from __future__ import annotations +import functools +from typing import Optional + +import torch + +import transformer_engine_torch as tex +from ...quantization import Recipe +from ...utils import clear_tensor_data, get_device_compute_capability +from ...triton.grouped_dbias_dscales import compute_grouped_dbias +from ..basic import GroupedLinear +from ..fuser import register_backward_fusion +from ..op import FusedOperation, FusibleOperation, OperationContext +from .._common import ( + get_accumulate_flag_in_param, + get_dummy_wgrads_for_params, + get_main_grad_from_param, + view_main_grad_as_grouped_buffer, +) +from .forward_grouped_mlp_megacpp import ( + _megacpp_activation_config, + _megacpp_enabled, + _megacpp_supports_recipe, + _resolve_megacpp_grouped_mlp_config, +) + + +def _megacpp_saved_weight_arg( + saved_tensors: tuple[torch.Tensor, ...], + *, + single_weight_arg: bool, + num_groups: int, +) -> tuple[torch.Tensor | list[torch.Tensor], tuple[torch.Tensor, ...]]: + """Unpack saved C++ weight argument in the same shape used by forward.""" + if single_weight_arg: + return saved_tensors[0], saved_tensors[1:] + return list(saved_tensors[:num_groups]), saved_tensors[num_groups:] + + +def _delay_wgrad(fc_op: GroupedLinear, ctx: OperationContext) -> bool: + """Whether this FC op requested unsupported delayed wgrad.""" + return bool( + ctx.weight_requires_grad + and fc_op.wgrad_store is not None + and fc_op.wgrad_store.delay_wgrad_compute() + ) + + +def _compute_bias_grad_params( + fc_op: GroupedLinear, + dy_2d: torch.Tensor, + base_offsets: torch.Tensor, + *, + num_groups: int, + dtype: torch.dtype, +) -> tuple[Optional[list[torch.Tensor]], Optional[torch.Tensor]]: + """Compute bias grads in GroupedLinear parameter layout.""" + if not fc_op.has_bias: + return None, None + dbias_packed = compute_grouped_dbias(dy_2d, base_offsets, num_groups).to(dtype=dtype) + if fc_op.single_grouped_bias: + return None, dbias_packed + return [dbias_packed[idx] for idx in range(num_groups)], None + + +def _prepare_cpp_wgrad_output( + fc_op: GroupedLinear, + ctx: OperationContext, + *, + num_groups: int, + weight_shape: tuple[int, int], + label: str, +) -> tuple[Optional[torch.Tensor | list[torch.Tensor]], bool, bool, list[Optional[torch.Tensor]]]: + """Return an optional externally-owned wgrad buffer for C++. + + If Megatron has already installed ``main_grad`` buffers, C++ writes into + them. Otherwise this returns ``None`` and C++ allocates/returns a packed + ``[num_groups, out_features, in_features]`` wgrad tensor. + """ + weights = fc_op._get_weight_tensors() + weight_grads: list[Optional[torch.Tensor]] = ( + [None] if fc_op.single_grouped_weight else [None] * num_groups + ) + if _delay_wgrad(fc_op, ctx): + raise ValueError("megacpp grouped MLP does not support delay_wgrad_compute=True.") + if not ctx.weight_requires_grad: + return None, False, False, weight_grads + + accumulate_into_main_grad = False + if fc_op.single_grouped_weight: + if fc_op._accumulate_into_main_grad: + main_grad = get_main_grad_from_param(weights[0], op_label=label) + wgrad_output = view_main_grad_as_grouped_buffer( + main_grad, + num_groups, + weight_shape, + label=f"{label} weight", + ) + accumulate_into_main_grad = get_accumulate_flag_in_param(weights[0]) + weight_grads = get_dummy_wgrads_for_params(weights) + else: + wgrad_output = None + else: + if fc_op._accumulate_into_main_grad: + wgrad_output = [get_main_grad_from_param(w, op_label=label) for w in weights] + accumulate_into_main_grad = get_accumulate_flag_in_param(weights[0]) + weight_grads = get_dummy_wgrads_for_params(weights) + else: + wgrad_output = None + + return wgrad_output, True, accumulate_into_main_grad, weight_grads + + +def _assemble_grad_params( + fc_op: GroupedLinear, + weight_grads: list[Optional[torch.Tensor]], + bias_grads: Optional[list[torch.Tensor]], + bias_grad_packed: Optional[torch.Tensor], + *, + num_groups: int, +) -> list[Optional[torch.Tensor]]: + """Assemble parameter grads in GroupedLinear registration order.""" + if not fc_op.has_bias: + return weight_grads + if fc_op.single_grouped_bias: + return weight_grads + [bias_grad_packed] + bias_list = bias_grads if bias_grads is not None else [None] * num_groups + if fc_op.single_grouped_weight: + return bias_list + weight_grads + return weight_grads + bias_list + + +class BackwardGroupedMLP_MegaCpp(FusedOperation): + """Experimental C++ grouped MLP backward for BF16/FP16. + + Weight gradients are computed in C++. Delayed wgrad is intentionally not + supported in this first implementation to keep ownership and lifetime rules + simple. + """ + + @classmethod + @functools.lru_cache(maxsize=None) + def is_supported(cls) -> bool: + if not torch.cuda.is_available(): + return False + if get_device_compute_capability()[0] < 10: + return False + return hasattr(tex, "megacpp_grouped_mlp_backward") + + def __init__( + self, + *, + fc1: GroupedLinear, + activation: Optional[FusibleOperation], + fc2: GroupedLinear, + ) -> None: + if activation is None: + raise TypeError("Expected a grouped MLP activation op.") + super().__init__((fc1, activation, fc2)) + _resolve_megacpp_grouped_mlp_config(fc1, activation, fc2) + if fc1._scale_bias or fc2._scale_bias: + raise RuntimeError("megacpp grouped MLP does not support scale_bias yet.") + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + **unused, # pylint: disable=unused-argument + ) -> tuple[ + torch.Tensor, + list[tuple[Optional[torch.Tensor], ...]], + list[tuple[()]], + ]: + fc1_op, activation_op, fc2_op = self.basic_ops + fc1_ctx, activation_ctx, fc2_ctx = basic_op_ctxs + num_groups = fc1_op.num_groups + dtype = fc1_ctx.dtype + + fc1_saved = fc1_ctx.saved_tensors + split_sizes, base_offsets, x_offsets, fc1_offsets = fc1_saved[:4] + x, fc1_activation_input = fc1_saved[4:6] + fc1_weight_arg, _ = _megacpp_saved_weight_arg( + fc1_saved[6:], + single_weight_arg=bool(getattr(fc1_ctx, "single_weight_arg", False)), + num_groups=num_groups, + ) + + activation_config = _megacpp_activation_config(activation_op) + _, act_scales = activation_ctx.saved_tensors + + fc2_saved = fc2_ctx.saved_tensors + fc2_offsets = fc2_saved[2] + fc2_dy_offsets = fc2_saved[3] + fc2_x = fc2_saved[4] + fc2_weight_arg, _ = _megacpp_saved_weight_arg( + fc2_saved[5:], + single_weight_arg=bool(getattr(fc2_ctx, "single_weight_arg", False)), + num_groups=num_groups, + ) + + ( + fc1_wgrad_output, + fc1_compute_wgrad, + fc1_accumulate_wgrad, + fc1_weight_grads, + ) = _prepare_cpp_wgrad_output( + fc1_op, + fc1_ctx, + num_groups=num_groups, + weight_shape=(fc1_op.out_features, fc1_op.in_features), + label="Grouped MLP megacpp backward (FC1)", + ) + ( + fc2_wgrad_output, + fc2_compute_wgrad, + fc2_accumulate_wgrad, + fc2_weight_grads, + ) = _prepare_cpp_wgrad_output( + fc2_op, + fc2_ctx, + num_groups=num_groups, + weight_shape=(fc2_op.out_features, fc2_op.in_features), + label="Grouped MLP megacpp backward (FC2)", + ) + ( + grad_input, + fc1_dy, + grad_act_scales, + fc1_owned_weight_grads, + fc2_owned_weight_grads, + ) = tex.megacpp_grouped_mlp_backward( + grad_output.to(dtype=dtype), + split_sizes, + x_offsets, + fc1_offsets, + fc2_offsets, + fc2_dy_offsets, + base_offsets, + x, + fc1_activation_input, + fc2_x, + act_scales, + fc1_weight_arg, + fc2_weight_arg, + fc1_wgrad_output, + fc1_compute_wgrad, + fc1_accumulate_wgrad, + fc2_wgrad_output, + fc2_compute_wgrad, + fc2_accumulate_wgrad, + activation_config.name, + activation_config.glu_interleave_size, + activation_config.limit, + activation_config.alpha, + activation_config.glu_linear_offset, + bool(activation_ctx.extra_input_requires_grad), + bool(fc1_ctx.input_requires_grad), + ) + if not fc1_ctx.input_requires_grad: + grad_input = None + + grad_output_2d = grad_output.reshape(-1, fc2_op.out_features).to(dtype=dtype) + fc2_bias_grads, fc2_bias_grad_packed = _compute_bias_grad_params( + fc2_op, + grad_output_2d, + base_offsets, + num_groups=num_groups, + dtype=dtype, + ) + fc1_bias_grads, fc1_bias_grad_packed = _compute_bias_grad_params( + fc1_op, + fc1_dy, + base_offsets, + num_groups=num_groups, + dtype=dtype, + ) + + # Wgrad ownership cases: + # 1. No weight grad: keep [None] placeholders prepared above. + # 2. Megatron-owned main_grad: C++ wrote into the provided buffer; + # keep dummy wgrads prepared above for autograd. + # 3. C++-owned allocation: replace the placeholder list with returned + # wgrads. Single grouped weight returns [packed], discrete weights + # return one tensor per expert. + if fc2_ctx.weight_requires_grad and not fc2_op._accumulate_into_main_grad: + expected_wgrads = 1 if fc2_op.single_grouped_weight else num_groups + if len(fc2_owned_weight_grads) != expected_wgrads: + raise RuntimeError(f"FC2 expected {expected_wgrads} owned wgrad tensors.") + fc2_weight_grads = fc2_owned_weight_grads + fc2_grad_params = _assemble_grad_params( + fc2_op, + fc2_weight_grads, + fc2_bias_grads, + fc2_bias_grad_packed, + num_groups=num_groups, + ) + clear_tensor_data(fc2_x) + + # Same ownership policy as FC2. Megatron-owned main_grad keeps the + # prepared dummy grads; C++-owned allocation uses the returned wgrads. + if fc1_ctx.weight_requires_grad and not fc1_op._accumulate_into_main_grad: + expected_wgrads = 1 if fc1_op.single_grouped_weight else num_groups + if len(fc1_owned_weight_grads) != expected_wgrads: + raise RuntimeError(f"FC1 expected {expected_wgrads} owned wgrad tensors.") + fc1_weight_grads = fc1_owned_weight_grads + fc1_grad_params = _assemble_grad_params( + fc1_op, + fc1_weight_grads, + fc1_bias_grads, + fc1_bias_grad_packed, + num_groups=num_groups, + ) + clear_tensor_data(x) + + activation_grad_extra = ( + (grad_act_scales.to(dtype=dtype),) + if activation_ctx.extra_input_requires_grad + else (None,) + ) + + return grad_input, [fc1_grad_params, (), fc2_grad_params], [ + (None,), + activation_grad_extra, + (None,), + ] + + +def fuse_backward_megacpp_ops( + ops: list[FusibleOperation], + *, + recipe: Optional[Recipe] = None, + **unused, # pylint: disable=unused-argument +) -> list[FusibleOperation]: + """Apply opt-in C++ grouped MLP backward fusion for BF16/FP16.""" + if not _megacpp_enabled(): + return ops + if not _megacpp_supports_recipe(recipe): + return ops + if not BackwardGroupedMLP_MegaCpp.is_supported(): + return ops + + out = [] + window, ops = ops[:3], ops[3:] + while len(window) == 3: + matches_pattern = True + if not ( + isinstance(window[0], GroupedLinear) + and isinstance(window[2], GroupedLinear) + ): + matches_pattern = False + elif ( + window[0]._scale_bias + or window[2]._scale_bias + ): + matches_pattern = False + else: + try: + _resolve_megacpp_grouped_mlp_config(window[0], window[1], window[2]) + except (TypeError, ValueError, RuntimeError): + matches_pattern = False + + if matches_pattern: + window = [ + BackwardGroupedMLP_MegaCpp( + fc1=window[0], + activation=window[1], + fc2=window[2], + ) + ] + else: + out.extend(window[:-2]) + window = window[-2:] + + out.extend(window[:-3]) + window = window[-3:] + while ops and len(window) < 3: + window.append(ops[0]) + ops = ops[1:] + + out.extend(window) + return out + + +# Use the same opt-in and recipe gate as forward. Unsupported recipes fall +# through unchanged so the matching recipe-specific backward fuser can run. +register_backward_fusion(fuse_backward_megacpp_ops, prepend=True) diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp_megacpp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp_megacpp.py new file mode 100644 index 0000000000..bd05b6218f --- /dev/null +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp_megacpp.py @@ -0,0 +1,382 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Mega C++ grouped MLP forward fuser.""" + +from __future__ import annotations +from collections.abc import Iterable +import functools +import os +from typing import Any, NamedTuple, Optional + +import torch + +import transformer_engine_torch as tex +from ...quantization import Recipe +from ...tensor import Quantizer +from ...utils import get_device_compute_capability +from ..basic import GroupedLinear, ScaledSReLU, ScaledClampedQGeGLU, ScaledSwiGLU +from ..fuser import register_forward_fusion +from ..op import FusedOperation, FusibleOperation, OperationContext + + +def _megacpp_enabled() -> bool: + """Whether the experimental grouped MLP C++ path is explicitly enabled.""" + return int(os.getenv("NVTE_MEGACPP_GROUPED_LINEAR", "0")) > 0 + + +def _megacpp_supports_recipe(recipe: Optional[Recipe]) -> bool: + """Whether megacpp is a valid candidate for the active quantization recipe. + + Today the C++ implementation is BF16/FP16-only, so only the no-recipe path + is supported. Returning False for FP8 recipes is intentional: it leaves the + op list unchanged so the existing MXFP8/NVFP4 CuTe DSL fusers can match. + Future MXFP8/NVFP4 support should be enabled by changing this predicate, + not by reordering fusion registrations. + """ + if recipe is None: + return True + if recipe.mxfp8() or recipe.nvfp4(): + return False + return False + + +class _MegaCppActivationConfig(NamedTuple): + """Activation semantics consumed by the C++ grouped MLP path.""" + + name: str + is_scaled: bool + is_gated: bool + glu_interleave_size: int + limit: float = 0.0 + alpha: float = 0.0 + glu_linear_offset: float = 0.0 + + +def _megacpp_activation_config(activation) -> _MegaCppActivationConfig: + """Return activation parameters consumed by the C++ grouped MLP path.""" + glu_interleave_size = int(getattr(activation, "glu_interleave_size", None) or 0) + if isinstance(activation, ScaledSwiGLU): + return _MegaCppActivationConfig("swiglu", True, True, glu_interleave_size) + if isinstance(activation, ScaledClampedQGeGLU): + return _MegaCppActivationConfig( + "clamped_swiglu", + True, + True, + glu_interleave_size, + float(activation._clamped.limit), + float(activation._clamped.alpha), + float(activation._clamped.glu_linear_offset), + ) + if isinstance(activation, ScaledSReLU): + return _MegaCppActivationConfig("srelu", True, False, 0) + if getattr(activation, "num_extra_inputs", 0) == 0: + return _MegaCppActivationConfig("plain_unsupported", False, False, 0) + raise TypeError( + "megacpp grouped MLP currently supports only ScaledSwiGLU, " + "ScaledClampedQGeGLU, and ScaledSReLU." + ) + + +def _resolve_megacpp_grouped_mlp_config( + fc1: GroupedLinear, + activation, + fc2: GroupedLinear, +) -> _MegaCppActivationConfig: + """Resolve megacpp activation config and validate grouped MLP support.""" + config = _megacpp_activation_config(activation) + if not config.is_scaled: + raise RuntimeError( + "megacpp grouped MLP keeps an optional-scale activation API, but plain " + f"{activation.__class__.__name__} is not supported yet." + ) + if fc1.in_features % 64 != 0 or fc1.out_features % 64 != 0: + raise ValueError( + f"Unsupported dims for FC1 (num_groups={fc1.num_groups}, " + f"in_features={fc1.in_features}, out_features={fc1.out_features})." + ) + if fc2.in_features % 64 != 0 or fc2.out_features % 64 != 0: + raise ValueError( + f"Unsupported dims for FC2 (num_groups={fc2.num_groups}, " + f"in_features={fc2.in_features}, out_features={fc2.out_features})." + ) + expected_fc1_out_features = 2 * fc2.in_features if config.is_gated else fc2.in_features + if fc1.out_features != expected_fc1_out_features or fc1.num_groups != fc2.num_groups: + raise ValueError( + f"FC1 (num_groups={fc1.num_groups}, in_features={fc1.in_features}, " + f"out_features={fc1.out_features}) " + f"and FC2 (num_groups={fc2.num_groups}, in_features={fc2.in_features}, " + f"out_features={fc2.out_features}) do not match." + ) + if config.glu_interleave_size and fc1.out_features % (2 * config.glu_interleave_size) != 0: + raise ValueError( + "GLU interleaving requires FC1 out_features to be divisible by " + f"2*glu_interleave_size, got out_features={fc1.out_features}, " + f"glu_interleave_size={config.glu_interleave_size}." + ) + return config + + +def _megacpp_weight_arg( + linear_op: GroupedLinear, + dtype: torch.dtype, + *, + input_requires_grad: bool, +) -> torch.Tensor | list[torch.Tensor]: + """Return GEMM-ready high-precision weights for the current C++ path. + + Keep the layout policy in GroupedLinear. This handles quantized weights the + same way as the Python grouped GEMM path: BF16/FP16 compute dequantizes when + needed, while a future quantized-compute path can preserve quantized weights + by switching ``with_quantized_compute``. + """ + with_quantized_compute = False + if linear_op.single_grouped_weight: + grouped_weight = linear_op._get_grouped_weight_for_gemm( + linear_op.weight, + [linear_op.get_quantizer("forward", 1)], + columnwise_usage=input_requires_grad, + with_quantized_compute=with_quantized_compute, + dtype=dtype, + ) + if grouped_weight.rowwise_data is None: + raise RuntimeError("megacpp grouped MLP expected dense grouped weight rowwise_data.") + # Keep single grouped weight packed. The C++ path wraps this as a + # uniform GroupedTensor and dispatches nvte_grouped_gemm instead of + # expanding it into per-expert discrete tensors. + return grouped_weight.rowwise_data.view( + linear_op.num_groups, + linear_op.out_features, + linear_op.in_features, + ) + return linear_op._get_discrete_weights_for_gemm( + [getattr(linear_op, f"weight{idx}") for idx in range(linear_op.num_groups)], + [linear_op.get_quantizer("forward", 2 * idx + 1) for idx in range(linear_op.num_groups)], + columnwise_usage=input_requires_grad, + with_quantized_compute=with_quantized_compute, + dtype=dtype, + ) + + +def _megacpp_bias_arg(linear_op: GroupedLinear, dtype: torch.dtype) -> Optional[torch.Tensor]: + """Return a packed [G, N] high-precision bias tensor or None.""" + grouped_bias = linear_op._get_grouped_bias_for_gemm(dtype) + if grouped_bias is None: + return None + return grouped_bias.rowwise_data.view(linear_op.num_groups, linear_op.out_features) + + +class ForwardGroupedMLP_MegaCpp(FusedOperation): + """Experimental BF16/FP16 grouped MLP forward implemented in C++. + + The C++ function returns plain tensors only. Python still owns autograd + context layout; delayed wgrad is rejected by the matching backward op. + """ + + @classmethod + @functools.lru_cache(maxsize=None) + def is_supported(cls) -> bool: + """Whether the C++ grouped MLP path can be dispatched.""" + if not torch.cuda.is_available(): + return False + if get_device_compute_capability()[0] < 10: + return False + return hasattr(tex, "megacpp_grouped_mlp_forward") + + def __init__( + self, + *, + fc1: GroupedLinear, + activation: Optional[FusibleOperation], + fc2: GroupedLinear, + ) -> None: + if activation is None: + raise TypeError("Expected a grouped MLP activation op.") + super().__init__((fc1, activation, fc2)) + _resolve_megacpp_grouped_mlp_config(fc1, activation, fc2) + if fc1._scale_bias or fc2._scale_bias: + raise RuntimeError("megacpp grouped MLP does not support scale_bias yet.") + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + basic_op_kwargs: list[dict[str, Any]], + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + del prev_op_grad_output_quantizer, next_op_input_quantizer, basic_op_kwargs + fc1_op, activation_op, fc2_op = self.basic_ops + fc1_ctx, activation_ctx, fc2_ctx = basic_op_ctxs + num_groups = fc1_op.num_groups + + split_sizes = basic_op_extra_inputs[0][0] + fc2_split_sizes = basic_op_extra_inputs[2][0] + if ( + split_sizes.size() != fc2_split_sizes.size() + or split_sizes.data_ptr() != fc2_split_sizes.data_ptr() + ): + raise RuntimeError(f"{self.__class__.__name__} got different split sizes for FC1/FC2.") + if int(split_sizes.numel()) != num_groups: + raise ValueError(f"Expected {num_groups} splits, got {int(split_sizes.numel())}.") + + activation_config = _megacpp_activation_config(activation_op) + act_scales = basic_op_extra_inputs[1][0] + fc1_weight_param = fc1_op.weight if fc1_op.single_grouped_weight else fc1_op.weight0 + fc2_weight_param = fc2_op.weight if fc2_op.single_grouped_weight else fc2_op.weight0 + dtype = ( + torch.get_autocast_dtype("cuda") + if torch.is_autocast_enabled() + else fc1_weight_param.dtype + ) + if dtype not in (torch.bfloat16, torch.float16): + raise RuntimeError(f"megacpp grouped MLP supports BF16/FP16 only, got {dtype}.") + + requires_grad = any(ctx.requires_grad for ctx in basic_op_ctxs) + input_requires_grad = requires_grad + fc1_weight_requires_grad = requires_grad and fc1_weight_param.requires_grad + fc2_weight_requires_grad = requires_grad and fc2_weight_param.requires_grad + + fc1_weights = _megacpp_weight_arg( + fc1_op, + dtype, + input_requires_grad=input_requires_grad, + ) + fc2_weights = _megacpp_weight_arg( + fc2_op, + dtype, + input_requires_grad=input_requires_grad, + ) + ( + fc2_out, + x, + split_sizes_i64, + base_split_offsets, + x_offsets, + fc1_offsets, + fc2_offsets, + fc2_dy_offsets, + fc1_activation_input, + fc2_x, + ) = tex.megacpp_grouped_mlp_forward( + input_.to(dtype=dtype), + split_sizes, + fc1_weights, + _megacpp_bias_arg(fc1_op, dtype), + fc2_weights, + _megacpp_bias_arg(fc2_op, dtype), + act_scales, + activation_config.name, + activation_config.glu_interleave_size, + activation_config.limit, + activation_config.alpha, + activation_config.glu_linear_offset, + ) + + if x.data_ptr() == input_.data_ptr(): + x._do_not_clear = True + + if requires_grad: + fc1_saved_weights = [fc1_weights] if isinstance(fc1_weights, torch.Tensor) else fc1_weights + fc2_saved_weights = [fc2_weights] if isinstance(fc2_weights, torch.Tensor) else fc2_weights + + fc1_ctx.save_for_backward( + split_sizes_i64, + base_split_offsets, + x_offsets, + fc1_offsets, + x, + fc1_activation_input, + *fc1_saved_weights, + ) + fc1_ctx.use_megacpp_grouped_mlp = True + fc1_ctx.dtype = dtype + fc1_ctx.input_requires_grad = input_requires_grad + fc1_ctx.weight_requires_grad = fc1_weight_requires_grad + fc1_ctx.single_weight_arg = isinstance(fc1_weights, torch.Tensor) + + activation_ctx.save_for_backward(fc1_activation_input, act_scales) + activation_ctx.extra_input_requires_grad = act_scales.requires_grad + activation_ctx.input_requires_grad = True + activation_ctx.dtype = dtype + + fc2_ctx.save_for_backward( + split_sizes_i64, + base_split_offsets, + fc2_offsets, + fc2_dy_offsets, + fc2_x, + *fc2_saved_weights, + ) + fc2_ctx.use_megacpp_grouped_mlp = True + fc2_ctx.dtype = dtype + fc2_ctx.input_requires_grad = input_requires_grad + fc2_ctx.weight_requires_grad = fc2_weight_requires_grad + fc2_ctx.single_weight_arg = isinstance(fc2_weights, torch.Tensor) + + return fc2_out, [(), (), ()] + + +def fuse_forward_megacpp_ops( + ops: list[FusibleOperation], + *, + recipe: Optional[Recipe] = None, + **unused, # pylint: disable=unused-argument +) -> list[FusibleOperation]: + """Apply opt-in C++ grouped MLP fusion for BF16/FP16.""" + if not _megacpp_enabled(): + return ops + if not _megacpp_supports_recipe(recipe): + return ops + if not ForwardGroupedMLP_MegaCpp.is_supported(): + return ops + + out = [] + window, ops = ops[:3], ops[3:] + while len(window) == 3: + matches_pattern = True + if not ( + isinstance(window[0], GroupedLinear) + and isinstance(window[2], GroupedLinear) + ): + matches_pattern = False + elif ( + window[0]._scale_bias + or window[2]._scale_bias + ): + matches_pattern = False + else: + try: + _resolve_megacpp_grouped_mlp_config(window[0], window[1], window[2]) + except (TypeError, ValueError, RuntimeError): + matches_pattern = False + + if matches_pattern: + window = [ + ForwardGroupedMLP_MegaCpp( + fc1=window[0], + activation=window[1], + fc2=window[2], + ) + ] + else: + out.extend(window[:-2]) + window = window[-2:] + + out.extend(window[:-3]) + window = window[-3:] + while ops and len(window) < 3: + window.append(ops[0]) + ops = ops[1:] + + out.extend(window) + return out + + +# Explicit env opt-in gives megacpp first chance. Unsupported recipes intentionally +# return the ops unchanged so lower-priority recipe-specific fusers remain the +# fallback path. +register_forward_fusion(fuse_forward_megacpp_ops, prepend=True) From 1120ec9e4182f4d33e5b9423ff144029f5ce200b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 6 Jun 2026 07:39:18 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/megacpp/test_grouped_mlp.py | 6 +- transformer_engine/pytorch/csrc/extensions.h | 18 ++- .../pytorch/csrc/extensions/pybind.cpp | 6 +- .../pytorch/csrc/megacpp/grouped_mlp.cpp | 122 ++++++++---------- .../ops/fused/backward_grouped_mlp_megacpp.py | 24 ++-- .../ops/fused/forward_grouped_mlp_megacpp.py | 18 ++- 6 files changed, 85 insertions(+), 109 deletions(-) diff --git a/tests/pytorch/megacpp/test_grouped_mlp.py b/tests/pytorch/megacpp/test_grouped_mlp.py index ddddcb7fc4..d3cc9cd04c 100644 --- a/tests/pytorch/megacpp/test_grouped_mlp.py +++ b/tests/pytorch/megacpp/test_grouped_mlp.py @@ -464,10 +464,10 @@ def test_megacpp_grouped_mlp_delay_wgrad_raises(monkeypatch): glu_interleave_size=None, single_grouped_param=False, ) - x = torch.randn(total_tokens, _HIDDEN_SIZE, device="cuda", dtype=torch.bfloat16).requires_grad_() - act_scales = torch.rand( - total_tokens, device="cuda", dtype=torch.bfloat16 + x = torch.randn( + total_tokens, _HIDDEN_SIZE, device="cuda", dtype=torch.bfloat16 ).requires_grad_() + act_scales = torch.rand(total_tokens, device="cuda", dtype=torch.bfloat16).requires_grad_() dy = torch.randn(total_tokens, _HIDDEN_SIZE, device="cuda", dtype=torch.bfloat16) monkeypatch.setenv("NVTE_MEGACPP_GROUPED_LINEAR", "1") diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index a59e85456d..1fd334ea69 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -198,16 +198,14 @@ std::vector megacpp_grouped_mlp_forward( py::tuple megacpp_grouped_mlp_backward( const at::Tensor &grad_output, const at::Tensor &split_sizes, const at::Tensor &x_offsets, - const at::Tensor &fc1_offsets, const at::Tensor &fc2_offsets, - const at::Tensor &fc2_dy_offsets, const at::Tensor &base_offsets, - const at::Tensor &x, const at::Tensor &fc1_activation_input, const at::Tensor &fc2_x, - const std::optional &act_scales, py::handle fc1_weight, - py::handle fc2_weight, - py::handle fc1_wgrad_output, bool fc1_compute_wgrad, bool fc1_accumulate_wgrad, - py::handle fc2_wgrad_output, bool fc2_compute_wgrad, bool fc2_accumulate_wgrad, - const std::string &activation, int64_t glu_interleave_size, double activation_limit, - double activation_alpha, double activation_glu_linear_offset, bool act_scales_requires_grad, - bool input_requires_grad); + const at::Tensor &fc1_offsets, const at::Tensor &fc2_offsets, const at::Tensor &fc2_dy_offsets, + const at::Tensor &base_offsets, const at::Tensor &x, const at::Tensor &fc1_activation_input, + const at::Tensor &fc2_x, const std::optional &act_scales, py::handle fc1_weight, + py::handle fc2_weight, py::handle fc1_wgrad_output, bool fc1_compute_wgrad, + bool fc1_accumulate_wgrad, py::handle fc2_wgrad_output, bool fc2_compute_wgrad, + bool fc2_accumulate_wgrad, const std::string &activation, int64_t glu_interleave_size, + double activation_limit, double activation_alpha, double activation_glu_linear_offset, + bool act_scales_requires_grad, bool input_requires_grad); /*************************************************************************************************** * Transpose diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 78c9e280f3..d70b6cb813 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -357,15 +357,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("te_general_grouped_gemm_for_discrete_out", &transformer_engine::pytorch::te_general_grouped_gemm_for_discrete_out, "Grouped GEMM for discrete output list"); - m.def("megacpp_grouped_mlp_forward", - &transformer_engine::pytorch::megacpp_grouped_mlp_forward, + m.def("megacpp_grouped_mlp_forward", &transformer_engine::pytorch::megacpp_grouped_mlp_forward, "Mega C++ grouped MLP forward", py::arg("input"), py::arg("split_sizes"), py::arg("fc1_weight"), py::arg("fc1_bias"), py::arg("fc2_weight"), py::arg("fc2_bias"), py::arg("act_scales"), py::arg("activation"), py::arg("glu_interleave_size"), py::arg("activation_limit") = 0.0, py::arg("activation_alpha") = 0.0, py::arg("activation_glu_linear_offset") = 0.0); - m.def("megacpp_grouped_mlp_backward", - &transformer_engine::pytorch::megacpp_grouped_mlp_backward, + m.def("megacpp_grouped_mlp_backward", &transformer_engine::pytorch::megacpp_grouped_mlp_backward, "Mega C++ grouped MLP backward", py::arg("grad_output"), py::arg("split_sizes"), py::arg("x_offsets"), py::arg("fc1_offsets"), py::arg("fc2_offsets"), py::arg("fc2_dy_offsets"), py::arg("base_offsets"), py::arg("x"), diff --git a/transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp b/transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp index 2f9a642041..f85837f40a 100644 --- a/transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp +++ b/transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp @@ -4,6 +4,8 @@ * See LICENSE for license information. ************************************************************************/ +#include +#include #include #include @@ -12,9 +14,6 @@ #include #include -#include -#include - #include "../extensions.h" #include "../pybind.h" #include "common/util/cuda_runtime.h" @@ -58,7 +57,8 @@ size_t num_groups_from_prepared_split_sizes(const at::Tensor &split_sizes, } GroupedTensorWrapper make_grouped_tensor(at::Tensor data, const at::Tensor &prepared_split_sizes, - const at::Tensor &tensor_offsets, int64_t logical_last_dim) { + const at::Tensor &tensor_offsets, + int64_t logical_last_dim) { const auto num_groups = static_cast(prepared_split_sizes.numel()); const auto total_tokens = static_cast(data.numel() / logical_last_dim); auto grouped = GroupedTensorWrapper( @@ -75,9 +75,8 @@ GroupedTensorWrapper make_grouped_tensor(at::Tensor data, const at::Tensor &prep GroupedTensorWrapper make_uniform_grouped_tensor(at::Tensor data, size_t num_groups, int64_t first_dim, int64_t last_dim) { auto grouped = GroupedTensorWrapper( - num_groups, - std::vector{num_groups * static_cast(first_dim), - static_cast(last_dim)}); + num_groups, std::vector{num_groups * static_cast(first_dim), + static_cast(last_dim)}); grouped.set_rowwise_data(data.data_ptr(), GetTransformerEngineDType(data.scalar_type()), tensor_shape_1d(data)); return grouped; @@ -94,9 +93,7 @@ struct GroupedWeightArg { int64_t rows = 0; int64_t cols = 0; - c10::Device device() const { - return is_grouped ? packed.device() : discrete[0].device(); - } + c10::Device device() const { return is_grouped ? packed.device() : discrete[0].device(); } }; GroupedWeightArg weight_arg_from_py(py::handle arg, size_t num_groups, at::ScalarType dtype, @@ -201,9 +198,9 @@ struct GroupedGemmResources { te_alpha(makeTransformerEngineTensor(alpha)), te_beta_zero(makeTransformerEngineTensor(beta_zero)), te_beta_one(makeTransformerEngineTensor(beta_one)), - te_setup(makeTransformerEngineTensor(setup.data_ptr(), - std::vector{static_cast(setup.numel())}, - DType::kByte)), + te_setup(makeTransformerEngineTensor( + setup.data_ptr(), std::vector{static_cast(setup.numel())}, + DType::kByte)), te_cublas(makeTransformerEngineTensor( cublas.data_ptr(), std::vector{static_cast(cublas.numel())}, DType::kByte)) { @@ -220,9 +217,7 @@ struct GroupedGemmResources { } } - NVTETensor beta(bool accumulate) { - return accumulate ? te_beta_one.data() : te_beta_zero.data(); - } + NVTETensor beta(bool accumulate) { return accumulate ? te_beta_one.data() : te_beta_zero.data(); } NVTEGroupedMatmulConfig config_data() { return config.has_value() ? static_cast(*config) : nullptr; @@ -243,14 +238,12 @@ void grouped_gemm(GroupedTensorWrapper *A, bool transa, GroupedTensorWrapper *B, nvte_grouped_gemm(A->data(), transa, B->data(), transb, D->data(), D->data(), resources->te_alpha.data(), resources->beta(accumulate), resources->te_setup.data(), resources->te_cublas.data(), - resources->config_data(), - at::cuda::getCurrentCUDAStream()); + resources->config_data(), at::cuda::getCurrentCUDAStream()); }); } std::vector output_tensor_list_from_arg(py::handle arg, size_t num_groups, - at::ScalarType dtype, - const std::string &name) { + at::ScalarType dtype, const std::string &name) { std::vector out; if (is_none(arg)) { return out; @@ -303,8 +296,8 @@ WgradOutput wgrad_output_from_arg(py::handle arg, bool compute_wgrad, size_t num // Cases 1 and 2: no external wgrad buffer was provided, so C++ owns the // allocation. Single grouped weight keeps this packed as [G, N, K]; // discrete weights split the same packed allocation into per-expert views. - out.packed = at::empty({static_cast(num_groups), rows, cols}, - at::device(device).dtype(dtype)); + out.packed = + at::empty({static_cast(num_groups), rows, cols}, at::device(device).dtype(dtype)); out.owns_storage = true; out.is_grouped = prefer_grouped_output; if (out.is_grouped) { @@ -345,9 +338,8 @@ void grouped_gemm_fwd_dgrad(GroupedWeightArg *weights, bool trans_weight, if (weights->is_grouped) { // Single grouped weight case: weights are packed as [G, N, K]. Wrap the // packed buffer as a uniform GroupedTensor and use the grouped-tensor GEMM. - auto grouped_weight = - make_uniform_grouped_tensor(weights->packed, input->num_tensors(), weights->rows, - weights->cols); + auto grouped_weight = make_uniform_grouped_tensor(weights->packed, input->num_tensors(), + weights->rows, weights->cols); grouped_gemm(&grouped_weight, trans_weight, input, trans_input, output, resources, false); } else { // Discrete weight case: weights are a list of per-expert tensors. Use the @@ -413,7 +405,8 @@ GroupedTensorWrapper make_grouped_bias(const at::Tensor &bias, size_t num_groups NVTE_CHECK(bias.defined(), "Bias tensor must be defined."); auto grouped = GroupedTensorWrapper( num_groups, std::vector{num_groups, static_cast(out_features)}); - grouped.set_rowwise_data(bias.data_ptr(), GetTransformerEngineDType(dtype), tensor_shape_1d(bias)); + grouped.set_rowwise_data(bias.data_ptr(), GetTransformerEngineDType(dtype), + tensor_shape_1d(bias)); return grouped; } @@ -498,7 +491,8 @@ at::Tensor activation_forward_impl(const at::Tensor &input, const std::string &a } else if (activation == "sreglu") { nvte_sreglu(te_input.data(), te_output.data(), stream); } else if (activation == "clamped_swiglu") { - nvte_clamped_swiglu_v2(te_input.data(), te_output.data(), static_cast(activation_limit), + nvte_clamped_swiglu_v2(te_input.data(), te_output.data(), + static_cast(activation_limit), static_cast(activation_alpha), static_cast(activation_glu_linear_offset), stream); } else if (activation == "srelu") { @@ -520,8 +514,7 @@ at::Tensor activation_forward_impl(const at::Tensor &input, const std::string &a at::Tensor activation_backward_impl(const at::Tensor &grad, const at::Tensor &input, const std::string &activation, double activation_limit, - double activation_alpha, - double activation_glu_linear_offset) { + double activation_alpha, double activation_glu_linear_offset) { auto output = at::empty_like(input); auto te_grad = makeTransformerEngineTensor(grad); auto te_input = makeTransformerEngineTensor(input); @@ -568,7 +561,7 @@ at::Tensor grouped_mlp_activation_forward( double activation_alpha, double activation_glu_linear_offset, at::ScalarType dtype) { auto activation_input = maybe_deinterleave_glu(input, glu_interleave_size); auto activation_output = activation_forward_impl(activation_input, activation, activation_limit, - activation_alpha, activation_glu_linear_offset); + activation_alpha, activation_glu_linear_offset); if (!act_scales.has_value()) { return activation_output; } @@ -607,10 +600,9 @@ ActivationBackwardResult grouped_mlp_activation_backward( } auto grad_activation_input = - activation_backward_impl(grad_activation_output, activation_input, activation, activation_limit, - activation_alpha, activation_glu_linear_offset); - return {maybe_reinterleave_glu_grad(grad_activation_input, glu_interleave_size), - grad_act_scales}; + activation_backward_impl(grad_activation_output, activation_input, activation, + activation_limit, activation_alpha, activation_glu_linear_offset); + return {maybe_reinterleave_glu_grad(grad_activation_input, glu_interleave_size), grad_act_scales}; } } // namespace @@ -653,8 +645,7 @@ std::vector megacpp_grouped_mlp_forward( split_sizes, x.device(), std::vector{1, in_features, fc1_out_features, fc2_in_features, fc2_out_features}, std::vector{true, true, true, true, true}, - std::vector{at::kLong, at::kLong, at::kLong, at::kLong, at::kLong}, - true); + std::vector{at::kLong, at::kLong, at::kLong, at::kLong, at::kLong}, true); // splits_to_offsets_multi returns the canonical int64 CUDA split sizes and // offsets in the same order as the stride list above. The CuTe path also asks // for int32 split_points, but cuBLAS grouped GEMM does not consume them. @@ -675,10 +666,9 @@ std::vector megacpp_grouped_mlp_forward( &gemm_resources); add_grouped_bias(&grouped_fc1_preact, fc1_bias_tensor, num_groups, dtype, fc1_out_features); - auto fc2_x = - grouped_mlp_activation_forward(fc1_preact, act_scales, activation, glu_interleave_size, - activation_limit, activation_alpha, - activation_glu_linear_offset, dtype); + auto fc2_x = grouped_mlp_activation_forward( + fc1_preact, act_scales, activation, glu_interleave_size, activation_limit, activation_alpha, + activation_glu_linear_offset, dtype); std::vector out_shape = input.sizes().vec(); out_shape.back() = fc2_out_features; @@ -692,22 +682,20 @@ std::vector megacpp_grouped_mlp_forward( &gemm_resources); add_grouped_bias(&grouped_output, fc2_bias_tensor, num_groups, dtype, fc2_out_features); - return {output, x, split_sizes_i64, base_offsets, x_offsets, fc1_offsets, fc2_offsets, - output_offsets, fc1_preact, fc2_x}; + return {output, x, split_sizes_i64, base_offsets, x_offsets, + fc1_offsets, fc2_offsets, output_offsets, fc1_preact, fc2_x}; } py::tuple megacpp_grouped_mlp_backward( const at::Tensor &grad_output, const at::Tensor &split_sizes, const at::Tensor &x_offsets, - const at::Tensor &fc1_offsets, const at::Tensor &fc2_offsets, - const at::Tensor &fc2_dy_offsets, const at::Tensor &base_offsets, - const at::Tensor &x, const at::Tensor &fc1_activation_input, const at::Tensor &fc2_x, - const std::optional &act_scales, py::handle fc1_weight, - py::handle fc2_weight, - py::handle fc1_wgrad_output, bool fc1_compute_wgrad, bool fc1_accumulate_wgrad, - py::handle fc2_wgrad_output, bool fc2_compute_wgrad, bool fc2_accumulate_wgrad, - const std::string &activation, int64_t glu_interleave_size, double activation_limit, - double activation_alpha, double activation_glu_linear_offset, bool act_scales_requires_grad, - bool input_requires_grad) { + const at::Tensor &fc1_offsets, const at::Tensor &fc2_offsets, const at::Tensor &fc2_dy_offsets, + const at::Tensor &base_offsets, const at::Tensor &x, const at::Tensor &fc1_activation_input, + const at::Tensor &fc2_x, const std::optional &act_scales, py::handle fc1_weight, + py::handle fc2_weight, py::handle fc1_wgrad_output, bool fc1_compute_wgrad, + bool fc1_accumulate_wgrad, py::handle fc2_wgrad_output, bool fc2_compute_wgrad, + bool fc2_accumulate_wgrad, const std::string &activation, int64_t glu_interleave_size, + double activation_limit, double activation_alpha, double activation_glu_linear_offset, + bool act_scales_requires_grad, bool input_requires_grad) { (void)base_offsets; NVTE_CHECK(grad_output.is_cuda(), "megacpp_grouped_mlp_backward requires CUDA grad_output."); at::cuda::CUDAGuard device_guard(grad_output.device()); @@ -737,23 +725,20 @@ py::tuple megacpp_grouped_mlp_backward( fc2_x_for_wgrad = fc2_x_for_wgrad.view({-1, fc2_in_features}); auto grouped_fc2_x_for_wgrad = make_grouped_tensor(fc2_x_for_wgrad.view({-1}), split_sizes, fc2_offsets, fc2_in_features); - fc2_wgrads = - grouped_gemm_wgrad(&grouped_fc2_x_for_wgrad, &grouped_dy, fc2_wgrad_output, - fc2_compute_wgrad, fc2_accumulate_wgrad, &gemm_resources, dtype, - fc2_out_features, fc2_in_features, "fc2_wgrad_output", - fc2_weights.is_grouped); + fc2_wgrads = grouped_gemm_wgrad(&grouped_fc2_x_for_wgrad, &grouped_dy, fc2_wgrad_output, + fc2_compute_wgrad, fc2_accumulate_wgrad, &gemm_resources, dtype, + fc2_out_features, fc2_in_features, "fc2_wgrad_output", + fc2_weights.is_grouped); } auto fc2_dx = at::empty({total_tokens, fc2_in_features}, dy.options()); auto grouped_fc2_dx = make_grouped_tensor(fc2_dx.view({-1}), split_sizes, fc2_offsets, fc2_in_features); - grouped_gemm_fwd_dgrad(&fc2_weights, false, &grouped_dy, false, &grouped_fc2_dx, - &gemm_resources); + grouped_gemm_fwd_dgrad(&fc2_weights, false, &grouped_dy, false, &grouped_fc2_dx, &gemm_resources); auto activation_grads = grouped_mlp_activation_backward( - fc2_dx, fc1_activation_input, act_scales, activation, glu_interleave_size, - activation_limit, activation_alpha, activation_glu_linear_offset, dtype, - act_scales_requires_grad); + fc2_dx, fc1_activation_input, act_scales, activation, glu_interleave_size, activation_limit, + activation_alpha, activation_glu_linear_offset, dtype, act_scales_requires_grad); auto fc1_dy = activation_grads.grad_input; auto grad_act_scales = activation_grads.grad_act_scales; auto grouped_fc1_dy = @@ -766,11 +751,10 @@ py::tuple megacpp_grouped_mlp_backward( x_for_wgrad = x_for_wgrad.view({-1, in_features}); auto grouped_x_for_wgrad = make_grouped_tensor(x_for_wgrad.view({-1}), split_sizes, x_offsets, in_features); - fc1_wgrads = - grouped_gemm_wgrad(&grouped_x_for_wgrad, &grouped_fc1_dy, fc1_wgrad_output, - fc1_compute_wgrad, fc1_accumulate_wgrad, &gemm_resources, dtype, - fc1_out_features, in_features, "fc1_wgrad_output", - fc1_weights.is_grouped); + fc1_wgrads = grouped_gemm_wgrad(&grouped_x_for_wgrad, &grouped_fc1_dy, fc1_wgrad_output, + fc1_compute_wgrad, fc1_accumulate_wgrad, &gemm_resources, dtype, + fc1_out_features, in_features, "fc1_wgrad_output", + fc1_weights.is_grouped); } at::Tensor grad_input; @@ -779,8 +763,8 @@ py::tuple megacpp_grouped_mlp_backward( grad_input_shape.back() = in_features; grad_input = at::empty(grad_input_shape, dy.options()); auto grad_input_2d = grad_input.view({-1, in_features}); - auto grouped_grad_input = make_grouped_tensor(grad_input_2d.view({-1}), split_sizes, - x_offsets, in_features); + auto grouped_grad_input = + make_grouped_tensor(grad_input_2d.view({-1}), split_sizes, x_offsets, in_features); grouped_gemm_fwd_dgrad(&fc1_weights, false, &grouped_fc1_dy, false, &grouped_grad_input, &gemm_resources); } else { diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp_megacpp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp_megacpp.py index ebaf30d075..a0a69c5804 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp_megacpp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp_megacpp.py @@ -324,11 +324,15 @@ def fuser_backward( else (None,) ) - return grad_input, [fc1_grad_params, (), fc2_grad_params], [ - (None,), - activation_grad_extra, - (None,), - ] + return ( + grad_input, + [fc1_grad_params, (), fc2_grad_params], + [ + (None,), + activation_grad_extra, + (None,), + ], + ) def fuse_backward_megacpp_ops( @@ -349,15 +353,9 @@ def fuse_backward_megacpp_ops( window, ops = ops[:3], ops[3:] while len(window) == 3: matches_pattern = True - if not ( - isinstance(window[0], GroupedLinear) - and isinstance(window[2], GroupedLinear) - ): + if not (isinstance(window[0], GroupedLinear) and isinstance(window[2], GroupedLinear)): matches_pattern = False - elif ( - window[0]._scale_bias - or window[2]._scale_bias - ): + elif window[0]._scale_bias or window[2]._scale_bias: matches_pattern = False else: try: diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp_megacpp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp_megacpp.py index bd05b6218f..61906e3714 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp_megacpp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp_megacpp.py @@ -280,8 +280,12 @@ def fuser_forward( x._do_not_clear = True if requires_grad: - fc1_saved_weights = [fc1_weights] if isinstance(fc1_weights, torch.Tensor) else fc1_weights - fc2_saved_weights = [fc2_weights] if isinstance(fc2_weights, torch.Tensor) else fc2_weights + fc1_saved_weights = ( + [fc1_weights] if isinstance(fc1_weights, torch.Tensor) else fc1_weights + ) + fc2_saved_weights = ( + [fc2_weights] if isinstance(fc2_weights, torch.Tensor) else fc2_weights + ) fc1_ctx.save_for_backward( split_sizes_i64, @@ -338,15 +342,9 @@ def fuse_forward_megacpp_ops( window, ops = ops[:3], ops[3:] while len(window) == 3: matches_pattern = True - if not ( - isinstance(window[0], GroupedLinear) - and isinstance(window[2], GroupedLinear) - ): + if not (isinstance(window[0], GroupedLinear) and isinstance(window[2], GroupedLinear)): matches_pattern = False - elif ( - window[0]._scale_bias - or window[2]._scale_bias - ): + elif window[0]._scale_bias or window[2]._scale_bias: matches_pattern = False else: try: From ad906e54b232b48c735b1f395cc9324b0e556eb6 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Tue, 9 Jun 2026 23:25:15 -0700 Subject: [PATCH 3/5] fixes for E2E run Signed-off-by: Zhongbo Zhu --- tests/pytorch/megacpp/test_grouped_mlp.py | 24 ++++++--- .../pytorch/csrc/megacpp/grouped_mlp.cpp | 50 +++++++++---------- 2 files changed, 40 insertions(+), 34 deletions(-) diff --git a/tests/pytorch/megacpp/test_grouped_mlp.py b/tests/pytorch/megacpp/test_grouped_mlp.py index d3cc9cd04c..b056af1978 100644 --- a/tests/pytorch/megacpp/test_grouped_mlp.py +++ b/tests/pytorch/megacpp/test_grouped_mlp.py @@ -98,7 +98,7 @@ def _copy_grouped_mlp_params(dst: te_ops.Sequential, src: te_ops.Sequential) -> ) -def _init_main_grads(module: te_ops.Sequential) -> None: +def _init_main_grads(module: te_ops.Sequential, dtype: torch.dtype) -> None: for linear in (module[0], module[2]): if linear.single_grouped_weight: linear.weight.main_grad = torch.zeros( @@ -106,12 +106,12 @@ def _init_main_grads(module: te_ops.Sequential) -> None: linear.out_features, linear.in_features, device="cuda", - dtype=torch.bfloat16, + dtype=dtype, ) else: for group_idx in range(linear.num_groups): weight = getattr(linear, f"weight{group_idx}") - weight.main_grad = torch.zeros_like(weight) + weight.main_grad = torch.zeros_like(weight, dtype=dtype) def _run_grouped_mlp( @@ -241,6 +241,7 @@ def _run_megacpp_against_python( activation_kind: str = "scaled_swiglu", single_grouped_param: bool = False, accumulate_into_main_grad: bool = False, + main_grad_dtype: torch.dtype | None = None, compare_zero_expert_grads: bool = True, monkeypatch, ) -> None: @@ -274,8 +275,10 @@ def _run_megacpp_against_python( ) _copy_grouped_mlp_params(test, ref) if accumulate_into_main_grad: - _init_main_grads(ref) - _init_main_grads(test) + if main_grad_dtype is None: + raise ValueError("main_grad_dtype must be set when using Megatron-owned main_grad.") + _init_main_grads(ref, main_grad_dtype) + _init_main_grads(test, main_grad_dtype) # Paged stashing passes a static physical buffer to the op while m_splits # describe only the valid prefix. Rows after sum(m_splits) are garbage and @@ -332,13 +335,17 @@ def _run_megacpp_against_python( ids=["discrete_weight", "packed_weight"], ) @pytest.mark.parametrize( - "accumulate_into_main_grad", - [False, True], - ids=["cpp_allocated_wgrad", "megatron_main_grad"], + "accumulate_into_main_grad,main_grad_dtype", + [ + pytest.param(False, None, id="cpp_allocated_wgrad"), + pytest.param(True, torch.bfloat16, id="megatron_main_grad_bf16"), + pytest.param(True, torch.float32, id="megatron_main_grad_fp32"), + ], ) def test_megacpp_grouped_mlp_wgrad_storage_matches_python( single_grouped_param, accumulate_into_main_grad, + main_grad_dtype, monkeypatch, ): torch.manual_seed(1234) @@ -349,6 +356,7 @@ def test_megacpp_grouped_mlp_wgrad_storage_matches_python( split_device="cuda", single_grouped_param=single_grouped_param, accumulate_into_main_grad=accumulate_into_main_grad, + main_grad_dtype=main_grad_dtype, monkeypatch=monkeypatch, ) diff --git a/transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp b/transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp index f85837f40a..e3ce05407d 100644 --- a/transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp +++ b/transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp @@ -243,36 +243,33 @@ void grouped_gemm(GroupedTensorWrapper *A, bool transa, GroupedTensorWrapper *B, } std::vector output_tensor_list_from_arg(py::handle arg, size_t num_groups, - at::ScalarType dtype, const std::string &name) { + int64_t rows, int64_t cols, + const std::string &name) { std::vector out; if (is_none(arg)) { return out; } out.reserve(num_groups); - if (py::isinstance(arg) || py::isinstance(arg)) { - auto seq = py::reinterpret_borrow(arg); - NVTE_CHECK(static_cast(seq.size()) == num_groups, name, " must have ", num_groups, - " tensors."); - for (size_t i = 0; i < num_groups; ++i) { - auto tensor = seq[i].cast(); - NVTE_CHECK(tensor.is_cuda(), name, " tensors must be CUDA tensors."); - NVTE_CHECK(tensor.scalar_type() == dtype, name, " tensors must have the requested dtype."); - NVTE_CHECK(tensor.dim() == 2, name, " tensors must be rank-2 wgrad buffers."); - check_contiguous(tensor, name); - out.emplace_back(tensor); - } - return out; - } - - auto packed = arg.cast(); - NVTE_CHECK(packed.is_cuda(), name, " must be a CUDA tensor."); - NVTE_CHECK(packed.scalar_type() == dtype, name, " must have the requested dtype."); - NVTE_CHECK(packed.dim() == 3, name, " must have shape [num_groups, rows, cols]."); - NVTE_CHECK(static_cast(packed.size(0)) == num_groups, name, " first dimension must be ", - num_groups, "."); - check_contiguous(packed, name); + // This helper is intentionally only for the discrete-weight external wgrad + // path, where Megatron provides one main_grad tensor per expert. The packed + // [G, rows, cols] external buffer used by single grouped weight is handled in + // wgrad_output_from_arg so it can stay packed and use grouped-tensor GEMM. + NVTE_CHECK(py::isinstance(arg) || py::isinstance(arg), name, + " must be a list or tuple of wgrad output tensors."); + auto seq = py::reinterpret_borrow(arg); + NVTE_CHECK(static_cast(seq.size()) == num_groups, name, " must have ", num_groups, + " tensors."); for (size_t i = 0; i < num_groups; ++i) { - out.emplace_back(packed.select(0, static_cast(i))); + auto tensor = seq[i].cast(); + NVTE_CHECK(tensor.is_cuda(), name, " tensors must be CUDA tensors."); + // Do not require tensor.scalar_type() == compute dtype. Caller-owned + // main_grad buffers are allocated by Megatron and may be FP32 even when TE + // grouped MLP compute is BF16. + NVTE_CHECK(tensor.dim() == 2, name, " tensors must be rank-2 wgrad buffers."); + NVTE_CHECK(tensor.size(0) == rows && tensor.size(1) == cols, name, + " tensors must have shape [rows, cols]."); + check_contiguous(tensor, name); + out.emplace_back(tensor); } return out; } @@ -315,7 +312,8 @@ WgradOutput wgrad_output_from_arg(py::handle arg, bool compute_wgrad, size_t num // should not receive a newly allocated grad tensor from this helper. out.packed = arg.cast(); NVTE_CHECK(out.packed.is_cuda(), name, " must be a CUDA tensor."); - NVTE_CHECK(out.packed.scalar_type() == dtype, name, " must have the requested dtype."); + // Do not require out.packed.scalar_type() == compute dtype. Caller-owned + // main_grad buffers keep the dtype chosen by Megatron's grad-buffer config. NVTE_CHECK(out.packed.dim() == 3, name, " must have shape [num_groups, rows, cols]."); NVTE_CHECK(static_cast(out.packed.size(0)) == num_groups, name, " first dimension must be ", num_groups, "."); @@ -328,7 +326,7 @@ WgradOutput wgrad_output_from_arg(py::handle arg, bool compute_wgrad, size_t num // Case 4: discrete weights with externally-owned per-expert buffers, e.g. // Megatron main_grad list. GEMM writes each tensor in-place and returns no // allocated grad list to Python. - out.tensors = output_tensor_list_from_arg(arg, num_groups, dtype, name); + out.tensors = output_tensor_list_from_arg(arg, num_groups, rows, cols, name); return out; } From b3847a2dec8e7cf2fe42eaffb63fd607368f0693 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Wed, 10 Jun 2026 15:29:39 -0700 Subject: [PATCH 4/5] micro optimizations Signed-off-by: Zhongbo Zhu --- transformer_engine/pytorch/csrc/extensions.h | 10 +- .../pytorch/csrc/extensions/pybind.cpp | 17 +- .../pytorch/csrc/megacpp/grouped_mlp.cpp | 171 +++++++++++------- .../ops/fused/backward_grouped_mlp_megacpp.py | 8 +- .../ops/fused/forward_grouped_mlp_megacpp.py | 43 ++++- 5 files changed, 171 insertions(+), 78 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 1fd334ea69..a66a35b27d 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -190,14 +190,16 @@ py::object te_general_grouped_gemm_for_discrete_out(py::handle A, bool transa, p **************************************************************************************************/ std::vector megacpp_grouped_mlp_forward( - const at::Tensor &input, const at::Tensor &split_sizes, py::handle fc1_weight, + const at::Tensor &input, at::ScalarType act_dtype, const at::Tensor &split_sizes, + py::handle fc1_weight, py::handle fc1_bias, py::handle fc2_weight, py::handle fc2_bias, const std::optional &act_scales, const std::string &activation, int64_t glu_interleave_size, double activation_limit, double activation_alpha, - double activation_glu_linear_offset); + double activation_glu_linear_offset, py::handle gemm_scratch); py::tuple megacpp_grouped_mlp_backward( - const at::Tensor &grad_output, const at::Tensor &split_sizes, const at::Tensor &x_offsets, + const at::Tensor &grad_output, at::ScalarType act_dtype, const at::Tensor &split_sizes, + const at::Tensor &x_offsets, const at::Tensor &fc1_offsets, const at::Tensor &fc2_offsets, const at::Tensor &fc2_dy_offsets, const at::Tensor &base_offsets, const at::Tensor &x, const at::Tensor &fc1_activation_input, const at::Tensor &fc2_x, const std::optional &act_scales, py::handle fc1_weight, @@ -205,7 +207,7 @@ py::tuple megacpp_grouped_mlp_backward( bool fc1_accumulate_wgrad, py::handle fc2_wgrad_output, bool fc2_compute_wgrad, bool fc2_accumulate_wgrad, const std::string &activation, int64_t glu_interleave_size, double activation_limit, double activation_alpha, double activation_glu_linear_offset, - bool act_scales_requires_grad, bool input_requires_grad); + bool act_scales_requires_grad, bool input_requires_grad, py::handle gemm_scratch); /*************************************************************************************************** * Transpose diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index d70b6cb813..f870938441 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -358,22 +358,25 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &transformer_engine::pytorch::te_general_grouped_gemm_for_discrete_out, "Grouped GEMM for discrete output list"); m.def("megacpp_grouped_mlp_forward", &transformer_engine::pytorch::megacpp_grouped_mlp_forward, - "Mega C++ grouped MLP forward", py::arg("input"), py::arg("split_sizes"), - py::arg("fc1_weight"), py::arg("fc1_bias"), py::arg("fc2_weight"), py::arg("fc2_bias"), + "Mega C++ grouped MLP forward", py::arg("input"), py::arg("act_dtype"), + py::arg("split_sizes"), py::arg("fc1_weight"), py::arg("fc1_bias"), + py::arg("fc2_weight"), py::arg("fc2_bias"), py::arg("act_scales"), py::arg("activation"), py::arg("glu_interleave_size"), py::arg("activation_limit") = 0.0, py::arg("activation_alpha") = 0.0, - py::arg("activation_glu_linear_offset") = 0.0); + py::arg("activation_glu_linear_offset") = 0.0, + py::arg("gemm_scratch") = py::none()); m.def("megacpp_grouped_mlp_backward", &transformer_engine::pytorch::megacpp_grouped_mlp_backward, - "Mega C++ grouped MLP backward", py::arg("grad_output"), py::arg("split_sizes"), - py::arg("x_offsets"), py::arg("fc1_offsets"), py::arg("fc2_offsets"), - py::arg("fc2_dy_offsets"), py::arg("base_offsets"), py::arg("x"), + "Mega C++ grouped MLP backward", py::arg("grad_output"), py::arg("act_dtype"), + py::arg("split_sizes"), py::arg("x_offsets"), py::arg("fc1_offsets"), + py::arg("fc2_offsets"), py::arg("fc2_dy_offsets"), py::arg("base_offsets"), py::arg("x"), py::arg("fc1_activation_input"), py::arg("fc2_x"), py::arg("act_scales"), py::arg("fc1_weight"), py::arg("fc2_weight"), py::arg("fc1_wgrad_output"), py::arg("fc1_compute_wgrad"), py::arg("fc1_accumulate_wgrad"), py::arg("fc2_wgrad_output"), py::arg("fc2_compute_wgrad"), py::arg("fc2_accumulate_wgrad"), py::arg("activation"), py::arg("glu_interleave_size"), py::arg("activation_limit") = 0.0, py::arg("activation_alpha") = 0.0, py::arg("activation_glu_linear_offset") = 0.0, - py::arg("act_scales_requires_grad") = true, py::arg("input_requires_grad") = true); + py::arg("act_scales_requires_grad") = true, py::arg("input_requires_grad") = true, + py::arg("gemm_scratch") = py::none()); m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O", py::arg("input"), py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard()); diff --git a/transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp b/transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp index e3ce05407d..589fb56c5b 100644 --- a/transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp +++ b/transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp @@ -9,9 +9,11 @@ #include #include +#include #include #include #include +#include #include #include "../extensions.h" @@ -56,10 +58,13 @@ size_t num_groups_from_prepared_split_sizes(const at::Tensor &split_sizes, return static_cast(split_sizes.numel()); } -GroupedTensorWrapper make_grouped_tensor(at::Tensor data, const at::Tensor &prepared_split_sizes, +GroupedTensorWrapper make_grouped_tensor(const at::Tensor &data, + const at::Tensor &prepared_split_sizes, const at::Tensor &tensor_offsets, int64_t logical_last_dim) { const auto num_groups = static_cast(prepared_split_sizes.numel()); + NVTE_CHECK(data.numel() % logical_last_dim == 0, + "Grouped tensor storage is not divisible by logical last dimension."); const auto total_tokens = static_cast(data.numel() / logical_last_dim); auto grouped = GroupedTensorWrapper( num_groups, std::vector{total_tokens, static_cast(logical_last_dim)}); @@ -169,6 +174,36 @@ int grouped_gemm_math_sm_count(const c10::Device &device) { return sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); } +std::array grouped_gemm_scratch_from_arg(py::handle scratch, + const c10::Device &device, + size_t num_groups) { + const int64_t num_groups_i64 = static_cast(num_groups); + const int64_t setup_size = + static_cast(nvte_get_grouped_gemm_setup_workspace_size(num_groups)); + + if (is_none(scratch)) { + return { + at::ones({num_groups_i64}, at::device(device).dtype(at::kFloat)), + at::zeros({num_groups_i64}, at::device(device).dtype(at::kFloat)), + at::empty({setup_size}, at::device(device).dtype(at::kByte)), + at::empty({kGroupedGemmCublasWorkspaceSize}, at::device(device).dtype(at::kByte)), + }; + } + + NVTE_CHECK(py::isinstance(scratch) || py::isinstance(scratch), + "megacpp grouped MLP GEMM scratch must be None or a 4-tensor tuple/list."); + auto seq = py::reinterpret_borrow(scratch); + NVTE_CHECK(seq.size() == 4, "megacpp grouped MLP GEMM scratch must have 4 tensors."); + + std::array tensors = { + seq[0].cast(), + seq[1].cast(), + seq[2].cast(), + seq[3].cast(), + }; + return tensors; +} + struct GroupedGemmResources { c10::Device device; size_t num_groups; @@ -184,17 +219,15 @@ struct GroupedGemmResources { TensorWrapper te_cublas; std::optional config; - GroupedGemmResources(const c10::Device &device_, size_t num_groups_) + GroupedGemmResources(const c10::Device &device_, size_t num_groups_, + std::array scratch) : device(device_), num_groups(num_groups_), - alpha(at::ones({static_cast(num_groups_)}, at::device(device).dtype(at::kFloat))), - beta_zero( - at::zeros({static_cast(num_groups_)}, at::device(device).dtype(at::kFloat))), + alpha(std::move(scratch[0])), + beta_zero(std::move(scratch[1])), beta_one(alpha), - setup(at::empty( - {static_cast(nvte_get_grouped_gemm_setup_workspace_size(num_groups_))}, - at::device(device).dtype(at::kByte))), - cublas(at::empty({kGroupedGemmCublasWorkspaceSize}, at::device(device).dtype(at::kByte))), + setup(std::move(scratch[2])), + cublas(std::move(scratch[3])), te_alpha(makeTransformerEngineTensor(alpha)), te_beta_zero(makeTransformerEngineTensor(beta_zero)), te_beta_one(makeTransformerEngineTensor(beta_one)), @@ -204,12 +237,10 @@ struct GroupedGemmResources { te_cublas(makeTransformerEngineTensor( cublas.data_ptr(), std::vector{static_cast(cublas.numel())}, DType::kByte)) { - // These scratch tensors are intentionally local to one megacpp call. They - // are safe after this CPU function returns because every current cuBLAS - // grouped GEMM below is enqueued on at::cuda::getCurrentCUDAStream(), so - // PyTorch's caching allocator observes same-stream allocation/release - // ordering. If a future backend uses auxiliary streams, this helper must - // either record those streams on the tensors or extend resource lifetime. + // These scratch tensors may be cached by Python per CUDA stream. Every + // current megacpp grouped GEMM below is enqueued on at::cuda::getCurrentCUDAStream(), + // so same-stream ordering protects workspace reuse. If a future backend + // uses auxiliary streams, cache keys or stream recording must be revisited. const int math_sm_count = grouped_gemm_math_sm_count(device); if (math_sm_count > 0) { config.emplace(); @@ -225,11 +256,13 @@ struct GroupedGemmResources { }; GroupedGemmResources make_grouped_mlp_backend_resources(const c10::Device &device, - size_t num_groups) { + size_t num_groups, + py::handle scratch) { // Keep the backend resource policy private to megacpp. Today this is cuBLAS // grouped GEMM scratch; future backends can change this helper without // changing the Python or pybind contract. - return GroupedGemmResources(device, num_groups); + return GroupedGemmResources(device, num_groups, + grouped_gemm_scratch_from_arg(scratch, device, num_groups)); } void grouped_gemm(GroupedTensorWrapper *A, bool transa, GroupedTensorWrapper *B, bool transb, @@ -262,7 +295,7 @@ std::vector output_tensor_list_from_arg(py::handle arg, size_t num_g for (size_t i = 0; i < num_groups; ++i) { auto tensor = seq[i].cast(); NVTE_CHECK(tensor.is_cuda(), name, " tensors must be CUDA tensors."); - // Do not require tensor.scalar_type() == compute dtype. Caller-owned + // Do not require tensor.scalar_type() == dtype. Caller-owned // main_grad buffers are allocated by Megatron and may be FP32 even when TE // grouped MLP compute is BF16. NVTE_CHECK(tensor.dim() == 2, name, " tensors must be rank-2 wgrad buffers."); @@ -294,7 +327,8 @@ WgradOutput wgrad_output_from_arg(py::handle arg, bool compute_wgrad, size_t num // allocation. Single grouped weight keeps this packed as [G, N, K]; // discrete weights split the same packed allocation into per-expert views. out.packed = - at::empty({static_cast(num_groups), rows, cols}, at::device(device).dtype(dtype)); + at::empty({static_cast(num_groups), rows, cols}, + at::device(device).dtype(dtype)); out.owns_storage = true; out.is_grouped = prefer_grouped_output; if (out.is_grouped) { @@ -312,8 +346,8 @@ WgradOutput wgrad_output_from_arg(py::handle arg, bool compute_wgrad, size_t num // should not receive a newly allocated grad tensor from this helper. out.packed = arg.cast(); NVTE_CHECK(out.packed.is_cuda(), name, " must be a CUDA tensor."); - // Do not require out.packed.scalar_type() == compute dtype. Caller-owned - // main_grad buffers keep the dtype chosen by Megatron's grad-buffer config. + // Do not require out.packed.scalar_type() == dtype. Caller-owned + // main_grad buffers keep the precision chosen by Megatron's grad-buffer config. NVTE_CHECK(out.packed.dim() == 3, name, " must have shape [num_groups, rows, cols]."); NVTE_CHECK(static_cast(out.packed.size(0)) == num_groups, name, " first dimension must be ", num_groups, "."); @@ -355,12 +389,14 @@ void grouped_gemm_fwd_dgrad(GroupedWeightArg *weights, bool trans_weight, } std::vector grouped_gemm_wgrad(GroupedTensorWrapper *x, GroupedTensorWrapper *dy, - py::handle output, bool compute_wgrad, bool accumulate, - GroupedGemmResources *resources, at::ScalarType dtype, - int64_t rows, int64_t cols, const std::string &name, - bool prefer_grouped_output) { - auto prepared = wgrad_output_from_arg(output, compute_wgrad, resources->num_groups, dtype, - resources->device, rows, cols, name, prefer_grouped_output); + py::handle output, bool compute_wgrad, bool accumulate, + GroupedGemmResources *resources, + at::ScalarType dtype, int64_t rows, + int64_t cols, const std::string &name, + bool prefer_grouped_output) { + auto prepared = + wgrad_output_from_arg(output, compute_wgrad, resources->num_groups, dtype, + resources->device, rows, cols, name, prefer_grouped_output); NVTE_CHECK(!(prepared.owns_storage && accumulate), name, " cannot accumulate into a newly allocated wgrad buffer."); std::vector returned_wgrads; @@ -399,11 +435,11 @@ std::vector grouped_gemm_wgrad(GroupedTensorWrapper *x, GroupedTenso } GroupedTensorWrapper make_grouped_bias(const at::Tensor &bias, size_t num_groups, - at::ScalarType dtype, int64_t out_features) { + at::ScalarType bias_dtype, int64_t out_features) { NVTE_CHECK(bias.defined(), "Bias tensor must be defined."); auto grouped = GroupedTensorWrapper( num_groups, std::vector{num_groups, static_cast(out_features)}); - grouped.set_rowwise_data(bias.data_ptr(), GetTransformerEngineDType(dtype), + grouped.set_rowwise_data(bias.data_ptr(), GetTransformerEngineDType(bias_dtype), tensor_shape_1d(bias)); return grouped; } @@ -606,18 +642,25 @@ ActivationBackwardResult grouped_mlp_activation_backward( } // namespace std::vector megacpp_grouped_mlp_forward( - const at::Tensor &input, const at::Tensor &split_sizes, py::handle fc1_weight, + const at::Tensor &input, at::ScalarType act_dtype, const at::Tensor &split_sizes, + py::handle fc1_weight, py::handle fc1_bias, py::handle fc2_weight, py::handle fc2_bias, const std::optional &act_scales, const std::string &activation, int64_t glu_interleave_size, double activation_limit, double activation_alpha, - double activation_glu_linear_offset) { + double activation_glu_linear_offset, py::handle gemm_scratch) { NVTE_CHECK(input.is_cuda(), "megacpp_grouped_mlp_forward requires CUDA input."); at::cuda::CUDAGuard device_guard(input.device()); + // act_dtype is the requested activation/GEMM input dtype. The incoming + // tensor may have a different dtype, so canonicalize it once at the API + // boundary and use this tensor for all downstream grouped GEMMs. + const auto dtype = act_dtype; + auto x = maybe_cast_dtype(input, dtype); + check_contiguous(x, "input"); + const auto num_groups = static_cast(split_sizes.numel()); NVTE_CHECK(num_groups > 0, "megacpp grouped MLP requires at least one group."); - const auto dtype = input.scalar_type(); NVTE_CHECK(dtype == at::kBFloat16 || dtype == at::kHalf, "megacpp grouped MLP currently supports BF16/FP16 only."); @@ -636,9 +679,8 @@ std::vector megacpp_grouped_mlp_forward( auto fc2_bias_tensor = packed_bias_from_arg(fc2_bias, num_groups, dtype, fc2_out_features, "fc2_bias"); - auto x = maybe_cast_dtype(input, dtype); - check_contiguous(x, "input"); - x = x.view({-1, in_features}); + NVTE_CHECK(x.numel() % in_features == 0, "input last dimension is incompatible with FC1."); + const int64_t total_tokens = x.numel() / in_features; auto [split_sizes_i64, split_offsets] = splits_to_offsets_multi( split_sizes, x.device(), std::vector{1, in_features, fc1_out_features, fc2_in_features, fc2_out_features}, @@ -653,13 +695,12 @@ std::vector megacpp_grouped_mlp_forward( auto fc1_offsets = split_offsets[2]; auto fc2_offsets = split_offsets[3]; auto output_offsets = split_offsets[4]; - const int64_t total_tokens = x.size(0); - auto gemm_resources = make_grouped_mlp_backend_resources(x.device(), num_groups); + auto gemm_resources = make_grouped_mlp_backend_resources(x.device(), num_groups, gemm_scratch); auto fc1_preact = at::empty({total_tokens, fc1_out_features}, x.options()); - auto grouped_x = make_grouped_tensor(x.view({-1}), split_sizes_i64, x_offsets, in_features); + auto grouped_x = make_grouped_tensor(x, split_sizes_i64, x_offsets, in_features); auto grouped_fc1_preact = - make_grouped_tensor(fc1_preact.view({-1}), split_sizes_i64, fc1_offsets, fc1_out_features); + make_grouped_tensor(fc1_preact, split_sizes_i64, fc1_offsets, fc1_out_features); grouped_gemm_fwd_dgrad(&fc1_weights, true, &grouped_x, false, &grouped_fc1_preact, &gemm_resources); add_grouped_bias(&grouped_fc1_preact, fc1_bias_tensor, num_groups, dtype, fc1_out_features); @@ -671,11 +712,10 @@ std::vector megacpp_grouped_mlp_forward( std::vector out_shape = input.sizes().vec(); out_shape.back() = fc2_out_features; auto output = at::empty(out_shape, x.options()); - auto output_2d = output.view({-1, fc2_out_features}); auto grouped_fc2_x = - make_grouped_tensor(fc2_x.view({-1}), split_sizes_i64, fc2_offsets, fc2_in_features); + make_grouped_tensor(fc2_x, split_sizes_i64, fc2_offsets, fc2_in_features); auto grouped_output = - make_grouped_tensor(output_2d.view({-1}), split_sizes_i64, output_offsets, fc2_out_features); + make_grouped_tensor(output, split_sizes_i64, output_offsets, fc2_out_features); grouped_gemm_fwd_dgrad(&fc2_weights, true, &grouped_fc2_x, false, &grouped_output, &gemm_resources); add_grouped_bias(&grouped_output, fc2_bias_tensor, num_groups, dtype, fc2_out_features); @@ -685,7 +725,8 @@ std::vector megacpp_grouped_mlp_forward( } py::tuple megacpp_grouped_mlp_backward( - const at::Tensor &grad_output, const at::Tensor &split_sizes, const at::Tensor &x_offsets, + const at::Tensor &grad_output, at::ScalarType act_dtype, const at::Tensor &split_sizes, + const at::Tensor &x_offsets, const at::Tensor &fc1_offsets, const at::Tensor &fc2_offsets, const at::Tensor &fc2_dy_offsets, const at::Tensor &base_offsets, const at::Tensor &x, const at::Tensor &fc1_activation_input, const at::Tensor &fc2_x, const std::optional &act_scales, py::handle fc1_weight, @@ -693,13 +734,19 @@ py::tuple megacpp_grouped_mlp_backward( bool fc1_accumulate_wgrad, py::handle fc2_wgrad_output, bool fc2_compute_wgrad, bool fc2_accumulate_wgrad, const std::string &activation, int64_t glu_interleave_size, double activation_limit, double activation_alpha, double activation_glu_linear_offset, - bool act_scales_requires_grad, bool input_requires_grad) { + bool act_scales_requires_grad, bool input_requires_grad, py::handle gemm_scratch) { (void)base_offsets; NVTE_CHECK(grad_output.is_cuda(), "megacpp_grouped_mlp_backward requires CUDA grad_output."); at::cuda::CUDAGuard device_guard(grad_output.device()); + // act_dtype is the requested grouped-MLP compute dtype. Backward receives + // autograd's grad_output as-is, so canonicalize it here instead of requiring + // a Python-side aten::to before entering C++. + const auto dtype = act_dtype; + auto dy = maybe_cast_dtype(grad_output, dtype); + check_contiguous(dy, "grad_output"); + const auto num_groups = num_groups_from_prepared_split_sizes(split_sizes, grad_output.device()); - const auto dtype = grad_output.scalar_type(); auto fc1_weights = weight_arg_from_py(fc1_weight, num_groups, dtype, "fc1_weight"); auto fc2_weights = weight_arg_from_py(fc2_weight, num_groups, dtype, "fc2_weight"); @@ -708,30 +755,28 @@ py::tuple megacpp_grouped_mlp_backward( const int64_t fc2_out_features = fc2_weights.rows; const int64_t fc2_in_features = fc2_weights.cols; - auto dy = maybe_cast_dtype(grad_output, dtype); - check_contiguous(dy, "grad_output"); - dy = dy.view({-1, fc2_out_features}); - const int64_t total_tokens = dy.size(0); - auto gemm_resources = make_grouped_mlp_backend_resources(grad_output.device(), num_groups); + NVTE_CHECK(dy.numel() % fc2_out_features == 0, + "grad_output last dimension is incompatible with FC2."); + const int64_t total_tokens = dy.numel() / fc2_out_features; + auto gemm_resources = + make_grouped_mlp_backend_resources(grad_output.device(), num_groups, gemm_scratch); - auto grouped_dy = - make_grouped_tensor(dy.view({-1}), split_sizes, fc2_dy_offsets, fc2_out_features); + auto grouped_dy = make_grouped_tensor(dy, split_sizes, fc2_dy_offsets, fc2_out_features); std::vector fc2_wgrads; if (fc2_compute_wgrad) { auto fc2_x_for_wgrad = maybe_cast_dtype(fc2_x, dtype); check_contiguous(fc2_x_for_wgrad, "fc2_x"); - fc2_x_for_wgrad = fc2_x_for_wgrad.view({-1, fc2_in_features}); auto grouped_fc2_x_for_wgrad = - make_grouped_tensor(fc2_x_for_wgrad.view({-1}), split_sizes, fc2_offsets, fc2_in_features); + make_grouped_tensor(fc2_x_for_wgrad, split_sizes, fc2_offsets, fc2_in_features); fc2_wgrads = grouped_gemm_wgrad(&grouped_fc2_x_for_wgrad, &grouped_dy, fc2_wgrad_output, - fc2_compute_wgrad, fc2_accumulate_wgrad, &gemm_resources, dtype, - fc2_out_features, fc2_in_features, "fc2_wgrad_output", + fc2_compute_wgrad, fc2_accumulate_wgrad, &gemm_resources, + dtype, fc2_out_features, fc2_in_features, "fc2_wgrad_output", fc2_weights.is_grouped); } auto fc2_dx = at::empty({total_tokens, fc2_in_features}, dy.options()); auto grouped_fc2_dx = - make_grouped_tensor(fc2_dx.view({-1}), split_sizes, fc2_offsets, fc2_in_features); + make_grouped_tensor(fc2_dx, split_sizes, fc2_offsets, fc2_in_features); grouped_gemm_fwd_dgrad(&fc2_weights, false, &grouped_dy, false, &grouped_fc2_dx, &gemm_resources); auto activation_grads = grouped_mlp_activation_backward( @@ -740,18 +785,17 @@ py::tuple megacpp_grouped_mlp_backward( auto fc1_dy = activation_grads.grad_input; auto grad_act_scales = activation_grads.grad_act_scales; auto grouped_fc1_dy = - make_grouped_tensor(fc1_dy.view({-1}), split_sizes, fc1_offsets, fc1_out_features); + make_grouped_tensor(fc1_dy, split_sizes, fc1_offsets, fc1_out_features); std::vector fc1_wgrads; if (fc1_compute_wgrad) { auto x_for_wgrad = maybe_cast_dtype(x, dtype); check_contiguous(x_for_wgrad, "x"); - x_for_wgrad = x_for_wgrad.view({-1, in_features}); auto grouped_x_for_wgrad = - make_grouped_tensor(x_for_wgrad.view({-1}), split_sizes, x_offsets, in_features); + make_grouped_tensor(x_for_wgrad, split_sizes, x_offsets, in_features); fc1_wgrads = grouped_gemm_wgrad(&grouped_x_for_wgrad, &grouped_fc1_dy, fc1_wgrad_output, - fc1_compute_wgrad, fc1_accumulate_wgrad, &gemm_resources, dtype, - fc1_out_features, in_features, "fc1_wgrad_output", + fc1_compute_wgrad, fc1_accumulate_wgrad, &gemm_resources, + dtype, fc1_out_features, in_features, "fc1_wgrad_output", fc1_weights.is_grouped); } @@ -760,9 +804,8 @@ py::tuple megacpp_grouped_mlp_backward( std::vector grad_input_shape = grad_output.sizes().vec(); grad_input_shape.back() = in_features; grad_input = at::empty(grad_input_shape, dy.options()); - auto grad_input_2d = grad_input.view({-1, in_features}); auto grouped_grad_input = - make_grouped_tensor(grad_input_2d.view({-1}), split_sizes, x_offsets, in_features); + make_grouped_tensor(grad_input, split_sizes, x_offsets, in_features); grouped_gemm_fwd_dgrad(&fc1_weights, false, &grouped_fc1_dy, false, &grouped_grad_input, &gemm_resources); } else { diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp_megacpp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp_megacpp.py index a0a69c5804..3899b5c6ee 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp_megacpp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp_megacpp.py @@ -24,6 +24,7 @@ view_main_grad_as_grouped_buffer, ) from .forward_grouped_mlp_megacpp import ( + _grouped_gemm_scratch, _megacpp_activation_config, _megacpp_enabled, _megacpp_supports_recipe, @@ -235,7 +236,8 @@ def fuser_backward( fc1_owned_weight_grads, fc2_owned_weight_grads, ) = tex.megacpp_grouped_mlp_backward( - grad_output.to(dtype=dtype), + grad_output, + dtype, split_sizes, x_offsets, fc1_offsets, @@ -261,6 +263,7 @@ def fuser_backward( activation_config.glu_linear_offset, bool(activation_ctx.extra_input_requires_grad), bool(fc1_ctx.input_requires_grad), + _grouped_gemm_scratch(num_groups, grad_output.device), ) if not fc1_ctx.input_requires_grad: grad_input = None @@ -318,8 +321,9 @@ def fuser_backward( ) clear_tensor_data(x) + # d(act_scales) belongs to the extra input, so match act_scales.dtype activation_grad_extra = ( - (grad_act_scales.to(dtype=dtype),) + (grad_act_scales.to(dtype=act_scales.dtype),) if activation_ctx.extra_input_requires_grad else (None,) ) diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp_megacpp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp_megacpp.py index 61906e3714..0b086aa56b 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp_megacpp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp_megacpp.py @@ -13,6 +13,10 @@ import torch import transformer_engine_torch as tex +from ...cpp_extensions.gemm import ( + get_cublas_workspace_size_bytes, + get_grouped_gemm_setup_workspace_size, +) from ...quantization import Recipe from ...tensor import Quantizer from ...utils import get_device_compute_capability @@ -42,6 +46,40 @@ def _megacpp_supports_recipe(recipe: Optional[Recipe]) -> bool: return False +@functools.lru_cache(maxsize=None) +def _cached_grouped_gemm_scratch( + num_groups: int, + device_index: int, + _stream_handle: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Cached cuBLAS grouped GEMM scratch for one CUDA stream. + + ``_stream_handle`` is intentionally part of the cache key. The workspace is + reused without recording extra streams, so it must not be shared by + concurrent streams. + """ + device = torch.device("cuda", device_index) + with torch.cuda.device(device): + setup_size = get_grouped_gemm_setup_workspace_size(num_groups) + cublas_size = get_cublas_workspace_size_bytes() + return ( + torch.ones(num_groups, dtype=torch.float32, device=device), + torch.zeros(num_groups, dtype=torch.float32, device=device), + torch.empty(setup_size, dtype=torch.uint8, device=device), + torch.empty(cublas_size, dtype=torch.uint8, device=device), + ) + + +def _grouped_gemm_scratch( + num_groups: int, + device: torch.device, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Return cached GEMM resources for the current stream on ``device``.""" + device_index = torch.cuda.current_device() if device.index is None else device.index + stream_handle = int(torch.cuda.current_stream(device_index).cuda_stream) + return _cached_grouped_gemm_scratch(num_groups, device_index, stream_handle) + + class _MegaCppActivationConfig(NamedTuple): """Activation semantics consumed by the C++ grouped MLP path.""" @@ -250,6 +288,7 @@ def fuser_forward( dtype, input_requires_grad=input_requires_grad, ) + gemm_scratch = _grouped_gemm_scratch(num_groups, input_.device) ( fc2_out, x, @@ -262,7 +301,8 @@ def fuser_forward( fc1_activation_input, fc2_x, ) = tex.megacpp_grouped_mlp_forward( - input_.to(dtype=dtype), + input_, + dtype, split_sizes, fc1_weights, _megacpp_bias_arg(fc1_op, dtype), @@ -274,6 +314,7 @@ def fuser_forward( activation_config.limit, activation_config.alpha, activation_config.glu_linear_offset, + gemm_scratch, ) if x.data_ptr() == input_.data_ptr(): From 7ab8bc65fac5d4d8787921e2db67efcee533a61b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Jun 2026 22:31:05 +0000 Subject: [PATCH 5/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/csrc/extensions.h | 21 +++--- .../pytorch/csrc/extensions/pybind.cpp | 9 ++- .../pytorch/csrc/megacpp/grouped_mlp.cpp | 66 ++++++++----------- 3 files changed, 42 insertions(+), 54 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index a66a35b27d..3561254a7c 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -191,23 +191,22 @@ py::object te_general_grouped_gemm_for_discrete_out(py::handle A, bool transa, p std::vector megacpp_grouped_mlp_forward( const at::Tensor &input, at::ScalarType act_dtype, const at::Tensor &split_sizes, - py::handle fc1_weight, - py::handle fc1_bias, py::handle fc2_weight, py::handle fc2_bias, + py::handle fc1_weight, py::handle fc1_bias, py::handle fc2_weight, py::handle fc2_bias, const std::optional &act_scales, const std::string &activation, int64_t glu_interleave_size, double activation_limit, double activation_alpha, double activation_glu_linear_offset, py::handle gemm_scratch); py::tuple megacpp_grouped_mlp_backward( const at::Tensor &grad_output, at::ScalarType act_dtype, const at::Tensor &split_sizes, - const at::Tensor &x_offsets, - const at::Tensor &fc1_offsets, const at::Tensor &fc2_offsets, const at::Tensor &fc2_dy_offsets, - const at::Tensor &base_offsets, const at::Tensor &x, const at::Tensor &fc1_activation_input, - const at::Tensor &fc2_x, const std::optional &act_scales, py::handle fc1_weight, - py::handle fc2_weight, py::handle fc1_wgrad_output, bool fc1_compute_wgrad, - bool fc1_accumulate_wgrad, py::handle fc2_wgrad_output, bool fc2_compute_wgrad, - bool fc2_accumulate_wgrad, const std::string &activation, int64_t glu_interleave_size, - double activation_limit, double activation_alpha, double activation_glu_linear_offset, - bool act_scales_requires_grad, bool input_requires_grad, py::handle gemm_scratch); + const at::Tensor &x_offsets, const at::Tensor &fc1_offsets, const at::Tensor &fc2_offsets, + const at::Tensor &fc2_dy_offsets, const at::Tensor &base_offsets, const at::Tensor &x, + const at::Tensor &fc1_activation_input, const at::Tensor &fc2_x, + const std::optional &act_scales, py::handle fc1_weight, py::handle fc2_weight, + py::handle fc1_wgrad_output, bool fc1_compute_wgrad, bool fc1_accumulate_wgrad, + py::handle fc2_wgrad_output, bool fc2_compute_wgrad, bool fc2_accumulate_wgrad, + const std::string &activation, int64_t glu_interleave_size, double activation_limit, + double activation_alpha, double activation_glu_linear_offset, bool act_scales_requires_grad, + bool input_requires_grad, py::handle gemm_scratch); /*************************************************************************************************** * Transpose diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index f870938441..34ad560ae0 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -359,11 +359,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Grouped GEMM for discrete output list"); m.def("megacpp_grouped_mlp_forward", &transformer_engine::pytorch::megacpp_grouped_mlp_forward, "Mega C++ grouped MLP forward", py::arg("input"), py::arg("act_dtype"), - py::arg("split_sizes"), py::arg("fc1_weight"), py::arg("fc1_bias"), - py::arg("fc2_weight"), py::arg("fc2_bias"), - py::arg("act_scales"), py::arg("activation"), py::arg("glu_interleave_size"), - py::arg("activation_limit") = 0.0, py::arg("activation_alpha") = 0.0, - py::arg("activation_glu_linear_offset") = 0.0, + py::arg("split_sizes"), py::arg("fc1_weight"), py::arg("fc1_bias"), py::arg("fc2_weight"), + py::arg("fc2_bias"), py::arg("act_scales"), py::arg("activation"), + py::arg("glu_interleave_size"), py::arg("activation_limit") = 0.0, + py::arg("activation_alpha") = 0.0, py::arg("activation_glu_linear_offset") = 0.0, py::arg("gemm_scratch") = py::none()); m.def("megacpp_grouped_mlp_backward", &transformer_engine::pytorch::megacpp_grouped_mlp_backward, "Mega C++ grouped MLP backward", py::arg("grad_output"), py::arg("act_dtype"), diff --git a/transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp b/transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp index 589fb56c5b..4292adb349 100644 --- a/transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp +++ b/transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp @@ -256,8 +256,7 @@ struct GroupedGemmResources { }; GroupedGemmResources make_grouped_mlp_backend_resources(const c10::Device &device, - size_t num_groups, - py::handle scratch) { + size_t num_groups, py::handle scratch) { // Keep the backend resource policy private to megacpp. Today this is cuBLAS // grouped GEMM scratch; future backends can change this helper without // changing the Python or pybind contract. @@ -275,9 +274,8 @@ void grouped_gemm(GroupedTensorWrapper *A, bool transa, GroupedTensorWrapper *B, }); } -std::vector output_tensor_list_from_arg(py::handle arg, size_t num_groups, - int64_t rows, int64_t cols, - const std::string &name) { +std::vector output_tensor_list_from_arg(py::handle arg, size_t num_groups, int64_t rows, + int64_t cols, const std::string &name) { std::vector out; if (is_none(arg)) { return out; @@ -327,8 +325,7 @@ WgradOutput wgrad_output_from_arg(py::handle arg, bool compute_wgrad, size_t num // allocation. Single grouped weight keeps this packed as [G, N, K]; // discrete weights split the same packed allocation into per-expert views. out.packed = - at::empty({static_cast(num_groups), rows, cols}, - at::device(device).dtype(dtype)); + at::empty({static_cast(num_groups), rows, cols}, at::device(device).dtype(dtype)); out.owns_storage = true; out.is_grouped = prefer_grouped_output; if (out.is_grouped) { @@ -389,14 +386,12 @@ void grouped_gemm_fwd_dgrad(GroupedWeightArg *weights, bool trans_weight, } std::vector grouped_gemm_wgrad(GroupedTensorWrapper *x, GroupedTensorWrapper *dy, - py::handle output, bool compute_wgrad, bool accumulate, - GroupedGemmResources *resources, - at::ScalarType dtype, int64_t rows, - int64_t cols, const std::string &name, - bool prefer_grouped_output) { - auto prepared = - wgrad_output_from_arg(output, compute_wgrad, resources->num_groups, dtype, - resources->device, rows, cols, name, prefer_grouped_output); + py::handle output, bool compute_wgrad, bool accumulate, + GroupedGemmResources *resources, at::ScalarType dtype, + int64_t rows, int64_t cols, const std::string &name, + bool prefer_grouped_output) { + auto prepared = wgrad_output_from_arg(output, compute_wgrad, resources->num_groups, dtype, + resources->device, rows, cols, name, prefer_grouped_output); NVTE_CHECK(!(prepared.owns_storage && accumulate), name, " cannot accumulate into a newly allocated wgrad buffer."); std::vector returned_wgrads; @@ -643,8 +638,7 @@ ActivationBackwardResult grouped_mlp_activation_backward( std::vector megacpp_grouped_mlp_forward( const at::Tensor &input, at::ScalarType act_dtype, const at::Tensor &split_sizes, - py::handle fc1_weight, - py::handle fc1_bias, py::handle fc2_weight, py::handle fc2_bias, + py::handle fc1_weight, py::handle fc1_bias, py::handle fc2_weight, py::handle fc2_bias, const std::optional &act_scales, const std::string &activation, int64_t glu_interleave_size, double activation_limit, double activation_alpha, double activation_glu_linear_offset, py::handle gemm_scratch) { @@ -712,8 +706,7 @@ std::vector megacpp_grouped_mlp_forward( std::vector out_shape = input.sizes().vec(); out_shape.back() = fc2_out_features; auto output = at::empty(out_shape, x.options()); - auto grouped_fc2_x = - make_grouped_tensor(fc2_x, split_sizes_i64, fc2_offsets, fc2_in_features); + auto grouped_fc2_x = make_grouped_tensor(fc2_x, split_sizes_i64, fc2_offsets, fc2_in_features); auto grouped_output = make_grouped_tensor(output, split_sizes_i64, output_offsets, fc2_out_features); grouped_gemm_fwd_dgrad(&fc2_weights, true, &grouped_fc2_x, false, &grouped_output, @@ -726,15 +719,15 @@ std::vector megacpp_grouped_mlp_forward( py::tuple megacpp_grouped_mlp_backward( const at::Tensor &grad_output, at::ScalarType act_dtype, const at::Tensor &split_sizes, - const at::Tensor &x_offsets, - const at::Tensor &fc1_offsets, const at::Tensor &fc2_offsets, const at::Tensor &fc2_dy_offsets, - const at::Tensor &base_offsets, const at::Tensor &x, const at::Tensor &fc1_activation_input, - const at::Tensor &fc2_x, const std::optional &act_scales, py::handle fc1_weight, - py::handle fc2_weight, py::handle fc1_wgrad_output, bool fc1_compute_wgrad, - bool fc1_accumulate_wgrad, py::handle fc2_wgrad_output, bool fc2_compute_wgrad, - bool fc2_accumulate_wgrad, const std::string &activation, int64_t glu_interleave_size, - double activation_limit, double activation_alpha, double activation_glu_linear_offset, - bool act_scales_requires_grad, bool input_requires_grad, py::handle gemm_scratch) { + const at::Tensor &x_offsets, const at::Tensor &fc1_offsets, const at::Tensor &fc2_offsets, + const at::Tensor &fc2_dy_offsets, const at::Tensor &base_offsets, const at::Tensor &x, + const at::Tensor &fc1_activation_input, const at::Tensor &fc2_x, + const std::optional &act_scales, py::handle fc1_weight, py::handle fc2_weight, + py::handle fc1_wgrad_output, bool fc1_compute_wgrad, bool fc1_accumulate_wgrad, + py::handle fc2_wgrad_output, bool fc2_compute_wgrad, bool fc2_accumulate_wgrad, + const std::string &activation, int64_t glu_interleave_size, double activation_limit, + double activation_alpha, double activation_glu_linear_offset, bool act_scales_requires_grad, + bool input_requires_grad, py::handle gemm_scratch) { (void)base_offsets; NVTE_CHECK(grad_output.is_cuda(), "megacpp_grouped_mlp_backward requires CUDA grad_output."); at::cuda::CUDAGuard device_guard(grad_output.device()); @@ -769,14 +762,13 @@ py::tuple megacpp_grouped_mlp_backward( auto grouped_fc2_x_for_wgrad = make_grouped_tensor(fc2_x_for_wgrad, split_sizes, fc2_offsets, fc2_in_features); fc2_wgrads = grouped_gemm_wgrad(&grouped_fc2_x_for_wgrad, &grouped_dy, fc2_wgrad_output, - fc2_compute_wgrad, fc2_accumulate_wgrad, &gemm_resources, - dtype, fc2_out_features, fc2_in_features, "fc2_wgrad_output", + fc2_compute_wgrad, fc2_accumulate_wgrad, &gemm_resources, dtype, + fc2_out_features, fc2_in_features, "fc2_wgrad_output", fc2_weights.is_grouped); } auto fc2_dx = at::empty({total_tokens, fc2_in_features}, dy.options()); - auto grouped_fc2_dx = - make_grouped_tensor(fc2_dx, split_sizes, fc2_offsets, fc2_in_features); + auto grouped_fc2_dx = make_grouped_tensor(fc2_dx, split_sizes, fc2_offsets, fc2_in_features); grouped_gemm_fwd_dgrad(&fc2_weights, false, &grouped_dy, false, &grouped_fc2_dx, &gemm_resources); auto activation_grads = grouped_mlp_activation_backward( @@ -784,8 +776,7 @@ py::tuple megacpp_grouped_mlp_backward( activation_alpha, activation_glu_linear_offset, dtype, act_scales_requires_grad); auto fc1_dy = activation_grads.grad_input; auto grad_act_scales = activation_grads.grad_act_scales; - auto grouped_fc1_dy = - make_grouped_tensor(fc1_dy, split_sizes, fc1_offsets, fc1_out_features); + auto grouped_fc1_dy = make_grouped_tensor(fc1_dy, split_sizes, fc1_offsets, fc1_out_features); std::vector fc1_wgrads; if (fc1_compute_wgrad) { @@ -794,8 +785,8 @@ py::tuple megacpp_grouped_mlp_backward( auto grouped_x_for_wgrad = make_grouped_tensor(x_for_wgrad, split_sizes, x_offsets, in_features); fc1_wgrads = grouped_gemm_wgrad(&grouped_x_for_wgrad, &grouped_fc1_dy, fc1_wgrad_output, - fc1_compute_wgrad, fc1_accumulate_wgrad, &gemm_resources, - dtype, fc1_out_features, in_features, "fc1_wgrad_output", + fc1_compute_wgrad, fc1_accumulate_wgrad, &gemm_resources, dtype, + fc1_out_features, in_features, "fc1_wgrad_output", fc1_weights.is_grouped); } @@ -804,8 +795,7 @@ py::tuple megacpp_grouped_mlp_backward( std::vector grad_input_shape = grad_output.sizes().vec(); grad_input_shape.back() = in_features; grad_input = at::empty(grad_input_shape, dy.options()); - auto grouped_grad_input = - make_grouped_tensor(grad_input, split_sizes, x_offsets, in_features); + auto grouped_grad_input = make_grouped_tensor(grad_input, split_sizes, x_offsets, in_features); grouped_gemm_fwd_dgrad(&fc1_weights, false, &grouped_fc1_dy, false, &grouped_grad_input, &gemm_resources); } else {