@@ -80,6 +80,82 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob
8080 return output_py;
8181}
8282
83+ namespace {
84+
85+ // helper functions for NVFP4 grouped quantization (cuda graph safe with shapes stored in device without D2H copy)
86+ void group_quantize_nvfp4_impl (const GroupedTensorWrapper &grouped_input_tensor,
87+ GroupedTensorWrapper &grouped_output_tensor,
88+ NVFP4Quantizer *nvfp4_quantizer_cpp, cudaStream_t stream) {
89+ size_t num_tensors = grouped_input_tensor.num_tensors ();
90+
91+ // assert the 2D scaling case, since 2D scaling grouped quant kernel is not ready yet
92+ NVTE_CHECK (!nvfp4_quantizer_cpp->with_2d_quantization ,
93+ " 2D scaling grouped quant kernel is not ready yet" );
94+
95+ auto quant_config_cpp = QuantizationConfigWrapper ();
96+
97+ // stochastic rounding
98+ bool need_stochastic_rounding = nvfp4_quantizer_cpp->stochastic_rounding ;
99+ auto opts = at::TensorOptions ().dtype (torch::kInt64 ).device (torch::kCUDA );
100+ at::Tensor rng_states_tensor; // Declare tensor outside, do not allocate yet
101+ TensorWrapper te_rng_state;
102+
103+ if (need_stochastic_rounding) {
104+ // in fused kernel, one rng state will be used by the grouped kernel to generate random
105+ // number for different tensors in the group, so we only need to allocate one rng state
106+ const size_t rng_elts_per_thread = 1024 * num_tensors;
107+ rng_states_tensor = torch::empty ({2 }, opts);
108+ auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
109+ std::nullopt , at::cuda::detail::getDefaultCUDAGenerator ());
110+ at::PhiloxCudaState philox_args = init_philox_state (gen, rng_elts_per_thread);
111+ philox_unpack (philox_args, static_cast <int64_t *>(rng_states_tensor.data_ptr ()));
112+
113+ te_rng_state = makeTransformerEngineTensor (rng_states_tensor);
114+ quant_config_cpp.set_rng_state (te_rng_state.data ());
115+ quant_config_cpp.set_stochastic_rounding (true );
116+ }
117+
118+ // fast math
119+ const auto use_fast_math = transformer_engine::getenv<bool >(" NVTE_USE_FAST_MATH" );
120+ if (use_fast_math) {
121+ quant_config_cpp.set_use_fast_math (true );
122+ }
123+
124+ // so far, only the RHT path has grouped kernel support
125+ // grouped kernels for non-RHT path will be added later
126+
127+ if (nvfp4_quantizer_cpp->with_rht ) {
128+ // post-RHT amax or not
129+ if (nvfp4_quantizer_cpp->with_post_rht_amax ) {
130+ NVTE_SCOPED_GIL_RELEASE ({
131+ nvte_group_hadamard_transform_amax_graph_safe (
132+ grouped_input_tensor.data (), grouped_output_tensor.data (), 0 ,
133+ nvfp4_quantizer_cpp->rht_matrix_random_sign_mask_t , stream);
134+ });
135+ } else {
136+ NVTE_ERROR (" graph safe grouped quant kernel for non-RHT path is not ready yet" );
137+ }
138+
139+ // RHT cast fusion
140+ auto tile_scheduler_workspace_torch =
141+ at::empty ({1 }, at::device (at::kCUDA ).dtype (torch::kInt32 ));
142+ auto nvte_tile_scheduler_workspace =
143+ makeTransformerEngineTensor (tile_scheduler_workspace_torch);
144+
145+ auto rht_matrix_nvte = makeTransformerEngineTensor (nvfp4_quantizer_cpp->rht_matrix );
146+ NVTE_SCOPED_GIL_RELEASE ({
147+ nvte_group_hadamard_transform_cast_fusion_graph_safe (
148+ grouped_input_tensor.data (), grouped_output_tensor.data (), rht_matrix_nvte.data (),
149+ quant_config_cpp, nvte_tile_scheduler_workspace.data (), stream);
150+ });
151+
152+ } else {
153+ NVTE_ERROR (" graph safe grouped quant kernel for non-RHT path is not ready yet" );
154+ }
155+ }
156+
157+ } // namespace
158+
83159// NOTE: Only supports varying first dim.
84160py::object group_quantize (const at::Tensor &tensor, py::handle quantizer, const size_t num_tensors,
85161 std::optional<at::Tensor> first_dims) {
@@ -95,6 +171,8 @@ py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const
95171 const auto logical_first_dim = logical_shape[0 ];
96172 const auto logical_last_dim = logical_shape[1 ];
97173
174+ bool empty_input_buffer = logical_first_dim == 0 || logical_last_dim == 0 ;
175+
98176 auto quantizer_cpp = convert_quantizer (quantizer);
99177
100178 // Create input GroupedTensor.
@@ -108,10 +186,47 @@ py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const
108186 py::reinterpret_borrow<py::object>(quantizer), first_dims, logical_first_dim,
109187 logical_last_dim);
110188
111- NVTE_SCOPED_GIL_RELEASE ({
112- nvte_group_quantize (grouped_input_tensor.data (), grouped_output_tensor_cpp.data (),
113- at::cuda::getCurrentCUDAStream ());
114- });
189+ // dispatch to scaling methods
190+ enum class GroupedQuantizationMode {
191+ MXFP8_GROUPED_QUANTIZE ,
192+ NVFP4_GROUPED_QUANTIZE ,
193+ INVALID_FOR_GROUPED_QUANTIZE
194+ };
195+ GroupedQuantizationMode grouped_quantization_mode =
196+ GroupedQuantizationMode::INVALID_FOR_GROUPED_QUANTIZE ;
197+ if (detail::IsMXFP8Quantizers (quantizer.ptr ())) {
198+ grouped_quantization_mode = GroupedQuantizationMode::MXFP8_GROUPED_QUANTIZE ;
199+ } else if (detail::IsNVFP4Quantizers (quantizer.ptr ())) {
200+ grouped_quantization_mode = GroupedQuantizationMode::NVFP4_GROUPED_QUANTIZE ;
201+ }
202+
203+ if (empty_input_buffer) {
204+ // early return for empty input buffer
205+ // just return the output tensor as is
206+ // no need to quantize
207+ return py::reinterpret_borrow<py::object>(grouped_output_py);
208+ }
209+
210+ switch (grouped_quantization_mode) {
211+ case GroupedQuantizationMode::NVFP4_GROUPED_QUANTIZE : {
212+ // NVFP4 grouped quantization
213+ NVFP4Quantizer *nvfp4_quantizer_cpp = static_cast <NVFP4Quantizer *>(quantizer_cpp.get ());
214+ group_quantize_nvfp4_impl (grouped_input_tensor, grouped_output_tensor_cpp,
215+ nvfp4_quantizer_cpp, at::cuda::getCurrentCUDAStream ());
216+ break ;
217+ }
218+ case GroupedQuantizationMode::MXFP8_GROUPED_QUANTIZE : {
219+ NVTE_SCOPED_GIL_RELEASE ({
220+ nvte_group_quantize (grouped_input_tensor.data (), grouped_output_tensor_cpp.data (),
221+ at::cuda::getCurrentCUDAStream ());
222+ });
223+ break ;
224+ }
225+ case GroupedQuantizationMode::INVALID_FOR_GROUPED_QUANTIZE :
226+ default :
227+ NVTE_ERROR (" group_quantize: only support NVFP4 or MXFP8 quantizer." );
228+ break ;
229+ }
115230
116231 return py::reinterpret_borrow<py::object>(grouped_output_py);
117232}
0 commit comments