Skip to content
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
d0dbe66
rowwise colwise RHT group quant v1
zhongbozhu Nov 26, 2025
b345534
remove local array RW
zhongbozhu Dec 2, 2025
2eb23b3
change wait_barrier
zhongbozhu Dec 2, 2025
004e529
fast math options
zhongbozhu Dec 3, 2025
d9a6c24
use mult to replace div
zhongbozhu Dec 3, 2025
9b9efb8
format
zhongbozhu Dec 3, 2025
a9d0fc5
bulk move random states
zhongbozhu Dec 3, 2025
1af82af
greptile
zhongbozhu Dec 3, 2025
b4515d2
lint
zhongbozhu Dec 3, 2025
626e3fe
revert to use divides
zhongbozhu Dec 4, 2025
fc6f7f2
avoid fp32 bf16 round-trip in RHT cast fusion
zhongbozhu Dec 4, 2025
48e5d75
trigger fastmath by toggle NVTE_RHT_CAST_FUSION_USE_FAST_MATH
zhongbozhu Dec 5, 2025
3d07a9b
integrate row col rht fusion, functional
zhongbozhu Dec 12, 2025
70523c8
numerics aligned
zhongbozhu Dec 12, 2025
0388466
style
zhongbozhu Dec 12, 2025
27f1047
remove device sync
zhongbozhu Dec 13, 2025
380a116
128 padding
zhongbozhu Dec 13, 2025
f61979a
revert colwise rng state creation because of row-col fused kernel
zhongbozhu Dec 16, 2025
6f38c78
fix CI, linter
zhongbozhu Dec 16, 2025
badcf74
refactor RS for generating two random values
zhongbozhu Dec 16, 2025
0d245ae
Avoid invalid configs with templated kernel
timmoon10 Dec 18, 2025
83e7bf2
fix acc pipeline init with 0 arrival count
zhongbozhu Dec 18, 2025
b554bef
restore rowwise-only mode
zhongbozhu Dec 18, 2025
247a20b
switch to dynamic atomic scheduler
zhongbozhu Dec 19, 2025
4df34ce
Avoid instantiating group RHT+cast kernel without row-wise or col-wis…
timmoon10 Dec 19, 2025
cbdda20
Include fast math option in quantization config
timmoon10 Dec 19, 2025
0ac4d74
Fix linter warnings and review nits
timmoon10 Dec 19, 2025
d98b732
Merge branch 'main' into zhongbo/multi_rht_cast_colwise_fuse
timmoon10 Dec 19, 2025
c14b156
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 19, 2025
40ae64c
Use TE license
timmoon10 Dec 19, 2025
15e1edb
Fix bug where kernel is always launched on stream
timmoon10 Dec 20, 2025
79cc660
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 20, 2025
8534c38
Restore BF16 intermediate downcast in fused RHT-cast kernels
timmoon10 Dec 20, 2025
57db30f
fix numerical test of grouped kernel
zhongbozhu Dec 20, 2025
b258ca9
Make sure row-wise and col-wise quantization use different RNG seeds
timmoon10 Dec 20, 2025
d79c2ac
Merge branch 'main' into zhongbo/multi_rht_cast_colwise_fuse
timmoon10 Dec 20, 2025
66ac756
Restore autoformatter
timmoon10 Dec 20, 2025
376687c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions benchmarks/linear/benchmark_grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
--set=full \
--kernel-name "GroupHadamardAmaxTmaKernel" \
-s 5 -c 5 \
python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe nvfp4 --profile
python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe nvfp4

"""

Expand Down Expand Up @@ -173,7 +173,9 @@ def benchmark_linear(
return timing_ms


def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None):
def run_benchmark_linear(
mkns, recipe_name, use_bias, num_gemms=4, m_splits_provided=None, fwd_only=False
):
data = []
assert not use_bias, "Bias is not supported for GroupedLinear benchmark"

Expand All @@ -182,22 +184,22 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None
device = "cuda"
x = torch.randn((m, k), dtype=torch.bfloat16, device=device, requires_grad=True)
ws = [torch.randn((n, k), dtype=torch.bfloat16, device=device) for _ in range(num_gemms)]
assert m % num_gemms == 0
m_splits = [m // num_gemms] * num_gemms if m_splits is None else m_splits
m_splits = [m // num_gemms] * num_gemms if m_splits_provided is None else m_splits_provided
# Bias is not supported for GroupedLinear benchmark
bias = None

# Run the benchmark
print(f"fwd_m={m}, fwd_k={k}, fwd_n={n}")
print(f"m_splits: {m_splits}")
print(f"fwd_only: {fwd_only}")

grouped_fwd_bwd_timing_ms = benchmark_linear(
x,
ws,
m_splits,
bias,
recipe_name,
mode="fwd_bwd",
mode="fwd_only" if fwd_only else "fwd_bwd",
num_gemms=num_gemms,
)

Expand All @@ -213,6 +215,8 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None
]
)

timing_notation = "grouped_fwd_time_ms" if fwd_only else "grouped_fwd_bwd_time_ms"

df = pd.DataFrame(
data=data,
columns=[
Expand All @@ -221,7 +225,7 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None
"n",
"recipe",
"num_gemms",
"grouped_fwd_bwd_time_ms",
timing_notation,
],
)

Expand All @@ -234,7 +238,7 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None
parser = argparse.ArgumentParser()
parser.add_argument("--profile", action="store_true", help="Enable profiling mode")
parser.add_argument(
"--output_dir",
"--output-dir",
type=str,
default="benchmark_output/",
help="output path for report",
Expand Down Expand Up @@ -266,6 +270,12 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None
default=2048,
help="Output dimension to use, default is 2048",
)
parser.add_argument(
"--fwd-only",
action="store_true",
default=False,
help="Run forward pass only, default is both forward and backward passes",
)
args = parser.parse_args()

jagged_input_splits = None
Expand Down Expand Up @@ -297,7 +307,7 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None
if jagged_input_splits is not None:
num_gemms_list = [len(jagged_input_splits)]

token_dim_list = [65536]
token_dim_list = [16384, 32768, 65536, 98304]
hidden_dim_list = [7168]
output_dim_list = [2048]

Expand Down Expand Up @@ -371,7 +381,8 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None
recipe_name,
use_bias,
num_gemms=num_gemms,
m_splits=jagged_input_splits,
m_splits_provided=jagged_input_splits,
fwd_only=args.fwd_only,
)
df_linears = pd.concat([df_linears, df])

Expand Down
5 changes: 3 additions & 2 deletions tests/pytorch/nvfp4/test_nvfp4_group_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def check_group_quantization_nvfp4_versus_reference(

for i in range(len(x_qx)):
if split_sections[i] == 0:
# then just assert the same same and dtype because the buffer won't be zero out
# then just assert the same shape and dtype because the buffer won't be zero out
assert_same_shape_and_dtype(x_amax_rowwise[i], x_amax_rowwise_ref[i])
assert_same_shape_and_dtype(x_qx[i], x_qx_ref[i])
assert_same_shape_and_dtype(x_sx[i], x_sx_ref[i])
Expand All @@ -221,7 +221,7 @@ def check_group_quantization_nvfp4_versus_reference(
# assert with zero tolerance
for i in range(len(x_qx_t)):
if split_sections[i] == 0:
# then just assert the same same and dtype because the buffer won't be zero out
# then just assert the same shape and dtype because the buffer won't be zero out
assert_same_shape_and_dtype(x_amax_colwise[i], x_amax_colwise_ref[i])
assert_same_shape_and_dtype(x_qx_t[i], x_qx_t_ref[i])
assert_same_shape_and_dtype(x_sx_t[i], x_sx_t_ref[i])
Expand All @@ -247,6 +247,7 @@ def check_group_quantization_nvfp4_versus_reference(
(1024, 256),
# larger sizes
(8192, 1024),
(16384, 8192),
(16384, 16384),
],
)
Expand Down
2 changes: 2 additions & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ list(APPEND transformer_engine_cuda_arch_specific_sources
hadamard_transform/group_hadamard_transform.cu
hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu
hadamard_transform/group_hadamard_transform_cast_fusion.cu
hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu
multi_tensor/compute_scale.cu
recipe/mxfp8_scaling.cu
transpose/quantize_transpose_square_blockwise.cu
Expand Down
15 changes: 15 additions & 0 deletions transformer_engine/common/cast/cast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,18 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs,
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, detail::get_compute_stream_event(s)));
}
}

// Group quantize assumes contiguous inputs and outputs in memory allocation
// TODO (zhongbo): find a better way to make it a more generalized API
void nvte_group_nvfp4_quantize_with_amax(const NVTETensor input, NVTETensor *outputs,
const size_t *split_sections, const size_t num_tensors,
const NVTEQuantizationConfig quant_config,
cudaStream_t stream) {
NVTE_API_CALL(nvte_group_nvfp4_quantize_with_amax);
using namespace transformer_engine;

constexpr bool IS_ACT = false;

dispatch::group_quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, outputs, split_sections,
num_tensors, quant_config, stream);
}
65 changes: 65 additions & 0 deletions transformer_engine/common/cast/dispatch/quantize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "../core/common.cuh"
#include "../fp8/quantize_fp8.cuh"
#include "../mxfp8/quantize_mxfp8.cuh"
#include "../nvfp4/group_quantize_transpose_nvfp4.cuh"
#include "../nvfp4/quantize_nvfp4.cuh"
#include "../nvfp4/quantize_transpose_nvfp4.cuh"

Expand Down Expand Up @@ -320,6 +321,70 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens
}
}

template <bool IS_ACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
void group_quantize_fwd_helper(const NVTETensor input, NVTETensor *outputs,
Comment on lines +324 to +325

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're copy-pasting the templated infrastructure for single-tensor quantization, and inheriting all of its complexity to handle unnecessary features.

const size_t *split_sections, const size_t num_tensors,
const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
using namespace detail;

const Tensor *input_tensor = convertNVTETensorCheck(input);
std::vector<Tensor *> output_tensors;
for (size_t i = 0; i < num_tensors; ++i) {
output_tensors.push_back(convertNVTETensorCheck(outputs[i]));
}

// Quantization config
QuantizationConfig quant_config_cpp;
if (quant_config != nullptr) {
quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config);
}

// Noop flag
Tensor dummy_tensor;
Tensor *noop_tensor = &dummy_tensor;
if (quant_config_cpp.noop_tensor != nullptr) {
noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor);
}

// Check for unsupported options
if (quant_config_cpp.stochastic_rounding) {
NVTE_CHECK(output_tensors[0]->scaling_mode == NVTE_NVFP4_1D_SCALING,
"Stochastic rounding is only supported for NVFP4 quantization.");
}

// Take the scaling mode of the first output tensor
auto scaling_mode = output_tensors[0]->scaling_mode;
Comment thread
zhongbozhu marked this conversation as resolved.

// Dispatch to quantization kernel depending on data format
switch (scaling_mode) {
case NVTE_NVFP4_1D_SCALING: {
NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING");

// Check tensors
CheckNoopTensor(*noop_tensor, "cast_noop");
CheckInputTensor(*input_tensor, "input");
// Skip checking output tensor list
// output list here is allowed to have empty tensor

// Choose kernel
int32_t rows = input_tensor->flat_first_dim();
int32_t cols = input_tensor->flat_last_dim();
auto dtype = input_tensor->dtype();

NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization,
"2D quantization is not supported for group quantize.");

// Launch NVFP4 group quantize kernel
nvfp4::group_quantize_transpose</*use_2d_quantization*/ false>(
*input_tensor, noop_tensor, output_tensors, split_sections, num_tensors,
&quant_config_cpp, stream);
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + ".");
}
}

} // namespace dispatch
} // namespace transformer_engine

Expand Down
Loading
Loading