diff --git a/csrc/trtllm_mnnvl_allreduce.cu b/csrc/trtllm_mnnvl_allreduce.cu index 7c151fd91d..81f2c6d334 100644 --- a/csrc/trtllm_mnnvl_allreduce.cu +++ b/csrc/trtllm_mnnvl_allreduce.cu @@ -26,77 +26,90 @@ using tvm::ffi::Optional; } \ }() -void trtllm_mnnvl_all_reduce(TensorView in, int64_t multicast_buffer_ptr, int64_t buffer_ptrs_dev, - int64_t buffer_M, TensorView buffer_flags_mnnvl, int64_t nranks, - int64_t rank, bool wait_for_results, bool launch_with_pdl, - Optional out) { - ffi::CUDADeviceGuard device_guard(in.device().device_id); - auto stream = get_stream(in.device()); +void trtllm_mnnvl_allreduce_fusion(TensorView input, int64_t multicast_buffer_ptr, + int64_t buffer_ptrs_dev, int64_t buffer_ptr_local, + TensorView buffer_flags_mnnvl, int64_t nranks, int64_t rank, + bool rmsnorm_fusion, bool launch_with_pdl, bool use_oneshot, + TensorView output, Optional residual_out, + Optional residual_in, Optional gamma, + Optional epsilon) { + ffi::CUDADeviceGuard device_guard(input.device().device_id); + auto stream = get_stream(input.device()); - DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE(in.dtype(), c_type, [&] { + DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE(input.dtype(), c_type, [&] { // Extract parameters from tensors - int64_t num_tokens = in.size(0); - int64_t token_dim = in.size(1); + int64_t num_tokens = input.size(0); + int64_t token_dim = input.size(1); // Validate input parameters - TVM_FFI_ICHECK_EQ(token_dim % (sizeof(float2) / sizeof(c_type)), 0) - << "token_dim must be divisible by " << sizeof(float2) / sizeof(c_type); + TVM_FFI_ICHECK_EQ(token_dim % (sizeof(float4) / sizeof(c_type)), 0) + << "token_dim must be divisible by " << sizeof(float4) / sizeof(c_type); + TVM_FFI_ICHECK(output.size(0) == input.size(0) && output.size(1) == input.size(1)) + << "output shape mismatch: expected (" << input.size(0) << ", " << input.size(1) + << ") but got (" << output.size(0) << ", " << output.size(1) << ")"; TVM_FFI_ICHECK(nranks >= 2 && nranks <= 64) << "nranks must be between 2 and 64, got " << nranks; TVM_FFI_ICHECK(rank >= 0 && rank < nranks) << "rank must be between 0 and nranks-1, got " << rank; - TVM_FFI_ICHECK(out.has_value() || !wait_for_results) - << "out tensor must be provided if wait_for_results is true"; + TVM_FFI_ICHECK((residual_in.has_value() && residual_out.has_value() && gamma.has_value() && + epsilon.has_value()) || + !rmsnorm_fusion) + << "residual_in, residual_out, gamma, and epsilon must be provided if rmsnorm_fusion is " + "true"; + + if (rmsnorm_fusion) { + TVM_FFI_ICHECK(residual_in.value().size(0) == num_tokens && + residual_in.value().size(1) == token_dim) + << "residual_in shape mismatch: expected (" << input.size(0) << ", " << input.size(1) + << ") but got (" << residual_in.value().size(0) << ", " << residual_in.value().size(1) + << ")"; + TVM_FFI_ICHECK(residual_out.value().size(0) == num_tokens && + residual_out.value().size(1) == token_dim) + << "residual_out shape mismatch: expected (" << input.size(0) << ", " << input.size(1) + << ") but got (" << residual_out.value().size(0) << ", " << residual_out.value().size(1) + << ")"; + TVM_FFI_ICHECK(gamma.value().size(0) == token_dim) + << "gamma must have the same shape as token dimension (" << token_dim << ") but got (" + << gamma.value().size(0) << ")"; + } // Create the parameters struct - AllReduceParams params; - params.nranks = nranks; - params.rank = rank; - params.buffer_M = buffer_M; - params.num_tokens = num_tokens; - params.token_dim = token_dim; - params.buffer_ptrs_dev = reinterpret_cast(buffer_ptrs_dev); - params.multicast_ptr = reinterpret_cast(multicast_buffer_ptr); - params.buffer_flags = buffer_flags_mnnvl.data_ptr(); - params.wait_for_results = wait_for_results; - params.launch_with_pdl = launch_with_pdl; - params.input = in.data_ptr(); - params.output = out.has_value() ? out.value().data_ptr() : nullptr; - params.stream = stream; + AllReduceFusionParams params; - auto status = twoshot_allreduce_dispatch_world_size(params); - TVM_FFI_ICHECK(status == cudaSuccess) - << "twoshot_allreduce_dispatch_world_size failed with error code " - << cudaGetErrorString(status); - }); -} + // Aux Information + params.nRanks = nranks; + params.rank = rank; + params.numTokens = num_tokens; + params.tokenDim = token_dim; + params.bufferPtrsDev = reinterpret_cast(buffer_ptrs_dev); + params.bufferPtrLocal = reinterpret_cast(buffer_ptr_local); + params.multicastPtr = reinterpret_cast(multicast_buffer_ptr); + params.bufferFlags = reinterpret_cast(buffer_flags_mnnvl.data_ptr()); + params.rmsNormFusion = rmsnorm_fusion; + params.launchWithPdl = launch_with_pdl; -void trtllm_mnnvl_rmsnorm(int64_t multicast_buffer_ptr, TensorView prenorm_output, - TensorView normed_output, TensorView gamma, double epsilon, - TensorView residual, TensorView buffer_flags, bool launch_with_pdl) { - ffi::CUDADeviceGuard device_guard(prenorm_output.device().device_id); - auto stream = get_stream(prenorm_output.device()); + // input data + params.input = const_cast(input.data_ptr()); + params.residualIn = + residual_in.has_value() ? const_cast(residual_in.value().data_ptr()) : nullptr; + params.gamma = gamma.has_value() ? const_cast(gamma.value().data_ptr()) : nullptr; + params.epsilon = epsilon.has_value() ? epsilon.value() : 1e-5; - DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE(prenorm_output.dtype(), c_type, [&] { - // Create the parameters struct - RMSNormParams params; - params.residual_output = prenorm_output.data_ptr(); - params.output = normed_output.data_ptr(); - params.input = reinterpret_cast(multicast_buffer_ptr); - params.gamma = gamma.data_ptr(); - params.epsilon = epsilon; - params.residual = residual.data_ptr(); - params.buffer_flags = reinterpret_cast(buffer_flags.data_ptr()); - params.batch = normed_output.size(0); - params.hidden_dim = normed_output.size(1); + // output data + params.output = const_cast(output.data_ptr()); + params.residualOut = + residual_out.has_value() ? const_cast(residual_out.value().data_ptr()) : nullptr; params.stream = stream; - params.launch_with_pdl = launch_with_pdl; - auto status = twoshot_rmsnorm_dispatch_hidden_dim(params); + + cudaError_t status; + if (use_oneshot) { + status = oneshotAllreduceFusionDispatch(params); + } else { + status = twoshotAllreduceFusionDispatch(params); + } TVM_FFI_ICHECK(status == cudaSuccess) - << "twoshot_rmsnorm_dispatch_hidden_dim failed with error code " - << cudaGetErrorString(status); + << "trtllm_mnnvl_allreduce_fusion failed with error code " << cudaGetErrorString(status); }); } -TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_mnnvl_all_reduce, trtllm_mnnvl_all_reduce); -TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_mnnvl_rmsnorm, trtllm_mnnvl_rmsnorm); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_mnnvl_allreduce_fusion, trtllm_mnnvl_allreduce_fusion); diff --git a/flashinfer/comm/__init__.py b/flashinfer/comm/__init__.py index f7ae3754ac..6d945980be 100644 --- a/flashinfer/comm/__init__.py +++ b/flashinfer/comm/__init__.py @@ -39,4 +39,15 @@ from .vllm_ar import register_buffer as vllm_register_buffer from .vllm_ar import register_graph_buffers as vllm_register_graph_buffers +# Unified AllReduce Fusion API +from .allreduce import AllReduceFusionWorkspace as AllReduceFusionWorkspace +from .trtllm_mnnvl_ar import ( + MNNVLAllReduceFusionWorkspace as MNNVLAllReduceFusionWorkspace, +) +from .allreduce import TRTLLMAllReduceFusionWorkspace as TRTLLMAllReduceFusionWorkspace +from .allreduce import allreduce_fusion as allreduce_fusion +from .allreduce import ( + create_allreduce_fusion_workspace as create_allreduce_fusion_workspace, +) + # from .mnnvl import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo diff --git a/flashinfer/comm/allreduce.py b/flashinfer/comm/allreduce.py new file mode 100644 index 0000000000..b9c3b20ed9 --- /dev/null +++ b/flashinfer/comm/allreduce.py @@ -0,0 +1,697 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +""" +Unified AllReduce Fusion API + +This module provides a unified interface for AllReduce + RMSNorm fusion operations +across different backends (TensorRT-LLM, MNNVL). + +Example usage: + >>> # Auto-select best backend based on topology + >>> workspace = create_allreduce_fusion_workspace( + ... backend="auto", + ... world_size=8, + ... rank=0, + ... max_token_num=2048, + ... hidden_dim=4096, + ... dtype=torch.bfloat16, + ... topology="single_node" + ... ) + >>> + >>> # Perform AllReduce + RMSNorm fusion + >>> prenorm = torch.empty_like(hidden_states) + >>> normed = torch.empty_like(hidden_states) + >>> output = allreduce_fusion( + ... input=hidden_states, + ... workspace=workspace, + ... launch_with_pdl=True, + ... residual_out=prenorm, + ... norm_out=normed, + ... residual_in=residual, + ... rms_gamma=norm_weight + ... ) + >>> + >>> workspace.destroy() +""" + +from typing import Union, Literal, Optional, Tuple, List, cast, Any +from .workspace_base import AllReduceFusionWorkspace + +import torch + +from .trtllm_ar import trtllm_allreduce_fusion +from .trtllm_ar import trtllm_create_ipc_workspace_for_all_reduce_fusion +from .trtllm_ar import trtllm_destroy_ipc_workspace_for_all_reduce_fusion +from .trtllm_ar import check_trtllm_allreduce_fusion_workspace_metadata + +from .mapping import Mapping + +from .mnnvl import CommBackend + +# Note: AllReduceFusionPattern and QuantizationSFLayout are pseudo-types (classes with int constants) +# Import them for runtime use but type hint as int for mypy compatibility +from .trtllm_ar import AllReduceFusionPattern +from .trtllm_mnnvl_ar import MNNVLAllReduceFusionWorkspace +from .trtllm_mnnvl_ar import trtllm_mnnvl_allreduce +from .trtllm_mnnvl_ar import trtllm_mnnvl_fused_allreduce_add_rmsnorm + +# ============================================================================ +# WORKSPACE IMPLEMENTATIONS +# ============================================================================ +# +# Workspace classes wrap the underlying backend workspace implementations: +# - TRTLLMAllReduceFusionWorkspace: Wraps trtllm_create_ipc_workspace_for_all_reduce_fusion +# - MNNVLAllReduceFusionWorkspace: Wraps MNNVL workspace (see trtllm_mnnvl_ar.py) +# +# Each workspace: +# 1. Calls the backend-specific workspace creation function in __init__ +# 2. Stores the internal workspace as _internal_workspace +# 3. Exposes essential attributes for the unified API +# 4. Can be destroyed using workspace.destroy() +# ============================================================================ + + +class TRTLLMAllReduceFusionWorkspace(AllReduceFusionWorkspace): + """TensorRT-LLM workspace for AllReduce fusion.""" + + def __init__( + self, + tp_size: int, + tp_rank: int, + max_token_num: int, + hidden_dim: int, + dtype: torch.dtype = torch.float16, + process_group: Optional["torch.distributed.ProcessGroup"] = None, + ): + """ + Create TensorRT-LLM AllReduce fusion workspace. + + Args: + tp_size: Tensor parallel size (world size) + tp_rank: Tensor parallel rank + max_token_num: Maximum number of tokens + hidden_dim: Hidden dimension size + dtype: Data type + process_group: PyTorch distributed process group + **kwargs: Additional arguments for workspace creation + """ + super().__init__(tp_size, tp_rank) + + # Call the actual workspace creation function + self._internal_workspace = trtllm_create_ipc_workspace_for_all_reduce_fusion( + tp_rank=tp_rank, + tp_size=tp_size, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + group=process_group, + create_metadata=True, + use_fp32_lamport=dtype == torch.float32, + ) + + # Store essential attributes for easy access + # Cast to 3-tuple to make linter happy, since we always call with create_metadata=True + workspace_tuple = cast( + Tuple[List[List[int]], torch.Tensor, dict], self._internal_workspace + ) + self.ipc_handles = workspace_tuple[0] + self.workspace_tensor = workspace_tuple[1] + self.metadata = workspace_tuple[2] + + @property + def backend(self) -> str: + return "trtllm" + + def __getattr__(self, name): + """Delegate attribute access to internal workspace if not found.""" + if name.startswith("_"): + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'" + ) + return getattr(self._internal_workspace, name) + + def is_buffer_size_sufficient( + self, + tp_size: int, + num_tokens: int, + hidden_dim: int, + dtype: torch.dtype, + use_oneshot: Optional[Any] = None, + ) -> bool: + try: + check_trtllm_allreduce_fusion_workspace_metadata( + num_tokens, hidden_dim, tp_size, dtype, self.metadata + ) + return True + except ValueError as e: + print(f"Workspace is insufficient for problem size. {e}") + return False + + def destroy(self) -> None: + """Destroy workspace and free resources.""" + if self._destroyed is True: + return # Already destroyed, nothing to do + + trtllm_destroy_ipc_workspace_for_all_reduce_fusion(self.ipc_handles) + self._destroyed = True + + +# ============================================================================ +# BACKEND CHECKS - Hard requirements for backend selection +# ============================================================================ + + +def _trtllm_workspace_check( + backend: str, + world_size: int, + rank: int, + max_token_num: int, + hidden_dim: int, + dtype: torch.dtype, + topology: str, + **kwargs, +) -> bool: + """ + Check if trtllm backend CAN be used for workspace creation. + + Hard requirements: + - Single-node topology (multi-node not supported) + + """ + # trtllm is optimized for single-node + if topology == "multi_node": + return False + + return True + + +def _mnnvl_workspace_check( + backend: str, + world_size: int, + rank: int, + max_token_num: int, + hidden_dim: int, + dtype: torch.dtype, + topology: str, + **kwargs, +) -> bool: + """ + Check if mnnvl backend CAN be used for workspace creation. + + """ + + if topology == "multi_node": + return True + + return False + + +# ============================================================================ +# HEURISTIC - Performance-based backend selection +# ============================================================================ + + +def _workspace_creation_heuristic( + suitable_backends: list[str], + backend: str, + world_size: int, + rank: int, + max_token_num: int, + hidden_dim: int, + dtype: torch.dtype, + topology: str, + # TODO(nvmbreughe): Remove this + **kwargs, +) -> list[str]: + """ + Select best backend for workspace creation based on performance. + + Called by decorator after checking which backends pass requirements. + Uses benchmarking data to pick fastest option. + + Args: + suitable_backends: List of backends that passed hard requirement checks + backend: Requested backend ("auto", "trtllm", or "mnnvl") + world_size: Number of ranks + rank: Current rank + max_token_num: Maximum number of tokens + hidden_dim: Hidden dimension size + dtype: Data type + topology: Network topology ("single_node" or "multi_node") + **kwargs: Additional arguments + + Returns: + List containing the selected backend (single element) + """ + if not suitable_backends: + return [] + + if len(suitable_backends) == 1: + return suitable_backends + + # Decision tree based on benchmark data + # TODO: Replace with actual benchmarking results + + # Multi-node: MNNVL is designed for this + if topology == "multi_node": + if "mnnvl" in suitable_backends: + return ["mnnvl"] + + # Single-node scenarios + elif "trtllm" in suitable_backends: + return ["trtllm"] + else: + return [] + + +# ============================================================================ +# WORKSPACE CREATION +# ============================================================================ + + +def create_allreduce_fusion_workspace( + backend: Literal["trtllm", "mnnvl", "auto"] = "auto", + world_size: int = None, + rank: int = None, + max_token_num: int = None, + hidden_dim: int = None, + dtype: torch.dtype = None, + topology: str = "single_node", + process_group: Optional["torch.distributed.ProcessGroup"] = None, + gpus_per_node: int = None, + comm_backend: Optional[CommBackend] = None, +) -> AllReduceFusionWorkspace: + """ + Create workspace for AllReduce fusion operations. + + Backend selection uses topology-based checks and heuristics. + + **Important: Workspace Reusability** + The workspace is allocated based on the total size (max_token_num * hidden_dim * dtype_size). + You can reuse the same workspace with different shapes as long as the total size fits: + + - Workspace(max_token_num=2048, hidden_dim=4096) can handle: + - (token_num=2048, hidden_dim=4096) ✓ + - (token_num=1024, hidden_dim=4096) ✓ + - (token_num=4096, hidden_dim=2048) ✓ (same total size) + - (token_num=1024, hidden_dim=8192) ✓ (same total size) + - (token_num=4096, hidden_dim=4096) ✗ (too large) + + Use `workspace.is_sufficient_for(token_num, hidden_dim, dtype)` to check before use. + + Args: + backend: Backend to use ("trtllm", "mnnvl", or "auto") + "auto" uses heuristic to select best backend based on topology + and problem size + world_size: Number of ranks in the process group + rank: Current rank ID + max_token_num: Maximum number of tokens to support + hidden_dim: Hidden dimension size + dtype: Data type for communication tensors + topology: Network topology hint for backend selection + "single_node" - All ranks on one node (default) + "multi_node" - Ranks span multiple nodes + process_group: PyTorch distributed process group (for trtllm backend). + gpus_per_node: Number of GPUs per node (for multi-node topology). + comm_backend: Communication backend to use (for multi-node topology). + + Returns: + Workspace object (TRTLLMAllReduceFusionWorkspace or MNNVLAllReduceFusionWorkspace) + The workspace type determines which backend will be used in allreduce_fusion() + + Raises: + BackendSupportedError: If no suitable backend available for the configuration + ValueError: If problem size not supported for the specified backend + + Examples: + >>> # Auto-select best backend based on topology + >>> workspace = create_allreduce_fusion_workspace( + ... backend="auto", + ... world_size=8, + ... rank=0, + ... max_token_num=2048, + ... hidden_dim=4096, + ... dtype=torch.bfloat16, + ... topology="single_node" + ... ) + >>> print(workspace.backend) # "trtllm" + >>> print(workspace.get_workspace_capacity()) # 8388608 elements + + >>> # Check if workspace can handle different problem sizes + >>> workspace.is_sufficient_for(1024, 4096, 8, torch.bfloat16) # True + >>> workspace.is_sufficient_for(4096, 2048, 8, torch.bfloat16) # True (same total) + + >>> # Explicit backend selection + >>> workspace = create_allreduce_fusion_workspace( + ... backend="mnnvl", + ... world_size=16, + ... rank=0, + ... max_token_num=2048, + ... hidden_dim=4096, + ... dtype=torch.bfloat16, + ... topology="multi_node" + ... ) + >>> print(workspace.backend) # "mnnvl" + """ + if gpus_per_node is None: + gpus_per_node = min(torch.cuda.device_count(), world_size) + # Determine the actual backend to use + if backend == "auto": + # Find suitable backends based on topology (anny CC check needs to be checked at kernel runtime, since there are no tensor available at this point) + suitable_backends = [] + if _trtllm_workspace_check( + backend=backend, + world_size=world_size, + rank=rank, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + dtype=dtype, + topology=topology, + ): + suitable_backends.append("trtllm") + if _mnnvl_workspace_check( + backend=backend, + world_size=world_size, + rank=rank, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + dtype=dtype, + topology=topology, + ): + suitable_backends.append("mnnvl") + + if not suitable_backends: + raise ValueError( + f"No suitable backend found for topology={topology}. " + f"trtllm requires single_node topology, mnnvl works with both." + ) + + # Apply heuristic to select best backend + selected = _workspace_creation_heuristic( + suitable_backends=suitable_backends, + backend=backend, + world_size=world_size, + rank=rank, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + dtype=dtype, + topology=topology, + ) + actual_backend = selected[0] if selected else suitable_backends[0] + else: + actual_backend = backend + + # Create workspace for selected backend using workspace constructors + if actual_backend == "trtllm": + return TRTLLMAllReduceFusionWorkspace( + tp_size=world_size, + tp_rank=rank, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + dtype=dtype, + process_group=process_group, + ) + + elif actual_backend == "mnnvl": + mapping = Mapping( + world_size=world_size, + rank=rank, + gpus_per_node=gpus_per_node, + tp_size=world_size, + ) + return MNNVLAllReduceFusionWorkspace( + mapping=mapping, + max_num_tokens=max_token_num, + hidden_dim=hidden_dim, + dtype=dtype, + comm_backend=comm_backend, + ) + else: + raise RuntimeError(f"Unknown backend: {actual_backend}") + + +# ============================================================================ +# MAIN API - NO backend parameter, infers from workspace type +# ============================================================================ + + +def allreduce_fusion( + input: torch.Tensor, + workspace: AllReduceFusionWorkspace, + pattern: int, + launch_with_pdl: bool = False, + # ===== OUTPUT tensors (pre-allocated, will be filled) ===== + output: Optional[torch.Tensor] = None, + residual_out: Optional[torch.Tensor] = None, + norm_out: Optional[torch.Tensor] = None, + quant_out: Optional[torch.Tensor] = None, + scale_out: Optional[torch.Tensor] = None, + # ===== INPUT parameters ===== + residual_in: Optional[torch.Tensor] = None, + rms_gamma: Optional[torch.Tensor] = None, + rms_eps: float = 1e-6, + scale_factor: Optional[Union[torch.Tensor, float]] = None, + layout_code: Optional[int] = None, + # ===== Control parameters ===== + use_oneshot: Optional[bool] = None, + fp32_acc: bool = False, +) -> torch.Tensor: + """ + AllReduce + RMSNorm fusion operation. + + Backend is automatically determined from workspace type. + No backend parameter needed! + + Supports multiple fusion patterns: + - AllReduce only + - AllReduce + Residual + RMSNorm + - AllReduce + Residual + RMSNorm + Quantization (FP8/FP4) + + **Note on Workspace Reusability:** + You can reuse the same workspace with different (token_num, hidden_dim) combinations + as long as `workspace.is_sufficient_for(token_num, hidden_dim, tp_size, dtype)` returns True. + + Args: + input: Input tensor [token_num, hidden_dim] + workspace: Workspace object (type determines backend) + pattern: Fusion pattern (AllReduceFusionPattern constant, 0-5) + - kAllReduce = 0 + - kARResidualRMSNorm = 1 + - kARResidualRMSNormFP8Quant = 2 + - kARResidualRMSNormFP4Quant = 3 + - kARResidualRMSNormOutFP8Quant = 4 + - kARResidualRMSNormOutFP4Quant = 5 + Note: MNNVL only supports patterns 0 and 1 + launch_with_pdl: Use Persistent Dependency Launch + + # ===== OUTPUT tensors (pre-allocated, filled by function) ===== + output: AllReduce output [token_num, hidden_dim] + residual_out: Prenorm output (after residual add, before norm) [token_num, hidden_dim] + norm_out: Normalized output [token_num, hidden_dim] + quant_out: Quantized output [token_num, hidden_dim] [trtllm only] + scale_out: Quantization scale factors [trtllm only] + + # ===== INPUT parameters ===== + residual_in: Residual tensor to ADD [token_num, hidden_dim] + rms_gamma: RMSNorm weight [hidden_dim] + rms_eps: RMSNorm epsilon for numerical stability + scale_factor: Input scale factor for quantization [trtllm only] + layout_code: Scale factor layout (QuantizationSFLayout) [trtllm only] + + # ===== Control parameters ===== + use_oneshot: Use oneshot strategy vs twoshot + If None, uses internal heuristics. + Note that the MNNVL backend needs to be initialized with a sufficiently large workspace if one_shot is used. + fp32_acc: [trtllm only] Use FP32 accumulation for AllReduce + + Returns: + Output tensor (typically norm_out for fusion cases, output otherwise) + + Examples: + >>> # Basic AllReduce + Residual + RMSNorm + >>> workspace = create_allreduce_fusion_workspace( + ... backend="auto", + ... world_size=8, + ... rank=0, + ... max_token_num=2048, + ... hidden_dim=4096, + ... dtype=torch.bfloat16, + ... topology="single_node" + ... ) + >>> + >>> # Pre-allocate output tensors + >>> prenorm = torch.empty_like(hidden_states) + >>> normed = torch.empty_like(hidden_states) + >>> + >>> # Call fusion - backend inferred from workspace type + >>> output = allreduce_fusion( + ... input=hidden_states, + ... workspace=workspace, + ... pattern=AllReduceFusionPattern.kARResidualRMSNorm, + ... launch_with_pdl=True, + ... residual_out=prenorm, + ... norm_out=normed, + ... residual_in=residual, + ... rms_gamma=norm_weight + ... ) + >>> # output == normed (final result) + + >>> # With FP8 quantization + >>> quant = torch.empty_like(hidden_states, dtype=torch.float8_e4m3fn) + >>> scales = torch.empty(token_num * hidden_dim // 16, dtype=torch.float16) + >>> + >>> output = allreduce_fusion( + ... input=hidden_states, + ... workspace=workspace, + ... pattern=AllReduceFusionPattern.kARResidualRMSNormFP8Quant, + ... norm_out=normed, + ... quant_out=quant, + ... scale_out=scales, + ... residual_in=residual, + ... rms_gamma=norm_weight, + ... scale_factor=scale_tensor + ... ) + """ + # Dispatch based on workspace type + if isinstance(workspace, TRTLLMAllReduceFusionWorkspace): + # TensorRT-LLM backend implementation + # Extract shape from 2D input + token_num, hidden_dim = input.shape + + # Allocate output if needed (keep 2D shape) + if output is None: + output = torch.empty_like(input) + + # Flatten all tensors to 1D for legacy trtllm_allreduce_fusion API + # The legacy API expects flattened tensors and explicit token_num/hidden_dim + # We require contiguous tensors so that view(-1) creates a view (not a copy), + # ensuring writes to the flattened tensors are reflected in the original 2D tensors + def _flatten_checked(t, name): + if not t.is_contiguous(): + raise ValueError(f"{name} must be contiguous") + return t.view(-1) + + input_flat = _flatten_checked(input, "input") + output_flat = _flatten_checked(output, "output") + residual_in_flat = ( + _flatten_checked(residual_in, "residual_in") + if residual_in is not None + else None + ) + residual_out_flat = ( + _flatten_checked(residual_out, "residual_out") + if residual_out is not None + else None + ) + norm_out_flat = ( + _flatten_checked(norm_out, "norm_out") if norm_out is not None else None + ) + quant_out_flat = ( + _flatten_checked(quant_out, "quant_out") if quant_out is not None else None + ) + + # Call legacy API with flattened tensors + # Note: pattern and layout_code are ints but legacy API uses pseudo-type hints + trtllm_allreduce_fusion( + allreduce_in=input_flat, + world_size=workspace.world_size, + world_rank=workspace.rank, + token_num=token_num, + hidden_dim=hidden_dim, + workspace_ptrs=workspace.workspace_tensor, + launch_with_pdl=launch_with_pdl, + trigger_completion_at_end=launch_with_pdl, # Same meaning + fp32_acc=fp32_acc, + pattern_code=pattern, # type: ignore[arg-type] + use_oneshot=use_oneshot, + allreduce_out=output_flat, + residual_in=residual_in_flat, + residual_out=residual_out_flat, + norm_out=norm_out_flat, + quant_out=quant_out_flat, + scale_out=scale_out, # scale_out is not reshaped + rms_gamma=rms_gamma, # 1D tensor, no reshape needed + rms_eps=rms_eps, + scale_factor=scale_factor, + layout_code=layout_code, # type: ignore[arg-type] + metadata=workspace.metadata, + ) + + # Return the most downstream output (already in 2D shape from input views) + if norm_out is not None: + return norm_out + elif quant_out is not None: + return quant_out + else: + return output + + elif isinstance(workspace, MNNVLAllReduceFusionWorkspace): + if ( + pattern != AllReduceFusionPattern.kARResidualRMSNorm + and pattern != AllReduceFusionPattern.kAllReduce + ): + raise ValueError( + f"MNNVL AllReduce+RMS fusion does not support pattern {pattern}" + ) + + # MNNVL backend implementation + if pattern == AllReduceFusionPattern.kAllReduce: + # AllReduce only + if output is None: + output = torch.empty_like(input) + trtllm_mnnvl_allreduce( + input=input, + workspace=workspace, + launch_with_pdl=launch_with_pdl, + output=output, + ) + return output + + elif pattern == AllReduceFusionPattern.kARResidualRMSNorm: + # AllReduce + Residual + RMSNorm fusion + # Validate required parameters + if residual_in is None: + raise ValueError("MNNVL AllReduce+RMS fusion requires residual_in") + if rms_gamma is None: + raise ValueError("MNNVL AllReduce+RMS fusion requires rms_gamma") + + # Allocate output tensors if not provided + if norm_out is None: + norm_out = torch.empty_like(input) + if residual_out is None: + residual_out = torch.empty_like(input) + + # Call the MNNVL fusion function + norm_result, residual_result = trtllm_mnnvl_fused_allreduce_add_rmsnorm( + input=input, + residual_in=residual_in, + gamma=rms_gamma, + workspace=workspace, + epsilon=rms_eps, + output=norm_out, + residual_out=residual_out, + launch_with_pdl=launch_with_pdl, + ) + return norm_result + + else: + raise ValueError(f"Unsupported pattern for MNNVL backend: {pattern}") + + else: + raise TypeError( + f"Unknown workspace type: {type(workspace)}. " + f"Expected TRTLLMAllReduceFusionWorkspace or MNNVLAllReduceFusionWorkspace" + ) diff --git a/flashinfer/comm/mnnvl.py b/flashinfer/comm/mnnvl.py index 2d280a68e8..13ca4f534d 100644 --- a/flashinfer/comm/mnnvl.py +++ b/flashinfer/comm/mnnvl.py @@ -16,6 +16,12 @@ import ctypes import logging import os +import socket +import array +import random + +import contextlib + from abc import ABC, abstractmethod from dataclasses import dataclass import platform @@ -123,7 +129,7 @@ def test_cuda_memory_access(ptr: int, size: int, device_id: int) -> bool: return False -def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> Optional[int]: +def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int: """ A helper function that allocates memory on cuda and copies the data from the host to the device. """ @@ -140,7 +146,7 @@ def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> Optional[int]: ) # c_array should be freed by GC - return device_ptr + return int(device_ptr) class CommBackend(ABC): @@ -155,6 +161,9 @@ def Get_size(self) -> int: ... @abstractmethod def allgather(self, data: int) -> List[int]: ... + @abstractmethod + def bcast(self, data: Any, root: int) -> Any: ... + @abstractmethod def barrier(self) -> None: ... @@ -212,6 +221,9 @@ def Get_size(self) -> int: def allgather(self, data: int) -> List[int]: return self._mpicomm.allgather(data) + def bcast(self, data: Any, root: int) -> Any: + return self._mpicomm.bcast(data, root) + def barrier(self): self._mpicomm.Barrier() @@ -551,6 +563,208 @@ def supports_mnnvl() -> bool: return support_nvlink_and_all_up +# The helper class for passing the FD handle over the socket. +class IpcSocket: + """Unix Domain Socket for IPC file descriptor passing""" + + def __init__(self, rank: int, op_id: int, use_abstract=True): + """ + Initialize IPC socket + + Args: + rank: Process rank + op_id: Unique operation ID (hash) + use_abstract: Use Linux abstract socket namespace + """ + self.rank = rank + self.op_id = op_id + self.use_abstract = use_abstract + + # Create Unix domain socket (DGRAM for compatibility with C code) + self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + + # Create unique socket name + socket_name = f"/tmp/mcastmem-socket-{rank}-{op_id:x}" + + if use_abstract: + # Linux abstract socket: prepend null byte + self.socket_path = "\0" + socket_name + else: + self.socket_path = socket_name + # Remove existing socket file if it exists + with contextlib.suppress(FileNotFoundError): + os.unlink(socket_name) + + # Bind socket + self.sock.bind(self.socket_path) + + def send_fd(self, fd: int, dest_rank: int, dest_op_id: Optional[int] = None): + """ + Send a file descriptor to another process + + Args: + fd: File descriptor to send + dest_rank: Destination process rank + dest_op_id: Destination operation ID + """ + # Construct destination socket path + dest_op_id = dest_op_id or self.op_id + dest_socket_name = f"/tmp/mcastmem-socket-{dest_rank}-{dest_op_id:x}" + + if self.use_abstract: + dest_path = "\0" + dest_socket_name + else: + dest_path = dest_socket_name + + # Prepare message with file descriptor + # Send dummy byte as data (required) + dummy_data = b"\x00" + + # Pack file descriptor in ancillary data (SCM_RIGHTS) + fds = array.array("i", [fd]) + ancillary = [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fds.tobytes())] + + # Send message with file descriptor + self.sock.sendmsg([dummy_data], ancillary, 0, dest_path) + + def recv_fd(self): + """ + Receive a file descriptor from another process + + Returns: + int: Received file descriptor + """ + # Receive message with ancillary data + # Maximum size for ancillary data containing one fd + fds = array.array("i") + msg, ancdata, flags, addr = self.sock.recvmsg( + 1, + socket.CMSG_SPACE( + fds.itemsize + ), # Buffer size for dummy data # Ancillary data size + ) + + # Extract file descriptor from ancillary data + for cmsg_level, cmsg_type, cmsg_data in ancdata: + if cmsg_level == socket.SOL_SOCKET and cmsg_type == socket.SCM_RIGHTS: + fds = array.array("i") + fds.frombytes( + cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)] + ) + return fds[0] + + raise RuntimeError("No file descriptor received") + + def close(self): + """Close the socket""" + self.sock.close() + if not self.use_abstract and self.socket_path: + with contextlib.suppress(FileNotFoundError): + os.unlink(self.socket_path) + + +class HandleExchanger(ABC): + """Abstract interface for exchanging CUDA shareable handles across ranks.""" + + def __init__(self, comm_backend: "CommBackend", group_rank: int, group_size: int): + self.comm = comm_backend + self.rank = group_rank + self.size = group_size + + @property + @abstractmethod + def handle_type(self) -> cuda.CUmemAllocationHandleType: + """The CUDA handle type this exchanger works with.""" + ... + + @abstractmethod + def allgather(self, local_handle) -> List: + """All-gather shareable handles from all ranks.""" + ... + + @abstractmethod + def broadcast(self, handle, root: int): + """Broadcast a handle from root to all ranks.""" + ... + + @abstractmethod + def cleanup(self, handle) -> None: ... + + @abstractmethod + def close(self) -> None: ... + + +class FabricHandleExchanger(HandleExchanger): + """Handle exchange using CUDA Fabric handles via MPI/collective backend.""" + + @property + def handle_type(self) -> cuda.CUmemAllocationHandleType: + return cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC + + def allgather(self, local_handle) -> List: + return self.comm.allgather(local_handle.data) + + def broadcast(self, handle, root: int): + return self.comm.bcast(handle.data if handle else None, root=root) + + def cleanup(self, handle) -> None: + pass # No cleanup needed for Fabric handles. + + def close(self) -> None: + pass # No close needed for Fabric handles. + + +class PosixFDHandleExchanger(HandleExchanger): + """Handle exchange using POSIX file descriptors via IPC sockets.""" + + def __init__(self, comm_backend: "CommBackend", group_rank: int, group_size: int): + super().__init__(comm_backend, group_rank, group_size) + self._socket = self._init_ipc_socket() + + def _init_ipc_socket(self) -> IpcSocket: + if self.rank == 0: + opId = random.randint(0, 2**64 - 1) + else: + opId = None + opId = self.comm.bcast(opId, root=0) + return IpcSocket(self.rank, opId) + + @property + def handle_type(self) -> cuda.CUmemAllocationHandleType: + return cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR + + def allgather(self, local_handle) -> List: + result = [None] * self.size + for i in range(self.size): + self.comm.barrier() + self._socket.send_fd(local_handle, (self.rank + i) % self.size) + src = (self.rank + self.size - i) % self.size + result[src] = self._socket.recv_fd() + return result + + def broadcast(self, handle, root: int): + if self.rank == root: + for p in range(1, self.size): + self.comm.barrier() + self._socket.send_fd(handle, p) + return handle + else: + # Ordered receive to avoid race condition + for _ in range(self.rank): + self.comm.barrier() + result = self._socket.recv_fd() + for _ in range(self.size - self.rank - 1): + self.comm.barrier() + return result + + def cleanup(self, handle) -> None: + os.close(handle) + + def close(self) -> None: + self._socket.close() + + +# TODO: This class follows similar logic with MnnvlMemory, but the latter use single instance mode to manage the memory allocation. class McastDeviceMemory: """Python port of McastDeviceMemory from TensorRT-LLM""" @@ -588,6 +802,7 @@ def __init__( self.buf_size = buf_size self.signal_pad_offset = 0 self.allocation_size = 0 + self.comm_backend = comm_backend_for_handle_transfer or MPIBackend() # CUDA memory handles and pointers self.mc_ptr = 0 # CUdeviceptr mMcPtr @@ -625,6 +840,7 @@ def __init__( f"Signal pad offset: {self.signal_pad_offset}" ) + # Create handle exchanger based on multi-node mode if self.is_multi_node: # Check if fabric handle is supported fabric_handle_supported = checkCudaErrors( @@ -637,11 +853,14 @@ def __init__( raise RuntimeError( "[McastDeviceMemory] Device does not support fabric handle." ) - - self._alloc_mn_mcast_mem(buf_size, comm_backend_for_handle_transfer) + self._exchanger: HandleExchanger = FabricHandleExchanger( + self.comm_backend, self.group_rank, self.group_size + ) else: - # For single-node NVLS, would need to implement _alloc_nvls_mcast_mem - raise NotImplementedError("Single-node NVLS allocation not implemented yet") + self._exchanger = PosixFDHandleExchanger( + self.comm_backend, self.group_rank, self.group_size + ) + self._alloc_mn_mcast_mem(buf_size) # Initialize signal pads self.signal_pads = [0] * self.group_size @@ -663,8 +882,8 @@ def __del__(self): if not hasattr(self, "is_multi_node"): return - if not self.is_multi_node: - return + if hasattr(self, "_exchanger"): + self._exchanger.close() # Skip cleanup during Python finalization to avoid segfaults # Especially cause the CUDA context could be destroyed at this point. @@ -760,48 +979,58 @@ def get_world_size(self) -> int: """Get the total number of devices in the group""" return self.group_size - def _alloc_mn_mcast_mem( - self, buf_size: int, comm_backend_for_handle_transfer: Any = None - ): + def get_allocation_size(self) -> int: + """Get the total allocation size (including signal pad)""" + return self.allocation_size + + def get_usable_buffer_size(self) -> int: + """Get the usable buffer size (excluding signal pad)""" + return self.allocation_size - self.SIGNAL_PAD_SIZE + + def _alloc_mn_mcast_mem(self, buf_size: int): """Allocate multi-node multicast memory using MNNVL""" + self._verify_cuda_context() + + # Compute allocation size and get allocation properties + allocation_prop, mc_prop = self._get_allocation_prop(buf_size) + + # Allocate, exchange, and map unicast buffers + self._allocate_unicast_buffers(allocation_prop) - # Verify CUDA context + # Setup multicast object, exchange handles, map and bind memory + self._setup_multicast(mc_prop) + + def _verify_cuda_context(self): + """Verify CUDA context is set to the correct device.""" try: current_device = checkCudaErrors(cuda.cuCtxGetDevice()) - if int(current_device) != self.device_idx: print( f"CUDA context device mismatch! Current: {current_device}, Expected: {self.device_idx}" ) except Exception as e: print(f"Error checking CUDA context: {e}") - if comm_backend_for_handle_transfer is None: - comm = MpiComm() - else: - comm = comm_backend_for_handle_transfer - # Set up allocation properties - handle_type = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC + def _get_allocation_prop(self, buf_size: int): + """Compute allocation size and return allocation/multicast properties.""" allocation_prop = cuda.CUmemAllocationProp() - allocation_prop.requestedHandleTypes = handle_type + allocation_prop.requestedHandleTypes = self._exchanger.handle_type allocation_prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED allocation_prop.location = cuda.CUmemLocation() allocation_prop.location.type = ( cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE ) allocation_prop.location.id = self.device_idx - allocation_prop.allocFlags.gpuDirectRDMACapable = 1 # Get allocation granularity alloc_granularity = checkCudaErrors( cuda.cuMemGetAllocationGranularity( allocation_prop, - cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_MINIMUM, + cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_RECOMMENDED, ) ) - # mAllocationSize = roundUp(bufSize + kSIGNAL_PAD_SIZE, alloc_granularity); self.allocation_size = round_up( buf_size + self.SIGNAL_PAD_SIZE, alloc_granularity ) @@ -810,18 +1039,21 @@ def _alloc_mn_mcast_mem( mc_prop = cuda.CUmulticastObjectProp() mc_prop.numDevices = self.group_size mc_prop.size = self.allocation_size - mc_prop.handleTypes = handle_type + mc_prop.handleTypes = self._exchanger.handle_type - # Get multicast granularity - mc_granularity = checkCudaErrors( + # Get multicast granularity and adjust allocation size + self._mc_granularity = checkCudaErrors( cuda.cuMulticastGetGranularity( mc_prop, cuda.CUmulticastGranularity_flags.CU_MULTICAST_GRANULARITY_RECOMMENDED, ) ) + self.allocation_size = round_up(self.allocation_size, self._mc_granularity) - self.allocation_size = round_up(self.allocation_size, mc_granularity) + return allocation_prop, mc_prop + def _allocate_unicast_buffers(self, allocation_prop): + """Allocate local UC memory, exchange handles with peers, and map memory.""" # Initialize UC handles list self.uc_handles = [0] * self.group_size @@ -830,17 +1062,17 @@ def _alloc_mn_mcast_mem( cuda.cuMemCreate(self.allocation_size, allocation_prop, 0) ) - # Export local handle to fabric handle - my_fabric_handle = checkCudaErrors( + # Export local handle to shareable handle + local_shareable_uc_handle = checkCudaErrors( cuda.cuMemExportToShareableHandle( self.uc_handles[self.group_rank], - cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC, + self._exchanger.handle_type, 0, ) ) - # All-gather fabric handles - all_fabric_handles = comm.allgather(my_fabric_handle.data) + # All-gather shareable handles + all_shareable_uc_handles = self._exchanger.allgather(local_shareable_uc_handle) cuda.cuCtxSynchronize() # Import remote handles @@ -848,62 +1080,20 @@ def _alloc_mn_mcast_mem( if p != self.group_rank: self.uc_handles[p] = checkCudaErrors( cuda.cuMemImportFromShareableHandle( - all_fabric_handles[p], - cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC, + all_shareable_uc_handles[p], + self._exchanger.handle_type, ) ) - - # Initialize multicasting - if self.group_rank == 0: - # Create multicast object - self.mc_handle = checkCudaErrors(cuda.cuMulticastCreate(mc_prop)) - - # Export multicast handle - mc_fabric_handle = checkCudaErrors( - cuda.cuMemExportToShareableHandle( - self.mc_handle, - cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC, - 0, - ) - ) - else: - mc_fabric_handle = None - - # Broadcast multicast handle - mc_fabric_handle_data = comm.bcast( - mc_fabric_handle.data if mc_fabric_handle else None, root=0 - ) - # Sync device to ensure broadcast is complete - cuda.cuCtxSynchronize() - # Import multicast handle for non-root ranks - if self.group_rank != 0: - self.mc_handle = checkCudaErrors( - cuda.cuMemImportFromShareableHandle( - mc_fabric_handle_data, - cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC, - ) - ) - - # Add device to multicast - checkCudaErrors(cuda.cuMulticastAddDevice(self.mc_handle, self.device_idx)) - - # Bind memory addresses - self.uc_ptrs = [0] * self.group_size + self._exchanger.cleanup(all_shareable_uc_handles[p]) # Reserve address space for UC pointers + self.uc_ptrs = [0] * self.group_size total_uc_size = self.allocation_size * self.group_size self.total_uc_size = total_uc_size uc_base_ptr = checkCudaErrors( - cuda.cuMemAddressReserve(total_uc_size, mc_granularity, 0, 0) + cuda.cuMemAddressReserve(total_uc_size, self._mc_granularity, 0, 0) ) - self.uc_base_ptr = uc_base_ptr # Store for cleanup - - # Set up memory access descriptor - access_desc = cuda.CUmemAccessDesc() - access_desc.location = cuda.CUmemLocation() - access_desc.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE - access_desc.location.id = self.device_idx - access_desc.flags = cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE + self.uc_base_ptr = uc_base_ptr # Map UC memory for i in range(self.group_size): @@ -915,23 +1105,57 @@ def _alloc_mn_mcast_mem( ) ) - # Set memory access permissions + # Set memory access permissions for UC + access_desc = self._get_mem_access_desc() checkCudaErrors( cuda.cuMemSetAccess(uc_base_ptr, total_uc_size, [access_desc], 1) ) - # Bind MC pointer + def _setup_multicast(self, mc_prop): + """Create multicast object, exchange handle, map memory, and bind.""" + # Rank 0 creates the multicast object + if self.group_rank == 0: + self.mc_handle = checkCudaErrors(cuda.cuMulticastCreate(mc_prop)) + shareable_mc_handle = checkCudaErrors( + cuda.cuMemExportToShareableHandle( + self.mc_handle, + self._exchanger.handle_type, + 0, + ) + ) + else: + shareable_mc_handle = None + + # Broadcast multicast handle from rank 0 + shareable_mc_handle = self._exchanger.broadcast(shareable_mc_handle, root=0) + cuda.cuCtxSynchronize() + + # Import multicast handle for non-root ranks + if self.group_rank != 0: + self.mc_handle = checkCudaErrors( + cuda.cuMemImportFromShareableHandle( + shareable_mc_handle, + self._exchanger.handle_type, + ) + ) + self._exchanger.cleanup(shareable_mc_handle) + + # Add device to multicast + checkCudaErrors(cuda.cuMulticastAddDevice(self.mc_handle, self.device_idx)) + + # Reserve and map MC pointer self.mc_ptr = checkCudaErrors( - cuda.cuMemAddressReserve(self.allocation_size, mc_granularity, 0, 0) + cuda.cuMemAddressReserve(self.allocation_size, self._mc_granularity, 0, 0) ) checkCudaErrors( cuda.cuMemMap(self.mc_ptr, self.allocation_size, 0, self.mc_handle, 0) ) + access_desc = self._get_mem_access_desc() checkCudaErrors( cuda.cuMemSetAccess(self.mc_ptr, self.allocation_size, [access_desc], 1) ) - # Bind memory to multicast + # Bind local memory to multicast checkCudaErrors( cuda.cuMulticastBindMem( self.mc_handle, @@ -943,6 +1167,15 @@ def _alloc_mn_mcast_mem( ) ) + def _get_mem_access_desc(self): + """Create memory access descriptor for this device.""" + access_desc = cuda.CUmemAccessDesc() + access_desc.location = cuda.CUmemLocation() + access_desc.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE + access_desc.location.id = self.device_idx + access_desc.flags = cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE + return access_desc + def lamport_initialize(self, rank: int, dtype: torch.dtype): if dtype == torch.bfloat16 or dtype == torch.float16: neg_zero = 0x8000 @@ -955,8 +1188,8 @@ def lamport_initialize(self, rank: int, dtype: torch.dtype): else: raise ValueError(f"Unsupported dtype: {dtype}") - # Calculate number of elements that fit in allocation_size - num_elements = self.allocation_size // dsize + # Calculate number of elements that fit in allocation_size; We don't want to include the signal pad. + num_elements = (self.allocation_size - self.SIGNAL_PAD_SIZE) // dsize checkCudaErrors( memset_func(int(self.uc_ptrs[self.group_rank]), neg_zero, num_elements) @@ -984,7 +1217,7 @@ def __init__( Constructor for McastGpuBuffer. Args: - buf_size: The total size of the buffer in bytes + buf_size: The requested size of the buffer in bytes. The actual usable size may differ due to alignment requirements. group_size: The number of ranks in the communication group group_rank: The rank of the local process within the group device: The CUDA device for buffer allocation @@ -999,13 +1232,14 @@ def __init__( mn_nvlink, comm_backend_for_handle_transfer, ) - self.buf_size = buf_size + # Update buf_size to reflect the actual usable buffer size after allocation + self.buf_size = self.mcast_device_memory.get_usable_buffer_size() self.local_device = device def lamport_initialize(self, rank: int, dtype: torch.dtype): self.mcast_device_memory.lamport_initialize(rank, dtype) - def get_mc_buffer( + def get_multicast_buffer( self, sizes: tuple, dtype: torch.dtype, storage_offset: int = 0 ) -> torch.Tensor: """ @@ -1019,12 +1253,28 @@ def get_mc_buffer( Returns: A PyTorch tensor wrapping the multicast buffer section """ + + # FIXME: Is this needed? As the behavior of reading from mc_ptr is undefined. + raise NotImplementedError("Not implemented yet") + + def get_unicast_buffer( + self, sizes: tuple, dtype: torch.dtype, storage_offset: int = 0 + ) -> torch.Tensor: + """ + Returns a PyTorch tensor view of the unicast buffer portion. + """ + + # TODO: How can I warp a raw pointer to a tensor in python level? raise NotImplementedError("Not implemented yet") def get_multicast_ptr(self) -> int: """Get the raw multicast pointer""" return self.mcast_device_memory.get_multicast_ptr() + def get_unicast_ptr(self, rank: int) -> int: + """Get the raw unicast pointer to a given rank""" + return self.mcast_device_memory.get_unicast_ptr(rank) + def get_buffer_ptrs_dev(self) -> int: """Get the buffer pointers device array""" return self.mcast_device_memory.get_buffer_ptrs_dev() diff --git a/flashinfer/comm/trtllm_ar.py b/flashinfer/comm/trtllm_ar.py index 33bb7ac97b..87246f739a 100644 --- a/flashinfer/comm/trtllm_ar.py +++ b/flashinfer/comm/trtllm_ar.py @@ -804,6 +804,51 @@ def _should_use_oneshot( return comm_size_mb <= _use_oneshot_heuristics[world_size] +def check_trtllm_allreduce_fusion_workspace_metadata( + token_num: int, + hidden_dim: int, + world_size: int, + dtype: torch.dtype, + metadata: dict, +) -> None: + errors = [] + required_keys = ["max_token_num", "tp_size", "hidden_dim", "use_fp32_lamport"] + for key in required_keys: + if key not in metadata: + errors.append(f"Workspace metadata is missing required key: {key}") + if errors: + error_msg = "Workspace metadata validation failed:\n" + "\n".join( + f" - {e}" for e in errors + ) + raise ValueError(error_msg) + + # world_size must match tp_size (flag size depends on it) + if world_size != metadata["tp_size"]: + errors.append( + f"world_size ({world_size}) does not match workspace tp_size ({metadata['tp_size']}). " + f"Workspace was created for tp_size={metadata['tp_size']}." + ) + + # token_num * hidden_dim must not exceed max_token_num * hidden_dim + if token_num * hidden_dim > metadata["max_token_num"] * metadata["hidden_dim"]: + errors.append( + f"token_num ({token_num}) * hidden_dim ({hidden_dim}) exceeds workspace max_token_num ({metadata['max_token_num']}) * hidden_dim ({metadata['hidden_dim']}). " + f"This may cause Illegal Memory Access." + ) + + # use_fp32_lamport must match + if metadata["use_fp32_lamport"] != (dtype == torch.float32): + errors.append( + f"use_fp32_lamport ({metadata['use_fp32_lamport']}) does not match allreduce_in.dtype ({dtype}). " + f"Workspace was created for use_fp32_lamport={metadata['use_fp32_lamport']}." + ) + if errors: + error_msg = "Workspace validation failed:\n" + "\n".join( + f" - {e}" for e in errors + ) + raise ValueError(error_msg) + + def trtllm_allreduce_fusion( allreduce_in: torch.Tensor, world_size: int, @@ -858,50 +903,9 @@ def trtllm_allreduce_fusion( # Validate against workspace metadata if provided if metadata is not None: - errors = [] - required_keys = ["max_token_num", "tp_size", "hidden_dim", "use_fp32_lamport"] - for key in required_keys: - if key not in metadata: - errors.append(f"Workspace metadata is missing required key: {key}") - if errors: - error_msg = "Workspace metadata validation failed:\n" + "\n".join( - f" - {e}" for e in errors - ) - raise ValueError(error_msg) - - # Check 1: token_num must not exceed max_token_num - if token_num > metadata["max_token_num"]: - errors.append( - f"token_num ({token_num}) exceeds workspace max_token_num ({metadata['max_token_num']}). " - f"This may cause Illegal Memory Access." - ) - - # Check 2: world_size must match tp_size - if world_size != metadata["tp_size"]: - errors.append( - f"world_size ({world_size}) does not match workspace tp_size ({metadata['tp_size']}). " - f"Workspace was created for tp_size={metadata['tp_size']}." - ) - - # Check 3: hidden_dim must match - if hidden_dim != metadata["hidden_dim"]: - errors.append( - f"hidden_dim ({hidden_dim}) does not match workspace hidden_dim ({metadata['hidden_dim']}). " - f"Workspace was created for hidden_dim={metadata['hidden_dim']}." - ) - - # Check 4: use_fp32_lamport must match - if metadata["use_fp32_lamport"] != (allreduce_in.dtype == torch.float32): - errors.append( - f"use_fp32_lamport ({metadata['use_fp32_lamport']}) does not match allreduce_in.dtype ({allreduce_in.dtype}). " - f"Workspace was created for use_fp32_lamport={metadata['use_fp32_lamport']}." - ) - - if errors: - error_msg = "Workspace validation failed:\n" + "\n".join( - f" - {e}" for e in errors - ) - raise ValueError(error_msg) + check_trtllm_allreduce_fusion_workspace_metadata( + token_num, hidden_dim, world_size, allreduce_in.dtype, metadata + ) if use_oneshot is None: use_oneshot = _should_use_oneshot( diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index 84a9c150de..c236020185 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -5,17 +5,20 @@ import functools import math -import os +import logging from types import SimpleNamespace from typing import Optional, Tuple +from enum import Enum import torch +from typing_extensions import deprecated from flashinfer.comm.mapping import Mapping from ..jit import gen_trtllm_mnnvl_comm_module from ..utils import register_custom_op -from .mnnvl import McastGPUBuffer, CommBackend +from .mnnvl import McastGPUBuffer, CommBackend, MPIBackend +from .workspace_base import AllReduceFusionWorkspace def mpi_barrier(): @@ -25,102 +28,474 @@ def mpi_barrier(): MPI.COMM_WORLD.Barrier() +class MNNVLAllreduceFusionStrategy(Enum): + ONESHOT = 0 + TWOSHOT = 1 + AUTO = 99 + + @staticmethod + def select_strategy( + tp_size: int, num_tokens: int, hidden_dim: int, dtype: torch.dtype + ) -> "MNNVLAllreduceFusionStrategy": + elem_size = torch.tensor([], dtype=dtype).element_size() + if num_tokens * hidden_dim * tp_size * elem_size <= MNNVL_ONE_SHOT_THRESHOLD: + return MNNVLAllreduceFusionStrategy.ONESHOT + else: + return MNNVLAllreduceFusionStrategy.TWOSHOT + + +# Empirical result calculated from num_tokens * hidden_dim * tp_size * elem_size +MNNVL_ONE_SHOT_THRESHOLD = 64 * 1024 * 8 * 2 + + +class MNNVLAllReduceFusionWorkspace(AllReduceFusionWorkspace): + NUM_LAMPORT_BUFFERS = 3 + + def __init__( + self, + mapping: Mapping, + max_num_tokens: Optional[int] = None, + hidden_dim: Optional[int] = None, + dtype: Optional[torch.dtype] = None, + buffer_size_in_bytes: Optional[int] = None, + comm_backend: Optional[CommBackend] = None, + ): + """ + Initialize the MNNVL Allreduce Fusion Workspace. The workspace will be allocated and initialized based on the provided problem size. If max_num_tokens is larger than the one-shot threshold, the workspace will be created according to the max of required one-shot size at threshold, or the required two-shot size. Note that the workspace is not bind to the given problem size. It can be reused for different problem size without reinitialization given the allocated size is sufficient. + + If the buffer_size_in_bytes is provided, the workspace will be created according to the provided size. The user is expected to use the utility function get_required_buffer_size_bytes to calculate the required size. The actual allocation size may be larger due to alignment requirements. This covers the advanced used case, for example, the user may want to enforce oneshot strategy and ignore the heuristics. + + Either max_num_tokens or buffer_size_in_bytes must be provided. + + comm_backend will be used for creating the workspace and synchronization. If not provided, MPIBackend will be used which will use COMM_WORLD for synchronization. + + Args: + mapping: Mapping configuration containing rank info + max_num_tokens: The maximum number of tokens in the input tensor. + hidden_dim: The hidden dimension of the tensors to be reduced. + dtype: The data type of the tensors to be reduced. + buffer_size_in_bytes: The requested size in bytes for each lamport buffer. The actual allocation size may be larger due to alignment requirements. The actual usable size will be NUM_LAMPORT_BUFFERS * actual_buffer_size_per_lamport_buffer. + """ + super().__init__(mapping.world_size, mapping.rank) + + print("Allocating MNNVL Allreduce Fusion Workspace...") + if buffer_size_in_bytes is None: + assert ( + max_num_tokens is not None + and hidden_dim is not None + and dtype is not None + ), ( + "max_num_tokens, hidden_dim, and dtype must be provided if buffer_size_in_bytes is not provided." + ) + + # If the user want to explictly use one-shot pass the threshold, which requires larger workspace size, + # We expect the user to set workspace size manually. + elem_size = torch.tensor([], dtype=dtype).element_size() + oneshot_max_num_tokens = min( + MNNVL_ONE_SHOT_THRESHOLD // (mapping.tp_size * elem_size * hidden_dim), + max_num_tokens, + ) + one_shot_size_bytes = self.get_required_buffer_size_bytes( + mapping.tp_size, + oneshot_max_num_tokens, + hidden_dim, + dtype, + MNNVLAllreduceFusionStrategy.ONESHOT, + ) + two_shot_size_bytes = self.get_required_buffer_size_bytes( + mapping.tp_size, + max_num_tokens, + hidden_dim, + dtype, + MNNVLAllreduceFusionStrategy.TWOSHOT, + ) + # We don't do roundup here as it will happen at the allocation. + buffer_size_in_bytes = max(one_shot_size_bytes, two_shot_size_bytes) + else: + logging.debug( + f"[MNNVL Allreduce] Using provided buffer size override in bytes: {buffer_size_in_bytes} bytes." + ) + + if comm_backend is None: + comm_backend = MPIBackend() + if buffer_size_in_bytes > (2**32 - 1): + raise ValueError( + f"The buffer size in bytes {buffer_size_in_bytes} is greater than the maximum supported size (UINT32_MAX)." + ) + + # Calculate total requested workspace size + requested_workspace_size = buffer_size_in_bytes * self.NUM_LAMPORT_BUFFERS + + self.rank = mapping.tp_rank + self.tp_size = mapping.tp_size + logging.debug( + f"[MNNVL Allreduce] TP size: {mapping.tp_size}, rank: {mapping.tp_rank}, Allocating workspace with requested size {buffer_size_in_bytes} bytes per buffer." + ) + + # Allocate the workspace + self.mcast_buffer_handle = McastGPUBuffer( + requested_workspace_size, + mapping.tp_size, + mapping.tp_rank, + torch.device("cuda", mapping.local_rank), + mapping.is_multi_node(), + comm_backend, + ) + + # Get the actual usable buffer size after allocation (buf_size is updated by McastGPUBuffer) + allocated_size = self.mcast_buffer_handle.buf_size + # We want the buffer size to be aligned to 16B which is the granularity for buffer management. + self.buffer_size_bytes = ( + math.floor(allocated_size / self.NUM_LAMPORT_BUFFERS) // 16 * 16 + ) + # This workspace size is used for checking the buffer. We need to set it to the actual size in use. The buffer free logic does not rely on this size. + self.workspace_size_bytes = self.buffer_size_bytes * self.NUM_LAMPORT_BUFFERS + + logging.debug( + f"[MNNVL Allreduce] Actual allocated size: {allocated_size} bytes, Actual buffer size per lamport buffer: {self.buffer_size_bytes} bytes, total workspace: {self.workspace_size_bytes} bytes." + ) + + # We use FP32 for sentinel value regardless of the real dtype + self.mcast_buffer_handle.lamport_initialize(mapping.tp_rank, torch.float32) + # Wait until the initialization is done + torch.cuda.synchronize() + comm_backend.barrier() + + # This is a buffer to maintain the state of this allreduce Op + # Should have the same lifetime with self._buffer + # The flag should be binded to each buffer allocation + # Layout: [cur idx, dirty idx, bytes per buffer, dirty num stages, numBytesToClear[4], access count ptr] + num_bytes_to_clear = [0] * 4 + self.buffer_flags = torch.tensor( + [0, 2, self.buffer_size_bytes, 0, *num_bytes_to_clear, 0], + dtype=torch.uint32, + device=torch.device("cuda", mapping.local_rank), + ) + + self.uc_ptrs_dev = self.mcast_buffer_handle.get_buffer_ptrs_dev() + self.uc_ptr_local = self.mcast_buffer_handle.get_unicast_ptr(self.rank) + self.mc_ptr = self.mcast_buffer_handle.get_multicast_ptr() + + @functools.cache + def is_buffer_size_sufficient( + self, + tp_size: int, + num_tokens: int, + hidden_dim: int, + dtype: torch.dtype, + strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, + ) -> bool: + """ + Calculate the required buffer size for a given problem size. + """ + required_buffer_size = self.get_required_buffer_size_bytes( + tp_size, num_tokens, hidden_dim, dtype, strategy + ) + if required_buffer_size > self.buffer_size_bytes: + return False + else: + return True + + @staticmethod + @functools.cache + def get_required_buffer_size_bytes( + tp_size: int, + num_tokens: int, + hidden_dim: int, + dtype: torch.dtype, + strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, + ) -> int: + """ + Calculate the required buffer size for a given problem size. + """ + elem_size = torch.tensor([], dtype=dtype).element_size() + if strategy == MNNVLAllreduceFusionStrategy.AUTO: + strategy = MNNVLAllreduceFusionStrategy.select_strategy( + tp_size, num_tokens, hidden_dim, dtype + ) + + if strategy == MNNVLAllreduceFusionStrategy.ONESHOT: + # For one-shot, each rank needs to store num_tokens * tp_size tokens + buffer_size = num_tokens * hidden_dim * tp_size * elem_size + else: + # For two-shot, each rank stores a slices of tokens. We need to round up to the nearest tp_size. + # 2 Stage is required for the two-shot allreduce. + buffer_size = ( + 2 * math.ceil(num_tokens / tp_size) * tp_size * hidden_dim * elem_size + ) + return buffer_size + + @property + def backend(self) -> str: + return "mnnvl" + + def destroy(self) -> None: + """Destroy workspace and free resources.""" + if self._destroyed: + return # Already destroyed, nothing to do + + # TODO: Implement proper cleanup of mcast_buffer_handle if needed + self._destroyed = True + + @functools.cache def get_trtllm_mnnvl_comm_module(): module = gen_trtllm_mnnvl_comm_module().build_and_load() @register_custom_op( - "flashinfer::trtllm_mnnvl_all_reduce", + "flashinfer::trtllm_mnnvl_allreduce_fusion", mutates_args=[ - "inp", + "input", "multicast_buffer_ptr", "buffer_ptrs_dev", - "buffer_mnnvl", + "buffer_ptr_local", "buffer_flags_mnnvl", "nranks", "rank", - "wait_for_results", + "rmsnorm_fusion", "launch_with_pdl", - "out", + "use_oneshot", + "output", + "residual_out", + "residual_in", + "gamma", + "epsilon", ], ) - def trtllm_mnnvl_all_reduce( - inp: torch.Tensor, + def trtllm_mnnvl_allreduce_fusion( + input: torch.Tensor, multicast_buffer_ptr: int, # Pointer address as integer buffer_ptrs_dev: int, # Pointer address as integer - buffer_mnnvl: torch.Tensor, + buffer_ptr_local: int, # Pointer address as integer buffer_flags_mnnvl: torch.Tensor, nranks: int, rank: int, - wait_for_results: bool, + rmsnorm_fusion: bool, launch_with_pdl: bool, - out: Optional[torch.Tensor], + use_oneshot: bool, + output: torch.Tensor, + residual_out: Optional[torch.Tensor], + residual_in: Optional[torch.Tensor], + gamma: Optional[torch.Tensor], + epsilon: Optional[float], ) -> None: - module.trtllm_mnnvl_all_reduce( - inp, + """ + Perform a multi-node NVLink all-reduce operation with fusion. + Args: + input: Input tensor + multicast_buffer_ptr: Pointer to the multicast buffer as an integer + buffer_ptrs_dev: Pointer to the device array of buffer pointers as an integer + buffer_ptr_local: Pointer to local buffer as an integer + buffer_flags_mnnvl: Buffer flags tensor for synchronization + nranks: Total number of ranks participating in the all-reduce + rank: Current process rank + rmsnorm_fusion: Whether to perform RMSNorm fusion + launch_with_pdl: Whether to launch with PDL + use_oneshot: Whether to use one-shot (true) or two-shot (false) + output: Output tensor + residual_out: Residual output tensor (if rmsnorm) + gamma: Gamma tensor (if rmsnorm) + epsilon: Epsilon value (if rmsnorm) + """ + module.trtllm_mnnvl_allreduce_fusion( + input, multicast_buffer_ptr, buffer_ptrs_dev, - buffer_mnnvl, + buffer_ptr_local, buffer_flags_mnnvl, nranks, rank, - wait_for_results, + rmsnorm_fusion, launch_with_pdl, - out, + use_oneshot, + output, + residual_out, + residual_in, + gamma, + epsilon, ) - @register_custom_op( - "flashinfer::trtllm_mnnvl_rmsnorm", - mutates_args=[ - "mcast_buffer_input", - "prenorm_output", - "normed_output", - "gamma", - "epsilon", - "residual", - "buffer_flags", - "launch_with_pdl", - ], + return SimpleNamespace( + trtllm_mnnvl_allreduce_fusion=trtllm_mnnvl_allreduce_fusion, ) - def trtllm_mnnvl_rmsnorm( - mcast_buffer_input: int, - prenorm_output: torch.Tensor, - normed_output: torch.Tensor, - gamma: torch.Tensor, - epsilon: float, - residual: torch.Tensor, - buffer_flags: torch.Tensor, - launch_with_pdl: bool, - ) -> None: - """Performs MNNVL TwoShot RMSNorm on the communication buffer. - Args: - prenorm_output: Output tensor for prenorm results - normed_output: Output tensor for normalized results - mcast_buffer_input: Input tensor - gamma: The gamma parameter for RMSNorm - epsilon: The epsilon parameter for RMSNorm - residual: The residual tensor to add - buffer_flags: Buffer flags for synchronization - launch_with_pdl: Whether to launch with PDL - """ - return module.trtllm_mnnvl_rmsnorm( - mcast_buffer_input, - prenorm_output, - normed_output, - gamma, - epsilon, - residual, - buffer_flags, - launch_with_pdl, + +def trtllm_mnnvl_allreduce( + input: torch.Tensor, + workspace: MNNVLAllReduceFusionWorkspace, + launch_with_pdl: bool, + output: Optional[torch.Tensor] = None, + strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, +) -> torch.Tensor: + """Perform a multi-node NVLink all-reduce operation across multiple GPUs. + + This function performs an all-reduce (sum) operation using NVIDIA's multi-node NVLink (MNNVL) + technology to efficiently combine tensors across multiple GPUs and nodes. + + There are 2 variants: One-shot and Two-shot: + - One-shot: Each rank stores local shard to all other ranks. Each ranks will receive all shards at the end of the communication round and perfom local reduction. Suitable for small data size and is optimized for low latency. + - Two-shot: There will be 3 steps: + 1. Scatter each GPU's input shard to other ranks. Each rank will received all shards of a slice of tokens. + 2. Each rank perform reduction on the local tokens. + 3. Each rank broadcast the result to all ranks. + Suitable for large data size and is optimized for balancing throughput and latency. + + Args: + input: Local Input Shard [num_tokens, hidden_dim] + workspace: MNNVLAllReduceFusionWorkspace + launch_with_pdl: Whether to launch with PDL + output: Output tensor to store the result, empty tensor will be created if not provided. + strategy: MNNVLAllreduceFusionStrategy. Internal heuristics will be used if not provided. + Returns: + output: Reduced tensor [num_tokens, hidden_dim] + """ + + # Check ndims here as the shape check is done in the kernel launch code. + if len(input.shape) != 2: + raise ValueError( + f"The input tensor must be 2D, got {len(input.shape)}D. The shape is {input.shape}." ) - return SimpleNamespace( - trtllm_mnnvl_all_reduce=trtllm_mnnvl_all_reduce, - trtllm_mnnvl_rmsnorm=trtllm_mnnvl_rmsnorm, + if output is None: + output = torch.empty_like(input) + elif len(output.shape) != 2: + raise ValueError( + f"The output tensor must be 2D, got {len(output.shape)}D. The shape is {output.shape}." + ) + + module = get_trtllm_mnnvl_comm_module() + + if strategy == MNNVLAllreduceFusionStrategy.AUTO: + strategy = MNNVLAllreduceFusionStrategy.select_strategy( + workspace.tp_size, input.shape[0], input.shape[1], input.dtype + ) + + if not workspace.is_buffer_size_sufficient( + workspace.tp_size, input.shape[0], input.shape[1], input.dtype, strategy + ): + raise ValueError( + f"The buffer size in the given workspace is insufficient for the given problem size. Buffer: {workspace.buffer_size_bytes} bytes, Required: {workspace.get_required_buffer_size_bytes(workspace.tp_size, input.shape[0], input.shape[1], input.dtype, strategy)} bytes." + ) + + module.trtllm_mnnvl_allreduce_fusion( + input, + workspace.mc_ptr, + workspace.uc_ptrs_dev, + workspace.uc_ptr_local, + workspace.buffer_flags, + workspace.tp_size, + workspace.rank, + False, # No RMSNorm Fusion + launch_with_pdl, + strategy == MNNVLAllreduceFusionStrategy.ONESHOT, + output, + None, + None, + None, + None, ) + return output + +def trtllm_mnnvl_fused_allreduce_add_rmsnorm( + input: torch.Tensor, + residual_in: torch.Tensor, + gamma: torch.Tensor, + workspace: MNNVLAllReduceFusionWorkspace, + epsilon: Optional[float] = None, + output: Optional[torch.Tensor] = None, + residual_out: Optional[torch.Tensor] = None, + launch_with_pdl: bool = False, + strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Performs MNNVL Allreduce + Residual + RMSNorm. + + This function performs a multi-node all-reduce (sum) operation by first calling trtllm_mnnvl_allreduce on the shard_input. + After this, it performs residual addition and RMSNorm on the all-reduced result, reading it directly from the multicast buffer. + Note: multicast buffer is the same as the unicast buffer for the current rank. + + Args: + input: Input tensor [num_tokens, hidden_dim] + residual_in: Residual input tensor [num_tokens, hidden_dim] + gamma: Gamma tensor [hidden_dim] + workspace: MNNVLAllReduceFusionWorkspace + epsilon: The epsilon parameter for RMSNorm, torch.finfo.eps will be used if not provided. + output: Output tensor for normalized results [num_tokens, hidden_dim], empty tensor will be created if not provided. + residual_out: Residual output tensor [num_tokens, hidden_dim], empty tensor will be created if not provided. + launch_with_pdl: Whether to launch with PDL + strategy: MNNVLAllreduceFusionStrategy. Internal heuristics will be used if not provided. + + Returns: + output: Add-residual and normalized tensor [num_tokens, hidden_dim] + residual_out: Add-residual tensor [num_tokens, hidden_dim] + """ + + if epsilon is None: + epsilon = torch.finfo(input.dtype).eps + + if len(input.shape) != 2: + raise ValueError( + f"The input tensor must be 2D, got {len(input.shape)}D. The shape is {input.shape}." + ) + if len(residual_in.shape) != 2: + raise ValueError( + f"The residual input tensor must be 2D, got {len(residual_in.shape)}D. The shape is {residual_in.shape}." + ) + if gamma.numel() != input.shape[1]: + raise ValueError( + f"The gamma tensor must have the same number of elements as the hidden dimension, got {gamma.numel()} elements but expected {input.shape[1]} elements." + ) + if output is None: + output = torch.empty_like(input) + elif len(output.shape) != 2: + raise ValueError( + f"The output tensor must be 2D, got {len(output.shape)}D. The shape is {output.shape}." + ) + if residual_out is None: + residual_out = torch.empty_like(residual_in) + elif len(residual_out.shape) != 2: + raise ValueError( + f"The residual output tensor must be 2D, got {len(residual_out.shape)}D. The shape is {residual_out.shape}." + ) + + module = get_trtllm_mnnvl_comm_module() + + if strategy == MNNVLAllreduceFusionStrategy.AUTO: + strategy = MNNVLAllreduceFusionStrategy.select_strategy( + workspace.tp_size, input.shape[0], input.shape[1], input.dtype + ) + if not workspace.is_buffer_size_sufficient( + workspace.tp_size, input.shape[0], input.shape[1], input.dtype, strategy + ): + raise ValueError( + f"The buffer size in the given workspace is insufficient for the given problem size. Buffer: {workspace.buffer_size_bytes} bytes, Required: {workspace.get_required_buffer_size_bytes(workspace.tp_size, input.shape[0], input.shape[1], input.dtype, strategy)} bytes." + ) + + module.trtllm_mnnvl_allreduce_fusion( + input, + workspace.mc_ptr, + workspace.uc_ptrs_dev, + workspace.uc_ptr_local, + workspace.buffer_flags, + workspace.tp_size, + workspace.rank, + True, # RMSNorm Fusion + launch_with_pdl, + strategy == MNNVLAllreduceFusionStrategy.ONESHOT, + output, + residual_out, + residual_in, + gamma, + epsilon, + ) + return output, residual_out + + +# Legacy API that has been deprecated; Left for backward compatibility +@deprecated( + "get_allreduce_mnnvl_workspace is deprecated, use MNNVLAllReduceFusionWorkspace class to manage the workspace instead" +) def get_allreduce_mnnvl_workspace( mapping: Mapping, dtype: torch.dtype, @@ -141,7 +516,6 @@ def get_allreduce_mnnvl_workspace( Args: mapping: Tensor parallel mapping configuration containing rank info dtype: Data type of the tensors being reduced - comm: Optional communication backend for multi-node synchronization buffer_size_in_bytes: Optional buffer size. Practically, assign this to 3 * 2 * dtype.itemsize * hidden_dim * max_tokens Returns: @@ -150,8 +524,6 @@ def get_allreduce_mnnvl_workspace( - torch.Tensor: Buffer flags tensor tracking state - int: Maximum number of elements that can fit in buffer """ - force_mn = os.environ.get("TRTLLM_FORCE_MNNVL_AR", "0") == "1" - # buffer shape: [3, 2, buffer_tokens, hidden_dim] stride = 3 * 2 * dtype.itemsize # LCM for hidden_dim: 2048, 4096, 5120, 7168, 8192 = 286720 @@ -163,35 +535,19 @@ def get_allreduce_mnnvl_workspace( buffer_size_in_bytes = math.ceil( TARGET_WORKSPACE_SIZE_BYTES / (lcm_hidden_dim * stride) ) * (lcm_hidden_dim * stride) - max_num_elements = buffer_size_in_bytes // stride - - mcast_buffer = McastGPUBuffer( - buffer_size_in_bytes, - mapping.tp_size, - mapping.tp_rank, - torch.device("cuda", mapping.local_rank), - mapping.is_multi_node() or force_mn, - comm_backend_for_handle_transfer=comm_backend_for_handle_transfer, - ) - # Initialize the unicast buffer with -0.0 - mcast_buffer.lamport_initialize(mapping.tp_rank, dtype) - - # CPU barrier since we assume this should not be called in cuda graph - torch.cuda.synchronize() - if comm_backend_for_handle_transfer is None: - mpi_barrier() - else: - comm_backend_for_handle_transfer.barrier() - - # This is a buffer to maintain the state of this allreduce Op - # [Buffer_ptr, Clear_ptr, Buffer_size, num_tokens_prev, atomic access counter] - buffer_flags = torch.tensor( - [0, 2, max_num_elements, 0, 0], - dtype=torch.uint32, - device=torch.device("cuda", mapping.local_rank), + # Redirect to the new workspace allocation logic. The new kernel needs the new flag buffer layout. + workspace = MNNVLAllReduceFusionWorkspace( + mapping, + buffer_size_in_bytes=buffer_size_in_bytes, + comm_backend=comm_backend_for_handle_transfer, ) + mcast_buffer = workspace.mcast_buffer_handle + buffer_flags = workspace.buffer_flags + # this is calculated using the legacy behavior. We do not use the actual allocated size. + max_num_elements = workspace.buffer_size_bytes // stride + return ( mcast_buffer, buffer_flags, @@ -199,6 +555,9 @@ def get_allreduce_mnnvl_workspace( ) +@deprecated( + "trtllm_mnnvl_all_reduce is deprecated, use trtllm_mnnvl_allreduce instead. This function will be removed in the future." +) def trtllm_mnnvl_all_reduce( inp: torch.Tensor, multicast_buffer_ptr: int, # Pointer address as integer @@ -240,26 +599,39 @@ def trtllm_mnnvl_all_reduce( f"The input tensor must be 2D, got {len(inp.shape)}D. The shape is {inp.shape}." ) + # buffer_M is no longer used in this kernel but let's keep this check for consistency in behavior. if inp.shape[0] > buffer_M: raise ValueError( f"The number of tokens in the input tensor {inp.shape[0]} is greater than the buffer_M {buffer_M}. This is not supported. Please increase the workspace size, or decrease the amount of tokens to at most {buffer_M}." ) + # Even in legacy code, this should only be used when we implement the fused allreduce+rmsnorm. + assert wait_for_results and (out is not None), ( + "Calling the legacy trtllm_mnnvl_all_reduce with wait_for_results=False is not supported. Please use trtllm_mnnvl_allreduce instead." + ) module = get_trtllm_mnnvl_comm_module() - module.trtllm_mnnvl_all_reduce( + module.trtllm_mnnvl_allreduce_fusion( inp, multicast_buffer_ptr, - int(buffer_ptrs_dev), - buffer_M, + buffer_ptrs_dev, + 0, # Allreduce kernel itself does not use this local pointer; still this could be risky but it is only used for legacy code compatibility. buffer_flags_mnnvl, nranks, rank, - wait_for_results, + False, # No RMSNorm Fusion launch_with_pdl, + False, # Use two-shot out, + None, + None, + None, + None, ) +@deprecated( + "trtllm_mnnvl_fused_allreduce_rmsnorm is deprecated, use trtllm_mnnvl_fused_allreduce_add_rmsnorm instead. This function will be removed in the future." +) def trtllm_mnnvl_fused_allreduce_rmsnorm( prenorm_output: torch.Tensor, normed_output: torch.Tensor, @@ -299,30 +671,52 @@ def trtllm_mnnvl_fused_allreduce_rmsnorm( launch_with_pdl: Whether to launch with PDL """ - # allreduce_result = Σ(shard_input across all ranks) - trtllm_mnnvl_all_reduce( + if len(shard_input.shape) != 2: + raise ValueError( + f"The input tensor must be 2D, got {len(shard_input.shape)}D. The shape is {shard_input.shape}." + ) + + # buffer_M is no longer used in this kernel but let's keep this check for consistency in behavior. + if shard_input.shape[0] > buffer_M: + raise ValueError( + f"The number of tokens in the input tensor {shard_input.shape[0]} is greater than the buffer_M {buffer_M}. This is not supported. Please increase the workspace size, or decrease the amount of tokens to at most {buffer_M}." + ) + + if len(residual.shape) != 2: + raise ValueError( + f"The residual input tensor must be 2D, got {len(residual.shape)}D. The shape is {residual.shape}." + ) + if gamma.numel() != shard_input.shape[1]: + raise ValueError( + f"The gamma tensor must have the same number of elements as the hidden dimension, got {gamma.numel()} elements but expected {shard_input.shape[1]} elements." + ) + + if len(normed_output.shape) != 2: + raise ValueError( + f"The output tensor must be 2D, got {len(normed_output.shape)}D. The shape is {normed_output.shape}." + ) + + if len(prenorm_output.shape) != 2: + raise ValueError( + f"The prenorm output tensor must be 2D, got {len(prenorm_output.shape)}D. The shape is {prenorm_output.shape}." + ) + + module = get_trtllm_mnnvl_comm_module() + + module.trtllm_mnnvl_allreduce_fusion( shard_input, multicast_buffer_ptr, buffer_ptrs_dev, - buffer_M, + unicast_ptr, buffer_flags_mnnvl, nranks, rank, - False, # No need to wait to write AR results here as we are not writing them + True, # RMSNorm Fusion launch_with_pdl, - None, # out parameter - None since wait_for_results=False - ) - - # prenorm_output = AllReduce(shard_input) + residual - # rms = sqrt(mean(prenorm_output²) + epsilon) - # normed_output = (prenorm_output / rms) * gamma - get_trtllm_mnnvl_comm_module().trtllm_mnnvl_rmsnorm( - unicast_ptr, - prenorm_output, + False, normed_output, + prenorm_output, + residual, gamma, epsilon, - residual, - buffer_flags_mnnvl, - launch_with_pdl, ) diff --git a/flashinfer/comm/workspace_base.py b/flashinfer/comm/workspace_base.py new file mode 100644 index 0000000000..5de8d07483 --- /dev/null +++ b/flashinfer/comm/workspace_base.py @@ -0,0 +1,89 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from abc import ABC, abstractmethod +from typing import Optional, Any + +import torch + + +class AllReduceFusionWorkspace(ABC): + """Base class for AllReduce fusion workspaces.""" + + # Explicit type annotations for mypy (needed due to __getattr__ in subclasses) + world_size: int + rank: int + _destroyed: bool + + def __init__(self, world_size: int, rank: int): + self.world_size = world_size + self.rank = rank + self._destroyed = False + + @property + @abstractmethod + def backend(self) -> str: + """Return backend name.""" + pass + + @abstractmethod + def destroy(self) -> None: + """ + Destroy workspace and free resources. + + This should be called explicitly when done using the workspace. + Prefer using AllReduceFusionContext context manager for automatic cleanup. + """ + pass + + @abstractmethod + def is_buffer_size_sufficient( + self, + tp_size: int, + num_tokens: int, + hidden_dim: int, + dtype: torch.dtype, + use_oneshot: Optional[Any] = None, + ) -> bool: + pass + + def __del__(self): + """ + Destructor - safety net if destroy() wasn't called explicitly. + + Warns if cleanup wasn't done properly. Not recommended to rely on this + as __del__ timing is non-deterministic and can cause issues with + distributed/CUDA resources. + """ + if not self._destroyed: + import warnings + + warnings.warn( + f"{self.__class__.__name__} was not explicitly destroyed. " + f"Call workspace.destroy() or use AllReduceFusionContext to ensure " + f"proper cleanup of distributed/CUDA resources.", + ResourceWarning, + stacklevel=2, + ) + try: + self.destroy() + except Exception as e: + # Can't raise in __del__, just warn + warnings.warn( + f"Error during automatic cleanup of {self.__class__.__name__}: {e}", + ResourceWarning, + stacklevel=2, + ) diff --git a/include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh b/include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh index 3dbed4b649..2177cfc618 100644 --- a/include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh +++ b/include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh @@ -18,52 +18,54 @@ #include #include #include +#include #include +#include #include "../exception.h" #include "../logging.h" +#include "../utils.cuh" namespace flashinfer { namespace trtllm_mnnvl_allreduce { -template -struct AllReduceParams { - int nranks; +struct AllReduceFusionParams { + int nRanks; int rank; - int buffer_M; - int num_tokens; - int token_dim; - void** buffer_ptrs_dev; - void* multicast_ptr; - void* buffer_flags; - bool wait_for_results; - bool launch_with_pdl; - - void* input; - void* output; - cudaStream_t stream; -}; + int numTokens; + int tokenDim; + void** bufferPtrsDev; + void* bufferPtrLocal; + void* multicastPtr; + uint32_t* bufferFlags; + bool rmsNormFusion; + bool launchWithPdl; -template -struct RMSNormParams { - void* residual_output; - void* output; void const* input; + void const* residualIn; void const* gamma; double epsilon; - void* residual; - uint32_t* buffer_flags; - int batch; - int hidden_dim; - cudaStream_t stream; - bool launch_with_pdl; + + void* residualOut; + void* output; + cudaStream_t stream = nullptr; }; -__device__ bool isNegZero(float v) { return v == 0.f && signbit(v); } +namespace utils { + +constexpr uint16_t kNEGZERO_FP16 = 0x8000U; + +template +union Fp16BitCast { + T mFp; + uint16_t mInt; + + constexpr Fp16BitCast() : mInt(0) {} -__device__ bool isNegZero(__nv_bfloat16 val) { return isNegZero(__bfloat162float(val)); } + constexpr Fp16BitCast(T val) : mFp(val) {} -__device__ bool isNegZero(__nv_half val) { return isNegZero(__half2float(val)); } + constexpr Fp16BitCast(uint16_t val) : mInt(val) {} +}; template inline __device__ float toFloat(T val) { @@ -74,7 +76,6 @@ template <> inline __device__ float toFloat<__nv_bfloat16>(__nv_bfloat16 val) { return __bfloat162float(val); } - template <> inline __device__ float toFloat<__nv_half>(__nv_half val) { return __half2float(val); @@ -95,581 +96,1126 @@ inline __device__ __nv_half fromFloat<__nv_half>(float val) { return __float2half(val); } -inline __device__ float2 loadfloat2(void const* ptr) { - float2 return_value; - asm volatile("ld.volatile.global.v2.f32 {%0, %1}, [%2];\n" - : "=f"(return_value.x), "=f"(return_value.y) - : "l"(ptr)); - return return_value; +template +static constexpr __device__ __host__ T negZero() { + if constexpr (std::is_same_v) { + return -0.0F; + } else if constexpr (std::is_same_v || std::is_same_v) { + return Fp16BitCast(kNEGZERO_FP16).mFp; + } else { + static_assert(sizeof(T) == 0, "negativeZero not specialized for this type"); + } + return T{}; // Never reached, but needed for compilation } template -inline __device__ T divUp(T val, T divisor) { - return (val + divisor - 1) / divisor; +static inline __device__ bool isNegZero(T val) { + if constexpr (std::is_same_v) { + return val == 0.F && signbit(val); + } else if constexpr (std::is_same_v || std::is_same_v) { + return Fp16BitCast(val).mInt == kNEGZERO_FP16; + } else { + static_assert(sizeof(T) == 0, "isNegZero not specialized for this type"); + } + return false; // Never reached, but needed for compilation } -__device__ struct __attribute__((aligned(32))) LamportFlags { - uint32_t buffer_size; - uint32_t input_offset; - uint32_t clear_offset; - uint32_t num_tokens_prev; - uint32_t* offset_access_ptr; - uint32_t* buffer_flags; - - __device__ explicit LamportFlags(uint32_t* buffer_flags) - : offset_access_ptr(&buffer_flags[4]), buffer_flags(buffer_flags) { - uint4 flag = reinterpret_cast(buffer_flags)[0]; - buffer_size = flag.z; - input_offset = flag.x * (buffer_size << 1U); - clear_offset = flag.y * (buffer_size << 1U); - num_tokens_prev = flag.w; - } - - __device__ void cta_arrive() { - __syncthreads(); - if (threadIdx.x == 0) { -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) - asm volatile("red.async.release.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) - : "memory"); -#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("red.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory"); -#else - atomicAdd(offset_access_ptr, 1); -#endif - } - } +template +constexpr __device__ __host__ PackedType getPackedLamportInit() { + static_assert(sizeof(PackedType) % sizeof(T) == 0, "PackedType size must be divisible by T size"); + constexpr int kNumElements = sizeof(PackedType) / sizeof(T); - __device__ void wait_and_update(uint32_t num_tokens) { - if (threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == 0) { - while (*reinterpret_cast(offset_access_ptr) < gridDim.x * gridDim.y) { + union PackedT { + PackedType mPacked; + std::array mElements; + + constexpr PackedT() : mElements{} { + for (int i = 0; i < kNumElements; i++) { + mElements[i] = negZero(); } - uint4 flag = reinterpret_cast(buffer_flags)[0]; - buffer_flags[0] = (flag.x + 1) % 3; - buffer_flags[1] = (flag.y + 1) % 3; - buffer_flags[3] = num_tokens; - *(offset_access_ptr) = 0; } + }; + + PackedT initValue{}; + return initValue.mPacked; +} + +// A helper class to get the correct base pointer for a given layout +struct LamportBufferLayout { + uint32_t numStages = 1; + uint32_t bytesPerBuffer = 0; + static constexpr uint32_t sNumLamportBuffers = 3; + + // Implicitly inlined + [[nodiscard]] __device__ __host__ size_t getTotalBytes() const { + return numStages * static_cast(bytesPerBuffer / numStages) * sNumLamportBuffers; + } + + // Implicitly inlined + [[nodiscard]] __device__ __host__ void* getStagePtr(void* bufferBasePtr, uint32_t lamportIndex, + uint32_t stageIndex) const { + // Typecast to avoid warnings + return reinterpret_cast( + reinterpret_cast(bufferBasePtr) + + static_cast((lamportIndex * numStages + stageIndex) * + static_cast(bytesPerBuffer / numStages))); } }; +// Current Index +// Dirty Index +// bytes_per_buffer +// Dirty num_stages +// Dirty bytes_to_clear = {stage0, stage1, stage2, stage3} # We fix this to 4 stages +// offset_access_ptr -template -__global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ptrs, T* mcast_ptr, - int num_tokens, int buffer_M, int token_dim, int rank, - uint32_t* buffer_flags, bool wait_for_results) { - int elt = blockIdx.y * blockDim.x + threadIdx.x; +namespace cg = cooperative_groups; - if (elt >= token_dim) return; - int token = blockIdx.x; +// PackedType is the one used in kernel for Lamport buffer (LDG.128 or LDG.64) +template +__device__ struct __attribute__((aligned(32))) LamportFlags { + public: + __device__ explicit LamportFlags(uint32_t* bufferFlags, uint32_t numStages = 1) + : mBufferFlagsPtr(bufferFlags), mFlagAccessPtr(&bufferFlags[8]) { + mCurBufferLayout.numStages = numStages; + uint4 flag = reinterpret_cast(bufferFlags)[0]; + mCurrentIndex = flag.x; + mDirtyIndex = flag.y; + // Buffer size is unchanged as the flag should be coupled to each buffer + mCurBufferLayout.bytesPerBuffer = flag.z; + mDirtyBufferLayout.bytesPerBuffer = flag.z; + mDirtyBufferLayout.numStages = flag.w; + *reinterpret_cast(&mBytesToClear) = reinterpret_cast(bufferFlags)[1]; + } -#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - cudaGridDependencySynchronize(); -#endif + // Return the base pointer of the lamport buffer indexed by mCurrentIndex and the stageIdx + [[nodiscard]] __device__ void* getCurLamportBuf(void* bufferBasePtr, int stageIdx = 0) const { + return mCurBufferLayout.getStagePtr(bufferBasePtr, mCurrentIndex, stageIdx); + } - LamportFlags flags(buffer_flags); - - // Capture the number of tokens in previous iteration so that we can properly clear the buffer - // The scatter stage will use the buffer in WORLD_SIZE granularity, thus we need to round up - uint32_t clr_toks_cta = - divUp(flags.num_tokens_prev > num_tokens ? flags.num_tokens_prev : num_tokens, - WORLD_SIZE) * - WORLD_SIZE; - clr_toks_cta = divUp(clr_toks_cta, gridDim.x); - - if (elt < token_dim) { - // Scatter token - int dest_rank = token % WORLD_SIZE; - int dest_token_offset = token / WORLD_SIZE; - T val = shard_ptr[token * token_dim + elt]; - if (isNegZero(val)) val = fromFloat(0.f); - input_ptrs[dest_rank][flags.input_offset + dest_token_offset * token_dim * WORLD_SIZE + - rank * token_dim + elt] = val; - - // Clear the buffer used by the previous call. Note the number of tokens to clear could be - // larger than the - // number of tokens in the current call. - for (int clr_tok = 0; clr_tok < clr_toks_cta; clr_tok++) { - uint32_t clr_token_idx = token + clr_tok * gridDim.x; - if (clr_token_idx < buffer_M) { - input_ptrs[rank][flags.clear_offset + clr_token_idx * token_dim + elt] = fromFloat(-0.f); + // Fill the dirty lamport buffer with the init value; Use stageIdx to select the stage to clear, + // -1 to clear all + // FIXME: Current kernel may use less stages than the dirty numStages; How to guarantee the + // correctness? CAUTION: This function requires all threads in the grid to participate and ASSUME + // 1D thread block layout! + __device__ void clearDirtyLamportBuf(void* bufferBasePtr, int stageIdx = -1) { + // Rasterize the threads to 1D for flexible clearing + + uint32_t globalCtaIdx = blockIdx.x * gridDim.y + blockIdx.y; + uint32_t globalTid = globalCtaIdx * blockDim.x + threadIdx.x; + uint32_t numThreads = gridDim.x * gridDim.y * blockDim.x; + + if (stageIdx == -1) { + // Clear all stages + for (uint32_t i = 0; i < mDirtyBufferLayout.numStages; i++) { + clearPackedBuf(bufferBasePtr, globalTid, numThreads, mBytesToClear[i], mDirtyIndex, i); } + } else if (stageIdx < mDirtyBufferLayout.numStages) { + clearPackedBuf(bufferBasePtr, globalTid, numThreads, mBytesToClear[stageIdx], mDirtyIndex, + stageIdx); } + } - // Reduce and broadcast - if ((token % WORLD_SIZE) == rank) { - int local_token = token / WORLD_SIZE; - float accum = 0.f; - - T values[WORLD_SIZE]; - - while (1) { - bool valid = true; - for (int r = 0; r < WORLD_SIZE; r++) { - T volatile* lamport_ptr = - (T volatile*)&input_ptrs[rank] - [flags.input_offset + local_token * token_dim * WORLD_SIZE + - r * token_dim + elt]; - values[r] = *lamport_ptr; - valid &= !isNegZero(values[r]); - } - if (valid) break; - } - for (int r = 0; r < WORLD_SIZE; r++) { - accum += toFloat(values[r]); - } - mcast_ptr[flags.input_offset + buffer_M * token_dim + token * token_dim + elt] = - fromFloat(accum); + __device__ void ctaArrive() { + int tid{0}; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + + cg::cluster_group cluster = cg::this_cluster(); + // We update the atomic counter per cluster + tid = cluster.thread_rank(); + cluster.sync(); +#else + tid = threadIdx.x; + __syncthreads(); +#endif + if (tid == 0) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) + asm volatile("red.async.release.global.gpu.add.u32 [%0], %1;" ::"l"(mFlagAccessPtr), "r"(1) + : "memory"); +#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)) + asm volatile("red.release.global.gpu.add.u32 [%0], %1;" ::"l"(mFlagAccessPtr), "r"(1) + : "memory"); +#else + atomicAdd(mFlagAccessPtr, 1); +#endif } } -#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - cudaTriggerProgrammaticLaunchCompletion(); + __device__ void waitAndUpdate(uint4 bytesToClearPerStage) { + bool isLastCtaT0{false}; + int targetCount{0}; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cg::grid_group grid = cg::this_grid(); + // Use the first thread instead of the last thread as the last thread may exit early + isLastCtaT0 = grid.thread_rank() == 0; + targetCount = grid.num_clusters(); +#else + isLastCtaT0 = threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0; + targetCount = gridDim.x * gridDim.y; #endif + if (isLastCtaT0) { + uint4* flagPtr = reinterpret_cast(mBufferFlagsPtr); + while (*reinterpret_cast(mFlagAccessPtr) < targetCount) { + } + // 'Current' becomes 'Dirty' + flagPtr[0] = {(mCurrentIndex + 1) % 3, // Current index + mCurrentIndex, // Dirty index + mCurBufferLayout.bytesPerBuffer, // Buffer size + mCurBufferLayout.numStages}; // Dirty - Number of stages + flagPtr[1] = bytesToClearPerStage; + *mFlagAccessPtr = 0; + } + } - // Similarly clear broadcast buffer here - for (int clr_tok = 0; clr_tok < clr_toks_cta; clr_tok++) { - uint32_t clr_token_idx = token + clr_tok * gridDim.x; - if (clr_token_idx < buffer_M) { - input_ptrs[rank][flags.clear_offset + buffer_M * token_dim + clr_token_idx * token_dim + - elt] = fromFloat(-0.f); + private: + uint32_t* mBufferFlagsPtr; + uint32_t* mFlagAccessPtr; + + uint32_t mCurrentIndex, mDirtyIndex; + // So that we can access it with uint4 + alignas(16) std::array mBytesToClear; + LamportBufferLayout mCurBufferLayout, mDirtyBufferLayout; + + inline __device__ void clearPackedBuf(void* bufferBasePtr, uint32_t globalTid, + uint32_t numThreads, uint32_t bytesToClear, + uint8_t dirtyIndex, uint8_t stageIdx) { + // Round up to the float4 boundary + uint32_t clearBoundary = ceil_div(bytesToClear, sizeof(PackedType)); + for (uint32_t packedIdx = globalTid; packedIdx < clearBoundary; packedIdx += numThreads) { + reinterpret_cast( + mDirtyBufferLayout.getStagePtr(bufferBasePtr, dirtyIndex, stageIdx))[packedIdx] = + getPackedLamportInit(); } } +}; - // Optionally wait for results if the next layer isn't doing the Lamport check - if (wait_for_results) { - // Update the atomic counter to indicate the block has read the offsets - flags.cta_arrive(); - // Only use a set of CTAs for lamport sync, reargange the grid - constexpr int ELTS_PER_LOAD = sizeof(float2) / sizeof(T); - // blockDim.x / ELTS_PER_LOAD should be at least the size of a warp (32) - if (threadIdx.x < (blockDim.x / ELTS_PER_LOAD)) { - uint64_t current_pos = - blockIdx.x * token_dim + blockIdx.y * blockDim.x + threadIdx.x * ELTS_PER_LOAD; - - void* lamport_ptr = - (void*)&input_ptrs[rank][flags.input_offset + buffer_M * token_dim + current_pos]; - // We have 2 assumptions here: - // 1. The write is atomic in 8B granularity -> Each buffer in the buffer group should be - // aligned to 8B - // 2. The num_token * token_dim is divisible by ELTS_PER_LOAD (4 for BF16 and 2 for FP32) - float2 val = loadfloat2(lamport_ptr); - while (isNegZero(*(T*)&val)) { - val = loadfloat2(lamport_ptr); - } - if (output_ptr) { - *((float2*)&output_ptr[current_pos]) = val; - } +template +union PackedVec { + PackedType packed; + T elements[sizeof(PackedType) / sizeof(T)]; + + __device__ PackedVec& operator+=(PackedVec& other) { +#pragma unroll + for (int i = 0; i < sizeof(PackedType) / sizeof(T); i++) { + elements[i] += other.elements[i]; } + return *this; + } - // Update the buffer flags - flags.wait_and_update(num_tokens); + __device__ PackedVec operator+(PackedVec& other) { + PackedVec result; +#pragma unroll + for (int i = 0; i < sizeof(PackedType) / sizeof(T); i++) { + result.elements[i] = elements[i] + other.elements[i]; + } + return result; } +}; + +template +inline __device__ PackedType loadPacked(T* ptr) { + return *reinterpret_cast(ptr); } -// Template-based dispatch functions following the same pattern as trtllm_allreduce.cuh -template -cudaError_t twoshot_allreduce_dispatch(AllReduceParams& params) { - int const num_threads = 128; - int const num_blocks = (params.token_dim + num_threads - 1) / num_threads; - - dim3 grid(params.num_tokens, num_blocks); - - cudaLaunchConfig_t config; - cudaLaunchAttribute attrs[1]; - config.dynamicSmemBytes = 0; - config.stream = params.stream; - config.gridDim = grid; - config.blockDim = num_threads; - config.attrs = attrs; - attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = params.launch_with_pdl ? 1 : 0; - config.numAttrs = 1; +template +inline __device__ const PackedType loadPacked(T const* ptr) { + return *reinterpret_cast(ptr); +} - cudaLaunchKernelEx(&config, &twoshot_allreduce_kernel, - reinterpret_cast(params.output), reinterpret_cast(params.input), - reinterpret_cast(params.buffer_ptrs_dev), - reinterpret_cast(params.multicast_ptr), params.num_tokens, params.buffer_M, - params.token_dim, params.rank, - reinterpret_cast(params.buffer_flags), params.wait_for_results); +template +inline __device__ PackedType loadPackedVolatile(void const* ptr) { + static_assert(sizeof(PackedType) == 0, "Not implemented"); + return PackedType{}; +} - return cudaSuccess; +template <> +inline __device__ float4 loadPackedVolatile(void const* ptr) { + float4 returnValue; + asm volatile("ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];\n" + : "=f"(returnValue.x), "=f"(returnValue.y), "=f"(returnValue.z), "=f"(returnValue.w) + : "l"(ptr)); + return returnValue; } -template -cudaError_t twoshot_allreduce_dispatch_world_size(AllReduceParams& params) { - FLASHINFER_LOG_DEBUG("twoshot_allreduce_dispatch_world_size"); - switch (params.nranks) { - case 2: - return twoshot_allreduce_dispatch(params); - case 4: - return twoshot_allreduce_dispatch(params); - case 8: - return twoshot_allreduce_dispatch(params); - case 16: - return twoshot_allreduce_dispatch(params); - case 32: - return twoshot_allreduce_dispatch(params); - case 64: - return twoshot_allreduce_dispatch(params); - default: - FLASHINFER_ERROR("MNNVL AllReduce: unsupported world_size " + std::to_string(params.nranks) + - ". Supported sizes: {2, 4, 8, 16, 32, 64}"); - return cudaErrorInvalidValue; - } +template <> +inline __device__ float2 loadPackedVolatile(void const* ptr) { + float2 returnValue; + asm volatile("ld.volatile.global.v2.f32 {%0, %1}, [%2];\n" + : "=f"(returnValue.x), "=f"(returnValue.y) + : "l"(ptr)); + return returnValue; } template -__device__ void copy_f4(T_IN* dst, T_IN const* src) { - float4* dst4 = (float4*)dst; - float4 const* src4 = (float4 const*)src; +inline __device__ void copyF4(T_IN* dst, T_IN const* src) { + float4* dst4 = reinterpret_cast(dst); + float4 const* src4 = reinterpret_cast(src); __pipeline_memcpy_async(dst4, src4, sizeof(float4)); } -template -__device__ void copy_f4_ldg(T_IN* dst, T_IN const* src) { - float4* dst4 = (float4*)dst; - float4 const* src4 = (float4*)src; - *dst4 = *src4; -} +uint32_t constexpr kWARP_SIZE = 32U; +uint32_t constexpr kLOG2_WARP_SIZE = 5U; +uint32_t constexpr kLANE_ID_MASK = 0x1f; +uint32_t constexpr kFINAL_MASK = 0xffffffff; -__device__ float4 loadfloat4(void const* ptr) { - // Check alignment - ptr should be 16-byte aligned for safe float4 load - if (reinterpret_cast(ptr) % 16 != 0) { - // Fall back to scalar loads if not aligned - float4 return_value; - float const* float_ptr = reinterpret_cast(ptr); - return_value.x = float_ptr[0]; - return_value.y = float_ptr[1]; - return_value.z = float_ptr[2]; - return_value.w = float_ptr[3]; - return return_value; +template +inline __device__ T warpReduceSumFull(T val) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + val += __shfl_xor_sync(kFINAL_MASK, val, mask, kWARP_SIZE); } - - float4 return_value; - - asm volatile("ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];\n" - : "=f"(return_value.x), "=f"(return_value.y), "=f"(return_value.z), - "=f"(return_value.w) - : "l"(ptr)); - - return return_value; + return val; } -// Safer version that checks bounds before loading template -__device__ float4 loadfloat4_safe(T const* ptr, int remaining_elements) { - float return_value[4] = {0.0f, 0.0f, 0.0f, 0.0f}; +inline __device__ T warpReduceSumPartial(T val) { + int laneId = threadIdx.x & kLANE_ID_MASK; + // We make sure only the last warp will call this function + int warpSize = blockDim.x - (threadIdx.x & ~(kWARP_SIZE - 1)); + unsigned int active_mask = (1U << warpSize) - 1; - if (remaining_elements <= 0) { - return *(float4*)return_value; +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + int targetLane = laneId ^ mask; + auto tmp = __shfl_xor_sync(active_mask, val, mask, kWARP_SIZE); + val += targetLane < warpSize ? tmp : 0; } + return val; +} - // Check alignment - ptr should be 16-byte aligned for safe float4 load - bool is_aligned = (reinterpret_cast(ptr) % 16 == 0); - - if (is_aligned && remaining_elements >= 4) { - // Safe to do vectorized load - asm volatile("ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];\n" - : "=f"(return_value[0]), "=f"(return_value[1]), "=f"(return_value[2]), - "=f"(return_value[3]) - : "l"(ptr)); - } else { - // Fall back to scalar loads with bounds checking - float const* float_ptr = reinterpret_cast(ptr); - for (int i = 0; i < 4 && i < remaining_elements; i++) { - return_value[i] = toFloat(float_ptr[i]); - } +// SYNC: +// - True: share the sum across all threads +// - False: only thread 0 get the sum; Other thread's value is undefined. +template +inline __device__ T blockReduceSumPartial(T val) { + __shared__ T smem[kWARP_SIZE]; + int laneId = threadIdx.x & kLANE_ID_MASK; + int warpId = threadIdx.x >> kLOG2_WARP_SIZE; + int warpNum = (blockDim.x + kWARP_SIZE - 1) >> + kLOG2_WARP_SIZE; // Ceiling division to include partial warps + + val = (warpId == warpNum - 1) ? warpReduceSumPartial(val) : warpReduceSumFull(val); + if (laneId == 0) { + smem[warpId] = val; } + __syncthreads(); - return *(float4*)return_value; -} + if (warpId == 0) { + val = (laneId < warpNum) ? smem[laneId] : (T)0.f; + // Need to consider the corner case where we only have one warp and it is partial + val = (warpNum == 1) ? warpReduceSumPartial(val) : warpReduceSumFull(val); -template -inline __device__ T add(T a, T b) { - return a + b; + if constexpr (SYNC) { + if (laneId == 0) { + smem[warpId] = val; + } + } + } + if constexpr (SYNC) { + __syncthreads(); + val = smem[0]; + } + return val; } -#define FINAL_MASK 0xffffffff - template -__inline__ __device__ T warpReduceSum(T val) { -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val = add(val, __shfl_xor_sync(FINAL_MASK, val, mask, - 32)); //__shfl_sync bf16 return float when sm < 80 - return val; -} +inline __device__ T blockReduceSumFull(T val) { + __shared__ T smem[kWARP_SIZE]; + int lane_id = threadIdx.x & kLANE_ID_MASK; + int warp_id = threadIdx.x >> kLOG2_WARP_SIZE; + int warp_num = blockDim.x >> kLOG2_WARP_SIZE; -inline __device__ float block_reduce_sum(float val) { - __shared__ float smem[32]; - int lane_id = threadIdx.x % 32, warp_id = threadIdx.x / 32, warp_num = blockDim.x / 32; - val = warpReduceSum(val); + val = warpReduceSumFull(val); if (lane_id == 0) { smem[warp_id] = val; } __syncthreads(); - val = lane_id < warp_num ? smem[lane_id] : 0.f; - val = warpReduceSum(val); + + val = (lane_id < warp_num) ? smem[lane_id] : (T)0.f; + val = warpReduceSumFull(val); + return val; } -template -__global__ void __launch_bounds__(128, 1) - RMSNorm(T_IN* input_plus_residual, T_OUT* output_norm, T_IN const* buffer_input, - T_IN const* gamma, float epsilon, T_IN const* residual, int batch_size, - uint32_t* buffer_flags) { +template +inline __device__ T blockReduceSum(T val) { + bool hasPartialWarp = (blockDim.x & kLANE_ID_MASK) != 0; + if (hasPartialWarp) { + return blockReduceSumPartial(val); + } else { + return blockReduceSumFull(val); + } +} +// A helper function to tune the grid configuration for fused oneshot and rmsnorm kernels +// Return (block_size, cluster_size, loads_per_thread) +std::tuple adjustGridConfig(int numTokens, int dim, int eltsPerThread) { + // Start with preferred block_size and cluster_size #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + int clusterSize = 8; +#else + int clusterSize = 1; +#endif + int blockSize = 128; + // ========================== Adjust the grid configuration ========================== + int threadsNeeded = ceil_div(dim, eltsPerThread); + int loadsPerThread = 1; - static bool const LAMPORT = true; + blockSize = ceil_div(threadsNeeded, clusterSize); +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + while (threadsNeeded % clusterSize != 0 && clusterSize > 1) { + clusterSize /= 2; + } + blockSize = ceil_div(threadsNeeded, clusterSize); + while (blockSize < 128 && clusterSize >= 2) { + blockSize *= 2; + clusterSize /= 2; + } + int smCount = GetCudaMultiProcessorCount(); + while (numTokens * clusterSize > smCount && clusterSize > 1 && blockSize <= 512) { + blockSize *= 2; + clusterSize /= 2; + } +#endif - extern __shared__ uint8_t smem[]; + // Trying to scale up use multiple loads or CGA + while (blockSize > 1024) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + if (clusterSize < 8) { + clusterSize = clusterSize << 1; + } else { + break; + } +#else + if (loadsPerThread < 8) { + loadsPerThread += 1; + } else { + break; + } +#endif + blockSize = ceil_div(threadsNeeded, clusterSize * loadsPerThread); + } + return {blockSize, clusterSize, loadsPerThread}; +} +}; // namespace utils + +using utils::blockReduceSum; +using utils::fromFloat; +using utils::isNegZero; +using utils::LamportFlags; +using utils::loadPacked; +using utils::loadPackedVolatile; +using utils::PackedVec; +using utils::toFloat; + +template +__global__ void __launch_bounds__(1024) + oneshotAllreduceFusionKernel(T* outputPtr, T* prenormedPtr, T const* shardPtr, + T const* residualInPtr, T const* gammaPtr, T** inputPtrs, + T* mcastPtr, int const numTokens, int const tokenDim, + float epsilon, int const rank, uint32_t* bufferFlags) { + constexpr int kELTS_PER_THREAD = sizeof(PackedType) / sizeof(T); + constexpr int kLAMPORT_ELTS_PER_PACKED = sizeof(PackedType) / sizeof(float); + constexpr uint32_t kELT_SIZE = sizeof(T); +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + namespace cg = cooperative_groups; + cg::cluster_group cluster = cg::this_cluster(); + int packedIdx = cluster.thread_rank(); + int token = blockIdx.x; + int threadOffset = token * tokenDim + packedIdx * kELTS_PER_THREAD; - int sample = blockIdx.y; + cudaGridDependencySynchronize(); +#else + int packedIdx = blockIdx.y * blockDim.x + threadIdx.x; + int token = blockIdx.x; + // Offset w.r.t. the input shard + int threadOffset = token * tokenDim + packedIdx * kELTS_PER_THREAD; +#endif - static int const CGA_THREADS = NUM_THREADS * 1; + // We only use 1 stage for the oneshot allreduce + LamportFlags flag(bufferFlags, 1); + T* stagePtrMcast = reinterpret_cast(flag.getCurLamportBuf(mcastPtr, 0)); + T* stagePtrLocal = reinterpret_cast(flag.getCurLamportBuf(inputPtrs[rank], 0)); - static int const ITERS = DIM / CGA_THREADS; - float r_input[ITERS]; - float r_gamma[ITERS]; + if (packedIdx * kELTS_PER_THREAD >= tokenDim) { + flag.ctaArrive(); + flag.clearDirtyLamportBuf(inputPtrs[rank], -1); + return; + } - T_IN* sh_input = (T_IN*)&smem[0]; - T_IN* sh_residual = (T_IN*)&smem[NUM_INPUTS * NUM_THREADS * ITERS * sizeof(T_IN)]; - T_IN* sh_gamma = (T_IN*)&smem[(NUM_INPUTS + 1) * NUM_THREADS * ITERS * sizeof(T_IN)]; + // ==================== Broadcast tokens to each rank ============================= + PackedVec val; + val.packed = loadPacked(&shardPtr[threadOffset]); +#pragma unroll + for (int i = 0; i < kELTS_PER_THREAD; i++) { + if (isNegZero(val.elements[i])) val.elements[i] = fromFloat(0.f); + } - static int const ELTS_PER_THREAD = sizeof(float4) / sizeof(T_IN); + reinterpret_cast( + &stagePtrMcast[token * tokenDim * WorldSize + rank * tokenDim])[packedIdx] = val.packed; - int offsets[NUM_INPUTS][DIM / (1 * ELTS_PER_THREAD * NUM_THREADS)]; + flag.ctaArrive(); + // ======================= Lamport Sync and clear the output buffer from previous iteration + // ============================= + flag.clearDirtyLamportBuf(inputPtrs[rank], -1); - LamportFlags flags(buffer_flags); - T_IN const* input = &buffer_input[flags.input_offset + flags.buffer_size]; + PackedVec valuesLamport[WorldSize]; + while (1) { + bool valid = true; +#pragma unroll + for (int r = 0; r < WorldSize; r++) { + valuesLamport[r].packed = loadPackedVolatile( + &stagePtrLocal[token * tokenDim * WorldSize + r * tokenDim + + packedIdx * kELTS_PER_THREAD]); -#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) +#pragma unroll + for (int i = 0; i < kLAMPORT_ELTS_PER_PACKED; i++) { + valid &= !isNegZero(valuesLamport[r].elements[i]); + } + } + if (valid) { + break; + } + } + + auto values = reinterpret_cast*>(valuesLamport); + // ======================= Reduction ============================= + float accum[kELTS_PER_THREAD]; + PackedVec packedAccum; + +#pragma unroll + for (int i = 0; i < kELTS_PER_THREAD; i++) { + accum[i] = toFloat(values[0].elements[i]); + } + +#pragma unroll + for (int r = 1; r < WorldSize; r++) { +#pragma unroll + for (int i = 0; i < kELTS_PER_THREAD; i++) { + accum[i] += toFloat(values[r].elements[i]); + } + } + +#pragma unroll + for (int i = 0; i < kELTS_PER_THREAD; i++) { + packedAccum.elements[i] = fromFloat(accum[i]); + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaTriggerProgrammaticLaunchCompletion(); #endif + if constexpr (RMSNormFusion) { + // =============================== Residual =============================== + PackedVec residualIn; + residualIn.packed = *reinterpret_cast(&residualInPtr[threadOffset]); + packedAccum += residualIn; + *reinterpret_cast(&prenormedPtr[threadOffset]) = packedAccum.packed; + // =============================== Rmsnorm ================================ + PackedVec gamma; + gamma.packed = *reinterpret_cast(&gammaPtr[packedIdx * kELTS_PER_THREAD]); + + float threadSum = 0.F; +#pragma unroll + for (int i = 0; i < kELTS_PER_THREAD; i++) { + // FIXME: Use float square if accuracy issue + threadSum += toFloat(packedAccum.elements[i] * packedAccum.elements[i]); + } + float blockSum = blockReduceSum(threadSum); - for (int i = 0; i < NUM_INPUTS; i++) { - for (int j = 0; j < DIM / (1 * ELTS_PER_THREAD * NUM_THREADS); j++) { - int k = j * NUM_THREADS + threadIdx.x; - offsets[i][j] = - i * batch_size * DIM + sample * DIM + blockIdx.x * DIM / 1 + k * ELTS_PER_THREAD; + __shared__ float sharedVal[8]; // Temporary variable to share the sum within block + float fullSum = blockSum; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + namespace cg = cooperative_groups; + cg::cluster_group cluster = cg::this_cluster(); + int const numBlocks = cluster.num_blocks(); + if (numBlocks > 1) { + fullSum = 0.F; + // Need to reduce over the entire cluster + int const blockRank = cluster.block_rank(); + if (threadIdx.x < numBlocks) { + cluster.map_shared_rank(&sharedVal[0], threadIdx.x)[blockRank] = blockSum; + } + cluster.barrier_wait(cluster.barrier_arrive()); + for (int i = 0; i < numBlocks; ++i) { + fullSum += sharedVal[i]; + } + } +#endif + float rcpRms = rsqrtf(fullSum / tokenDim + epsilon); +#pragma unroll + for (int i = 0; i < kELTS_PER_THREAD; i++) { + packedAccum.elements[i] = fromFloat(toFloat(packedAccum.elements[i]) * rcpRms * + toFloat(gamma.elements[i])); } } + reinterpret_cast(&outputPtr[threadOffset])[0] = packedAccum.packed; + flag.waitAndUpdate( + {static_cast(numTokens * tokenDim * WorldSize * kELT_SIZE), 0, 0, 0}); +} -#pragma unroll - for (int j = 0; j < DIM / (1 * ELTS_PER_THREAD * NUM_THREADS); j++) { - int i = j * NUM_THREADS + threadIdx.x; - copy_f4(&sh_residual[i * ELTS_PER_THREAD], - &residual[sample * DIM + blockIdx.x * DIM + i * ELTS_PER_THREAD]); +using utils::adjustGridConfig; + +template +cudaError_t oneshotAllreduceFusionDispatch(AllReduceFusionParams const& params) { + int const numTokens = params.numTokens; + int const tokenDim = params.tokenDim; + int const eltsPerThread = sizeof(float4) / sizeof(T); + + auto [blockSize, clusterSize, loadsPerThread] = + adjustGridConfig(numTokens, tokenDim, eltsPerThread); + dim3 grid(numTokens, clusterSize, 1); + + FLASHINFER_CHECK(blockSize <= 1024 && loadsPerThread == 1, + "Hidden Dimension %d exceeds the maximum supported hidden dimension (%d)", + tokenDim, +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + 1024 * 8 * eltsPerThread); +#else + 1024 * eltsPerThread); +#endif + + FLASHINFER_LOG_DEBUG( + "[MNNVL AllReduceOneShot] Dispatch: grid size: (%d, %d, 1), block_size: %d, cluster_size: " + "%d, " + "loads_per_thread: %d, " + "threads_needed: %d", + numTokens, clusterSize, blockSize, clusterSize, loadsPerThread, + ceil_div(tokenDim, eltsPerThread)); + + cudaLaunchAttribute attrs[2]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = params.launchWithPdl ? 1 : 0; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + attrs[1].id = cudaLaunchAttributeClusterDimension; + attrs[1].val.clusterDim.x = 1; + attrs[1].val.clusterDim.y = clusterSize; + attrs[1].val.clusterDim.z = 1; +#endif + + cudaLaunchConfig_t config{ + .gridDim = grid, + .blockDim = blockSize, + .dynamicSmemBytes = 0, + .stream = params.stream, + .attrs = attrs, +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + .numAttrs = 2, +#else + .numAttrs = 1, +#endif + }; + +#define LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, RMSNORM) \ + FLASHINFER_CUDA_CALL(cudaLaunchKernelEx( \ + &config, &oneshotAllreduceFusionKernel, output, residualOut, input, \ + residualIn, gamma, ucPtrs, mcPtr, numTokens, tokenDim, static_cast(params.epsilon), \ + params.rank, params.bufferFlags)); +#define DISPATCH_ALLREDUCE_KERNEL(WORLD_SIZE) \ + if (params.rmsNormFusion) { \ + LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, true); \ + } else { \ + LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, false); \ } - __pipeline_commit(); + T** ucPtrs = reinterpret_cast(params.bufferPtrsDev); + T* mcPtr = reinterpret_cast(params.multicastPtr); + T* output = reinterpret_cast(params.output); + T* residualOut = reinterpret_cast(params.residualOut); + T const* input = reinterpret_cast(params.input); + T const* residualIn = reinterpret_cast(params.residualIn); + T const* gamma = reinterpret_cast(params.gamma); -#pragma unroll - for (int j = 0; j < DIM / (ELTS_PER_THREAD * NUM_THREADS); j++) { - int i = j * NUM_THREADS + threadIdx.x; - copy_f4(&sh_gamma[i * ELTS_PER_THREAD], &gamma[blockIdx.x * DIM + i * ELTS_PER_THREAD]); + switch (params.nRanks) { + // FIXME: Do we need other world sizes? + case 2: + DISPATCH_ALLREDUCE_KERNEL(2); + break; + case 4: + DISPATCH_ALLREDUCE_KERNEL(4); + break; + case 8: + DISPATCH_ALLREDUCE_KERNEL(8); + break; + case 16: + DISPATCH_ALLREDUCE_KERNEL(16); + break; + case 32: + DISPATCH_ALLREDUCE_KERNEL(32); + break; + case 64: + DISPATCH_ALLREDUCE_KERNEL(64); + break; + default: + FLASHINFER_ERROR("MNNVL AllReduce: unsupported world_size " + std::to_string(params.nRanks) + + ". Supported sizes: {2, 4, 8, 16, 32, 64}"); + return cudaErrorInvalidValue; } +#undef LAUNCH_ALLREDUCE_KERNEL + return cudaSuccess; +} - __pipeline_commit(); - flags.cta_arrive(); +enum MNNVLTwoShotStage : uint8_t { + SCATTER = 0, + BROADCAST = 1, + NUM_STAGES = 2, +}; - // Load all inputs - bool valid = false; +template +__global__ __launch_bounds__(128) void twoshotAllreduceKernel( + T* outputPtr, T const* shardPtr, T** inputPtrs, T* mcastPtr, uint32_t const numTokens, + uint32_t const tokenDim, uint32_t const rank, uint32_t* bufferFlags, + bool const wait_for_results) { + constexpr int kELTS_PER_THREAD = sizeof(PackedType) / sizeof(T); + constexpr int kLAMPORT_ELTS_PER_PACKED = sizeof(PackedType) / sizeof(float); + constexpr uint32_t kELT_SIZE = sizeof(T); + + int packedIdx = blockIdx.y * blockDim.x + threadIdx.x; + int token = blockIdx.x; + // Offset w.r.t. the input shard + int threadOffset = token * tokenDim + packedIdx * kELTS_PER_THREAD; -#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - if (!LAMPORT) cudaGridDependencySynchronize(); + int destRank = token % WorldSize; + int destTokenOffset = token / WorldSize; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); #endif + LamportFlags flag(bufferFlags, MNNVLTwoShotStage::NUM_STAGES); - while (!valid) { - valid = true; -#pragma unroll - for (int i = 0; i < NUM_INPUTS; i++) { - for (int j = 0; j < DIM / (ELTS_PER_THREAD * NUM_THREADS); j++) { - int k = j * NUM_THREADS + threadIdx.x; - - float4* dst4 = (float4*)&sh_input[i * NUM_THREADS * ITERS + k * ELTS_PER_THREAD]; - - // Calculate the absolute element offset from the start of buffer_input - int element_offset = offsets[i][j]; - - // The input pointer is already offset to: &buffer_input[buffer_offset + buffer_size] - // So the actual pointer we're accessing is: input + element_offset - // Which equals: &buffer_input[buffer_offset + buffer_size + element_offset] - - float4* src4 = (float4*)&input[element_offset]; - - float4 value; - // Check if we have enough elements remaining for a safe float4 load - if (element_offset >= 0 && element_offset + ELTS_PER_THREAD <= flags.buffer_size) { - value = loadfloat4(src4); - } else { - // Use safe load for boundary cases or out-of-bounds - int remaining_elements = flags.buffer_size - element_offset; - if (remaining_elements <= 0) { - // Completely out of bounds, return zeros - float4 return_value = {0.0f, 0.0f, 0.0f, 0.0f}; - value = return_value; - } else { - value = loadfloat4_safe(reinterpret_cast(src4), remaining_elements); - } - } + T* scatterBufLocal = + reinterpret_cast(flag.getCurLamportBuf(inputPtrs[rank], MNNVLTwoShotStage::SCATTER)); + T* scatterBufDest = + reinterpret_cast(flag.getCurLamportBuf(inputPtrs[destRank], MNNVLTwoShotStage::SCATTER)); + T* broadcastBufW = + reinterpret_cast(flag.getCurLamportBuf(mcastPtr, MNNVLTwoShotStage::BROADCAST)); + T* broadcastBufR = + reinterpret_cast(flag.getCurLamportBuf(inputPtrs[rank], MNNVLTwoShotStage::BROADCAST)); - if (LAMPORT) { - // Assume that the 16B were written atomically, so we only need to check one value - T_IN lowest_val = *(T_IN*)&value; - valid &= !isNegZero(lowest_val); - } - *dst4 = value; - } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif + // Make sure the clear function is called before OOB thread exits + if (packedIdx * kELTS_PER_THREAD >= tokenDim) { + flag.clearDirtyLamportBuf(inputPtrs[rank], -1); + return; + } + + // =============================== Scatter =============================== + + // Load vectorized data + PackedVec val; + val.packed = loadPacked(&shardPtr[threadOffset]); +#pragma unroll + for (int i = 0; i < kELTS_PER_THREAD; i++) { + if (isNegZero(val.elements[i])) { + val.elements[i] = fromFloat(0.F); } } - __syncthreads(); + // Store vectorized data + reinterpret_cast( + &scatterBufDest[destTokenOffset * tokenDim * WorldSize + rank * tokenDim])[packedIdx] = + val.packed; - // Perform the initial input reduction - if (NUM_INPUTS > 0) { - T_IN accum[ELTS_PER_THREAD]; - float4* accum4 = (float4*)&accum; + flag.clearDirtyLamportBuf(inputPtrs[rank], MNNVLTwoShotStage::SCATTER); - for (int j = 0; j < DIM / (ELTS_PER_THREAD * NUM_THREADS); j++) { - int k = j * NUM_THREADS + threadIdx.x; + // =============================== Reduction and Broadcast =============================== - *accum4 = *(float4*)&sh_input[k * ELTS_PER_THREAD]; + if ((token % WorldSize) == rank) { + int localToken = token / WorldSize; + float accum[kELTS_PER_THREAD] = {0.F}; + + // Use float as we only check each float value for validity + PackedVec valuesLamport[WorldSize]; + while (1) { + bool valid = true; +#pragma unroll + for (int r = 0; r < WorldSize; r++) { + valuesLamport[r].packed = loadPackedVolatile( + &scatterBufLocal[localToken * tokenDim * WorldSize + r * tokenDim + + packedIdx * kELTS_PER_THREAD]); - for (int i = 1; i < NUM_INPUTS; i++) { - float4 data = *(float4*)&sh_input[i * NUM_THREADS * ITERS + k * ELTS_PER_THREAD]; - T_IN* p_d = (T_IN*)&data; - for (int x = 0; x < ELTS_PER_THREAD; x++) { - accum[x] += p_d[x]; + // Check validity across all elements +#pragma unroll + for (int i = 0; i < kLAMPORT_ELTS_PER_PACKED; i++) { + valid &= !isNegZero(valuesLamport[r].elements[i]); } } + if (valid) { + break; + } + } + + // Now we view it as the value for reduction + auto values = reinterpret_cast*>(valuesLamport); +#pragma unroll + for (int r = 0; r < WorldSize; r++) { +#pragma unroll + for (int i = 0; i < kELTS_PER_THREAD; i++) { + accum[i] += toFloat(values[r].elements[i]); + } + } - // Write back to input 0's staging location. No sync needed since all data localized to - // thread. - *(float4*)&sh_input[k * ELTS_PER_THREAD] = *accum4; + // Store vectorized result + PackedVec packedAccum; +#pragma unroll + for (int i = 0; i < kELTS_PER_THREAD; i++) { + packedAccum.elements[i] = fromFloat(accum[i]); } + reinterpret_cast(&broadcastBufW[token * tokenDim])[packedIdx] = packedAccum.packed; } - // Wait for residual - __pipeline_wait_prior(1); - __syncthreads(); + flag.clearDirtyLamportBuf(inputPtrs[rank], MNNVLTwoShotStage::BROADCAST); - float thread_sum = 0.f; + // Optionally wait for results if the next layer isn't doing the Lamport check + if (wait_for_results) { + // Update the atomic counter to indicate the block has read the offsets + flag.ctaArrive(); -#pragma unroll - for (int io = 0; io < ITERS / ELTS_PER_THREAD; io++) { - float4 inp4 = - *(float4*)&sh_input[io * NUM_THREADS * ELTS_PER_THREAD + threadIdx.x * ELTS_PER_THREAD]; - float4 res4 = - *(float4*)&sh_residual[io * NUM_THREADS * ELTS_PER_THREAD + threadIdx.x * ELTS_PER_THREAD]; + PackedVec valLamport; + valLamport.packed = loadPackedVolatile(&broadcastBufR[threadOffset]); + while (isNegZero(valLamport.elements[0])) { + valLamport.packed = loadPackedVolatile(&broadcastBufR[threadOffset]); + } + if (outputPtr) { + reinterpret_cast(&outputPtr[threadOffset])[0] = valLamport.packed; + } - T_IN* r_inp = (T_IN*)&inp4; - T_IN* r_res = (T_IN*)&res4; + // Update the buffer flags + flag.waitAndUpdate( + {static_cast(round_up(numTokens, WorldSize) * tokenDim * + kELT_SIZE), // Clear Size for scatter stage + static_cast(numTokens * tokenDim * kELT_SIZE), // Clear Size for broadcast stage + 0, 0}); + // If not wait for results, we will rely on the following kernel to update the buffer + } +} - float4 out4; +using utils::copyF4; +// This kernel works performant when loads_per_thread is 1. +// For this mode, we are able to support up to 1024 (threads) x 8 (elements) = 8192 hidden +// dimension. There are two options for further scaling up: +// 1. Use CGA if supported. It expands the hidden dimension to 8k x 8 = 64k. +// 2. Set loads_per_thread >1. Which can be used if CGA is not supported. Note that this will +// be limited by the shared memory size and register count. +template +__global__ __launch_bounds__(1024) void rmsNormLamport(T_IN* outputPreNorm, T_OUT* outputNorm, + T_IN* bufferInput, T_IN const* gamma, + float epsilon, T_IN const* residual, + uint32_t numTokens, uint32_t dim, + uint32_t worldSize, uint32_t* bufferFlags) { + static_assert(std::is_same_v, "T_IN and T_OUT must be the same type"); + static int const kELTS_PER_LOAD = sizeof(float4) / sizeof(T_IN); + + uint32_t const token = blockIdx.x; + uint32_t const blockSize = blockDim.x; + uint32_t const threadOffset = threadIdx.x; + + uint32_t numThreads = blockSize; + uint32_t clusterSize = 1; + uint32_t blockOffset = 0; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + namespace cg = cooperative_groups; + cg::cluster_group cluster = cg::this_cluster(); + numThreads = cluster.num_threads(); + clusterSize = cluster.num_blocks(); + blockOffset = cluster.block_rank(); +#endif + uint32_t const dimPadded = round_up(dim, kELTS_PER_LOAD * numThreads); + uint32_t const elemsPerThread = dimPadded / numThreads; + uint32_t const loadStride = blockSize; - T_IN* r_out = (T_IN*)&out4; + extern __shared__ uint8_t smem[]; + float rInput[LoadsPerThread * kELTS_PER_LOAD]; + uint32_t offsets[LoadsPerThread * kELTS_PER_LOAD]; - for (int ii = 0; ii < ELTS_PER_THREAD; ii++) { - int i = io * ELTS_PER_THREAD + ii; + uint32_t const smemBufferSize = blockSize * elemsPerThread * sizeof(T_IN); + T_IN* smemInput = (T_IN*)&smem[0]; + T_IN* smemResidual = (T_IN*)&smem[smemBufferSize]; + T_IN* smemGamma = (T_IN*)&smem[2 * smemBufferSize]; - T_IN inp_plus_resid = r_inp[ii] + r_res[ii]; - r_out[ii] = inp_plus_resid; - r_input[i] = toFloat(inp_plus_resid); + LamportFlags flag(bufferFlags, MNNVLTwoShotStage::NUM_STAGES); + T_IN* input = reinterpret_cast( + flag.getCurLamportBuf(reinterpret_cast(bufferInput), MNNVLTwoShotStage::BROADCAST)); - // Accumulate the squares for RMSNorm - thread_sum += toFloat(inp_plus_resid * inp_plus_resid); - } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif + // The offset that current thread should load from. Note that the hidden dimension is split by CGA + // size and each block loads a contiguous chunk; The size of chunk that each block processes + uint32_t const blockChunkSize = ceil_div(dim, clusterSize * kELTS_PER_LOAD) * kELTS_PER_LOAD; + uint32_t const blockLoadOffset = token * dim + blockOffset * blockChunkSize; - *(float4*)&input_plus_residual[sample * DIM + blockIdx.x * DIM + - io * NUM_THREADS * ELTS_PER_THREAD + - threadIdx.x * ELTS_PER_THREAD] = out4; +#pragma unroll + for (uint32_t i = 0; i < LoadsPerThread; i++) { + // Each block load a contiguous chunk of tokens + uint32_t const threadLoadOffset = (i * loadStride + threadOffset) * kELTS_PER_LOAD; + offsets[i] = blockLoadOffset + threadLoadOffset; } - // Wait for Gamma. There will be a global synchronization as part of the reduction - __pipeline_wait_prior(0); +#pragma unroll + for (uint32_t i = 0; i < LoadsPerThread; i++) { + uint32_t const threadLoadOffset = (i * loadStride + threadOffset) * kELTS_PER_LOAD; + if (blockOffset * blockChunkSize + threadLoadOffset < dim) { + copyF4(&smemResidual[threadLoadOffset], &residual[blockLoadOffset + threadLoadOffset]); + } + } + __pipeline_commit(); +#pragma unroll + for (uint32_t i = 0; i < LoadsPerThread; i++) { + uint32_t const threadLoadOffset = (i * loadStride + threadOffset) * kELTS_PER_LOAD; + if (blockOffset * blockChunkSize + threadLoadOffset < dim) { + copyF4(&smemGamma[threadLoadOffset], &gamma[blockOffset * blockChunkSize + threadLoadOffset]); + } + } + __pipeline_commit(); - float cluster_sum = block_reduce_sum(thread_sum); + flag.ctaArrive(); + bool valid = false; + // ACQBLK if not lamport + while (!valid) { + valid = true; +#pragma unroll + for (uint32_t i = 0; i < LoadsPerThread; i++) { + uint32_t threadLoadOffset = (i * loadStride + threadOffset) * kELTS_PER_LOAD; - float rcp_rms = rsqrtf(cluster_sum / DIM + epsilon); + if (blockOffset * blockChunkSize + threadLoadOffset < dim) { + float4* dst4 = reinterpret_cast(&smemInput[threadLoadOffset]); + float4 const* src4 = reinterpret_cast(&input[offsets[i]]); -#pragma unroll - for (int io = 0; io < ITERS / ELTS_PER_THREAD; io++) { - float4 gamma4 = - *(float4*)&sh_gamma[io * NUM_THREADS * ELTS_PER_THREAD + threadIdx.x * ELTS_PER_THREAD]; - T_IN* r_g4 = (T_IN*)&gamma4; - - float4 out4; - // FIXME: this only works if T_OUT == T_IN - T_OUT* r_out = (T_OUT*)&out4; - - for (int ii = 0; ii < ELTS_PER_THREAD; ii++) { - int i = io * ELTS_PER_THREAD + ii; - r_gamma[i] = toFloat(r_g4[ii]); - r_out[ii] = fromFloat(r_gamma[i] * r_input[i] * rcp_rms); + float4 value = loadPackedVolatile(src4); + // Assume that the 16B were written atomically, so we only need to check one value + valid &= !isNegZero(value.x); + *dst4 = value; + } } + } + + __pipeline_wait_prior(1); + __syncthreads(); + + float threadSum = 0.f; +#pragma unroll + for (int i = 0; i < LoadsPerThread; i++) { + int threadLoadOffset = (i * loadStride + threadOffset) * kELTS_PER_LOAD; + if (blockOffset * blockChunkSize + threadLoadOffset < dim) { + PackedVec inp{.packed = loadPacked(&smemInput[threadLoadOffset])}; + PackedVec res{.packed = loadPacked(&smemResidual[threadLoadOffset])}; - *(float4*)&output_norm[sample * DIM + blockIdx.x * DIM + io * NUM_THREADS * ELTS_PER_THREAD + - threadIdx.x * ELTS_PER_THREAD] = out4; + PackedVec inp_plus_res = inp + res; +#pragma unroll + for (int j = 0; j < kELTS_PER_LOAD; j++) { + rInput[i * kELTS_PER_LOAD + j] = toFloat(inp_plus_res.elements[j]); + threadSum += toFloat(inp_plus_res.elements[j] * inp_plus_res.elements[j]); + } + + *reinterpret_cast(&outputPreNorm[blockLoadOffset + threadLoadOffset]) = + inp_plus_res.packed; + } } - // Update the buffer pointers - flags.wait_and_update(batch_size); -#endif -} -template -cudaError_t twoshot_rmsnorm_dispatch(RMSNormParams& params) { - static constexpr int NUM_THREADS = 128; - static constexpr int CGA_THREADS = NUM_THREADS; - constexpr int iters = H_DIM / CGA_THREADS; + __pipeline_wait_prior(0); - dim3 grid(1, params.batch, 1); + float blockSum = blockReduceSum(threadSum); - cudaLaunchConfig_t config; - cudaLaunchAttribute attrs[1]; - config.stream = params.stream; - config.gridDim = grid; - config.blockDim = NUM_THREADS; - config.attrs = attrs; - attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = params.launch_with_pdl ? 1 : 0; - config.numAttrs = 1; + float fullSum = blockSum; + __shared__ float sharedVal[8]; + // Use CGA Reduction if supported +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + int const numBlocks = cluster.num_blocks(); + if (numBlocks > 1) { + fullSum = 0.F; + // Need to reduce over the entire cluster + int const blockRank = cluster.block_rank(); + if (threadIdx.x < numBlocks) { + cluster.map_shared_rank(&sharedVal[0], threadIdx.x)[blockRank] = blockSum; + } + cluster.barrier_wait(cluster.barrier_arrive()); + for (int i = 0; i < numBlocks; ++i) { + fullSum += sharedVal[i]; + } + } +#endif - size_t shmem_size = 3 * NUM_THREADS * iters * sizeof(T); - config.dynamicSmemBytes = shmem_size; + float rcpRms = rsqrtf(fullSum / dim + epsilon); - cudaFuncSetAttribute(&RMSNorm, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); +#pragma unroll + for (int i = 0; i < LoadsPerThread; i++) { + PackedVec r_out; + uint32_t threadLoadOffset = (i * loadStride + threadOffset) * kELTS_PER_LOAD; + if (blockOffset * blockChunkSize + threadLoadOffset < dim) { + PackedVec gamma = {.packed = loadPacked(&smemGamma[threadLoadOffset])}; - cudaLaunchKernelEx( - &config, &RMSNorm, reinterpret_cast(params.residual_output), - reinterpret_cast(params.output), reinterpret_cast(params.input), - reinterpret_cast(params.gamma), static_cast(params.epsilon), - reinterpret_cast(params.residual), params.batch, params.buffer_flags); +#pragma unroll + for (uint32_t j = 0; j < kELTS_PER_LOAD; j++) { + r_out.elements[j] = fromFloat(toFloat(gamma.elements[j]) * + rInput[i * kELTS_PER_LOAD + j] * rcpRms); + } - return cudaSuccess; + *reinterpret_cast(&outputNorm[blockLoadOffset + threadLoadOffset]) = r_out.packed; + } + } + constexpr int kELTS_SIZE = sizeof(T_IN); + + // Update the buffer pointers + flag.waitAndUpdate({static_cast(round_up(numTokens, worldSize) * dim * kELTS_SIZE), + static_cast(numTokens * dim * kELTS_SIZE), 0, 0}); } template -cudaError_t twoshot_rmsnorm_dispatch_hidden_dim(RMSNormParams& params) { - FLASHINFER_LOG_DEBUG("twoshot_rmsnorm_dispatch_hidden_dim"); - switch (params.hidden_dim) { - case 2048: - return twoshot_rmsnorm_dispatch(params); - case 4096: - return twoshot_rmsnorm_dispatch(params); - case 5120: - return twoshot_rmsnorm_dispatch(params); // Llama-4 - case 7168: - return twoshot_rmsnorm_dispatch(params); // DeepSeek - case 8192: - return twoshot_rmsnorm_dispatch(params); +cudaError_t twoshotAllreduceFusionDispatch(AllReduceFusionParams const& params) { + int const numTokens = params.numTokens; + int const tokenDim = params.tokenDim; + int const numEltsPerThread = sizeof(float4) / sizeof(T); + FLASHINFER_CHECK(tokenDim % numEltsPerThread == 0, + "[MNNVL AllReduceTwoShot] token_dim must be divisible by %d", numEltsPerThread); + + int const arNumThreads = ceil_div(tokenDim, numEltsPerThread); + int const arNumBlocksPerToken = ceil_div(arNumThreads, 128); + + dim3 arGrid(numTokens, arNumBlocksPerToken); + + cudaLaunchAttribute arAttrs[1]; + arAttrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + arAttrs[0].val.programmaticStreamSerializationAllowed = params.launchWithPdl ? 1 : 0; + + cudaLaunchConfig_t arConfig{ + .gridDim = arGrid, + .blockDim = 128, + .dynamicSmemBytes = 0, + .stream = params.stream, + .attrs = arAttrs, + .numAttrs = 1, + }; + + FLASHINFER_LOG_DEBUG("[MNNVL AllReduceTwoShot] Dispatch: grid size: (%d, %d, 1), block_size: 128", + numTokens, arNumBlocksPerToken); + +#define LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE) \ + FLASHINFER_CUDA_CALL(cudaLaunchKernelEx( \ + &arConfig, &twoshotAllreduceKernel, output, input, ucPtrs, mcastPtr, \ + numTokens, tokenDim, params.rank, params.bufferFlags, (!params.rmsNormFusion))); + T** ucPtrs = reinterpret_cast(params.bufferPtrsDev); + T* mcastPtr = reinterpret_cast(params.multicastPtr); + T* output = reinterpret_cast(params.output); + T const* input = reinterpret_cast(params.input); + switch (params.nRanks) { + case 2: + LAUNCH_ALLREDUCE_KERNEL(2); + break; + case 4: + LAUNCH_ALLREDUCE_KERNEL(4); + break; + case 8: + LAUNCH_ALLREDUCE_KERNEL(8); + break; + case 16: + LAUNCH_ALLREDUCE_KERNEL(16); + break; + case 32: + LAUNCH_ALLREDUCE_KERNEL(32); + break; + case 64: + LAUNCH_ALLREDUCE_KERNEL(64); + break; default: - FLASHINFER_ERROR("MNNVL TwoShot RMSNorm: unsupported hidden_dim " + - std::to_string(params.hidden_dim) + - ". Supported sizes: {2048, 4096, 5120, 7168, 8192}"); + FLASHINFER_ERROR("[MNNVL AllReduceTwoShot] Unsupported world_size" + + std::to_string(params.nRanks) + ". Supported sizes: {2, 4, 8, 16, 32, 64}"); return cudaErrorInvalidValue; } -} +#undef LAUNCH_ALLREDUCE_KERNEL + + // Launch the rmsnorm lamport kernel if fusion is enabled + if (params.rmsNormFusion) { + auto gridConfig = adjustGridConfig(numTokens, tokenDim, numEltsPerThread); + int rnBlockSize = std::get<0>(gridConfig); + int rnClusterSize = std::get<1>(gridConfig); + int rnLoadsPerThread = std::get<2>(gridConfig); + + int rnNumThreads = rnClusterSize * rnBlockSize; + dim3 rnGrid(numTokens, rnClusterSize, 1); + cudaLaunchConfig_t rnConfig; + cudaLaunchAttribute rnAttrs[2]; + rnConfig.stream = params.stream; + rnConfig.gridDim = rnGrid; + rnConfig.blockDim = rnBlockSize; + rnConfig.attrs = rnAttrs; + rnAttrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + rnAttrs[0].val.programmaticStreamSerializationAllowed = params.launchWithPdl ? 1 : 0; +#ifndef DISABLE_CGA + rnAttrs[1].id = cudaLaunchAttributeClusterDimension; + rnAttrs[1].val.clusterDim.x = 1; + rnAttrs[1].val.clusterDim.y = rnClusterSize; + rnAttrs[1].val.clusterDim.z = 1; + rnConfig.numAttrs = 2; +#else + rnConfig.numAttrs = 1; +#endif + bool const rnUseCGA = rnClusterSize > 1; + int const dimPadded = round_up(tokenDim, numEltsPerThread * rnNumThreads); + int const iters = dimPadded / rnNumThreads; + + size_t const smemSize = 3 * rnBlockSize * iters * sizeof(T); + + FLASHINFER_LOG_DEBUG( + "[MNNVL AllReduceTwoShotRMSNorm] Dispatch: grid size: (%d, %d, 1), block_size: %d, " + "cluster_size: %d, " + "loads_per_thread: %d, " + "threads_needed: %d", + numTokens, rnClusterSize, rnBlockSize, rnClusterSize, rnLoadsPerThread, + ceil_div(tokenDim, numEltsPerThread)); + +#define RUN_RMSNORM_KERNEL(LOADS_PER_THREAD) \ + FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(&rmsNormLamport, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + smemSize)); \ + rnConfig.dynamicSmemBytes = smemSize; \ + FLASHINFER_CUDA_CALL( \ + cudaLaunchKernelEx(&rnConfig, &rmsNormLamport, residualOut, output, \ + bufferInput, gamma, static_cast(params.epsilon), residualIn, \ + numTokens, tokenDim, params.nRanks, params.bufferFlags)); + + T* residualOut = reinterpret_cast(params.residualOut); + T* output = reinterpret_cast(params.output); + T* bufferInput = reinterpret_cast(params.bufferPtrLocal); + T const* gamma = reinterpret_cast(params.gamma); + T const* residualIn = reinterpret_cast(params.residualIn); + if (rnUseCGA) { + RUN_RMSNORM_KERNEL(1); + } else { + switch (rnLoadsPerThread) { + case 1: + RUN_RMSNORM_KERNEL(1); + break; + case 2: + RUN_RMSNORM_KERNEL(2); + break; + case 3: + RUN_RMSNORM_KERNEL(3); + break; + case 4: + RUN_RMSNORM_KERNEL(4); + break; + case 5: + RUN_RMSNORM_KERNEL(5); + break; + case 6: + RUN_RMSNORM_KERNEL(6); + break; + case 7: + RUN_RMSNORM_KERNEL(7); + break; + case 8: + RUN_RMSNORM_KERNEL(8); + break; + default: + FLASHINFER_ERROR("[MNNVL AllReduceTwoShotRMSNorm] Unsupported loads_per_thread" + + std::to_string(rnLoadsPerThread) + + ". Supported sizes: {1, 2, 3, 4, 5, 6, 7, 8}"); + return cudaErrorInvalidValue; + } // switch (rnLoadsPerThread) + } // if (rnUseCGA) +#undef RUN_RMSNORM_KERNEL + + } // if (params.rmsNormFusion) + return cudaSuccess; +} } // namespace trtllm_mnnvl_allreduce } // namespace flashinfer diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index 0471bd1081..8481aabf39 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -335,6 +336,22 @@ inline std::pair GetCudaComputeCapability() { return std::make_pair(major, minor); } +// This function is thread-safe and cached the sm_count. +// But it will only check the current CUDA device, thus assuming each process handles single GPU. +inline int GetCudaMultiProcessorCount() { + static std::atomic sm_count{0}; + int cached = sm_count.load(std::memory_order_relaxed); + if (cached == 0) { + int device_id; + cudaGetDevice(&device_id); + cudaDeviceProp device_prop; + cudaGetDeviceProperties(&device_prop, device_id); + cached = device_prop.multiProcessorCount; + sm_count.store(cached, std::memory_order_relaxed); + } + return cached; +} + template inline void DebugPrintCUDAArray(T* device_ptr, size_t size, std::string prefix = "") { std::vector host_array(size); diff --git a/tests/comm/test_allreduce_unified_api.py b/tests/comm/test_allreduce_unified_api.py new file mode 100644 index 0000000000..7968750903 --- /dev/null +++ b/tests/comm/test_allreduce_unified_api.py @@ -0,0 +1,334 @@ +# Test for unified AllReduce API with multiple backends +# Run with: mpirun -np pytest tests/comm/test_allreduce_unified_api.py -vv -s +import os +import traceback +from typing import Tuple + +import pytest +import torch +import torch.distributed as dist +from mpi4py import MPI + +import flashinfer.comm.trtllm_mnnvl_ar as trtllm_mnnvl_ar + +# Unified API imports +from flashinfer.comm import ( + create_allreduce_fusion_workspace, + allreduce_fusion, + AllReduceFusionPattern, + AllReduceFusionWorkspace, +) + +# Use flashinfer.norm.rmsnorm as reference implementation. +from flashinfer.norm import rmsnorm + + +def init_torch_distributed_from_mpi(): + """Initialize torch.distributed using MPI rank info.""" + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + world_size = comm.Get_size() + + if dist.is_initialized(): + return + + # Set environment variables for torch.distributed + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "29500" + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + + dist.init_process_group( + backend="nccl", + rank=rank, + world_size=world_size, + ) + + +def cleanup_torch_distributed(): + """Cleanup torch.distributed if initialized.""" + if dist.is_initialized(): + dist.destroy_process_group() + + +@torch.inference_mode() +def run_allreduce_fusion_test( + x: torch.Tensor, + residual: torch.Tensor, + norm_weight: torch.Tensor, + eps: float, + rank: int, + fusion: bool, + reference_output: tuple[torch.Tensor, ...], + workspace: AllReduceFusionWorkspace, +): + """Test function using the unified API (create_allreduce_fusion_workspace + allreduce_fusion).""" + MPI.COMM_WORLD.barrier() + + def func( + input, + residual, + norm_weight, + eps, + enable_fusion, + workspace, + ): + # For both fused and unfused cases: + shape = input.shape + input = input.view(-1, shape[-1]) + use_pdl = True + + if enable_fusion: + trtllm_mnnvl_ar.mpi_barrier() + + # Use unified API + norm_out = torch.empty_like(input) + residual_out = torch.empty_like(input) + + allreduce_fusion( + input=input, + workspace=workspace, + pattern=AllReduceFusionPattern.kARResidualRMSNorm, + launch_with_pdl=use_pdl, + residual_out=residual_out, + norm_out=norm_out, + residual_in=residual.view(-1, shape[-1]), + rms_gamma=norm_weight, + rms_eps=eps, + ) + + return norm_out.view(shape), residual_out.view(shape) + + else: + # Use unified API for AllReduce only + output = torch.empty_like(input) + + allreduce_fusion( + input=input, + workspace=workspace, + pattern=AllReduceFusionPattern.kAllReduce, + launch_with_pdl=use_pdl, + output=output, + ) + return (output.view(shape),) + + output = func(x.clone(), residual.clone(), norm_weight, eps, fusion, workspace) + + assert output[0].shape == reference_output[0].shape + + if rank == 0: + print("output[0] (first 10 values):", output[0].flatten()[:10]) + print( + "reference_output[0] (first 10 values):", + reference_output[0].flatten()[:10], + ) + + if fusion: + print("output[1] (first 10 values):", output[1].flatten()[:10]) + print( + "reference_output[1] (first 10 values):", + reference_output[1].flatten()[:10], + ) + + torch.testing.assert_close( + output[0], + reference_output[0], + rtol=0.05, + atol=0.15, + ) + + if fusion: + torch.testing.assert_close( + output[1], + reference_output[1], + rtol=0.05, + atol=0.15, + ) + + +def prepare_test_data(seq_len: int, hidden_size: int, dtype: torch.dtype, fusion: bool): + """Prepare test data distributed across MPI ranks.""" + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + world_size = comm.Get_size() + if rank == 0: + x_full = torch.randn((world_size, seq_len, hidden_size), dtype=dtype) + residual = torch.randn((seq_len, hidden_size), dtype=dtype) + norm_weight = torch.randn((hidden_size,), dtype=dtype) + else: + x_full = None + residual = None + norm_weight = None + + # Use lowercase bcast() for Python object broadcasting + x_full = comm.bcast(x_full, root=0) + residual = comm.bcast(residual, root=0) + norm_weight = comm.bcast(norm_weight, root=0) + + x_full = x_full.cuda() + residual = residual.cuda() + norm_weight = norm_weight.cuda() + + x_local = x_full[rank, :, :] + reference_output: Tuple[torch.Tensor, ...] = None + if fusion: + # Fused case: AllReduce + Residual Add + RMS Norm + allreduce_result = torch.sum(x_full, dim=0) # AllReduce result + residual_out = allreduce_result + residual # Add residual + norm_out = rmsnorm( + residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False + ) + + reference_output = (norm_out, residual_out) + else: + # Non-fused case: Only AllReduce + allreduce_result = torch.sum(x_full, dim=0) # AllReduce result + reference_output = (allreduce_result,) + return (x_local, residual, norm_weight), reference_output + + +def run_allreduce_test( + monkeypatch, + seq_lens: list[int], + fusion: bool, + dtype: torch.dtype, + hidden_size: int, + backend: str, +): + """Core test logic for AllReduce operations using the unified API. + + Args: + monkeypatch: pytest monkeypatch fixture + seq_lens: List of sequence lengths to test + fusion: Whether to test fused allreduce+rmsnorm or just allreduce + dtype: Data type for tensors + hidden_size: Hidden dimension size + backend: Backend to use ("auto", "trtllm", "mnnvl") + """ + + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + world_size = comm.Get_size() + gpus_per_node = torch.cuda.device_count() + + if gpus_per_node == 0: + pytest.skip("AllReduce test requires at least one CUDA device per node") + if world_size < 2: + pytest.skip(f"This test requires at least 2 MPI ranks, got {world_size}") + + # Set CUDA device based on rank + local_rank = rank % gpus_per_node + torch.cuda.set_device(local_rank) + + # Initialize torch.distributed for trtllm backend (needed for IPC workspace) + # TODO: check if it is ok to do this with auto backend + process_group = None + if backend in ("trtllm", "auto"): + init_torch_distributed_from_mpi() + process_group = dist.group.WORLD + + if local_rank == 0: + print(f"Running AllReduce test with {world_size} ranks, backend={backend}") + print(f"Rank {rank} using GPU {torch.cuda.current_device()}") + + eps = 1e-5 + torch.manual_seed(42 + rank) + + workspace = None + + try: + # Create workspace using unified API + workspace = create_allreduce_fusion_workspace( + backend=backend, + world_size=world_size, + rank=rank, + max_token_num=max(seq_lens), + hidden_dim=hidden_size, + dtype=dtype, + topology="single_node", + gpus_per_node=gpus_per_node, + process_group=process_group, + ) + + print(f"Rank {rank}: Created workspace with backend={workspace.backend}") + + # Prepare test data for all sequence lengths + test_data = [] + for seq_len in seq_lens: + (x_local, residual, norm_weight), reference_output = prepare_test_data( + seq_len, hidden_size, dtype, fusion + ) + test_data.append( + (seq_len, x_local, residual, norm_weight, reference_output) + ) + + # Test each sequence length with the same workspace + for seq_len, x, residual, norm_weight, reference_output in test_data: + if rank == 0: + print( + f"Testing seq_len={seq_len}, hidden_size={hidden_size}, fusion={fusion}, dtype={dtype}" + ) + + run_allreduce_fusion_test( + x, + residual, + norm_weight, + eps, + rank, + fusion, + reference_output, + workspace, + ) + + # Synchronize before next test + trtllm_mnnvl_ar.mpi_barrier() + + print( + f"PASSED[rank={rank}]: seq_len={seq_len}, fusion={fusion}, dtype={dtype}, backend={backend}" + ) + + except Exception as e: + failure_message = f"FAILED[rank={rank}]: seq_lens={seq_lens}, fusion={fusion}, dtype={dtype}, backend={backend} failed: {e}" + print(failure_message) + print(traceback.format_exc()) + + # Gather failure status from all ranks for logging + all_failures = MPI.COMM_WORLD.allgather(True) + failed_ranks = [i for i, failed in enumerate(all_failures) if failed] + if rank == 0: + print(f"Test failed on ranks: {failed_ranks}") + + raise + + finally: + if workspace is not None: + workspace.destroy() + # Cleanup torch.distributed if we initialized it + if backend in ("trtllm", "auto"): + cleanup_torch_distributed() + + # Final synchronization + trtllm_mnnvl_ar.mpi_barrier() + + +@pytest.mark.parametrize( + "seq_lens", + [[1], [4], [15], [27, 11, 24, 256], [127], [998, 2048]], +) +@pytest.mark.parametrize("fusion", [False, True]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("hidden_size", [2880, 5120, 7168, 8192]) +@pytest.mark.parametrize("backend", ["auto", "trtllm", "mnnvl"]) +def test_allreduce_unified( + monkeypatch, + seq_lens: list[int], + fusion: bool, + dtype: torch.dtype, + hidden_size: int, + backend: str, +): + """Test AllReduce with unified API across different backends. + + Run with: mpirun -np pytest tests/comm/test_allreduce_unified_api.py -vv -s + """ + run_allreduce_test(monkeypatch, seq_lens, fusion, dtype, hidden_size, backend) diff --git a/tests/comm/test_trtllm_allreduce_fusion.py b/tests/comm/test_trtllm_allreduce_fusion.py index c3aa8c8252..dab4877fb9 100644 --- a/tests/comm/test_trtllm_allreduce_fusion.py +++ b/tests/comm/test_trtllm_allreduce_fusion.py @@ -22,7 +22,9 @@ SCALE_FACTOR_RANGE = (-1, 1) -def _run_correctness_worker(world_size, rank, dtype, hidden_dim, distributed_init_port): +def _run_correctness_worker( + world_size, rank, dtype, hidden_dim, distributed_init_port, legacy_api=True +): device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) distributed_init_method = f"tcp://localhost:{distributed_init_port}" @@ -57,18 +59,33 @@ def _run_correctness_worker(world_size, rank, dtype, hidden_dim, distributed_ini lamport_use_fp32 = dtype == torch.float32 - # create workspace for allreduce fusion with metadata - ipc_handles, workspace_tensor, workspace_metadata = ( - comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( - rank, - world_size, - MAX_TOKEN_NUM, - hidden_dim, - group=group, - use_fp32_lamport=lamport_use_fp32, - create_metadata=True, # Get metadata for validation + # Create workspace - choose between legacy and new API + if legacy_api: + # Legacy API: create workspace for allreduce fusion with metadata + ipc_handles, workspace_tensor, workspace_metadata = ( + comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( + rank, + world_size, + MAX_TOKEN_NUM, + hidden_dim, + group=group, + use_fp32_lamport=lamport_use_fp32, + create_metadata=True, # Get metadata for validation + ) + ) + else: + workspace = None + # New unified API: create workspace + workspace = comm.create_allreduce_fusion_workspace( + backend="trtllm", + world_size=world_size, + rank=rank, + max_token_num=MAX_TOKEN_NUM, + hidden_dim=hidden_dim, + dtype=dtype, + topology="single_node", + process_group=group, ) - ) test_loop = 5 @@ -163,60 +180,128 @@ def _run_correctness_worker(world_size, rank, dtype, hidden_dim, distributed_ini s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for _ in range(test_loop): - comm.trtllm_allreduce_fusion( - allreduce_in=allreduce_in, - world_size=world_size, - world_rank=rank, - token_num=token_num, - hidden_dim=hidden_dim, - workspace_ptrs=workspace_tensor, - launch_with_pdl=launch_with_pdl, - use_oneshot=use_oneshot, - trigger_completion_at_end=trigger_completion_at_end, - fp32_acc=fp32_acc, - pattern_code=pattern_code, - allreduce_out=all_reduce_out, - residual_in=residual_in, - residual_out=residual_out, - norm_out=norm_out, - quant_out=quant_out, - scale_out=scale_out, - rms_gamma=rms_gamma, - rms_eps=rms_eps, - scale_factor=scale_factor, - layout_code=swizzled_layout_code, - metadata=workspace_metadata, - ) + if legacy_api: + # Legacy API - uses flattened tensors + comm.trtllm_allreduce_fusion( + allreduce_in=allreduce_in, + world_size=world_size, + world_rank=rank, + token_num=token_num, + hidden_dim=hidden_dim, + workspace_ptrs=workspace_tensor, + launch_with_pdl=launch_with_pdl, + use_oneshot=use_oneshot, + trigger_completion_at_end=trigger_completion_at_end, + fp32_acc=fp32_acc, + pattern_code=pattern_code, + allreduce_out=all_reduce_out, + residual_in=residual_in, + residual_out=residual_out, + norm_out=norm_out, + quant_out=quant_out, + scale_out=scale_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_factor, + layout_code=swizzled_layout_code, + metadata=workspace_metadata, + ) + else: + # New unified API - expects 2D tensors [token_num, hidden_dim] + comm.allreduce_fusion( + input=allreduce_in.view( + token_num, hidden_dim + ), + workspace=workspace, + launch_with_pdl=launch_with_pdl, + output=all_reduce_out.view( + token_num, hidden_dim + ), + residual_in=residual_in.view( + token_num, hidden_dim + ), + residual_out=residual_out.view( + token_num, hidden_dim + ), + norm_out=norm_out.view( + token_num, hidden_dim + ), + quant_out=quant_out.view( + token_num, hidden_dim + ), + scale_out=scale_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_factor, + layout_code=swizzled_layout_code, + pattern=pattern_code, + use_oneshot=use_oneshot, + fp32_acc=fp32_acc, + ) # NOTE: in real case, you dont have to set all optional params. You could set those required by fusion pattern. # capture g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): for _ in range(test_loop): - comm.trtllm_allreduce_fusion( - allreduce_in=allreduce_in, - world_size=world_size, - world_rank=rank, - token_num=token_num, - hidden_dim=hidden_dim, - workspace_ptrs=workspace_tensor, - launch_with_pdl=launch_with_pdl, - use_oneshot=use_oneshot, - trigger_completion_at_end=trigger_completion_at_end, - fp32_acc=fp32_acc, - pattern_code=pattern_code, - allreduce_out=all_reduce_out, - residual_in=residual_in, - residual_out=residual_out, - norm_out=norm_out, - quant_out=quant_out, - scale_out=scale_out, - rms_gamma=rms_gamma, - rms_eps=rms_eps, - scale_factor=scale_factor, - layout_code=swizzled_layout_code, - metadata=workspace_metadata, - ) + if legacy_api: + # Legacy API - uses flattened tensors + comm.trtllm_allreduce_fusion( + allreduce_in=allreduce_in, + world_size=world_size, + world_rank=rank, + token_num=token_num, + hidden_dim=hidden_dim, + workspace_ptrs=workspace_tensor, + launch_with_pdl=launch_with_pdl, + use_oneshot=use_oneshot, + trigger_completion_at_end=trigger_completion_at_end, + fp32_acc=fp32_acc, + pattern_code=pattern_code, + allreduce_out=all_reduce_out, + residual_in=residual_in, + residual_out=residual_out, + norm_out=norm_out, + quant_out=quant_out, + scale_out=scale_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_factor, + layout_code=swizzled_layout_code, + metadata=workspace_metadata, + ) + else: + # New unified API - expects 2D tensors [token_num, hidden_dim] + comm.allreduce_fusion( + input=allreduce_in.view( + token_num, hidden_dim + ), + workspace=workspace, + launch_with_pdl=launch_with_pdl, + output=all_reduce_out.view( + token_num, hidden_dim + ), + residual_in=residual_in.view( + token_num, hidden_dim + ), + residual_out=residual_out.view( + token_num, hidden_dim + ), + norm_out=norm_out.view( + token_num, hidden_dim + ), + quant_out=quant_out.view( + token_num, hidden_dim + ), + scale_out=scale_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_factor, + layout_code=swizzled_layout_code, + pattern=pattern_code, + use_oneshot=use_oneshot, + fp32_acc=fp32_acc, + ) # replay g.replay() torch.cuda.synchronize() @@ -307,9 +392,14 @@ def _run_correctness_worker(world_size, rank, dtype, hidden_dim, distributed_ini finally: dist.barrier(group=group) - comm.trtllm_destroy_ipc_workspace_for_all_reduce_fusion( - ipc_handles, group=group - ) + # Destroy workspace - choose between legacy and new API + if legacy_api: + comm.trtllm_destroy_ipc_workspace_for_all_reduce_fusion( + ipc_handles, group=group + ) + elif workspace is not None: + # New unified API + workspace.destroy() dist.destroy_process_group(group=group) @@ -358,7 +448,8 @@ def multi_process_parallel( @pytest.mark.parametrize("world_size", [2, 4, 8]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("hidden_dim", [1024, 2048, 4096, 7168, 8192]) -def test_trtllm_allreduce_fusion(world_size, dtype, hidden_dim): +@pytest.mark.parametrize("legacy_api", [True, False]) +def test_trtllm_allreduce_fusion(world_size, dtype, hidden_dim, legacy_api): np.random.seed(42) torch.manual_seed(42) torch.cuda.manual_seed_all(42) @@ -367,17 +458,22 @@ def test_trtllm_allreduce_fusion(world_size, dtype, hidden_dim): pytest.skip( f"world_size {world_size} is greater than available_gpus {available_gpus}" ) - print(f"Running test for world_size={world_size}") + api_str = "legacy" if legacy_api else "unified" + print(f"Running test for world_size={world_size} with {api_str} API") multi_process_parallel( world_size, dtype, hidden_dim, _run_correctness_worker, - target_args=(), + target_args=(legacy_api,), ) - print(f"allreduce fusion tp = {world_size}: OK") + print(f"allreduce fusion tp = {world_size} ({api_str} API): OK") if __name__ == "__main__": - test_trtllm_allreduce_fusion(2, torch.float16, 1024) + # Test both legacy and unified APIs + print("Testing legacy API...") + test_trtllm_allreduce_fusion(2, torch.float16, 1024, legacy_api=True) + print("\nTesting unified API...") + test_trtllm_allreduce_fusion(2, torch.float16, 1024, legacy_api=False) diff --git a/tests/comm/test_trtllm_mnnvl_allreduce.py b/tests/comm/test_trtllm_mnnvl_allreduce.py index e7274c46f0..ce7880e406 100644 --- a/tests/comm/test_trtllm_mnnvl_allreduce.py +++ b/tests/comm/test_trtllm_mnnvl_allreduce.py @@ -1,4 +1,5 @@ # Check torch version: +import traceback from typing import Tuple, Optional import pytest @@ -7,7 +8,6 @@ import flashinfer.comm.trtllm_mnnvl_ar as trtllm_mnnvl_ar from flashinfer.comm.mapping import Mapping -from flashinfer.comm.mnnvl import CommBackend, MpiComm # Use flashinfer.norm.rmsnorm as reference implementation. from flashinfer.norm import rmsnorm @@ -15,6 +15,95 @@ @torch.inference_mode() def row_linear_residual_norm_fusion_forward( + x: torch.Tensor, + residual: torch.Tensor, + norm_weight: torch.Tensor, + eps: float, + mapping: Mapping, + fusion: bool, + reference_output: tuple[torch.Tensor, ...], + workspace: trtllm_mnnvl_ar.MNNVLAllReduceFusionWorkspace, +): + tensor_parallel_rank = mapping.tp_rank + MPI.COMM_WORLD.barrier() + + def func( + input, + residual, + norm_weight, + eps, + enable_fusion, + workspace, + ): + # For both fused and unfused cases: + shape = input.shape + input = input.view(-1, shape[-1]) + use_pdl = True + + if enable_fusion: + trtllm_mnnvl_ar.mpi_barrier() + + output, residual_out = ( + trtllm_mnnvl_ar.trtllm_mnnvl_fused_allreduce_add_rmsnorm( + input, + residual, + norm_weight, + workspace, + eps, + launch_with_pdl=use_pdl, + strategy=trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.AUTO, + ) + ) + + return output.view(shape), residual_out.view(shape) + + else: + output = torch.empty_like(input) + + output = trtllm_mnnvl_ar.trtllm_mnnvl_allreduce( + input, + workspace, + launch_with_pdl=use_pdl, + strategy=trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.AUTO, + ) + return (output.view(shape),) + + output = func(x.clone(), residual.clone(), norm_weight, eps, fusion, workspace) + + assert output[0].shape == reference_output[0].shape + + if tensor_parallel_rank == 0: + print("output[0] (first 10 values):", output[0].flatten()[:10]) + print( + "reference_output[0] (first 10 values):", + reference_output[0].flatten()[:10], + ) + + if fusion: + print("output[1] (first 10 values):", output[1].flatten()[:10]) + print( + "reference_output[1] (first 10 values):", + reference_output[1].flatten()[:10], + ) + + torch.testing.assert_close( + output[0], + reference_output[0], + rtol=0.05, + atol=0.15, + ) + + if fusion: + torch.testing.assert_close( + output[1], + reference_output[1], + rtol=0.05, + atol=0.15, + ) + + +@torch.inference_mode() +def row_linear_residual_norm_fusion_forward_legacy( x: torch.Tensor, residual: torch.Tensor, norm_weight: torch.Tensor, @@ -29,20 +118,10 @@ def row_linear_residual_norm_fusion_forward( unicast_ptr: int, max_num_elements_mnnvl: int, buffer_flags_mnnvl: torch.Tensor, - comm_backend_for_handle_transfer: Optional[CommBackend] = None, ): - x = x.cuda() - residual = residual.cuda() - norm_weight = norm_weight.cuda() - reference_output = tuple(t.cuda() for t in reference_output) - tensor_parallel_size = mapping.tp_size tensor_parallel_rank = mapping.tp_rank - if comm_backend_for_handle_transfer is None: - comm = MpiComm() - else: - comm = comm_backend_for_handle_transfer - comm.barrier() + MPI.COMM_WORLD.barrier() def func( input, @@ -57,11 +136,7 @@ def func( ): # For both fused and unfused cases: shape = input.shape - - assert max_num_elements_mnnvl % hidden_size == 0 - input = input.view(-1, shape[-1]) - buffer_M = max_num_elements_mnnvl // hidden_size if enable_fusion: @@ -155,13 +230,55 @@ def func( """Helper function to run the core MNNVL AllReduce test logic""" +def prepare_test_data(seq_len: int, hidden_size: int, dtype: torch.dtype, fusion: bool): + # Communicator used for passing data between ranks + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + world_size = comm.Get_size() + if rank == 0: + x_full = torch.randn((world_size, seq_len, hidden_size), dtype=dtype) + residual = torch.randn((seq_len, hidden_size), dtype=dtype) + norm_weight = torch.randn((hidden_size,), dtype=dtype) + else: + x_full = None + residual = None + norm_weight = None + + # Use lowercase bcast() for Python object broadcasting + x_full = comm.bcast(x_full, root=0) + residual = comm.bcast(residual, root=0) + norm_weight = comm.bcast(norm_weight, root=0) + + x_full = x_full.cuda() + residual = residual.cuda() + norm_weight = norm_weight.cuda() + + x_local = x_full[rank, :, :] + reference_output: Tuple[torch.Tensor, ...] = None + if fusion: + # Fused case: AllReduce + Residual Add + RMS Norm + allreduce_result = torch.sum(x_full, dim=0) # AllReduce result + residual_out = allreduce_result + residual # Add residual + norm_out = rmsnorm( + residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False + ) + + reference_output = (norm_out, residual_out) + else: + # Non-fused case: Only AllReduce + allreduce_result = torch.sum(x_full, dim=0) # AllReduce result + reference_output = (allreduce_result,) + return (x_local, residual, norm_weight), reference_output + + def run_mnnvl_ar_full( monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int, - explicit_workspace_bytes: int | None = None, + legacy_explicit_workspace_bytes: Optional[int] = None, + legacy_api: bool = False, ): """Core test logic for MNNVL AllReduce operations. @@ -173,17 +290,15 @@ def run_mnnvl_ar_full( hidden_size: Hidden dimension size explicit_workspace_bytes: If provided, use this workspace size instead of default """ - monkeypatch.setenv("TRTLLM_FORCE_MNNVL_AR", "1") # force multi-node allreduce. + comm = MPI.COMM_WORLD # Get MPI info - rank = MPI.COMM_WORLD.Get_rank() - world_size = MPI.COMM_WORLD.Get_size() + rank = comm.Get_rank() + world_size = comm.Get_size() gpus_per_node = torch.cuda.device_count() if gpus_per_node == 0: pytest.skip("MNNVL allreduce test requires at least one CUDA device per node") - - # Ensure we have exactly 2 ranks for this test if world_size < 2: pytest.skip(f"This test requires at least 2 MPI ranks, got {world_size}") @@ -204,90 +319,78 @@ def run_mnnvl_ar_full( print( f"[Node {mapping.node_rank}] Rank {rank} using GPU {torch.cuda.current_device()}" ) - - tensor_parallel_size = world_size eps = 1e-5 - torch.manual_seed(42) + torch.manual_seed(42 + rank) # Track if this rank failed rank_failed = False failure_message = "" try: - # Get workspace buffers using MPI rank - allocate once per seq_lens list and reuse within the list - # This workspace is sized for the maximum expected sequence length and can be reused within each list - # Each parameterized list gets its own fresh workspace allocation - mcast_buffer_mnnvl, buffer_flags_mnnvl, max_num_elements_mnnvl = ( - trtllm_mnnvl_ar.get_allreduce_mnnvl_workspace( - mapping, dtype, buffer_size_in_bytes=explicit_workspace_bytes + if legacy_api: + mcast_buffer_mnnvl, buffer_flags_mnnvl, max_num_elements_mnnvl = ( + trtllm_mnnvl_ar.get_allreduce_mnnvl_workspace( + mapping, dtype, buffer_size_in_bytes=legacy_explicit_workspace_bytes + ) ) - ) - multicast_ptr = mcast_buffer_mnnvl.get_multicast_ptr() - buffer_ptrs_dev = mcast_buffer_mnnvl.get_buffer_ptrs_dev() - unicast_ptr = mcast_buffer_mnnvl.mcast_device_memory.get_unicast_ptr( - mapping.tp_rank - ) - - # Test each sequence length with the same workspace (reusing allocated buffers within this list) - for seq_len in seq_lens: - if rank == 0: - print( - f"Testing seq_len={seq_len}, hidden_size={hidden_size}, fusion={fusion}, dtype={dtype}" - ) + multicast_ptr = mcast_buffer_mnnvl.get_multicast_ptr() + buffer_ptrs_dev = mcast_buffer_mnnvl.get_buffer_ptrs_dev() + unicast_ptr = mcast_buffer_mnnvl.mcast_device_memory.get_unicast_ptr( + mapping.tp_rank + ) - # Generate test data (same on all ranks due to same seed) - x_full = torch.randn( - (tensor_parallel_size, seq_len, hidden_size), + else: + workspace = trtllm_mnnvl_ar.MNNVLAllReduceFusionWorkspace( + mapping, + max_num_tokens=max(seq_lens), + hidden_dim=hidden_size, dtype=dtype, - device=torch.device("cuda"), ) - residual = torch.randn( - (seq_len, hidden_size), dtype=dtype, device=torch.device("cuda") + + test_data = [] + for seq_len in seq_lens: + (x_local, residual, norm_weight), reference_output = prepare_test_data( + seq_len, hidden_size, dtype, fusion ) - norm_weight = torch.randn( - (hidden_size,), dtype=dtype, device=torch.device("cuda") + test_data.append( + (seq_len, x_local, residual, norm_weight, reference_output) ) - # Each rank gets its slice of the input - x = x_full[rank, :, :] - - # Compute reference output based on fusion mode - reference_output: Tuple[torch.Tensor, ...] = None - if fusion: - # Fused case: AllReduce + Residual Add + RMS Norm - allreduce_result = torch.sum(x_full, dim=0) # AllReduce result - residual_out = allreduce_result + residual # Add residual + # Test each sequence length with the same workspace (reusing allocated buffers within this list) + for seq_len, x, residual, norm_weight, reference_output in test_data: + if rank == 0: print( - "Device of residual_out:{}, norm_weight:{}".format( - residual_out.device, norm_weight.device - ) + f"Testing seq_len={seq_len}, hidden_size={hidden_size}, fusion={fusion}, dtype={dtype}" + ) + if legacy_api: + row_linear_residual_norm_fusion_forward_legacy( + x, + residual, + norm_weight, + eps, + hidden_size, + dtype, + mapping, + fusion, + reference_output, + multicast_ptr, + buffer_ptrs_dev, + unicast_ptr, + max_num_elements_mnnvl, + buffer_flags_mnnvl, ) - norm_out = rmsnorm(residual_out, norm_weight, eps, enable_pdl=False) - - reference_output = (norm_out, residual_out) else: - # Non-fused case: Only AllReduce - allreduce_result = torch.sum(x_full, dim=0) # AllReduce result - reference_output = (allreduce_result,) - - # Run the test with the same workspace - row_linear_residual_norm_fusion_forward( - x, - residual, - norm_weight, - eps, - hidden_size, - dtype, - mapping, - fusion, - reference_output, - multicast_ptr, - buffer_ptrs_dev, - unicast_ptr, - max_num_elements_mnnvl, - buffer_flags_mnnvl, - ) + row_linear_residual_norm_fusion_forward( + x, + residual, + norm_weight, + eps, + mapping, + fusion, + reference_output, + workspace, + ) # Synchronize before next test trtllm_mnnvl_ar.mpi_barrier() @@ -300,6 +403,7 @@ def run_mnnvl_ar_full( rank_failed = True failure_message = f"FAILED[rank={rank}]: seq_lens={seq_lens}, fusion={fusion}, dtype={dtype} failed: {e}" print(failure_message) + print(traceback.format_exc()) # Gather failure status from all ranks for logging all_failures = MPI.COMM_WORLD.allgather(rank_failed) @@ -310,16 +414,16 @@ def run_mnnvl_ar_full( print(f"Test failed on ranks: {failed_ranks}") # Cleanup before re-raising - if "mcast_buffer_mnnvl" in locals(): - del mcast_buffer_mnnvl + if "workspace" in locals(): + del workspace # Re-raise the original exception so it can be caught by pytest.raises in negative tests raise finally: # Ensure cleanup happens for this list's workspace - if "mcast_buffer_mnnvl" in locals(): - del mcast_buffer_mnnvl + if "workspace" in locals(): + del workspace # Final synchronization and check for failures across all ranks trtllm_mnnvl_ar.mpi_barrier() @@ -330,79 +434,28 @@ def run_mnnvl_ar_full( @pytest.mark.parametrize( "seq_lens", - [ - [1], - [4], - [15], - [27, 11, 24], - [127], - ], + [[1], [4], [15], [27, 11, 24, 256], [127], [998, 2048]], ) @pytest.mark.parametrize("fusion", [False, True]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("hidden_size", [2048, 4096, 5120, 7168, 8192]) -def test_mnnvl_allreduce_default_workspace( +@pytest.mark.parametrize("hidden_size", [2880, 5120, 7168, 8192]) +def test_mnnvl_allreduce_refactored( monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int ): - """Test MNNVL AllReduce with default workspace size.""" - run_mnnvl_ar_full(monkeypatch, seq_lens, fusion, dtype, hidden_size) - - -"""Test with explicit workspace size""" + """Test MNNVL AllReduce with refactored API.""" + run_mnnvl_ar_full( + monkeypatch, seq_lens, fusion, dtype, hidden_size, legacy_api=False + ) -@pytest.mark.parametrize( - "seq_lens", - [ - [1, 4, 180], - ], -) +@pytest.mark.parametrize("seq_lens", [[1], [4], [15], [27, 11, 24], [127]]) @pytest.mark.parametrize("fusion", [False, True]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("hidden_size", [2048, 4096, 5120, 7168, 8192]) -def test_mnnvl_allreduce_explicit_workspace( +def test_mnnvl_allreduce_legacy( monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int ): - """Test MNNVL AllReduce with explicitly calculated workspace size.""" - # Calculate workspace to fit the maximum sequence length - # buffer shape: [3, 2, buffer_tokens, hidden_dim] - explicit_workspace_bytes = 3 * 2 * dtype.itemsize * hidden_size * max(seq_lens) + """Test MNNVL AllReduce with legacy API.""" run_mnnvl_ar_full( - monkeypatch, - seq_lens, - fusion, - dtype, - hidden_size, - explicit_workspace_bytes=explicit_workspace_bytes, + monkeypatch, seq_lens, fusion, dtype, hidden_size, legacy_api=True ) - - -"""Negative test: workspace too small""" - - -@pytest.mark.parametrize("fusion", [False, True]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("hidden_size", [2048, 4096]) -def test_mnnvl_allreduce_workspace_too_small( - monkeypatch, fusion: bool, dtype: torch.dtype, hidden_size: int -): - """Test that MNNVL AllReduce fails gracefully when workspace is too small.""" - # Use a large sequence length that won't fit in a small workspace - seq_len = 180 - - # Create a workspace that's too small (only enough for 10 tokens) - small_workspace_bytes = 3 * 2 * dtype.itemsize * hidden_size * 10 - - # Expect a ValueError with a message about buffer_M being too small - with pytest.raises((ValueError, RuntimeError)) as exc_info: - run_mnnvl_ar_full( - monkeypatch, - [seq_len], - fusion, - dtype, - hidden_size, - explicit_workspace_bytes=small_workspace_bytes, - ) - - # Verify the error message contains the expected text - assert "greater than the buffer_M" in str(exc_info.value)