Skip to content

Commit b345941

Browse files
ksivamanzhongbozhuvthumbe1503Oleg-Goncharov
authored
[PyTorch] GroupedTensor integration (#2600)
* Python GroupedTensor and contiguous weights for GroupedLinear Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Graph safe C API for grouped RHT, needs testing Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com> * C++ utils, untested Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Pytorch Binding for GroupedTensor APIs (#13) * changes for pytoch extension; but everything seems to be broken probably unrelated to my changes Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * fix the issues Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * comment nvte API since Oleg's PR is not merged Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * test for all cases: Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * tensor attributes should be set later Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> --------- Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix make grouped tensor api Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fixes to tests Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * PyTorch-Python GroupedTensor Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix test Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * All tests pass Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/pytorch/tensor/storage/grouped_tensor.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Remove mxfp8 gq test Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * C++ PyTorch GroupedTensor changes WIP Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Compiles Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix runtime failure for test Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix IMA in mxfp8 GQ Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Add CG test for grouped_quantize Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix recipe tests and FP8 weights Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix recipe tests and FP8 weights Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix device test Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Disable grouped weights for unsupported recipes Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Integrate NVFP4 Graph Safe Group Quantize (#14) * nvfp4 grouped quantize Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * fix for paged stashing Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * pass all edge cases Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * clean up Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * fix for other recipes Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> --------- Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * improve mxfp8 unit test Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * pre-swizzle nvfp4 mxfp8 for MoE Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * avoid having nvte_get_grouped_tensor_param_v2 Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * more tests Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * fix group quantize mxfp8 kernel Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * Relaxed restriction for the last dim to be a multiple of 128 Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> Signed-off-by: vthumbe1503 <vthumbe@nvidia.com> Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com> Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com> Co-authored-by: vthumbe1503 <vthumbe@nvidia.com> Co-authored-by: Oleg Goncharov <ogoncharov@nvidia.com>
1 parent ad56283 commit b345941

26 files changed

Lines changed: 2761 additions & 309 deletions

qa/L0_pytorch_unittest/test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED
3232
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py"
3333
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py"
3434
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4"
35+
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_mxfp8.xml $TE_PATH/tests/pytorch/mxfp8 || test_fail "test_mxfp8"
3536
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_quantized_tensor.xml $TE_PATH/tests/pytorch/test_quantized_tensor.py || test_fail "test_quantized_tensor.py"
3637
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
3738
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"

tests/cpp/operator/test_cast_mxfp8_grouped.cu

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -385,28 +385,41 @@ void performTest(const ProcessingMethod processing_method,
385385

386386
NVTEBasicTensor grad_data_tensor = {grad_data_d, static_cast<NVTEDType>(itype), logical_shape_};
387387
NVTEBasicTensor in_data_tensor = {in_data_d, static_cast<NVTEDType>(itype), logical_shape_};
388-
nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &in_data_tensor);
389-
nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &grad_data_tensor);
388+
nvte_set_grouped_tensor_param(in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData,
389+
&in_data_tensor, sizeof(in_data_tensor));
390+
nvte_set_grouped_tensor_param(grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData,
391+
&grad_data_tensor, sizeof(grad_data_tensor));
390392

391393
if ((shape_rep == VARYING_FIRST_DIM) || (shape_rep == VARYING_BOTH_DIMS)) {
392394
NVTEBasicTensor first_dims_tensor = {first_dims_d, kNVTEInt64, first_dims_shape_};
393-
nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims, &first_dims_tensor);
394-
nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims, &first_dims_tensor);
395-
nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims, &first_dims_tensor);
395+
nvte_set_grouped_tensor_param(grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims,
396+
&first_dims_tensor, sizeof(first_dims_tensor));
397+
nvte_set_grouped_tensor_param(in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims,
398+
&first_dims_tensor, sizeof(first_dims_tensor));
399+
nvte_set_grouped_tensor_param(out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims,
400+
&first_dims_tensor, sizeof(first_dims_tensor));
396401
}
397402

398403
if ((shape_rep == VARYING_LAST_DIM) || (shape_rep == VARYING_BOTH_DIMS)) {
399404
NVTEBasicTensor last_dims_tensor = {last_dims_d, kNVTEInt64, last_dims_shape_};
400-
nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims, &last_dims_tensor);
401-
nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims, &last_dims_tensor);
402-
nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims, &last_dims_tensor);
405+
nvte_set_grouped_tensor_param(grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims,
406+
&last_dims_tensor, sizeof(last_dims_tensor));
407+
nvte_set_grouped_tensor_param(in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims,
408+
&last_dims_tensor, sizeof(last_dims_tensor));
409+
nvte_set_grouped_tensor_param(out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims,
410+
&last_dims_tensor, sizeof(last_dims_tensor));
403411
}
404412

405413
if (shape_rep != SAME_BOTH_DIMS) {
406414
NVTEBasicTensor offsets_tensor = {offsets_d, kNVTEInt64, offsets_shape_};
407-
nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, &offsets_tensor);
408-
nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, &offsets_tensor);
409-
nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, &offsets_tensor);
415+
nvte_set_grouped_tensor_param(grad_group_tensor,
416+
NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets,
417+
&offsets_tensor, sizeof(offsets_tensor));
418+
nvte_set_grouped_tensor_param(in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets,
419+
&offsets_tensor, sizeof(offsets_tensor));
420+
nvte_set_grouped_tensor_param(out_group_tensor,
421+
NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets,
422+
&offsets_tensor, sizeof(offsets_tensor));
410423
}
411424

412425
if (rowwise) {
@@ -417,8 +430,11 @@ void performTest(const ProcessingMethod processing_method,
417430
NVTEBasicTensor out_data_rowwise_tensor = {out_data_rowwise_d, static_cast<NVTEDType>(otype), logical_shape_};
418431
NVTEShape scales_rowwise_shape_ = nvte_make_shape(scales_rowwise_shape.data(), scales_rowwise_shape.size());
419432
NVTEBasicTensor out_scales_rowwise_tensor = {out_scales_rowwise_d, NVTEDType::kNVTEFloat8E8M0, scales_rowwise_shape_};
420-
nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &out_data_rowwise_tensor);
421-
nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseScaleInv, &out_scales_rowwise_tensor);
433+
nvte_set_grouped_tensor_param(out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData,
434+
&out_data_rowwise_tensor, sizeof(out_data_rowwise_tensor));
435+
nvte_set_grouped_tensor_param(out_group_tensor,
436+
NVTEGroupedTensorParam::kNVTEGroupedRowwiseScaleInv,
437+
&out_scales_rowwise_tensor, sizeof(out_scales_rowwise_tensor));
422438
}
423439

424440
if (colwise) {
@@ -429,8 +445,12 @@ void performTest(const ProcessingMethod processing_method,
429445
NVTEBasicTensor out_data_colwise_tensor = {out_data_colwise_d, static_cast<NVTEDType>(otype), logical_shape_};
430446
NVTEShape scales_colwise_shape_ = nvte_make_shape(scales_colwise_shape.data(), scales_colwise_shape.size());
431447
NVTEBasicTensor out_scales_colwise_tensor = {out_scales_colwise_d, NVTEDType::kNVTEFloat8E8M0, scales_colwise_shape_};
432-
nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedColumnwiseData, &out_data_colwise_tensor);
433-
nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedColumnwiseScaleInv, &out_scales_colwise_tensor);
448+
nvte_set_grouped_tensor_param(out_group_tensor,
449+
NVTEGroupedTensorParam::kNVTEGroupedColumnwiseData,
450+
&out_data_colwise_tensor, sizeof(out_data_colwise_tensor));
451+
nvte_set_grouped_tensor_param(out_group_tensor,
452+
NVTEGroupedTensorParam::kNVTEGroupedColumnwiseScaleInv,
453+
&out_scales_colwise_tensor, sizeof(out_scales_colwise_tensor));
434454
}
435455

436456
Tensor output_dbias("output_dbias", std::vector<size_t>{ cols }, itype);
@@ -695,7 +715,10 @@ TEST_P(GroupedFusedCastMXFP8TestSuite, Test) {
695715
}
696716
offsets[t+1] = offsets[t] + first_dims[t] * last_dims[t];
697717
// Skips tests if tensor shape is not as required by the kernel
698-
if ((first_dims[t] % 128 != 0) || (last_dims[t] % 32 != 0)) {
718+
if (first_dims[t] % 128 != 0) {
719+
GTEST_SKIP();
720+
}
721+
if (!is_single_tensor && (last_dims[t] % 128 != 0)) {
699722
GTEST_SKIP();
700723
}
701724
}

tests/cpp/test_common.cu

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,7 +1157,7 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
11571157

11581158
NVTEBasicTensor data_tensor{grouped.data.get(), static_cast<NVTEDType>(dtype), grouped.logical_shape};
11591159
NVTEGroupedTensor h = grouped.handle.get();
1160-
nvte_set_grouped_tensor_param(&h, kNVTEGroupedRowwiseData, &data_tensor);
1160+
nvte_set_grouped_tensor_param(h, kNVTEGroupedRowwiseData, &data_tensor, sizeof(data_tensor));
11611161

11621162
const bool include_columnwise = isFp8Type(dtype) || isFp4Type(dtype);
11631163
if (include_columnwise) {
@@ -1172,7 +1172,7 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
11721172
NVTEBasicTensor col_tensor{grouped.columnwise_data.get(),
11731173
static_cast<NVTEDType>(dtype),
11741174
grouped.logical_shape};
1175-
nvte_set_grouped_tensor_param(&h, kNVTEGroupedColumnwiseData, &col_tensor);
1175+
nvte_set_grouped_tensor_param(h, kNVTEGroupedColumnwiseData, &col_tensor, sizeof(col_tensor));
11761176
}
11771177

11781178
if (!same_first) {
@@ -1181,7 +1181,7 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
11811181
num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice));
11821182
NVTEShape fd_shape = nvte_make_shape(&num_tensors, 1);
11831183
NVTEBasicTensor fd_tensor{grouped.first_dims_dev.get(), kNVTEInt64, fd_shape};
1184-
nvte_set_grouped_tensor_param(&h, kNVTEGroupedFirstDims, &fd_tensor);
1184+
nvte_set_grouped_tensor_param(h, kNVTEGroupedFirstDims, &fd_tensor, sizeof(fd_tensor));
11851185
}
11861186

11871187
if (!same_last) {
@@ -1190,7 +1190,7 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
11901190
num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice));
11911191
NVTEShape ld_shape = nvte_make_shape(&num_tensors, 1);
11921192
NVTEBasicTensor ld_tensor{grouped.last_dims_dev.get(), kNVTEInt64, ld_shape};
1193-
nvte_set_grouped_tensor_param(&h, kNVTEGroupedLastDims, &ld_tensor);
1193+
nvte_set_grouped_tensor_param(h, kNVTEGroupedLastDims, &ld_tensor, sizeof(ld_tensor));
11941194
}
11951195

11961196
if (!same_first || !same_last) {
@@ -1199,7 +1199,7 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
11991199
num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice));
12001200
NVTEShape off_shape = nvte_make_shape(&num_tensors, 1);
12011201
NVTEBasicTensor off_tensor{grouped.offsets_dev.get(), kNVTEInt64, off_shape};
1202-
nvte_set_grouped_tensor_param(&h, kNVTEGroupedTensorOffsets, &off_tensor);
1202+
nvte_set_grouped_tensor_param(h, kNVTEGroupedTensorOffsets, &off_tensor, sizeof(off_tensor));
12031203
}
12041204

12051205
if (isFp8Type(dtype)) {
@@ -1213,8 +1213,10 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
12131213
sizeof(float) * num_tensors, cudaMemcpyHostToDevice));
12141214
NVTEShape scale_shape = nvte_make_shape(&num_tensors, 1);
12151215
NVTEBasicTensor scale_tensor{grouped.scale_inv.get(), kNVTEFloat32, scale_shape};
1216-
nvte_set_grouped_tensor_param(&h, kNVTEGroupedRowwiseScaleInv, &scale_tensor);
1217-
nvte_set_grouped_tensor_param(&h, kNVTEGroupedColumnwiseScaleInv, &scale_tensor);
1216+
nvte_set_grouped_tensor_param(h, kNVTEGroupedRowwiseScaleInv, &scale_tensor,
1217+
sizeof(scale_tensor));
1218+
nvte_set_grouped_tensor_param(h, kNVTEGroupedColumnwiseScaleInv, &scale_tensor,
1219+
sizeof(scale_tensor));
12181220
}
12191221

12201222
return grouped;

tests/pytorch/mxfp8/mxfp8_utils.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
5+
import torch
6+
import math
7+
8+
9+
# Calculate the shape of the scaling tensor for MXFP8 1D blockwise quantization without padding
10+
def get_mxfp8_scale_shape_no_padding(shape, columnwise):
11+
M, K = 1, 1
12+
M = math.prod(shape[:-1])
13+
K = shape[-1]
14+
15+
if columnwise:
16+
outer = M // 32
17+
inner = K
18+
return (outer, inner)
19+
# rowwise
20+
outer = M
21+
inner = K // 32
22+
return (outer, inner)
23+
24+
25+
def _rowwise_swizzle_mxfp8_scale(input_M, input_N, scale: torch.Tensor) -> torch.Tensor:
26+
assert scale.dim() == 2
27+
assert input_M == scale.shape[0]
28+
assert input_N // 32 == scale.shape[1]
29+
30+
x = scale.view(input_M // 128, 4, 32, input_N // 128, 4)
31+
x = x.permute(0, 3, 2, 1, 4)
32+
x = x.contiguous()
33+
# View back as original 2D shape
34+
x = x.view(input_M, input_N // 32)
35+
return x
36+
37+
38+
def _columnwise_swizzle_mxfp8_scale(input_M, input_N, scale: torch.Tensor) -> torch.Tensor:
39+
assert scale.dim() == 2
40+
assert input_M // 32 == scale.shape[0]
41+
assert input_N == scale.shape[1]
42+
43+
x = scale.view(input_M // 128, 4, input_N // 128, 4, 32)
44+
x = x.permute(2, 0, 4, 3, 1)
45+
x = x.contiguous()
46+
47+
# alternative way: transpose the scale and do rowwise swizzle with M, N swapped
48+
x1 = _rowwise_swizzle_mxfp8_scale(input_N, input_M, scale.transpose(0, 1).contiguous())
49+
torch.testing.assert_close(
50+
x.view(-1), x1.view(-1), atol=0.0, rtol=0.0, msg="columnwise swizzle sanity check failed"
51+
)
52+
53+
# View back as original 2D shape
54+
x = x.view(input_M // 32, input_N)
55+
return x
56+
57+
58+
def swizzle_mxfp8_scale(input_M, input_N, scale: torch.Tensor, columnwise: bool) -> torch.Tensor:
59+
if not columnwise:
60+
return _rowwise_swizzle_mxfp8_scale(input_M, input_N, scale)
61+
else:
62+
return _columnwise_swizzle_mxfp8_scale(input_M, input_N, scale)

0 commit comments

Comments
 (0)