Skip to content

Commit 864c484

Browse files
committed
Integrate NVFP4 Graph Safe Group Quantize (#14)
* nvfp4 grouped quantize Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * fix for paged stashing Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * pass all edge cases Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * clean up Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * fix for other recipes Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> --------- Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
1 parent e3278dd commit 864c484

8 files changed

Lines changed: 714 additions & 63 deletions

File tree

tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py

Lines changed: 515 additions & 0 deletions
Large diffs are not rendered by default.

transformer_engine/common/cast/cast.cu

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,8 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs,
124124
}
125125

126126
// Group quantize assumes contiguous inputs and outputs in memory allocation
127-
// TODO (zhongbo): find a better way to make it a more generalized API
127+
// Note: this API assumes knowing split sections from the host, if split information
128+
// comes from D2H copy, it will break cuda graph capture
128129
void nvte_group_nvfp4_quantize_with_amax(const NVTETensor input, NVTETensor *outputs,
129130
const size_t *split_sections, const size_t num_tensors,
130131
const NVTEQuantizationConfig quant_config,
@@ -134,6 +135,6 @@ void nvte_group_nvfp4_quantize_with_amax(const NVTETensor input, NVTETensor *out
134135

135136
constexpr bool IS_ACT = false;
136137

137-
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, outputs, split_sections,
138-
num_tensors, quant_config, stream);
138+
dispatch::group_quantize_fwd_host_aware_helper<IS_ACT, Empty, nullptr>(
139+
input, outputs, split_sections, num_tensors, quant_config, stream);
139140
}

transformer_engine/common/cast/dispatch/quantize.cuh

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -308,10 +308,12 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens
308308
}
309309
}
310310

311+
// Host-aware and not graph-safe: group quantization with split section info from the host.
311312
template <bool IS_ACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
312-
void group_quantize_fwd_helper(const NVTETensor input, NVTETensor *outputs,
313-
const size_t *split_sections, const size_t num_tensors,
314-
const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
313+
void group_quantize_fwd_host_aware_helper(const NVTETensor input, NVTETensor *outputs,
314+
const size_t *split_sections, const size_t num_tensors,
315+
const NVTEQuantizationConfig quant_config,
316+
cudaStream_t stream) {
315317
using namespace detail;
316318

317319
const Tensor *input_tensor = convertNVTETensorCheck(input);

transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,11 @@ __global__ void GraphSafeGroupHadamardAmaxTmaKernel(
251251

252252
// calculate the global offset to get tensor id
253253
size_t global_offset = blockIdx.y * CHUNK_DIM_Y * last_logical_dim;
254+
// paged stashing: will have input buffer [M, N], where M is larger than sum(first_dims)
255+
// also need to early return if this CTA is processing a region larger than the last offsets[num_tensors]
256+
if (global_offset >= offsets_ptr[num_tensors]) {
257+
return;
258+
}
254259
int tensor_id = get_current_tensor_id(shape_rep, num_tensors, global_offset, first_logical_dim,
255260
last_logical_dim, offsets_ptr);
256261
output_pre_rht_amax_ptr = static_cast<float*>(amax_rowwise_ptr) + tensor_id;
@@ -440,9 +445,8 @@ void group_hadamard_transform_amax_graph_safe(const GroupedTensor* input, Groupe
440445
float* const amax_rowwise_ptr = reinterpret_cast<float*>(output->amax.dptr);
441446
float* const amax_colwise_ptr = reinterpret_cast<float*>(output->columnwise_amax.dptr);
442447

443-
const int64_t* const offsets_ptr = reinterpret_cast<const int64_t*>(input->tensor_offsets.dptr);
444-
const int64_t* const first_dims_ptr = reinterpret_cast<const int64_t*>(input->first_dims.dptr);
445-
// const int64_t *const last_dims_ptr = reinterpret_cast<const int64_t *>(input->last_dims.dptr);
448+
const int64_t* const offsets_ptr = reinterpret_cast<const int64_t*>(output->tensor_offsets.dptr);
449+
const int64_t* const first_dims_ptr = reinterpret_cast<const int64_t*>(output->first_dims.dptr);
446450

447451
// some sanity checks
448452
if (all_return_pre_rht_amax) {

transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1428,9 +1428,8 @@ void group_hadamard_transform_cast_fusion_graph_safe(const GroupedTensor *input,
14281428
float *const amax_rowwise_base_ptr = reinterpret_cast<float *>(output->amax.dptr);
14291429
float *const amax_colwise_base_ptr = reinterpret_cast<float *>(output->columnwise_amax.dptr);
14301430

1431-
const int64_t *const offsets_ptr = reinterpret_cast<const int64_t *>(input->tensor_offsets.dptr);
1432-
const int64_t *const first_dims_ptr = reinterpret_cast<const int64_t *>(input->first_dims.dptr);
1433-
// const int64_t *const last_dims_ptr = reinterpret_cast<const int64_t *>(input->last_dims.dptr);
1431+
const int64_t *const offsets_ptr = reinterpret_cast<const int64_t *>(output->tensor_offsets.dptr);
1432+
const int64_t *const first_dims_ptr = reinterpret_cast<const int64_t *>(output->first_dims.dptr);
14341433

14351434
const bool is_const_last_dim = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS ||
14361435
shape_rep == ShapeRepresentation::VARYING_FIRST_DIM);

transformer_engine/common/transformer_engine.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,8 +1145,8 @@ NVTEGroupedTensor nvte_create_grouped_tensor(NVTEScalingMode scaling_mode, size_
11451145
NVTEShape logical_shape) {
11461146
NVTE_CHECK(num_tensors > 0, "Number of tensors must be greater than 0");
11471147
NVTE_CHECK(logical_shape.ndim == 2, "Logical shape must be 2D");
1148-
NVTE_CHECK(logical_shape.data[0] > 0 && logical_shape.data[1] > 0,
1149-
"Logical shape must have positive dimensions");
1148+
// NVTE_CHECK(logical_shape.data[0] > 0 && logical_shape.data[1] > 0,
1149+
// "Logical shape must have positive dimensions");
11501150
NVTEGroupedTensor ret = transformer_engine::GroupedTensorAllocator::instance().Allocate(
11511151
scaling_mode, num_tensors, logical_shape);
11521152
return ret;

transformer_engine/pytorch/csrc/extensions/cast.cpp

Lines changed: 119 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,82 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob
8080
return output_py;
8181
}
8282

83+
namespace {
84+
85+
// helper functions for NVFP4 grouped quantization (cuda graph safe with shapes stored in device without D2H copy)
86+
void group_quantize_nvfp4_impl(const GroupedTensorWrapper &grouped_input_tensor,
87+
GroupedTensorWrapper &grouped_output_tensor,
88+
NVFP4Quantizer *nvfp4_quantizer_cpp, cudaStream_t stream) {
89+
size_t num_tensors = grouped_input_tensor.num_tensors();
90+
91+
// assert the 2D scaling case, since 2D scaling grouped quant kernel is not ready yet
92+
NVTE_CHECK(!nvfp4_quantizer_cpp->with_2d_quantization,
93+
"2D scaling grouped quant kernel is not ready yet");
94+
95+
auto quant_config_cpp = QuantizationConfigWrapper();
96+
97+
// stochastic rounding
98+
bool need_stochastic_rounding = nvfp4_quantizer_cpp->stochastic_rounding;
99+
auto opts = at::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA);
100+
at::Tensor rng_states_tensor; // Declare tensor outside, do not allocate yet
101+
TensorWrapper te_rng_state;
102+
103+
if (need_stochastic_rounding) {
104+
// in fused kernel, one rng state will be used by the grouped kernel to generate random
105+
// number for different tensors in the group, so we only need to allocate one rng state
106+
const size_t rng_elts_per_thread = 1024 * num_tensors;
107+
rng_states_tensor = torch::empty({2}, opts);
108+
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
109+
std::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
110+
at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread);
111+
philox_unpack(philox_args, static_cast<int64_t *>(rng_states_tensor.data_ptr()));
112+
113+
te_rng_state = makeTransformerEngineTensor(rng_states_tensor);
114+
quant_config_cpp.set_rng_state(te_rng_state.data());
115+
quant_config_cpp.set_stochastic_rounding(true);
116+
}
117+
118+
// fast math
119+
const auto use_fast_math = transformer_engine::getenv<bool>("NVTE_USE_FAST_MATH");
120+
if (use_fast_math) {
121+
quant_config_cpp.set_use_fast_math(true);
122+
}
123+
124+
// so far, only the RHT path has grouped kernel support
125+
// grouped kernels for non-RHT path will be added later
126+
127+
if (nvfp4_quantizer_cpp->with_rht) {
128+
// post-RHT amax or not
129+
if (nvfp4_quantizer_cpp->with_post_rht_amax) {
130+
NVTE_SCOPED_GIL_RELEASE({
131+
nvte_group_hadamard_transform_amax_graph_safe(
132+
grouped_input_tensor.data(), grouped_output_tensor.data(), 0,
133+
nvfp4_quantizer_cpp->rht_matrix_random_sign_mask_t, stream);
134+
});
135+
} else {
136+
NVTE_ERROR("graph safe grouped quant kernel for non-RHT path is not ready yet");
137+
}
138+
139+
// RHT cast fusion
140+
auto tile_scheduler_workspace_torch =
141+
at::empty({1}, at::device(at::kCUDA).dtype(torch::kInt32));
142+
auto nvte_tile_scheduler_workspace =
143+
makeTransformerEngineTensor(tile_scheduler_workspace_torch);
144+
145+
auto rht_matrix_nvte = makeTransformerEngineTensor(nvfp4_quantizer_cpp->rht_matrix);
146+
NVTE_SCOPED_GIL_RELEASE({
147+
nvte_group_hadamard_transform_cast_fusion_graph_safe(
148+
grouped_input_tensor.data(), grouped_output_tensor.data(), rht_matrix_nvte.data(),
149+
quant_config_cpp, nvte_tile_scheduler_workspace.data(), stream);
150+
});
151+
152+
} else {
153+
NVTE_ERROR("graph safe grouped quant kernel for non-RHT path is not ready yet");
154+
}
155+
}
156+
157+
} // namespace
158+
83159
// NOTE: Only supports varying first dim.
84160
py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const size_t num_tensors,
85161
std::optional<at::Tensor> first_dims) {
@@ -95,6 +171,8 @@ py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const
95171
const auto logical_first_dim = logical_shape[0];
96172
const auto logical_last_dim = logical_shape[1];
97173

174+
bool empty_input_buffer = logical_first_dim == 0 || logical_last_dim == 0;
175+
98176
auto quantizer_cpp = convert_quantizer(quantizer);
99177

100178
// Create input GroupedTensor.
@@ -108,10 +186,47 @@ py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const
108186
py::reinterpret_borrow<py::object>(quantizer), first_dims, logical_first_dim,
109187
logical_last_dim);
110188

111-
NVTE_SCOPED_GIL_RELEASE({
112-
nvte_group_quantize(grouped_input_tensor.data(), grouped_output_tensor_cpp.data(),
113-
at::cuda::getCurrentCUDAStream());
114-
});
189+
// dispatch to scaling methods
190+
enum class GroupedQuantizationMode {
191+
MXFP8_GROUPED_QUANTIZE,
192+
NVFP4_GROUPED_QUANTIZE,
193+
INVALID_FOR_GROUPED_QUANTIZE
194+
};
195+
GroupedQuantizationMode grouped_quantization_mode =
196+
GroupedQuantizationMode::INVALID_FOR_GROUPED_QUANTIZE;
197+
if (detail::IsMXFP8Quantizers(quantizer.ptr())) {
198+
grouped_quantization_mode = GroupedQuantizationMode::MXFP8_GROUPED_QUANTIZE;
199+
} else if (detail::IsNVFP4Quantizers(quantizer.ptr())) {
200+
grouped_quantization_mode = GroupedQuantizationMode::NVFP4_GROUPED_QUANTIZE;
201+
}
202+
203+
if (empty_input_buffer) {
204+
// early return for empty input buffer
205+
// just return the output tensor as is
206+
// no need to quantize
207+
return py::reinterpret_borrow<py::object>(grouped_output_py);
208+
}
209+
210+
switch (grouped_quantization_mode) {
211+
case GroupedQuantizationMode::NVFP4_GROUPED_QUANTIZE: {
212+
// NVFP4 grouped quantization
213+
NVFP4Quantizer *nvfp4_quantizer_cpp = static_cast<NVFP4Quantizer *>(quantizer_cpp.get());
214+
group_quantize_nvfp4_impl(grouped_input_tensor, grouped_output_tensor_cpp,
215+
nvfp4_quantizer_cpp, at::cuda::getCurrentCUDAStream());
216+
break;
217+
}
218+
case GroupedQuantizationMode::MXFP8_GROUPED_QUANTIZE: {
219+
NVTE_SCOPED_GIL_RELEASE({
220+
nvte_group_quantize(grouped_input_tensor.data(), grouped_output_tensor_cpp.data(),
221+
at::cuda::getCurrentCUDAStream());
222+
});
223+
break;
224+
}
225+
case GroupedQuantizationMode::INVALID_FOR_GROUPED_QUANTIZE:
226+
default:
227+
NVTE_ERROR("group_quantize: only support NVFP4 or MXFP8 quantizer.");
228+
break;
229+
}
115230

116231
return py::reinterpret_borrow<py::object>(grouped_output_py);
117232
}

0 commit comments

Comments
 (0)