diff --git a/benchmarks/linear/benchmark_graph_safe_grouped_linear.py b/benchmarks/linear/benchmark_graph_safe_grouped_linear.py new file mode 100644 index 0000000000..d8230c38fe --- /dev/null +++ b/benchmarks/linear/benchmark_graph_safe_grouped_linear.py @@ -0,0 +1,380 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Benchmark MXFP8 graph-safe grouped MLP. + +This mirrors ``benchmark_grouped_linear.py`` but targets the graph-safe TE ops +path used by grouped MLP: + + GroupedLinear -> ScaledSwiGLU -> GroupedLinear + +The benchmark intentionally uses CUDA-device ``m_splits`` and MXFP8 only. + +Example: + + python benchmarks/linear/benchmark_graph_safe_grouped_linear.py + +Forward-only: + + python benchmarks/linear/benchmark_graph_safe_grouped_linear.py --fwd-only + +Nsight Systems: + + (optionally: unset DEBUGINFOD_URLS) + + nsys profile \ + --output=./benchmarks/linear/graph_safe_grouped_linear_mxfp8 \ + --force-overwrite true \ + --trace=cuda,nvtx,cudnn,cublas \ + python benchmarks/linear/benchmark_graph_safe_grouped_linear.py --profile +""" + +# Match the Qwen MXFP8 SFT launch toggles before importing TE. +import os + +os.environ.setdefault("CUDA_DEVICE_MAX_CONNECTIONS", "1") +os.environ.setdefault("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1") +os.environ.setdefault("NVTE_CUTEDSL_FUSED_GROUPED_MLP", "1") +os.environ.setdefault("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") + +import argparse +from contextlib import nullcontext + +import pandas as pd +import torch +import torch.utils.benchmark as benchmark + +import transformer_engine.pytorch as te +import transformer_engine.pytorch.ops as te_ops +from transformer_engine.common.recipe import MXFP8BlockScaling +from transformer_engine.pytorch.quantization import FP8GlobalStateManager + + +MXFP8_AVAILABLE, REASON_FOR_NO_MXFP8 = FP8GlobalStateManager.is_mxfp8_available() + + +def parse_int_list(value: str) -> list[int]: + """Parse comma-separated integers.""" + return [int(x) for x in value.split(",") if x] + + +def make_uniform_splits(total_tokens: int, num_groups: int) -> list[int]: + """Split tokens uniformly across groups.""" + if total_tokens % num_groups != 0: + raise ValueError( + "Uniform split requires total_tokens divisible by num_groups, " + f"got total_tokens={total_tokens}, num_groups={num_groups}" + ) + return [total_tokens // num_groups] * num_groups + + +def build_grouped_mlp( + *, + num_groups: int, + hidden_dim: int, + ffn_hidden_dim: int, + dtype: torch.dtype, + single_grouped_weight: bool, + accumulate_into_main_grad: bool, + glu_interleave_size: int, +) -> te_ops.Sequential: + """Build graph-safe grouped MLP ops sequence.""" + recipe = MXFP8BlockScaling() + with te.quantized_model_init(enabled=True, recipe=recipe): + fc1 = te_ops.GroupedLinear( + num_groups, + hidden_dim, + 2 * ffn_hidden_dim, + bias=False, + device="cuda", + dtype=dtype, + single_grouped_weight=single_grouped_weight, + accumulate_into_main_grad=accumulate_into_main_grad, + ) + fc2 = te_ops.GroupedLinear( + num_groups, + ffn_hidden_dim, + hidden_dim, + bias=False, + device="cuda", + dtype=dtype, + single_grouped_weight=single_grouped_weight, + accumulate_into_main_grad=accumulate_into_main_grad, + ) + return te_ops.Sequential( + fc1, + te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size), + fc2, + ) + + +def init_main_grads(module: torch.nn.Module, value: float = 0.0) -> None: + """Initialize Megatron-style main_grad buffers for accumulate_into_main_grad.""" + with torch.no_grad(): + for param in module.parameters(): + if getattr(param, "main_grad", None) is None: + param.main_grad = torch.empty( + param.size(), device=param.device, dtype=torch.float32 + ) + param.main_grad.fill_(value) + + +def zero_grads(module: torch.nn.Module, x: torch.Tensor, scales: torch.Tensor) -> None: + """Reset gradients without changing allocated main_grad buffers.""" + module.zero_grad(set_to_none=True) + x.grad = None + scales.grad = None + + +def run_grouped_mlp_steps( + module: torch.nn.Module, + x: torch.Tensor, + split_sizes: torch.Tensor, + scales: torch.Tensor, + grad_output: torch.Tensor, + *, + recipe: MXFP8BlockScaling, + fwd_only: bool, + num_steps: int, + accumulate_into_main_grad: bool, +) -> torch.Tensor: + """Run eager grouped MLP for a number of synthetic microbatches.""" + quantization_context = te.autocast(enabled=True, recipe=recipe) + + if fwd_only: + with torch.no_grad(), quantization_context: + for _ in range(num_steps): + out = module(x, split_sizes, scales, split_sizes) + return out + + zero_grads(module, x, scales) + if accumulate_into_main_grad: + init_main_grads(module) + + with quantization_context: + for step in range(num_steps): + torch.cuda.nvtx.range_push(f"step_{step}") + out = module(x, split_sizes, scales, split_sizes) + out.backward(grad_output) + torch.cuda.nvtx.range_pop() + return out + + +def benchmark_case( + *, + total_tokens: int, + hidden_dim: int, + ffn_hidden_dim: int, + num_groups: int, + dtype: torch.dtype, + fwd_only: bool, + single_grouped_weight: bool, + accumulate_into_main_grad: bool, + glu_interleave_size: int, + num_microbatches: int, + min_run_time: float, + profile: bool, +) -> float: + """Benchmark one grouped MLP shape.""" + split_sizes_list = make_uniform_splits(total_tokens, num_groups) + split_sizes = torch.tensor(split_sizes_list, dtype=torch.int64, device="cuda") + x = torch.randn( + (total_tokens, hidden_dim), + dtype=dtype, + device="cuda", + requires_grad=not fwd_only, + ) + scales = torch.ones( + (total_tokens,), + dtype=dtype, + device="cuda", + requires_grad=not fwd_only, + ) + grad_output = torch.ones((total_tokens, hidden_dim), dtype=dtype, device="cuda") + + module = build_grouped_mlp( + num_groups=num_groups, + hidden_dim=hidden_dim, + ffn_hidden_dim=ffn_hidden_dim, + dtype=dtype, + single_grouped_weight=single_grouped_weight, + accumulate_into_main_grad=accumulate_into_main_grad, + glu_interleave_size=glu_interleave_size, + ) + recipe = MXFP8BlockScaling() + + print( + "case:", + f"tokens={total_tokens}", + f"hidden={hidden_dim}", + f"ffn_hidden={ffn_hidden_dim}", + f"num_groups={num_groups}", + f"fwd_only={fwd_only}", + f"single_grouped_weight={single_grouped_weight}", + f"accumulate_into_main_grad={accumulate_into_main_grad}", + f"glu_interleave_size={glu_interleave_size}", + ) + print(f"m_splits: {split_sizes_list}") + + # Warmup also forces the op-fuser to materialize the expected fused ops. + run_grouped_mlp_steps( + module, + x, + split_sizes, + scales, + grad_output, + recipe=recipe, + fwd_only=fwd_only, + num_steps=128, + accumulate_into_main_grad=accumulate_into_main_grad, + ) + torch.cuda.synchronize() + + forward_ops = module._module_groups[0]._forward_ops + print("forward fused op:", type(forward_ops[0][0]).__name__ if forward_ops else "none") + if not fwd_only: + backward_ops = module._module_groups[0]._backward_ops + print("backward fused op:", type(backward_ops[0][0]).__name__ if backward_ops else "none") + + label = "graph_safe_grouped_mlp_mxfp8_swiglu" + timing_context = ( + torch.autograd.profiler.emit_nvtx(record_shapes=True) if profile else nullcontext() + ) + with timing_context: + torch.cuda.nvtx.range_push(label) + timing = benchmark.Timer( + stmt=( + "run_grouped_mlp_steps(" + "module, x, split_sizes, scales, grad_output, " + "recipe=recipe, fwd_only=fwd_only, num_steps=num_microbatches, " + "accumulate_into_main_grad=accumulate_into_main_grad)" + ), + globals={ + "run_grouped_mlp_steps": run_grouped_mlp_steps, + "module": module, + "x": x, + "split_sizes": split_sizes, + "scales": scales, + "grad_output": grad_output, + "recipe": recipe, + "fwd_only": fwd_only, + "num_microbatches": num_microbatches, + "accumulate_into_main_grad": accumulate_into_main_grad, + }, + num_threads=1, + ).blocked_autorange(min_run_time=min_run_time) + torch.cuda.nvtx.range_pop() + + print(f"mxfp8_swiglu: {timing}\n") + return timing.median * 1000 / num_microbatches + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--profile", action="store_true", help="Enable NVTX profiling annotations") + parser.add_argument( + "--fwd-only", + action="store_true", + default=False, + help="Benchmark forward only. Default benchmarks forward + backward.", + ) + parser.add_argument( + "--num-groups", + type=str, + default="8", + help="Comma-separated local grouped GEMM/expert counts.", + ) + parser.add_argument( + "--token-dims", + type=str, + default="65536", + help="Comma-separated total token counts to benchmark.", + ) + parser.add_argument("--hidden-dim", type=int, default=7168) + parser.add_argument("--ffn-hidden-dim", type=int, default=2048) + parser.add_argument("--num-microbatches", type=int, default=32) + parser.add_argument("--min-run-time", type=float, default=10.0) + parser.add_argument("--glu-interleave-size", type=int, default=32) + parser.add_argument( + "--single-grouped-weight", + action="store_true", + default=False, + help="Use one GroupedTensor parameter for each grouped linear.", + ) + args = parser.parse_args() + + if not MXFP8_AVAILABLE: + raise RuntimeError(f"MXFP8 is not available: {REASON_FOR_NO_MXFP8}") + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for this benchmark.") + + dtype = torch.bfloat16 + accumulate_into_main_grad = True + token_dims = parse_int_list(args.token_dims) + num_groups_list = parse_int_list(args.num_groups) + + print("Environment toggles:") + for name in ( + "CUDA_DEVICE_MAX_CONNECTIONS", + "NVTE_ALLOW_NONDETERMINISTIC_ALGO", + "NVTE_CUTEDSL_FUSED_GROUPED_MLP", + "CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", + ): + print(f" {name}={os.environ.get(name)}") + print("Recipe: MXFP8BlockScaling") + print("Activation: ScaledSwiGLU") + print(f"Default GLU interleave size: {args.glu_interleave_size}") + print() + + data = [] + for num_groups in num_groups_list: + for total_tokens in token_dims: + timing_ms = benchmark_case( + total_tokens=total_tokens, + hidden_dim=args.hidden_dim, + ffn_hidden_dim=args.ffn_hidden_dim, + num_groups=num_groups, + dtype=dtype, + fwd_only=args.fwd_only, + single_grouped_weight=args.single_grouped_weight, + accumulate_into_main_grad=accumulate_into_main_grad, + glu_interleave_size=args.glu_interleave_size, + num_microbatches=args.num_microbatches, + min_run_time=args.min_run_time, + profile=args.profile, + ) + data.append( + [ + total_tokens, + args.hidden_dim, + args.ffn_hidden_dim, + num_groups, + args.glu_interleave_size, + args.single_grouped_weight, + accumulate_into_main_grad, + "fwd" if args.fwd_only else "fwd_bwd", + timing_ms, + ] + ) + + timing_col = "time_per_microbatch_ms" + df = pd.DataFrame( + data=data, + columns=[ + "tokens", + "hidden_dim", + "ffn_hidden_dim", + "num_groups", + "glu_interleave_size", + "single_grouped_weight", + "accumulate_into_main_grad", + "mode", + timing_col, + ], + ) + print(df) + + +if __name__ == "__main__": + main() diff --git a/tests/pytorch/test_grouped_tensor.py b/tests/pytorch/test_grouped_tensor.py index c54c9758ff..7ccb6802d5 100644 --- a/tests/pytorch/test_grouped_tensor.py +++ b/tests/pytorch/test_grouped_tensor.py @@ -20,6 +20,9 @@ from transformer_engine.pytorch.constants import TE_DType_To_Torch import transformer_engine_torch as tex +# Import test utilities +from utils import assert_close + # Check available recipes fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available( @@ -163,6 +166,93 @@ def test_basic_construction_varying_first_dim(self) -> None: shape[0][1], ) # sum of first dims + @pytest.mark.parametrize( + "split_sizes_list,logical_last_dim", + [ + pytest.param([3, 4, 5, 2], 7, id="all_nonzero"), + pytest.param([3, 0, 5, 2], 7, id="zero_middle"), + pytest.param([0, 3, 5, 0], 11, id="zero_edges"), + pytest.param([1], 17, id="single_group"), + pytest.param([1, 2, 3, 4, 5, 6, 7, 8], 13, id="many_groups"), + # MoE-style group counts. ``split_points`` (an int32[num_groups] + # tensor packed into a shared buffer alongside int64 outputs) used + # to land at an 8-byte-aligned offset for these counts, which + # tripped cuDNN's 16-byte alignment requirement in grouped GEMM. + pytest.param([8192] * 8, 2048, id="num_groups_8_uniform"), + pytest.param([4096] * 16, 4096, id="num_groups_16_uniform"), + pytest.param([2048] * 32, 7168, id="num_groups_32_uniform"), + pytest.param([1024] * 64, 7168, id="num_groups_64_uniform"), + pytest.param([512] * 128, 7168, id="num_groups_128_uniform"), + # Non-uniform with large totals to also exercise tensor_offsets > 2^31. + pytest.param( + [12345, 0, 8192, 1, 65536, 100, 131072, 7], + 7168, + id="non_uniform_large_totals", + ), + ], + ) + @pytest.mark.parametrize("input_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) + @pytest.mark.parametrize("input_device", ["cuda", "cpu"], ids=["cuda", "cpu"]) + @pytest.mark.parametrize("bulk_allocate", [False, True], ids=["separate", "bulk"]) + def test_splits_to_offsets_multi( + self, + bulk_allocate: bool, + input_device: str, + input_dtype: torch.dtype, + split_sizes_list: List[int], + logical_last_dim: int, + ) -> None: + """Test fused grouped split metadata preparation.""" + device = torch.device("cuda") + split_sizes = torch.tensor(split_sizes_list, dtype=input_dtype, device=input_device) + + # Exercise the grouped-MLP-shaped call: mix of int32 (no leading zero) + # and int64 (with leading zero) outputs, several strides. + strides = [1, 1, logical_last_dim, 0, logical_last_dim + 17] + include_leading_zero = [False, True, True, True, True] + dtypes = [torch.int32, torch.int64, torch.int64, torch.int64, torch.int64] + split_sizes_out, outputs = tex.splits_to_offsets_multi( + split_sizes, + device, + strides=strides, + include_leading_zero=include_leading_zero, + dtypes=dtypes, + bulk_allocate=bulk_allocate, + ) + + # Reference implementation. + expected_split_sizes_i64 = split_sizes.to(device=device, dtype=torch.int64) + expected_base_offsets = torch.cat( + ( + torch.zeros(1, dtype=torch.int64, device=device), + torch.cumsum(expected_split_sizes_i64, dim=0), + ) + ) + + # Check output split_sizes: always int64, always on the target device. + assert split_sizes_out.device.type == "cuda" + assert split_sizes_out.dtype == torch.int64 + assert_close(split_sizes_out, expected_split_sizes_i64) + + # Check output offsets. + assert len(outputs) == len(strides) + for output, stride, with_zero, dtype in zip(outputs, strides, include_leading_zero, dtypes): + assert output.dtype == dtype + assert output.device.type == "cuda" + expected_length = split_sizes.numel() + (1 if with_zero else 0) + assert output.numel() == expected_length + expected = expected_base_offsets * stride + if not with_zero: + expected = expected[1:] + assert_close(output, expected) + + # Check pointer alignment: cuDNN CuTe-DSL grouped GEMM kernels + # require 16-byte-aligned data pointers. + for idx, output in enumerate(outputs): + assert ( + output.data_ptr() % 16 == 0 + ), f"outputs[{idx}] data_ptr is not 16-byte aligned: {output.data_ptr():#x}" + def test_split_into_quantized_tensors_no_quantization(self) -> None: """Test split_into_quantized_tensors for unquantized tensors""" num_tensors = 3 @@ -410,6 +500,64 @@ def test_quantize_grouped_mxfp8(self, shape: List[Tuple[int, int]], output_dbias expected_dbias = torch.stack([t.sum(dim=0) for t in input_tensors]) assert torch.allclose(dbias, expected_dbias) + @pytest.mark.parametrize("output_dbias", [False, True]) + @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) + def test_group_quantize_precomputed_offsets(self, output_dbias: bool) -> None: + """Test grouped quantization can reuse caller-provided tensor offsets.""" + num_tensors = 2 + last_dim = 1024 + split_sizes_list = [512, 512] + input_tensors = [ + torch.randn(split, last_dim, dtype=torch.bfloat16, device="cuda") + for split in split_sizes_list + ] + grouped_input = torch.cat(input_tensors, dim=0) + + quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) + quantizer.set_usage(rowwise=True, columnwise=False) + split_sizes = torch.tensor(split_sizes_list, dtype=torch.int64, device="cuda") + split_sizes, (tensor_offsets,) = tex.splits_to_offsets_multi( + split_sizes, + torch.device("cuda"), + strides=[last_dim], + include_leading_zero=[True], + dtypes=[torch.int64], + ) + + if output_dbias: + grouped_output, dbias = tex.bgrad_group_quantize( + grouped_input, + quantizer, + num_tensors, + split_sizes, + tensor_offsets=tensor_offsets, + ) + expected_output, expected_dbias = tex.bgrad_group_quantize( + grouped_input, + quantizer, + num_tensors, + split_sizes, + ) + assert torch.allclose(dbias, expected_dbias) + else: + grouped_output = tex.group_quantize( + grouped_input, + quantizer, + num_tensors, + split_sizes, + tensor_offsets=tensor_offsets, + ) + expected_output = tex.group_quantize( + grouped_input, + quantizer, + num_tensors, + split_sizes, + ) + + assert grouped_output.tensor_offsets.data_ptr() == tensor_offsets.data_ptr() + assert torch.equal(grouped_output.rowwise_data, expected_output.rowwise_data) + assert torch.equal(grouped_output.scale_inv, expected_output.scale_inv) + @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) def test_bgrad_group_quantize_zero_size_tensor(self) -> None: """Test bgrad_group_quantize handles zero-row input without error.""" diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 06d85b6d84..9ad5b63421 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -194,6 +194,7 @@ list(APPEND transformer_engine_cuda_sources permutation/permutation.cu util/utils.cu util/padding.cu + util/splits_to_offsets.cu util/topk.cu swizzle/swizzle.cu swizzle/swizzle_block_scaling.cu diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index 1bdd80a369..ef757bac5f 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -61,6 +61,7 @@ void update_tensor_scale_inv(Tensor *t, cudaStream_t stream) { namespace { constexpr size_t kThreadsPerBlock = 256; + template __global__ void __launch_bounds__(kThreadsPerBlock) memset_kernel(void *__restrict__ ptr, int value, size_t size_in_bytes) { @@ -87,48 +88,6 @@ __global__ void __launch_bounds__(kThreadsPerBlock) reinterpret_cast(ptr)[idx] = data.value; } -__global__ void __launch_bounds__(kThreadsPerBlock) - splits_to_offsets_kernel(const int64_t *__restrict__ first_dims, int64_t *__restrict__ output, - size_t num_tensors, int64_t logical_last_dim) { - __shared__ int64_t block_scan[kThreadsPerBlock]; - __shared__ int64_t chunk_prefix; - - const size_t tid = threadIdx.x; - if (tid == 0) { - output[0] = 0; - chunk_prefix = 0; - } - __syncthreads(); - - for (size_t chunk_start = 0; chunk_start < num_tensors; chunk_start += kThreadsPerBlock) { - const size_t idx = chunk_start + tid; - int64_t value = 0; - if (idx < num_tensors) { - value = first_dims[idx] * logical_last_dim; - } - block_scan[tid] = value; - __syncthreads(); - - // Inclusive scan in shared memory. - for (size_t offset = 1; offset < kThreadsPerBlock; offset <<= 1) { - const int64_t addend = (tid >= offset) ? block_scan[tid - offset] : 0; - __syncthreads(); - block_scan[tid] += addend; - __syncthreads(); - } - - if (idx < num_tensors) { - output[idx + 1] = chunk_prefix + block_scan[tid]; - } - __syncthreads(); - - if (tid == kThreadsPerBlock - 1) { - chunk_prefix += block_scan[tid]; - } - __syncthreads(); - } -} - } // namespace #define MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, vectorizedType, stream) \ @@ -159,18 +118,6 @@ void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, uint8_t, stream); } -void nvte_splits_to_offsets(const int64_t *first_dims, int64_t *output, size_t num_tensors, - int64_t logical_last_dim, cudaStream_t stream) { - NVTE_API_CALL(nvte_splits_to_offsets); - NVTE_CHECK(output != nullptr, "Output pointer must be allocated."); - NVTE_CHECK(num_tensors > 0, "num_tensors must be greater than 0."); - NVTE_CHECK(first_dims != nullptr, "first_dims pointer must be allocated."); - NVTE_CHECK(logical_last_dim > 0, "logical_last_dim must be greater than 0."); - - splits_to_offsets_kernel<<<1, kThreadsPerBlock, 0, stream>>>(first_dims, output, num_tensors, - logical_last_dim); - NVTE_CHECK_CUDA(cudaGetLastError()); -} } // extern "C" void checkCuDriverContext(CUstream stream) { diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index f675b2f535..c32a561fb7 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -497,17 +497,36 @@ void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream * * Computes: * output[0] = 0 - * output[i + 1] = sum_{j=0..i}(first_dims[j] * logical_last_dim) - * for i in [0, num_tensors - 1]. + * output[i + 1] = sum_{j=0..i}(split_sizes[j] * stride) + * for i in [0, num_splits - 1]. * - * \param[in] first_dims Pointer to device int64 array of size num_tensors. - * \param[out] output Pointer to device int64 array of size num_tensors + 1. - * \param[in] num_tensors Number of entries in first_dims. - * \param[in] logical_last_dim Scale factor applied to each first_dims entry. + * \param[in] split_sizes Pointer to device int64 array of size num_splits. + * \param[out] output Pointer to device int64 array of size num_splits + 1. + * \param[in] num_splits Number of entries in split_sizes. + * \param[in] stride Scale factor applied to each split_sizes entry. * \param[in] stream CUDA stream to use for the operation. */ -void nvte_splits_to_offsets(const int64_t *first_dims, int64_t *output, size_t num_tensors, - int64_t logical_last_dim, cudaStream_t stream); +void nvte_splits_to_offsets(const int64_t *split_sizes, int64_t *output, size_t num_splits, + int64_t stride, cudaStream_t stream); + +/*! \brief Compute multiple scaled prefix-sum offsets for grouped tensors. + * + * Computes a prefix-sum over the values in split_sizes, and for each + * output multiplies the prefix-sum by a stride. Inputs and outputs + * can be any combination of int32 and int64 tensors. + * + * \param[in] split_sizes Device int32/int64 split sizes with shape [N]. + * \param[out] outputs Array of int32/int64 1D output tensors, one per scan. + * \param[in] strides Per-output scale factor. Length num_outputs. + * \param[in] include_leading_zero Per-output flag: 0 if outputs[i] has length N + * (inclusive scan), nonzero if outputs[i] has length N + 1 (inclusive + * scan prepended with zero). Length num_outputs. + * \param[in] num_outputs Number of output tensors. + * \param[in] stream CUDA stream to use for the operation. + */ +void nvte_splits_to_offsets_multi(NVTETensor split_sizes, NVTETensor *outputs, + const int64_t *strides, const int *include_leading_zero, + size_t num_outputs, cudaStream_t stream); /*! \brief TE Grouped Tensor type * diff --git a/transformer_engine/common/util/splits_to_offsets.cu b/transformer_engine/common/util/splits_to_offsets.cu new file mode 100644 index 0000000000..721f88c074 --- /dev/null +++ b/transformer_engine/common/util/splits_to_offsets.cu @@ -0,0 +1,212 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include +#include + +#include "../common.h" +#include "../util/logging.h" +#include "../utils.cuh" + +namespace transformer_engine::splits_to_offsets { + +namespace { + +constexpr size_t kThreadsPerBlock = 256; + +struct KernelArgs { + static constexpr size_t kMaxNumOutputs = 8; + const void *split_sizes = nullptr; + DType split_sizes_dtype = DType::kNumTypes; + size_t num_splits = 0; + void *outputs[kMaxNumOutputs] = {}; + DType outputs_dtype[kMaxNumOutputs] = {}; + int64_t strides[kMaxNumOutputs] = {}; + bool include_leading_zero[kMaxNumOutputs] = {}; + size_t num_outputs = 0; +}; + +__device__ __forceinline__ int64_t load_split_size(const void *__restrict__ split_sizes, + DType dtype, size_t idx) { + switch (dtype) { + case DType::kInt32: + return static_cast(static_cast(split_sizes)[idx]); + case DType::kInt64: + return static_cast(split_sizes)[idx]; + default: + NVTE_DEVICE_ERROR("Unsupported dtype for split_sizes (expected int32 or int64)."); + return 0; + } +} + +__device__ __forceinline__ void store_output(void *__restrict__ output, DType dtype, size_t idx, + int64_t value) { + switch (dtype) { + case DType::kInt32: + static_cast(output)[idx] = static_cast(value); + return; + case DType::kInt64: + static_cast(output)[idx] = value; + return; + default: + NVTE_DEVICE_ERROR("Unsupported dtype for output (expected int32 or int64)."); + } +} + +__global__ void __launch_bounds__(kThreadsPerBlock) kernel(KernelArgs args) { + const size_t tid = threadIdx.x; + + // Fill leading zeros if needed + if (tid == 0) { + for (size_t out_idx = 0; out_idx < args.num_outputs; ++out_idx) { + if (args.include_leading_zero[out_idx]) { + store_output(args.outputs[out_idx], args.outputs_dtype[out_idx], 0, 0); + } + } + } + + // Workspace for prefix sum chunk + __shared__ int64_t block_scan[kThreadsPerBlock]; + + // Sum from previous chunks + __shared__ int64_t chunk_prefix; + if (tid == 0) { + chunk_prefix = 0; + } + __syncthreads(); + + // Perform prefix sum in chunks + for (size_t chunk_start = 0; chunk_start < args.num_splits; chunk_start += kThreadsPerBlock) { + const size_t idx = chunk_start + tid; + + // Load input from global memory into shared memory + if (idx < args.num_splits) { + block_scan[tid] = load_split_size(args.split_sizes, args.split_sizes_dtype, idx); + } else { + block_scan[tid] = 0; + } + __syncthreads(); + + // Prefix sum in shared memory + for (size_t offset = 1; offset < kThreadsPerBlock; offset <<= 1) { + const int64_t addend = (tid >= offset) ? block_scan[tid - offset] : 0; + __syncthreads(); + block_scan[tid] += addend; + __syncthreads(); + } + + // Compute global prefix sum, apply strides, and store to output + if (idx < args.num_splits) { + const int64_t prefix = chunk_prefix + block_scan[tid]; + for (size_t out_idx = 0; out_idx < args.num_outputs; ++out_idx) { + const size_t write_idx = idx + (args.include_leading_zero[out_idx] ? 1 : 0); + store_output(args.outputs[out_idx], args.outputs_dtype[out_idx], write_idx, + prefix * args.strides[out_idx]); + } + } + + // Update sum for later chunks + __syncthreads(); + if (tid == kThreadsPerBlock - 1) { + chunk_prefix += block_scan[tid]; + } + __syncthreads(); + } +} + +} // namespace + +} // namespace transformer_engine::splits_to_offsets + +void nvte_splits_to_offsets(const int64_t *split_sizes, int64_t *output, size_t num_splits, + int64_t stride, cudaStream_t stream) { + NVTE_API_CALL(nvte_splits_to_offsets); + NVTE_CHECK(output != nullptr, "Output pointer is NULL."); + NVTE_CHECK(num_splits > 0, "num_splits must be greater than 0."); + NVTE_CHECK(split_sizes != nullptr, "split_sizes pointer is NULL."); + NVTE_CHECK(stride > 0, "stride must be greater than 0."); + + using namespace transformer_engine; + namespace s2o = transformer_engine::splits_to_offsets; + + s2o::KernelArgs args = {}; + args.split_sizes = split_sizes; + args.split_sizes_dtype = DType::kInt64; + args.num_splits = num_splits; + args.outputs[0] = output; + args.outputs_dtype[0] = DType::kInt64; + args.strides[0] = stride; + args.include_leading_zero[0] = true; + args.num_outputs = 1; + s2o::kernel<<<1, s2o::kThreadsPerBlock, 0, stream>>>(args); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +void nvte_splits_to_offsets_multi(NVTETensor split_sizes, NVTETensor *outputs, + const int64_t *strides, const int *include_leading_zero, + size_t num_outputs, cudaStream_t stream) { + NVTE_API_CALL(nvte_splits_to_offsets_multi); + using namespace transformer_engine; + namespace s2o = transformer_engine::splits_to_offsets; + + if (num_outputs == 0) { + return; + } + NVTE_CHECK(outputs != nullptr, "outputs is NULL."); + NVTE_CHECK(strides != nullptr, "strides is NULL."); + NVTE_CHECK(include_leading_zero != nullptr, "include_leading_zero is NULL."); + + // Check if dtype is supported + const auto is_integer_dtype = [](DType dtype) { + return dtype == DType::kInt32 || dtype == DType::kInt64; + }; + + // Check input tensor + const auto *split_sizes_tensor = convertNVTETensorCheck(split_sizes); + const auto split_sizes_dtype = split_sizes_tensor->dtype(); + const auto num_splits = split_sizes_tensor->numel(); + NVTE_CHECK(num_splits > 0 && split_sizes_tensor->dim() == 1, + "split_sizes must be a non-empty 1D tensor, but got shape=", + split_sizes_tensor->shape(), "."); + NVTE_CHECK(is_integer_dtype(split_sizes_dtype), + "split_sizes must be an int32/int64 tensor, but got dtype=", split_sizes_dtype, "."); + + // Check output tensors + std::vector output_tensors(num_outputs); + for (size_t i = 0; i < num_outputs; ++i) { + const auto *out_tensor = convertNVTETensorCheck(outputs[i]); + const auto out_dtype = out_tensor->dtype(); + const bool has_leading_zero = include_leading_zero[i] != 0; + const Shape expected_shape = {num_splits + (has_leading_zero ? 1 : 0)}; + NVTE_CHECK(out_tensor->shape() == expected_shape, "Expected outputs[", i, + "] to have shape=", expected_shape, ", but got shape=", out_tensor->shape(), "."); + NVTE_CHECK(is_integer_dtype(out_dtype), "Expected outputs[", i, + "] to be an int32/int64 tensor, but got dtype=", out_dtype, "."); + output_tensors[i] = out_tensor; + } + + // Chunk outputs to fit in kernel arguments and launch kernels + for (size_t chunk_start = 0; chunk_start < num_outputs; + chunk_start += s2o::KernelArgs::kMaxNumOutputs) { + const size_t chunk_size = std::min(s2o::KernelArgs::kMaxNumOutputs, num_outputs - chunk_start); + s2o::KernelArgs args = {}; + args.split_sizes = split_sizes_tensor->data.dptr; + args.split_sizes_dtype = split_sizes_dtype; + args.num_splits = num_splits; + args.num_outputs = chunk_size; + for (size_t i = 0; i < chunk_size; ++i) { + const size_t out_idx = chunk_start + i; + args.outputs[i] = output_tensors[out_idx]->data.dptr; + args.outputs_dtype[i] = output_tensors[out_idx]->dtype(); + args.strides[i] = strides[out_idx]; + args.include_leading_zero[i] = include_leading_zero[out_idx] != 0; + } + s2o::kernel<<<1, s2o::kThreadsPerBlock, 0, stream>>>(args); + NVTE_CHECK_CUDA(cudaGetLastError()); + } +} diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index b376b3022d..cb8af0fcd6 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -106,10 +106,19 @@ class Quantizer { const std::vector& shape, DType dtype, std::optional device = std::nullopt, bool pin_memory = false) const = 0; - /*! @brief Construct a grouped tensor with uninitialized data */ + /*! @brief Construct a grouped tensor with uninitialized data + * + * @param tensor_offsets If provided, the precomputed inclusive scan of + * ``first_dims * logical_last_dim`` with a leading zero, used to locate + * each per-group sub-tensor in the shared backing buffer. If null, the + * offsets are computed from ``first_dims`` on demand. Passing this in lets + * callers that already have the scan (e.g. from + * ``tex.splits_to_offsets_multi``) skip the redundant kernel launch. + */ virtual std::pair create_grouped_tensor( size_t num_tensors, const std::vector& logical_shape, DType dtype, - py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, + py::object quantizer, const std::optional& first_dims, + const std::optional& tensor_offsets, size_t logical_first_dim, size_t logical_last_dim) const = 0; /*! @brief Convert a PyTorch tensor into a Transformer Engine C++ tensor @@ -151,7 +160,8 @@ class NoneQuantizer : public Quantizer { std::pair create_grouped_tensor( size_t num_tensors, const std::vector& logical_shape, DType dtype, - py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, + py::object quantizer, const std::optional& first_dims, + const std::optional& tensor_offsets, size_t logical_first_dim, size_t logical_last_dim) const override; /*! @brief Construct a tensor with pre-initialized data */ @@ -182,7 +192,8 @@ class Float8Quantizer : public Quantizer { std::pair create_grouped_tensor( size_t num_tensors, const std::vector& logical_shape, DType dtype, - py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, + py::object quantizer, const std::optional& first_dims, + const std::optional& tensor_offsets, size_t logical_first_dim, size_t logical_last_dim) const override; /*! @brief Construct a tensor with pre-initialized data */ @@ -217,7 +228,8 @@ class Float8CurrentScalingQuantizer : public Quantizer { std::pair create_grouped_tensor( size_t num_tensors, const std::vector& logical_shape, DType dtype, - py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, + py::object quantizer, const std::optional& first_dims, + const std::optional& tensor_offsets, size_t logical_first_dim, size_t logical_last_dim) const override; /*! @brief Construct an unquantized tensor with a freshly allocated amax buffer. @@ -280,7 +292,8 @@ class Float8BlockQuantizer : public Quantizer { std::pair create_grouped_tensor( size_t num_tensors, const std::vector& logical_shape, DType dtype, - py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, + py::object quantizer, const std::optional& first_dims, + const std::optional& tensor_offsets, size_t logical_first_dim, size_t logical_last_dim) const override; std::pair convert_and_update_tensor(py::object shape) const override; @@ -305,7 +318,8 @@ class MXFP8Quantizer : public Quantizer { std::pair create_grouped_tensor( size_t num_tensors, const std::vector& logical_shape, DType dtype, - py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, + py::object quantizer, const std::optional& first_dims, + const std::optional& tensor_offsets, size_t logical_first_dim, size_t logical_last_dim) const override; std::pair convert_and_update_tensor(py::object shape) const override; @@ -349,7 +363,8 @@ class NVFP4Quantizer : public Quantizer { std::pair create_grouped_tensor( size_t num_tensors, const std::vector& logical_shape, DType dtype, - py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, + py::object quantizer, const std::optional& first_dims, + const std::optional& tensor_offsets, size_t logical_first_dim, size_t logical_last_dim) const override; /*! @brief Construct an unquantized tensor that shares NVFP4 tensor's amax pointer diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index b93c153443..316d2114de 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -334,12 +334,14 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob py::object dequantize(const py::handle &input, DType otype); py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const size_t num_tensors, - std::optional first_dims); + std::optional first_dims, + std::optional tensor_offsets); py::object group_dequantize(const py::handle &input, DType otype); py::object bgrad_group_quantize(const at::Tensor &tensor, py::handle quantizer, - const size_t num_tensors, std::optional first_dims); + const size_t num_tensors, std::optional first_dims, + std::optional tensor_offsets); std::vector multi_tensor_quantize(const std::vector &tensor_list, std::vector quantizer_list); @@ -487,6 +489,10 @@ size_t get_cublasLt_version(); size_t get_cudnn_version(); at::Tensor splits_to_offsets(const at::Tensor &first_dims, int64_t logical_last_dim); +std::tuple> splits_to_offsets_multi( + const at::Tensor &split_sizes, const c10::Device &device, const std::vector &strides, + const std::vector &include_leading_zero, const std::vector &dtypes, + bool bulk_allocate_outputs); at::Tensor copy_data_ptrs_to_device(const std::vector &tensors, const c10::Device &device); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index e4110131ea..bbadff10ae 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -159,7 +159,8 @@ void group_quantize_nvfp4_impl(const GroupedTensorWrapper &grouped_input_tensor, // NOTE: Only supports varying first dim. py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const size_t num_tensors, - std::optional first_dims) { + std::optional first_dims, + std::optional tensor_offsets) { using namespace transformer_engine::pytorch::detail; init_extension(); @@ -184,7 +185,7 @@ py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const // Create output GroupedTensor. auto [grouped_output_tensor_cpp, grouped_output_py] = quantizer_cpp->create_grouped_tensor( num_tensors, logical_shape, GetTransformerEngineDType(tensor.scalar_type()), - py::reinterpret_borrow(quantizer), first_dims, logical_first_dim, + py::reinterpret_borrow(quantizer), first_dims, tensor_offsets, logical_first_dim, logical_last_dim); // dispatch to scaling methods @@ -234,7 +235,8 @@ py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const } py::object bgrad_group_quantize(const at::Tensor &tensor, py::handle quantizer, - const size_t num_tensors, std::optional first_dims) { + const size_t num_tensors, std::optional first_dims, + std::optional tensor_offsets) { using namespace transformer_engine::pytorch::detail; init_extension(); @@ -260,7 +262,7 @@ py::object bgrad_group_quantize(const at::Tensor &tensor, py::handle quantizer, auto [grouped_output_tensor_cpp, grouped_output_py] = quantizer_cpp->create_grouped_tensor( num_tensors, logical_shape, GetTransformerEngineDType(tensor.scalar_type()), - py::reinterpret_borrow(quantizer), first_dims, logical_first_dim, + py::reinterpret_borrow(quantizer), first_dims, tensor_offsets, logical_first_dim, logical_last_dim); if (empty_input_buffer) { @@ -349,7 +351,7 @@ py::object group_dequantize(const py::handle &input, transformer_engine::DType o NoneQuantizer q{py::none()}; auto [out_cpp, out_py] = q.create_grouped_tensor(num_tensors, logical_shape, otype, py::none(), first_dims, - logical_first_dim, logical_last_dim); + tensor_offsets, logical_first_dim, logical_last_dim); return py::reinterpret_borrow(out_py); } @@ -387,8 +389,9 @@ py::object group_dequantize(const py::handle &input, transformer_engine::DType o // Create output GroupedTensor using NoneQuantizer. NoneQuantizer q{py::none()}; - auto [out_cpp, out_py] = q.create_grouped_tensor(num_tensors, logical_shape, otype, py::none(), - first_dims, logical_first_dim, logical_last_dim); + auto [out_cpp, out_py] = + q.create_grouped_tensor(num_tensors, logical_shape, otype, py::none(), first_dims, + tensor_offsets, logical_first_dim, logical_last_dim); NVTE_SCOPED_GIL_RELEASE({ nvte_group_dequantize(input_cpp.data(), out_cpp.data(), at::cuda::getCurrentCUDAStream()); diff --git a/transformer_engine/pytorch/csrc/extensions/misc.cpp b/transformer_engine/pytorch/csrc/extensions/misc.cpp index 727c79aea3..ba4371ffe1 100644 --- a/transformer_engine/pytorch/csrc/extensions/misc.cpp +++ b/transformer_engine/pytorch/csrc/extensions/misc.cpp @@ -6,10 +6,13 @@ #include +#include +#include #include #include "../extensions.h" #include "common/common.h" +#include "pybind.h" namespace transformer_engine::pytorch { @@ -35,6 +38,69 @@ at::Tensor splits_to_offsets(const at::Tensor &first_dims, int64_t logical_last_ return output; } +std::tuple> splits_to_offsets_multi( + const at::Tensor &split_sizes, const c10::Device &device, const std::vector &strides, + const std::vector &include_leading_zero, const std::vector &dtypes, + bool bulk_allocate_outputs) { + const size_t num_outputs = strides.size(); + const size_t num_splits = static_cast(split_sizes.numel()); + + // Check inputs. + NVTE_CHECK(include_leading_zero.size() == num_outputs && dtypes.size() == num_outputs, + "strides, include_leading_zero, and dtypes must have matching lengths, but got ", + strides.size(), ", ", include_leading_zero.size(), ", and ", dtypes.size(), "."); + NVTE_CHECK(device.is_cuda(), "device must be CUDA, but got ", device.str(), "."); + + // Convert split sizes to int64 GPU tensor. + const at::Tensor split_sizes_i64 = + split_sizes.scalar_type() == at::kLong ? split_sizes : split_sizes.to(at::kLong); + const at::Tensor split_sizes_out = + split_sizes_i64.device() == device ? split_sizes_i64 : split_sizes_i64.to(device); + + // Allocate outputs. + std::vector outputs; + outputs.reserve(num_outputs); + if (bulk_allocate_outputs) { + std::vector> shapes; + shapes.reserve(num_outputs); + for (size_t i = 0; i < num_outputs; ++i) { + const size_t length = num_splits + (include_leading_zero[i] ? 1 : 0); + shapes.emplace_back(std::vector{length}); + } + // cuDNN CuTe DSL grouped GEMM kernels require padded_offsets + // aligned to 16 bytes. + const std::vector alignments(num_outputs, 16); + outputs = bulk_allocate(shapes, dtypes, device, alignments); + } else { + for (size_t i = 0; i < num_outputs; ++i) { + const int64_t length = static_cast(num_splits) + (include_leading_zero[i] ? 1 : 0); + outputs.emplace_back( + at::empty({length}, at::TensorOptions().dtype(dtypes[i]).device(device))); + } + } + + // Construct NVTETensors. + MultiTensorWrapper outputs_nvte(num_outputs); + std::vector include_leading_zero_int(num_outputs); + for (size_t i = 0; i < num_outputs; ++i) { + const size_t length = num_splits + (include_leading_zero[i] ? 1 : 0); + NVTEShape shape = nvte_make_shape(&length, 1); + NVTEBasicTensor data = {outputs[i].data_ptr(), + static_cast(GetTransformerEngineDType(dtypes[i])), shape}; + nvte_set_tensor_param_v2(outputs_nvte[i], kNVTERowwiseData, &data, sizeof(data)); + include_leading_zero_int[i] = include_leading_zero[i] ? 1 : 0; + } + + auto split_sizes_nvte = makeTransformerEngineTensor(split_sizes_out); + NVTE_SCOPED_GIL_RELEASE({ + nvte_splits_to_offsets_multi(split_sizes_nvte.data(), outputs_nvte.data(), strides.data(), + include_leading_zero_int.data(), num_outputs, + at::cuda::getCurrentCUDAStream()); + }); + + return {split_sizes_out, std::move(outputs)}; +} + at::Tensor copy_data_ptrs_to_device(const std::vector &tensors, const c10::Device &device) { // Collect data pointers diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 91fbb61e1c..38e151fcbb 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -179,11 +179,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Create an empty quantized tensor", py::arg("quantizer"), py::arg("shape"), py::arg("dtype"), py::arg("device"), py::arg("pin_memory")); m.def("group_quantize", transformer_engine::pytorch::group_quantize, py::arg("tensor"), - py::arg("quantizer"), py::arg("num_tensors"), py::arg("first_dims")); + py::arg("quantizer"), py::arg("num_tensors"), py::arg("first_dims"), + py::arg("tensor_offsets") = py::none()); m.def("group_dequantize", transformer_engine::pytorch::group_dequantize, "Dequantize group tensor", py::arg("input"), py::arg("otype")); m.def("bgrad_group_quantize", transformer_engine::pytorch::bgrad_group_quantize, - py::arg("tensor"), py::arg("quantizer"), py::arg("num_tensors"), py::arg("first_dims")); + py::arg("tensor"), py::arg("quantizer"), py::arg("num_tensors"), py::arg("first_dims"), + py::arg("tensor_offsets") = py::none()); m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize, "Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer")); m.def("generic_gemm", transformer_engine::pytorch::gemm, "Compute GEMM (matrix-matrix multiply)", @@ -508,6 +510,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("splits_to_offsets", &transformer_engine::pytorch::splits_to_offsets, "Compute grouped tensor offsets from split sizes", py::arg("first_dims"), py::arg("logical_last_dim"), py::call_guard()); + m.def("splits_to_offsets_multi", &transformer_engine::pytorch::splits_to_offsets_multi, + "Compute multiple scaled inclusive-scan offsets from a split-sizes vector", + py::arg("split_sizes"), py::arg("device"), py::kw_only(), py::arg("strides"), + py::arg("include_leading_zero"), py::arg("dtypes"), py::arg("bulk_allocate") = false); m.def("get_num_cublas_streams", &nvte_get_num_compute_streams, "Get number of compute streams", py::call_guard()); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index a54a301664..fcc4db06f0 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -174,11 +174,14 @@ std::pair NoneQuantizer::create_tensor(const std::vec std::pair NoneQuantizer::create_grouped_tensor( const size_t num_tensors, const std::vector& logical_shape, const DType dtype, py::object quantizer, const std::optional& first_dims, - const size_t logical_first_dim, const size_t logical_last_dim) const { + const std::optional& precomputed_tensor_offsets, const size_t logical_first_dim, + const size_t logical_last_dim) const { using namespace pybind11::literals; const auto tensor_offsets = - build_grouped_tensor_offsets(num_tensors, first_dims, logical_last_dim); + precomputed_tensor_offsets.has_value() + ? precomputed_tensor_offsets + : build_grouped_tensor_offsets(num_tensors, first_dims, logical_last_dim); const int64_t total_elements = static_cast(logical_first_dim) * static_cast(logical_last_dim); @@ -377,11 +380,14 @@ std::pair Float8Quantizer::create_tensor( std::pair Float8Quantizer::create_grouped_tensor( const size_t num_tensors, const std::vector& logical_shape, const DType dtype, py::object quantizer, const std::optional& first_dims, - const size_t logical_first_dim, const size_t logical_last_dim) const { + const std::optional& precomputed_tensor_offsets, const size_t logical_first_dim, + const size_t logical_last_dim) const { using namespace pybind11::literals; const auto tensor_offsets = - build_grouped_tensor_offsets(num_tensors, first_dims, logical_last_dim); + precomputed_tensor_offsets.has_value() + ? precomputed_tensor_offsets + : build_grouped_tensor_offsets(num_tensors, first_dims, logical_last_dim); const int64_t total_elements = static_cast(logical_first_dim) * static_cast(logical_last_dim); @@ -687,11 +693,14 @@ std::pair Float8CurrentScalingQuantizer::create_tenso std::pair Float8CurrentScalingQuantizer::create_grouped_tensor( const size_t num_tensors, const std::vector& logical_shape, const DType dtype, py::object quantizer, const std::optional& first_dims, - const size_t logical_first_dim, const size_t logical_last_dim) const { + const std::optional& precomputed_tensor_offsets, const size_t logical_first_dim, + const size_t logical_last_dim) const { using namespace pybind11::literals; const auto tensor_offsets = - build_grouped_tensor_offsets(num_tensors, first_dims, logical_last_dim); + precomputed_tensor_offsets.has_value() + ? precomputed_tensor_offsets + : build_grouped_tensor_offsets(num_tensors, first_dims, logical_last_dim); const int64_t total_elements = static_cast(logical_first_dim) * static_cast(logical_last_dim); @@ -1066,11 +1075,14 @@ std::pair Float8BlockQuantizer::create_tensor( std::pair Float8BlockQuantizer::create_grouped_tensor( const size_t num_tensors, const std::vector& logical_shape, const DType dtype, py::object quantizer, const std::optional& first_dims, - const size_t logical_first_dim, const size_t logical_last_dim) const { + const std::optional& precomputed_tensor_offsets, const size_t logical_first_dim, + const size_t logical_last_dim) const { using namespace pybind11::literals; const auto tensor_offsets = - build_grouped_tensor_offsets(num_tensors, first_dims, logical_last_dim); + precomputed_tensor_offsets.has_value() + ? precomputed_tensor_offsets + : build_grouped_tensor_offsets(num_tensors, first_dims, logical_last_dim); const int64_t total_elements = static_cast(logical_first_dim) * static_cast(logical_last_dim); @@ -1484,11 +1496,14 @@ std::pair MXFP8Quantizer::create_tensor( std::pair MXFP8Quantizer::create_grouped_tensor( const size_t num_tensors, const std::vector& logical_shape, const DType dtype, py::object quantizer, const std::optional& first_dims, - const size_t logical_first_dim, const size_t logical_last_dim) const { + const std::optional& precomputed_tensor_offsets, const size_t logical_first_dim, + const size_t logical_last_dim) const { using namespace pybind11::literals; const auto tensor_offsets = - build_grouped_tensor_offsets(num_tensors, first_dims, logical_last_dim); + precomputed_tensor_offsets.has_value() + ? precomputed_tensor_offsets + : build_grouped_tensor_offsets(num_tensors, first_dims, logical_last_dim); const int64_t total_elements = static_cast(logical_first_dim) * static_cast(logical_last_dim); @@ -1942,11 +1957,14 @@ std::pair NVFP4Quantizer::create_tensor( std::pair NVFP4Quantizer::create_grouped_tensor( const size_t num_tensors, const std::vector& logical_shape, const DType dtype, py::object quantizer, const std::optional& first_dims, - const size_t logical_first_dim, const size_t logical_last_dim) const { + const std::optional& precomputed_tensor_offsets, const size_t logical_first_dim, + const size_t logical_last_dim) const { using namespace pybind11::literals; const auto tensor_offsets = - build_grouped_tensor_offsets(num_tensors, first_dims, logical_last_dim); + precomputed_tensor_offsets.has_value() + ? precomputed_tensor_offsets + : build_grouped_tensor_offsets(num_tensors, first_dims, logical_last_dim); const int64_t total_elements = static_cast(logical_first_dim) * static_cast(logical_last_dim); NVTE_CHECK(total_elements % 2 == 0, "NVFP4 data size must be divisible by 2."); diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index 87911d76f4..b78cbae854 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -73,7 +73,13 @@ def _group_quantize_for_grouped_mlp( # Typical case: group-quantize if num_groups != 1 or not isinstance(quantizer, NVFP4Quantizer): - return tex.group_quantize(tensor, quantizer, num_groups, split_sizes) + return tex.group_quantize( + tensor, + quantizer, + num_groups, + split_sizes, + tensor_offsets=tensor_offsets, + ) # -------------------------------------------------- # Special case: single-tensor NVFP4 quantize diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 3d4eba9b65..447f35b159 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -208,11 +208,22 @@ def fuser_forward( split_sizes = fc1_split_sizes if int(split_sizes.numel()) != num_groups: raise ValueError(f"Expected {num_groups} splits, but got {int(split_sizes.numel())}.") - split_sizes = split_sizes.to(dtype=torch.int64, device=device) - base_split_offsets = tex.splits_to_offsets(split_sizes, 1) - split_points = base_split_offsets[1:].to(dtype=torch.int) - fc1_x_tensor_offsets = base_split_offsets * fc1_weight_shape[1] - fc2_x_tensor_offsets = base_split_offsets * fc2_weight_shape[1] + + # Prepare split metadata + split_sizes, ( + split_points, + base_split_offsets, + fc1_x_tensor_offsets, + fc2_x_tensor_offsets, + fc2_out_tensor_offsets, + ) = tex.splits_to_offsets_multi( + split_sizes, + device, + strides=[1, 1, fc1_weight_shape[1], fc2_weight_shape[1], fc2_weight_shape[0]], + include_leading_zero=[False, True, True, True, True], + dtypes=[torch.int32, torch.int64, torch.int64, torch.int64, torch.int64], + bulk_allocate=True, + ) # Extract per-row activation probabilities from the middle op. scales = basic_op_extra_inputs[1][0] @@ -539,7 +550,6 @@ def fuser_forward( else: fc2_out_buf = fc2_out_buf + token_bias else: - fc2_out_offsets = base_split_offsets * fc2_weight_shape[0] fc2_out_grouped = GroupedTensor( shape=(in_shape[0], fc2_weight_shape[0]), dtype=dtype, @@ -547,7 +557,7 @@ def fuser_forward( quantizer=None, data=fc2_out_buf.view(-1), first_dims=split_sizes, - tensor_offsets=fc2_out_offsets, + tensor_offsets=fc2_out_tensor_offsets, ) general_grouped_gemm_for_grouped_tensor( grouped_fc2_weight,