Skip to content

Commit 893710f

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 3a1929f commit 893710f

7 files changed

Lines changed: 48 additions & 74 deletions

File tree

tests/pytorch/test_backward_override.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,9 @@ def _maybe_skip_unsupported_recipe_shape(
200200
" by 32."
201201
)
202202
return
203-
if recipe_name in ("nvfp4", "nvfp4_pertoken") and (flat_first_dim % 16 != 0 or last_dim % 16 != 0):
203+
if recipe_name in ("nvfp4", "nvfp4_pertoken") and (
204+
flat_first_dim % 16 != 0 or last_dim % 16 != 0
205+
):
204206
pytest.skip(
205207
"Linear/LayerNormLinear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible"
206208
" by 16."
@@ -225,7 +227,9 @@ def _maybe_skip_unsupported_recipe_shape(
225227
pytest.skip(
226228
"te_ops.Linear + MXFP8 requires prod(shape[:-1]) and shape[-1] divisible by 32."
227229
)
228-
if recipe_name in ("nvfp4", "nvfp4_pertoken") and (flat_first_dim % 16 != 0 or last_dim % 16 != 0):
230+
if recipe_name in ("nvfp4", "nvfp4_pertoken") and (
231+
flat_first_dim % 16 != 0 or last_dim % 16 != 0
232+
):
229233
pytest.skip(
230234
"te_ops.Linear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible by 16."
231235
)

tests/pytorch/test_nvfp4_pertoken_quant.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ def test_output_shapes(self, num_rows, num_cols, dtype):
9595

9696
assert data.shape == (num_rows, num_cols // 2), f"data shape: {data.shape}"
9797
assert scales.shape == (num_rows, num_cols // 16), f"scales shape: {scales.shape}"
98-
assert per_token_scales.shape == (num_rows,), f"per_token_scales shape: {per_token_scales.shape}"
98+
assert per_token_scales.shape == (
99+
num_rows,
100+
), f"per_token_scales shape: {per_token_scales.shape}"
99101
assert data.dtype == torch.uint8
100102
assert scales.dtype == torch.uint8
101103
assert per_token_scales.dtype == torch.float32

transformer_engine/common/cast/cast.cu

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -148,13 +148,9 @@ void nvte_group_nvfp4_quantize_with_amax(const NVTETensor input, NVTETensor *out
148148
input, outputs, split_sections, num_tensors, quant_config, stream);
149149
}
150150

151-
void nvte_quantize_nvfp4_pertoken(const NVTETensor input,
152-
NVTETensor output_data,
153-
NVTETensor output_scales,
154-
NVTETensor output_per_token_scales,
155-
size_t num_rows,
156-
size_t num_cols,
157-
cudaStream_t stream) {
151+
void nvte_quantize_nvfp4_pertoken(const NVTETensor input, NVTETensor output_data,
152+
NVTETensor output_scales, NVTETensor output_per_token_scales,
153+
size_t num_rows, size_t num_cols, cudaStream_t stream) {
158154
NVTE_API_CALL(nvte_quantize_nvfp4_pertoken);
159155
using namespace transformer_engine;
160156

@@ -170,24 +166,21 @@ void nvte_quantize_nvfp4_pertoken(const NVTETensor input,
170166

171167
if (itype == DType::kBFloat16) {
172168
dispatch::nvfp4::quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4<__nv_bfloat16>(
173-
num_rows, num_cols,
174-
reinterpret_cast<const __nv_bfloat16 *>(input_tensor.data.dptr),
169+
num_rows, num_cols, reinterpret_cast<const __nv_bfloat16 *>(input_tensor.data.dptr),
175170
nullptr, // row_offsets
176171
reinterpret_cast<uint8_t *>(data_tensor->data.dptr),
177172
reinterpret_cast<fp8e4m3 *>(scales_tensor->data.dptr),
178-
reinterpret_cast<float *>(pertoken_tensor->data.dptr),
179-
stream);
173+
reinterpret_cast<float *>(pertoken_tensor->data.dptr), stream);
180174
} else if (itype == DType::kFloat16) {
181175
dispatch::nvfp4::quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4<half>(
182-
num_rows, num_cols,
183-
reinterpret_cast<const half *>(input_tensor.data.dptr),
176+
num_rows, num_cols, reinterpret_cast<const half *>(input_tensor.data.dptr),
184177
nullptr, // row_offsets
185178
reinterpret_cast<uint8_t *>(data_tensor->data.dptr),
186179
reinterpret_cast<fp8e4m3 *>(scales_tensor->data.dptr),
187-
reinterpret_cast<float *>(pertoken_tensor->data.dptr),
188-
stream);
180+
reinterpret_cast<float *>(pertoken_tensor->data.dptr), stream);
189181
} else {
190-
NVTE_ERROR("Unsupported input dtype for per-token NVFP4 quantization. "
191-
"Expected BFloat16 or Float16.");
182+
NVTE_ERROR(
183+
"Unsupported input dtype for per-token NVFP4 quantization. "
184+
"Expected BFloat16 or Float16.");
192185
}
193186
}

transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh

Lines changed: 17 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
#include <cuda.h>
2727
#include <cuda_runtime.h>
28+
2829
#include <cub/cub.cuh>
2930

3031
#include "../../common.h"
@@ -74,24 +75,20 @@ __global__ void
7475
__launch_bounds__(BLOCK_SIZE)
7576
#endif
7677
quantize_pertoken_nvfp4_kernel(
77-
const int num_rows,
78-
const int num_cols,
79-
const IType *__restrict__ input,
78+
const int num_rows, const int num_cols, const IType *__restrict__ input,
8079
const int *__restrict__ row_offsets, // optional: nullptr for identity mapping
81-
uint8_t *__restrict__ output_data,
82-
fp8e4m3 *__restrict__ output_scales,
83-
float *__restrict__ output_per_token_scales,
84-
const int scale_stride) {
80+
uint8_t *__restrict__ output_data, fp8e4m3 *__restrict__ output_scales,
81+
float *__restrict__ output_per_token_scales, const int scale_stride) {
8582
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
8683

8784
using namespace detail;
88-
constexpr float fp8_max = TypeExtrema<fp8e4m3>::max; // 448.0f
89-
constexpr float fp4_max = TypeExtrema<fp4e2m1>::max; // 6.0f
85+
constexpr float fp8_max = TypeExtrema<fp8e4m3>::max; // 448.0f
86+
constexpr float fp4_max = TypeExtrema<fp4e2m1>::max; // 6.0f
9087
constexpr float fp4_max_inv = 1.0f / fp4_max;
9188

9289
// Packed type: 4 elements per float2 pair for FP4 conversion
93-
using IType2 = typename std::conditional<std::is_same<IType, half>::value,
94-
half2, __nv_bfloat162>::type;
90+
using IType2 =
91+
typename std::conditional<std::is_same<IType, half>::value, half2, __nv_bfloat162>::type;
9592

9693
const int row_idx = blockIdx.x;
9794
if (row_idx >= num_rows) return;
@@ -167,9 +164,7 @@ __launch_bounds__(BLOCK_SIZE)
167164
output_scales[row_idx * scale_stride + sf_idx] = S_dec_b;
168165

169166
// Compute inverse block scale for quantization
170-
float block_encode_scale = (S_dec_b_f != 0.0f)
171-
? __fdividef(S_enc, S_dec_b_f)
172-
: 0.0f;
167+
float block_encode_scale = (S_dec_b_f != 0.0f) ? __fdividef(S_enc, S_dec_b_f) : 0.0f;
173168

174169
// Quantize 16 elements to FP4 and pack into 8 bytes
175170
uint8_t *out_ptr = output_data + actual_row * (num_cols / 2) + col_start / 2;
@@ -190,30 +185,22 @@ __launch_bounds__(BLOCK_SIZE)
190185
* Host-side launcher for per-token NVFP4 quantization.
191186
*/
192187
template <typename IType>
193-
void launch_quantize_pertoken_nvfp4(
194-
const int num_rows,
195-
const int num_cols,
196-
const IType *input,
197-
const int *row_offsets,
198-
uint8_t *output_data,
199-
fp8e4m3 *output_scales,
200-
float *output_per_token_scales,
201-
cudaStream_t stream) {
188+
void launch_quantize_pertoken_nvfp4(const int num_rows, const int num_cols, const IType *input,
189+
const int *row_offsets, uint8_t *output_data,
190+
fp8e4m3 *output_scales, float *output_per_token_scales,
191+
cudaStream_t stream) {
202192
if (num_rows == 0 || num_cols == 0) return;
203193

204-
NVTE_CHECK(num_cols % PERTOKEN_SF_VEC_SIZE == 0,
205-
"num_cols must be a multiple of ", PERTOKEN_SF_VEC_SIZE,
206-
" for per-token NVFP4 quantization, got ", num_cols);
194+
NVTE_CHECK(num_cols % PERTOKEN_SF_VEC_SIZE == 0, "num_cols must be a multiple of ",
195+
PERTOKEN_SF_VEC_SIZE, " for per-token NVFP4 quantization, got ", num_cols);
207196

208197
const int scale_stride = num_cols / PERTOKEN_SF_VEC_SIZE;
209198
dim3 grid(num_rows);
210199
dim3 block(PERTOKEN_BLOCK_SIZE);
211200

212201
quantize_pertoken_nvfp4_kernel<IType, PERTOKEN_BLOCK_SIZE>
213-
<<<grid, block, 0, stream>>>(
214-
num_rows, num_cols, input, row_offsets,
215-
output_data, output_scales, output_per_token_scales,
216-
scale_stride);
202+
<<<grid, block, 0, stream>>>(num_rows, num_cols, input, row_offsets, output_data,
203+
output_scales, output_per_token_scales, scale_stride);
217204
NVTE_CHECK_CUDA(cudaGetLastError());
218205
}
219206

transformer_engine/common/include/transformer_engine/cast.h

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -466,13 +466,9 @@ void nvte_group_nvfp4_quantize_with_amax(const NVTETensor input, NVTETensor *out
466466
* \param[in] num_cols Number of columns (must be multiple of 16).
467467
* \param[in] stream CUDA stream.
468468
*/
469-
void nvte_quantize_nvfp4_pertoken(const NVTETensor input,
470-
NVTETensor output_data,
471-
NVTETensor output_scales,
472-
NVTETensor output_per_token_scales,
473-
size_t num_rows,
474-
size_t num_cols,
475-
cudaStream_t stream);
469+
void nvte_quantize_nvfp4_pertoken(const NVTETensor input, NVTETensor output_data,
470+
NVTETensor output_scales, NVTETensor output_per_token_scales,
471+
size_t num_rows, size_t num_cols, cudaStream_t stream);
476472

477473
#ifdef __cplusplus
478474
} // extern "C"

transformer_engine/pytorch/csrc/extensions/cast.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1556,8 +1556,7 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor,
15561556
return output_py_list;
15571557
}
15581558

1559-
std::tuple<at::Tensor, at::Tensor, at::Tensor> quantize_nvfp4_pertoken(
1560-
at::Tensor input) {
1559+
std::tuple<at::Tensor, at::Tensor, at::Tensor> quantize_nvfp4_pertoken(at::Tensor input) {
15611560
// Input validation
15621561
NVTE_CHECK(input.dim() == 2, "Input must be 2D (num_rows, num_cols)");
15631562
NVTE_CHECK(input.is_cuda(), "Input must be on CUDA device");
@@ -1574,8 +1573,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> quantize_nvfp4_pertoken(
15741573

15751574
// Allocate outputs
15761575
auto output_data = at::empty({num_rows, num_cols / 2}, options.dtype(at::kByte));
1577-
auto output_scales = at::empty(
1578-
{num_rows, (num_cols + 15) / 16}, options.dtype(at::kByte));
1576+
auto output_scales = at::empty({num_rows, (num_cols + 15) / 16}, options.dtype(at::kByte));
15791577
auto output_per_token_scales = at::empty({num_rows}, options.dtype(at::kFloat));
15801578

15811579
// Wrap as NVTETensors
@@ -1586,9 +1584,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> quantize_nvfp4_pertoken(
15861584

15871585
auto stream = at::cuda::getCurrentCUDAStream().stream();
15881586

1589-
nvte_quantize_nvfp4_pertoken(
1590-
te_input.data(), te_data.data(), te_scales.data(), te_pertoken.data(),
1591-
num_rows, num_cols, stream);
1587+
nvte_quantize_nvfp4_pertoken(te_input.data(), te_data.data(), te_scales.data(),
1588+
te_pertoken.data(), num_rows, num_cols, stream);
15921589

15931590
return {output_data, output_scales, output_per_token_scales};
15941591
}

transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -751,9 +751,7 @@ def fuser_forward(
751751
grouped_fc1_x = input_
752752
else:
753753
fc1_x = maybe_dequantize(input_, dtype)
754-
grouped_fc1_x = tex.group_quantize(
755-
fc1_x, fc1_input_quantizer, num_groups, split_sizes
756-
)
754+
grouped_fc1_x = tex.group_quantize(fc1_x, fc1_input_quantizer, num_groups, split_sizes)
757755

758756
# Pack data tensors for cuDNN kernel
759757
# NVFP4: data is uint8 (packed FP4), reinterpret as float4_e2m1fn_x2
@@ -785,7 +783,8 @@ def fuser_forward(
785783
global_scale_tensor = None
786784
try:
787785
_, _, fc1_per_token_scales = tex.quantize_nvfp4_pertoken(
788-
fc1_x.reshape(in_shape[0], in_shape[1]) if not isinstance(input_, GroupedTensor)
786+
fc1_x.reshape(in_shape[0], in_shape[1])
787+
if not isinstance(input_, GroupedTensor)
789788
else input_.dequantize(dtype=dtype).reshape(in_shape[0], in_shape[1])
790789
)
791790
global_scale_tensor = fc1_per_token_scales.reshape(-1, 1, 1)
@@ -831,9 +830,7 @@ def fuser_forward(
831830

832831
fc1_w_data = fc1_weight_for_gemm.rowwise_data
833832
fc1_w_data = fc1_w_data.view(dtype=torch.float4_e2m1fn_x2)
834-
fc1_w_data = fc1_w_data.view(
835-
num_groups, fc1_weight_shape[0], fc1_weight_shape[1] // 2
836-
)
833+
fc1_w_data = fc1_w_data.view(num_groups, fc1_weight_shape[0], fc1_weight_shape[1] // 2)
837834
fc1_w_data = fc1_w_data.permute(1, 2, 0)
838835
fc1_w_scales = fc1_weight_for_gemm.scale_inv.view(dtype=torch.float8_e4m3fn)
839836
fc1_w_scales = fc1_w_scales.view(
@@ -930,9 +927,7 @@ def fuser_forward(
930927

931928
fc2_w_data = fc2_weight_for_gemm.rowwise_data
932929
fc2_w_data = fc2_w_data.view(dtype=torch.float4_e2m1fn_x2)
933-
fc2_w_data = fc2_w_data.view(
934-
num_groups, fc2_weight_shape[0], fc2_weight_shape[1] // 2
935-
)
930+
fc2_w_data = fc2_w_data.view(num_groups, fc2_weight_shape[0], fc2_weight_shape[1] // 2)
936931
fc2_w_data = fc2_w_data.permute(1, 2, 0)
937932

938933
fc2_w_scales = fc2_weight_for_gemm.scale_inv.view(dtype=torch.float8_e4m3fn)

0 commit comments

Comments
 (0)