Skip to content

Commit 8a3a36d

Browse files
pre-commit-ci[bot]YigongQin
authored andcommitted
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
1 parent 8b1c88f commit 8a3a36d

4 files changed

Lines changed: 13 additions & 22 deletions

File tree

tests/pytorch/test_nvfp4_pertoken_quant.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,9 @@ def test_zero_input(self, dtype):
135135

136136
# When amax=0, compute_global_encode_scaling_factor_FP4 returns 1.0
137137
# so global_scale = 1/S_enc = 1/1 = 1.0
138-
assert (per_token_scales == 1.0).all(), (
139-
f"Zero input should give global_scale=1.0 (S_enc fallback), got {per_token_scales}"
140-
)
138+
assert (
139+
per_token_scales == 1.0
140+
).all(), f"Zero input should give global_scale=1.0 (S_enc fallback), got {per_token_scales}"
141141

142142
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
143143
def test_uniform_rows_same_scale(self, dtype):

transformer_engine/common/cast/cast.cu

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -166,22 +166,14 @@ void nvte_quantize_nvfp4_pertoken(const NVTETensor input, NVTETensor output_data
166166

167167
if (itype == NVTEDType::kNVTEBFloat16) {
168168
dispatch::nvfp4::quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4<__nv_bfloat16>(
169-
num_rows, num_cols,
170-
reinterpret_cast<const __nv_bfloat16 *>(input_ptr),
171-
nullptr,
172-
reinterpret_cast<uint8_t *>(data_ptr),
173-
reinterpret_cast<fp8e4m3 *>(scales_ptr),
174-
reinterpret_cast<float *>(pertoken_ptr),
175-
stream);
169+
num_rows, num_cols, reinterpret_cast<const __nv_bfloat16 *>(input_ptr), nullptr,
170+
reinterpret_cast<uint8_t *>(data_ptr), reinterpret_cast<fp8e4m3 *>(scales_ptr),
171+
reinterpret_cast<float *>(pertoken_ptr), stream);
176172
} else if (itype == NVTEDType::kNVTEFloat16) {
177173
dispatch::nvfp4::quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4<half>(
178-
num_rows, num_cols,
179-
reinterpret_cast<const half *>(input_ptr),
180-
nullptr,
181-
reinterpret_cast<uint8_t *>(data_ptr),
182-
reinterpret_cast<fp8e4m3 *>(scales_ptr),
183-
reinterpret_cast<float *>(pertoken_ptr),
184-
stream);
174+
num_rows, num_cols, reinterpret_cast<const half *>(input_ptr), nullptr,
175+
reinterpret_cast<uint8_t *>(data_ptr), reinterpret_cast<fp8e4m3 *>(scales_ptr),
176+
reinterpret_cast<float *>(pertoken_ptr), stream);
185177
} else {
186178
NVTE_ERROR(
187179
"Unsupported input dtype for per-token NVFP4 quantization. "

transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,8 @@ __launch_bounds__(BLOCK_SIZE)
119119
// Block-wide max reduction
120120
using BlockReduce = cub::BlockReduce<float, BLOCK_SIZE>;
121121
__shared__ typename BlockReduce::TempStorage temp_storage;
122-
float row_amax = BlockReduce(temp_storage).Reduce(thread_max,
123-
[](float a, float b) { return fmaxf(a, b); });
122+
float row_amax =
123+
BlockReduce(temp_storage).Reduce(thread_max, [](float a, float b) { return fmaxf(a, b); });
124124

125125
// Compute and store per-token global scale
126126
// global_scale = row_amax / (fp8_max * fp4_max)

transformer_engine/pytorch/csrc/extensions/cast.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1592,9 +1592,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> quantize_nvfp4_pertoken(at::Tenso
15921592
auto te_scales = makeTransformerEngineTensor(output_scales);
15931593
auto te_pertoken = makeTransformerEngineTensor(output_per_token_scales);
15941594

1595-
nvte_quantize_nvfp4_pertoken(
1596-
te_input.data(), te_data.data(), te_scales.data(), te_pertoken.data(),
1597-
num_rows, num_cols, stream);
1595+
nvte_quantize_nvfp4_pertoken(te_input.data(), te_data.data(), te_scales.data(),
1596+
te_pertoken.data(), num_rows, num_cols, stream);
15981597

15991598
return {output_data, output_scales, output_per_token_scales};
16001599
}

0 commit comments

Comments
 (0)