Skip to content

Commit 4130d73

Browse files
yaox12pre-commit-ci[bot]vthumbe1503
authored
[PyTorch] Update cuBLASLt grouped gemm filter (#3119)
* update cublaslt grouped gemm filter Signed-off-by: Xin Yao <xiny@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update nvfp4 filter and tests Signed-off-by: Xin Yao <xiny@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test correctness Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * better test Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Xin Yao <xiny@nvidia.com> Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Varun Thumbe <vthumbe@nvidia.com>
1 parent 547d284 commit 4130d73

4 files changed

Lines changed: 180 additions & 43 deletions

File tree

tests/pytorch/test_grouped_linear.py

Lines changed: 123 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1497,6 +1497,7 @@ def test_fp8_grouped_gemm(shape, accumulate):
14971497
_FUSED_GROUPED_GEMM_ENV = "NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM"
14981498
_ALL_BOOLEAN = all_boolean
14991499
_mxfp8_available, _reason_for_no_mxfp8 = mxfp8_available, reason_for_no_mxfp8
1500+
_nvfp4_available, _reason_for_no_nvfp4 = nvfp4_available, reason_for_no_nvfp4
15001501

15011502

15021503
@pytest.fixture(autouse=True)
@@ -1580,26 +1581,40 @@ def _run_grouped_linear_path(
15801581
recipe.MXFP8BlockScaling(),
15811582
marks=pytest.mark.skipif(not _mxfp8_available, reason=_reason_for_no_mxfp8),
15821583
),
1584+
pytest.param(
1585+
recipe.NVFP4BlockScaling(disable_stochastic_rounding=True),
1586+
marks=pytest.mark.skipif(not _nvfp4_available, reason=_reason_for_no_nvfp4),
1587+
),
15831588
],
1584-
ids=["bf16", "mxfp8"],
1589+
ids=["bf16", "mxfp8", "nvfp4"],
15851590
)
15861591
@pytest.mark.parametrize("bias", _ALL_BOOLEAN)
15871592
@pytest.mark.parametrize("fp8_model_params", _ALL_BOOLEAN)
15881593
@pytest.mark.parametrize("delay_wgrad_compute", _ALL_BOOLEAN)
15891594
def test_grouped_linear_grouped_tensor_path_matches_legacy(
15901595
fp8_recipe, bias, fp8_model_params, delay_wgrad_compute, monkeypatch
15911596
):
1592-
if torch.cuda.get_device_capability() < (10, 0):
1593-
pytest.skip("GroupedTensor grouped GEMM path requires SM100+")
1594-
15951597
use_fp8 = fp8_recipe is not None
1598+
device_capability = torch.cuda.get_device_capability()
1599+
if not (9, 0) <= device_capability <= (11, 0):
1600+
pytest.skip(
1601+
"GroupedTensor grouped GEMM path requires Hopper (SM90) or Blackwell (SM10x and SM110)."
1602+
)
1603+
if use_fp8 and device_capability < (10, 0):
1604+
pytest.skip("Quantized GroupedTensor grouped GEMM path requires Blackwell (SM100+).")
1605+
cublaslt_version = tex.get_cublasLt_version()
1606+
if device_capability < (10, 0) and cublaslt_version < 130400:
1607+
pytest.skip("Grouped GEMM on Hopper requires cuBLAS 13.4+.")
1608+
if cublaslt_version < 130300:
1609+
pytest.skip("Grouped GEMM requires cuBLAS 13.3+.")
1610+
15961611
if fp8_model_params and not use_fp8:
15971612
pytest.skip("fp8_model_params requires FP8")
15981613

15991614
dtype = torch.bfloat16
16001615
num_gemms = 3
1601-
in_features = 64
1602-
out_features = 64
1616+
in_features = 128
1617+
out_features = 128
16031618
m_splits = [128, 256, 384]
16041619
total_tokens = sum(m_splits)
16051620

@@ -1683,6 +1698,90 @@ def test_grouped_linear_grouped_tensor_path_single_grouped_bias_delay_wgrad(monk
16831698
grouped_linear.backward_dw()
16841699

16851700

1701+
@pytest.mark.skipif(not _nvfp4_available, reason=_reason_for_no_nvfp4)
1702+
def test_grouped_linear_grouped_tensor_path_skips_non_rht_nvfp4(monkeypatch):
1703+
"""Non-RHT NVFP4 falls back to the legacy path; check it stays numerically correct.
1704+
1705+
Graph-safe grouped quantization currently requires RHT, so requesting NVFP4 with
1706+
``disable_rht=True`` while the fused grouped-tensor path is enabled falls back to the
1707+
legacy path internally. We verify the output and gradients against a reference built from
1708+
per-GEMM ``te.Linear`` modules that share the same weights and use the same NVFP4 recipe;
1709+
the grouped GEMM should match the loop of single GEMMs.
1710+
"""
1711+
if torch.cuda.get_device_capability() < (10, 0):
1712+
pytest.skip("NVFP4 GroupedTensor grouped GEMM path requires SM100+")
1713+
1714+
monkeypatch.setenv(_FUSED_GROUPED_GEMM_ENV, "1")
1715+
FP8GlobalStateManager.reset()
1716+
1717+
dtype = torch.bfloat16
1718+
num_gemms = 3
1719+
in_features = 128
1720+
out_features = 128
1721+
m_splits = [128, 256, 384]
1722+
total_tokens = sum(m_splits)
1723+
1724+
torch.manual_seed(1234)
1725+
x_base = (0.1 * torch.randn(total_tokens, in_features, device="cuda")).to(dtype)
1726+
dy = (0.1 * torch.randn(total_tokens, out_features, device="cuda")).to(dtype)
1727+
weights = [
1728+
(0.1 * torch.randn(out_features, in_features, device="cuda")).to(dtype)
1729+
for _ in range(num_gemms)
1730+
]
1731+
1732+
fp8_recipe = recipe.NVFP4BlockScaling(
1733+
disable_rht=True,
1734+
disable_stochastic_rounding=True,
1735+
)
1736+
1737+
# Grouped path: fused path enabled, but non-RHT NVFP4 falls back to legacy internally.
1738+
grouped_linear = GroupedLinear(
1739+
num_gemms,
1740+
in_features,
1741+
out_features,
1742+
bias=False,
1743+
params_dtype=dtype,
1744+
device="cuda",
1745+
)
1746+
with torch.no_grad():
1747+
for i in range(num_gemms):
1748+
getattr(grouped_linear, f"weight{i}").copy_(weights[i])
1749+
1750+
x = x_base.detach().clone().requires_grad_(True)
1751+
with autocast(enabled=True, recipe=fp8_recipe):
1752+
y = grouped_linear(x, m_splits)
1753+
y.backward(dy)
1754+
1755+
# Reference: one te.Linear per GEMM sharing the same weights and NVFP4 recipe.
1756+
ref_linears = torch.nn.ModuleList(
1757+
[
1758+
Linear(in_features, out_features, bias=False, params_dtype=dtype, device="cuda")
1759+
for _ in range(num_gemms)
1760+
]
1761+
)
1762+
with torch.no_grad():
1763+
for i in range(num_gemms):
1764+
ref_linears[i].weight.copy_(weights[i])
1765+
1766+
x_ref = x_base.detach().clone().requires_grad_(True)
1767+
with autocast(enabled=True, recipe=fp8_recipe):
1768+
y_ref = torch.cat(
1769+
[ref_linears[i](x_i) for i, x_i in enumerate(torch.split(x_ref, m_splits))]
1770+
)
1771+
y_ref.backward(dy)
1772+
1773+
# cuBLAS grouped GEMM should match the loop of single GEMMs bit-for-bit.
1774+
tols = dict(rtol=0, atol=0)
1775+
torch.testing.assert_close(y.float(), y_ref.float(), **tols)
1776+
torch.testing.assert_close(x.grad.float(), x_ref.grad.float(), **tols)
1777+
for i in range(num_gemms):
1778+
torch.testing.assert_close(
1779+
getattr(grouped_linear, f"weight{i}").grad.float(),
1780+
ref_linears[i].weight.grad.float(),
1781+
**tols,
1782+
)
1783+
1784+
16861785
@pytest.mark.parametrize(
16871786
"fp8_recipe",
16881787
[
@@ -1691,19 +1790,33 @@ def test_grouped_linear_grouped_tensor_path_single_grouped_bias_delay_wgrad(monk
16911790
recipe.MXFP8BlockScaling(),
16921791
marks=pytest.mark.skipif(not _mxfp8_available, reason=_reason_for_no_mxfp8),
16931792
),
1793+
pytest.param(
1794+
recipe.NVFP4BlockScaling(disable_stochastic_rounding=True),
1795+
marks=pytest.mark.skipif(not _nvfp4_available, reason=_reason_for_no_nvfp4),
1796+
),
16941797
],
1695-
ids=["bf16", "mxfp8"],
1798+
ids=["bf16", "mxfp8", "nvfp4"],
16961799
)
16971800
@pytest.mark.parametrize("bias", _ALL_BOOLEAN)
16981801
def test_grouped_linear_fused_path_cuda_graph_safe(fp8_recipe, bias, monkeypatch):
16991802
"""Fused GroupedTensor GEMM path should be CUDA graph capturable."""
1700-
if torch.cuda.get_device_capability() < (10, 0):
1701-
pytest.skip("GroupedTensor grouped GEMM path requires SM100+")
1803+
use_fp8 = fp8_recipe is not None
1804+
device_capability = torch.cuda.get_device_capability()
1805+
if not (9, 0) <= device_capability <= (11, 0):
1806+
pytest.skip(
1807+
"GroupedTensor grouped GEMM path requires Hopper (SM90) or Blackwell (SM10x and SM110)."
1808+
)
1809+
if use_fp8 and device_capability < (10, 0):
1810+
pytest.skip("Quantized GroupedTensor grouped GEMM path requires Blackwell (SM100+).")
1811+
cublaslt_version = tex.get_cublasLt_version()
1812+
if device_capability < (10, 0) and cublaslt_version < 130400:
1813+
pytest.skip("Grouped GEMM on Hopper requires cuBLAS 13.4+.")
1814+
if cublaslt_version < 130300:
1815+
pytest.skip("Grouped GEMM requires cuBLAS 13.3+.")
17021816

17031817
monkeypatch.setenv(_FUSED_GROUPED_GEMM_ENV, "1")
17041818
FP8GlobalStateManager.reset()
17051819

1706-
use_fp8 = fp8_recipe is not None
17071820
dtype = torch.bfloat16
17081821
device = "cuda"
17091822
num_gemms = 3

transformer_engine/common/gemm/cublaslt_grouped_gemm.cu

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ struct GroupedGemmSetupWorkspace {
228228
}
229229
};
230230

231-
inline bool grouped_gemm_supports_per_group_alpha_beta(int sm) { return sm >= 100; }
231+
inline bool grouped_gemm_supports_per_group_alpha_beta(int sm) { return sm >= 100 && sm <= 110; }
232232

233233
inline size_t validate_grouped_gemm_inputs(
234234
size_t num_tensors, std::initializer_list<const transformer_engine::GroupedTensor *> inputs,
@@ -335,7 +335,8 @@ inline void check_grouped_gemm_requirements(const char *api_name) {
335335
const int sm = transformer_engine::cuda::sm_arch(current_device);
336336
const int cublas_ver = transformer_engine::cuda::cublas_version();
337337
#if CUBLAS_VERSION >= CUBLAS_GROUPED_GEMM_HOPPER_VERSION
338-
NVTE_CHECK(sm >= 90, api_name, " requires Hopper (SM90) or newer architecture.");
338+
NVTE_CHECK(sm >= 90 && sm <= 110, api_name,
339+
" requires Hopper (SM90) or Blackwell (SM10x and SM110).");
339340
NVTE_CHECK(cublas_ver >= CUBLAS_GROUPED_GEMM_VERSION, api_name,
340341
" requires cuBLAS 13.3+, but run-time cuBLAS version is ", cublas_ver);
341342
if (sm < 100) {
@@ -344,7 +345,7 @@ inline void check_grouped_gemm_requirements(const char *api_name) {
344345
cublas_ver);
345346
}
346347
#else
347-
NVTE_CHECK(sm >= 100, api_name, " requires Blackwell (SM100) or newer architecture.");
348+
NVTE_CHECK(sm >= 100 && sm <= 110, api_name, " requires Blackwell (SM10x and SM110).");
348349
NVTE_CHECK(cublas_ver >= CUBLAS_GROUPED_GEMM_VERSION, api_name,
349350
" requires cuBLAS 13.3+, but run-time cuBLAS version is ", cublas_ver);
350351
#endif
@@ -400,7 +401,7 @@ inline void validate_fp8_block_grouped_gemm_support(const GroupedOperandSelectio
400401
"Grouped GEMM: A and B must both use FP8 block scaling or both not.");
401402
NVTE_CHECK(sm == 90,
402403
"Grouped GEMM: FP8 block scaling is only supported on Hopper (SM90); "
403-
"use MXFP8 on Blackwell (SM100) or newer.");
404+
"use MXFP8 on Blackwell (SM10x and SM110).");
404405
}
405406

406407
inline bool is_compatible_grouped_scaling_mode(NVTEScalingMode a_mode, NVTEScalingMode b_mode) {
@@ -1567,7 +1568,7 @@ void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedT
15671568
NVTE_API_CALL(nvte_grouped_gemm);
15681569
using namespace transformer_engine;
15691570

1570-
// Grouped GEMM requires Blackwell (SM100) or newer with cuBLAS 13.3+,
1571+
// Grouped GEMM requires Blackwell (SM10x and SM110) with cuBLAS 13.3+,
15711572
// or Hopper (SM90) with cuBLAS 13.4+.
15721573
check_grouped_gemm_requirements("nvte_grouped_gemm");
15731574

@@ -1650,7 +1651,7 @@ void nvte_grouped_gemm_with_discrete_inputA(const NVTETensor *A_list, size_t num
16501651
NVTE_API_CALL(nvte_grouped_gemm_with_discrete_inputA);
16511652
using namespace transformer_engine;
16521653

1653-
// Grouped GEMM requires Blackwell (SM100) or newer with cuBLAS 13.3+,
1654+
// Grouped GEMM requires Blackwell (SM10x and SM110) with cuBLAS 13.3+,
16541655
// or Hopper (SM90) with cuBLAS 13.4+.
16551656
check_grouped_gemm_requirements("nvte_grouped_gemm_with_discrete_inputA");
16561657

@@ -1801,7 +1802,7 @@ void nvte_grouped_gemm_with_discrete_out(const NVTEGroupedTensor A, int transa,
18011802
NVTE_API_CALL(nvte_grouped_gemm_with_discrete_out);
18021803
using namespace transformer_engine;
18031804

1804-
// Grouped GEMM requires Blackwell (SM100) or newer with cuBLAS 13.3+,
1805+
// Grouped GEMM requires Blackwell (SM10x and SM110) with cuBLAS 13.3+,
18051806
// or Hopper (SM90) with cuBLAS 13.4+.
18061807
check_grouped_gemm_requirements("nvte_grouped_gemm_with_discrete_out");
18071808

transformer_engine/pytorch/module/grouped_linear.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,7 @@
5555
from ..cpu_offload import is_cpu_offload_enabled, mark_not_offload, start_offload
5656
from ..triton.grouped_dbias_dscales import compute_grouped_dbias
5757

58-
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
59-
from ..tensor.mxfp8_tensor import MXFP8Quantizer
58+
from ..tensor import Float8CurrentScalingQuantizer, Float8Quantizer, MXFP8Quantizer, NVFP4Quantizer
6059
from ..quantized_tensor import (
6160
QuantizedTensorStorage,
6261
Quantizer,
@@ -95,19 +94,29 @@ def _is_grouped_tensor_path_supported(
9594
save_original_input: bool,
9695
activation_dtype: torch.dtype,
9796
input_quantizers: List[Optional[Quantizer]],
98-
weight_quantizers: List[Optional[Quantizer]],
9997
output_quantizers: List[Optional[Quantizer]],
100-
grad_output_quantizers: List[Optional[Quantizer]],
10198
) -> bool:
102-
"""Whether to use cublasLt grouped GEMM through GroupedTensor metadata.
99+
"""Whether to use cuBLASLt grouped GEMM through GroupedTensor metadata.
103100
104101
There are no checks whether split sizes are supported. Splits
105102
may be in a CUDA tensor, so checking would hurt performance
106103
and be incompatible with CUDA Graphs.
107104
105+
Supported Compute Capability (CC) and precisions:
106+
* Hopper (CC 9.0): BF16/FP16.
107+
* Blackwell (CC 10.x and 11.0): BF16/FP16/MXFP8/NVFP4 with RHT.
108+
FP8 delayed / current scaling, and FP8 block scaling are not supported because the
109+
corresponding grouped quantization kernels are missing.
110+
Non-RHT NVFP4 falls back to the legacy path because graph-safe grouped quantization
111+
currently requires RHT.
112+
113+
Input/weight/grad_output quantizers are assumed to be of the same type, otherwise it would
114+
trigger a fatal error in the cuBLASLt grouped GEMM check.
108115
"""
116+
# 1. Filter by environment variable
109117
if not bool(int(os.getenv("NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM", "0"))):
110118
return False
119+
# 2. Filter out advanced features
111120
if (
112121
debug
113122
or cpu_offloading
@@ -116,16 +125,18 @@ def _is_grouped_tensor_path_supported(
116125
or save_original_input
117126
):
118127
return False
119-
if get_device_compute_capability() < (10, 0):
128+
# 3. Filter by compute capability
129+
if not (9, 0) <= get_device_compute_capability() <= (11, 0):
120130
return False
131+
# 4. Output quantization is not supported.
121132
if any(q is not None for q in output_quantizers):
122133
return False
134+
# 5. Filter by quantization recipes.
123135
if fp8:
124-
return (
125-
activation_dtype in (torch.bfloat16, torch.float16)
126-
and all(isinstance(q, MXFP8Quantizer) for q in input_quantizers)
127-
and all(isinstance(q, MXFP8Quantizer) for q in weight_quantizers)
128-
and all(q is None or isinstance(q, MXFP8Quantizer) for q in grad_output_quantizers)
136+
if not (10, 0) <= get_device_compute_capability() <= (11, 0):
137+
return False
138+
return all(isinstance(q, MXFP8Quantizer) for q in input_quantizers) or all(
139+
isinstance(q, NVFP4Quantizer) and q.with_rht for q in input_quantizers
129140
)
130141
return activation_dtype in (torch.bfloat16, torch.float16)
131142

@@ -234,7 +245,7 @@ def _forward_grouped_tensor(
234245
weights: Tuple[torch.Tensor, ...],
235246
biases: Tuple[torch.Tensor, ...],
236247
) -> Tuple[torch.Tensor, list]:
237-
"""Forward path backed by GroupedTensor + cublasLt grouped GEMM."""
248+
"""Forward path backed by GroupedTensor + cuBLASLt grouped GEMM."""
238249
num_gemms = len(m_splits)
239250
device = inp.device
240251
in_features = weights[0].size(-1)
@@ -491,9 +502,7 @@ def forward(
491502
save_original_input=save_original_input,
492503
activation_dtype=activation_dtype,
493504
input_quantizers=input_quantizers,
494-
weight_quantizers=weight_quantizers,
495505
output_quantizers=output_quantizers,
496-
grad_output_quantizers=grad_output_quantizers,
497506
):
498507
return _GroupedLinear._forward_grouped_tensor(
499508
ctx,
@@ -745,7 +754,7 @@ def _backward_grouped_tensor(
745754
columnwise=ctx.weights_requires_grad,
746755
)
747756
grad_output_quantizer.optimize_for_gemm = True
748-
if ctx.use_bias:
757+
if ctx.use_bias and isinstance(grad_output_quantizer, MXFP8Quantizer):
749758
grouped_dy, dbias_packed = tex.bgrad_group_quantize(
750759
dy_2d,
751760
grad_output_quantizer,

0 commit comments

Comments
 (0)