From cd167ad836c50aa52a5b3d581a4bb51ba052002b Mon Sep 17 00:00:00 2001 From: HakJu Kim Date: Wed, 1 Oct 2025 19:32:12 +0900 Subject: [PATCH 01/30] vllm mori integration from moreh team Note that only low latency mode is available on mori (https://github.com/ROCm/mori) Co-authored-by: Inhyeok Bang Co-authored-by: Dongmin Ra Co-authored-by: Jimin Park Co-authored-by: Geonwoo Choi Signed-off-by: HakJu Kim --- vllm/compilation/fix_functionalization.py | 2 + .../device_communicators/all2all.py | 204 +++++++++++++++++- .../base_device_communicator.py | 2 + .../device_communicators/cuda_communicator.py | 4 + vllm/envs.py | 5 +- .../layers/fused_moe/__init__.py | 2 + .../layers/fused_moe/aiter_experts.py | 121 +++++++++++ .../model_executor/layers/fused_moe/config.py | 12 +- vllm/model_executor/layers/fused_moe/layer.py | 61 +++++- .../layers/fused_moe/modular_kernel.py | 11 +- .../layers/fused_moe/mori_prepare_finalize.py | 200 +++++++++++++++++ .../layers/fused_moe/rocm_aiter_fused_moe.py | 22 +- .../model_executor/layers/quantization/fp8.py | 40 +++- vllm/utils/__init__.py | 6 + 14 files changed, 667 insertions(+), 25 deletions(-) create mode 100644 vllm/model_executor/layers/fused_moe/aiter_experts.py create mode 100644 vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index 54403c1f7ca3..4851429d7720 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -53,12 +53,14 @@ def __call__(self, graph: torch.fx.Graph): # While functionalized, results at[1] and at[2] are scattered # back into mm_node. After de-functionalization, we can just # use mm_node directly. + mutated_args = {1: 'query', 2: 'key'} for idx, user in self.getitem_users(node).items(): for user_of_getitem in user.users: if is_func(user_of_getitem, torch.ops.aten.slice_scatter.default): user_of_getitem.replace_all_uses_with(mm_node) self._remove(user_of_getitem) + user.replace_all_uses_with(kwargs[mutated_args[idx]]) self._remove(user) self.insert_defunctionalized(graph, node) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index bb3fd657facd..e9ddb52577dd 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -9,7 +9,7 @@ from vllm.distributed import get_dp_group, get_ep_group from vllm.forward_context import get_forward_context from vllm.logger import init_logger -from vllm.utils import has_deep_ep, has_pplx +from vllm.utils import has_deep_ep, has_pplx, has_mori from vllm.utils.flashinfer import has_flashinfer_all2all from .base_device_communicator import All2AllManagerBase, Cache @@ -438,3 +438,205 @@ def cleanup(self): self.prepare_workspace_tensor = None self.mapping = None self.initialized = False + +class MoriAll2AllManager(All2AllManagerBase): + """ + All2All communication based on mori kernels. + """ + def __init__(self, cpu_group): + assert has_mori( + ), "mori not found. Please follow https://github.com/ROCm/mori/blob/main/README.md#installation to install mori." # noqa + + super().__init__(cpu_group) + self.handle_cache = Cache() + self.config = None + self._op_handles = {} # Cache for EpDispatchCombineOp instances + self._shmem_initialized = False + # Delay mori shmem initialization until first use + logger.debug(f"[rank {self.rank}] MoriAll2AllManager created, shmem will be initialized lazily") + + def _ensure_shmem_initialized(self): + """Ensure mori's shared memory system is initialized (lazy initialization)""" + if self._shmem_initialized: + return + + import mori.shmem + import torch.distributed as dist + + try: + # Wait for PyTorch distributed to be ready + if not dist.is_initialized(): + raise RuntimeError("PyTorch distributed not initialized yet") + + # Check if we have a valid backend + backend = dist.get_backend() + if backend is None: + raise RuntimeError("No valid distributed backend found") + + logger.debug(f"[rank {self.rank}] PyTorch distributed ready with backend: {backend}") + + current_group = self.cpu_group if self.cpu_group is not None else dist.group.WORLD + + # TODO(inhyeok): make group_name more reasonable + group_name = "default" + try: + import torch._C._distributed_c10d as c10d + + # Try to unregister first in case it exists + try: + c10d._unregister_process_group(group_name) + except: + pass + + # Register the current process group + c10d._register_process_group(group_name, current_group) + logger.debug(f"[rank {self.rank}] Registered process group '{group_name}'") + + # Initialize mori shmem with the registered group + mori.shmem.shmem_torch_process_group_init(group_name) + logger.debug(f"[rank {self.rank}] Torch process group shmem initialization successful") + self._shmem_initialized = True + return + + except Exception as torch_error: + logger.debug(f"[rank {self.rank}] Torch process group shmem init failed: {torch_error}") + + self._shmem_initialized = True + + except Exception as e: + logger.error(f"[rank {self.rank}] mori shmem initialization failed: {e}") + # Don't fail completely - mark as initialized to avoid retry loops + self._shmem_initialized = True + logger.warning(f"[rank {self.rank}] Continuing without mori shmem optimization") + + def _make_mori_config(self, max_num_tokens: int, num_local_experts: int, + experts_per_token: int, hidden_dim: int, + scale_dim: int, scale_type_size: int, + data_type: torch.dtype = torch.bfloat16, + quant_dtype: torch.dtype = None): + """Create mori EpDispatchCombineConfig""" + import mori.ops.dispatch_combine as mori_ops + from mori.ops.dispatch_combine import EpDispatchCombineKernelType + + # Determine data type size + dtype_to_size = { + torch.float32: 4, + torch.bfloat16: 2, + torch.float16: 2, + } + max_token_type_size = dtype_to_size.get(data_type, 2) + + config = mori_ops.EpDispatchCombineConfig( + data_type=data_type if quant_dtype is None else quant_dtype, + rank=self.rank, + world_size=self.world_size, + hidden_dim=hidden_dim, + max_num_inp_token_per_rank=max_num_tokens, + num_experts_per_rank=num_local_experts, + num_experts_per_token=experts_per_token, + + # Performance tuning parameters + # warp_num_per_block=8, + # block_num=80, + max_token_type_size=max_token_type_size, + + # Quantization support + scale_dim=scale_dim, + scale_type_size=scale_type_size, + + # Determine kernel type based on topology + kernel_type=(EpDispatchCombineKernelType.InterNode + if self.internode + else EpDispatchCombineKernelType.IntraNode) + ) + + return config + + def get_handle(self, kwargs): + """ + Get or create mori operation handle. + Args: + kwargs: Dictionary with keys: + - max_num_tokens: Maximum tokens per DP rank + - num_local_experts: Number of local experts + - experts_per_token: Number of experts per token (topk) + - hidden_dim: Hidden dimension size + - data_type: Tensor data type (optional, default bfloat16) + """ + # Ensure shmem is initialized before creating handles + self._ensure_shmem_initialized() + + import mori.ops.dispatch_combine as mori_ops + + # Extract parameters + max_num_tokens = kwargs.get('max_num_tokens') + num_local_experts = kwargs.get('num_local_experts') + experts_per_token = kwargs.get('experts_per_token') + hidden_dim = kwargs.get('hidden_dim') + data_type = kwargs.get('data_type', torch.bfloat16) + scale_dim = kwargs.get('scale_dim') + scale_type_size = kwargs.get('scale_type_size') + + # Validate required parameters + if any(param is None for param in [max_num_tokens, num_local_experts, + experts_per_token, hidden_dim]): + raise ValueError("Missing required parameters for mori handle creation") + + # Create cache key + cache_key = (max_num_tokens, num_local_experts, experts_per_token, + hidden_dim, data_type) + + # Check cache first + if cache_key in self._op_handles: + return self._op_handles[cache_key] + + # Create new mori configuration and operation + config = self._make_mori_config( + max_num_tokens=max_num_tokens, + num_local_experts=num_local_experts, + experts_per_token=experts_per_token, + hidden_dim=hidden_dim, + data_type=data_type, + scale_dim=scale_dim, + scale_type_size=scale_type_size, + ) + + # Create operation handle + op = mori_ops.EpDispatchCombineOp(config) + + # Cache the handle + self._op_handles[cache_key] = op + + logger.debug(f"[rank {self.dp_rank}] Created mori handle with config: " + f"tokens={max_num_tokens}, experts={num_local_experts}, " + f"topk={experts_per_token}, hidden={hidden_dim}") + + return op + + def dispatch(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + raise NotImplementedError + + def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + def destroy(self): + """Clean up mori resources""" + try: + # Clear operation handle cache + self._op_handles.clear() + + # Try to finalize mori shared memory if it was successfully initialized + if self._shmem_initialized: + try: + import mori.shmem + # Check if shmem is actually active before finalizing + mori.shmem.shmem_finalize() + logger.debug(f"[rank {self.dp_rank}] mori shmem finalized") + except Exception as shmem_error: + logger.debug(f"[rank {self.dp_rank}] shmem finalize failed (may not have been active): {shmem_error}") + + logger.debug(f"[rank {self.dp_rank}] mori resources cleaned up") + + except Exception as e: + logger.warning(f"[rank {self.dp_rank}] Error during mori cleanup: {e}") \ No newline at end of file diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index a42081fb0c15..413cbd8f2313 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -7,6 +7,8 @@ import torch import torch.distributed as dist from torch.distributed import ProcessGroup +from vllm.logger import init_logger +logger = init_logger(__name__) class Cache: diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 9c2bf51a813e..2342d4b13334 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -115,6 +115,10 @@ def __init__(self, self.all2all_manager = FlashInferAllToAllManager( self.cpu_group) logger.info("Using Flashinfer all2allv manager.") + elif all2all_backend == "mori": + from .all2all import MoriAll2AllManager + self.all2all_manager = MoriAll2AllManager(self.cpu_group) + logger.info("Using Mori all2all manager.") else: raise ValueError(f"Unknown all2all backend: {all2all_backend}") diff --git a/vllm/envs.py b/vllm/envs.py index 03a22e4b2c7e..294019406575 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -157,7 +157,7 @@ VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost" VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5600 - VLLM_ALL2ALL_BACKEND: Literal["naive", "pplx", + VLLM_ALL2ALL_BACKEND: Literal["naive", "pplx", "mori", "deepep_high_throughput", "deepep_low_latency", "allgather_reducescatter", @@ -1241,12 +1241,13 @@ def get_vllm_port() -> Optional[int]: # - "allgather_reducescatter": all2all implementation based on allgather and # reducescatter # - "pplx": use pplx kernels + # - "mori": use mori kernels (currently, only low-latency is supported) # - "deepep_high_throughput", use deepep high-throughput kernels # - "deepep_low_latency", use deepep low-latency kernels # - "flashinfer_all2allv", use flashinfer alltoallv kernels for mnnvl "VLLM_ALL2ALL_BACKEND": env_with_choices("VLLM_ALL2ALL_BACKEND", "allgather_reducescatter", - ["naive", "pplx", + ["naive", "pplx", "mori", "deepep_high_throughput", "deepep_low_latency", "allgather_reducescatter", diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 75f56cd01a4e..f2be03f61fba 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -11,6 +11,7 @@ FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize) from vllm.model_executor.layers.fused_moe.utils import activation_without_mul +from vllm.model_executor.layers.fused_moe.aiter_experts import AiterExperts from vllm.triton_utils import HAS_TRITON _config: Optional[dict[str, Any]] = None @@ -77,6 +78,7 @@ def get_config() -> Optional[dict[str, Any]]: "BatchedDeepGemmExperts", "TritonOrDeepGemmExperts", "BatchedTritonOrDeepGemmExperts", + "AiterExperts", ] else: # Some model classes directly use the custom ops. Add placeholders diff --git a/vllm/model_executor/layers/fused_moe/aiter_experts.py b/vllm/model_executor/layers/fused_moe/aiter_experts.py new file mode 100644 index 000000000000..4665a344da8c --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/aiter_experts.py @@ -0,0 +1,121 @@ +""" +Aiter-based expert processing for Mori integration. +""" + +from typing import Any, Optional + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + rocm_aiter_fused_experts, +) + + +class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute): + """ + Aiter-based expert processing that works with Mori dispatch/combine. + + This class bridges Mori's all2all communication with Aiter's optimized + expert computation kernels for AMD GPUs. + """ + + def __init__( + self, + max_num_tokens: int, + quant_config: FusedMoEQuantConfig = None, + ): + super().__init__( + quant_config=quant_config, + ) + self.max_num_tokens = max_num_tokens + + @property + def activation_formats( + self, + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + """Aiter expects Standard format for both input and output.""" + return ( + mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard, + ) + + def supports_chunking(self) -> bool: + """Aiter kernels support chunking.""" + return True + + def supports_expert_map(self) -> bool: + """Aiter kernels support expert mapping.""" + return True + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + """Aiter handles weight and reduce internally.""" + from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceNoOP, + ) + + return TopKWeightAndReduceNoOP() + + def workspace_shapes( + self, + a: torch.Tensor, + aq: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + """ + Aiter kernels manage memory internally, so minimal workspace is needed. + """ + # Return minimal shapes since Aiter handles memory internally + workspace2 = () # No intermediate workspace needed + output_shape = aq.shape + workspace13 = output_shape + workspace_dtype = a.dtype + return (workspace13, workspace2, output_shape, workspace_dtype) + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + ): + """ + Process expert computation using Aiter kernels. + Works with pre-dispatched tokens from Mori all2all. + """ + # Call Aiter fused MoE expert processing + result = rocm_aiter_fused_experts( + hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + expert_num_tokens=expert_tokens_meta.expert_num_tokens, + output_dtype=output.dtype, + quant_config=self.quant_config, + a1q_scale=a1q_scale, + ) + + # Copy result to output tensor + output.copy_(result) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 34bfe1c16aac..3686ebb78dc0 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -438,7 +438,8 @@ def fp8_w8a8_moe_quant_config( per_act_token_quant=per_act_token_quant, per_out_ch_quant=per_out_ch_quant, block_shape=block_shape) - +# from vllm.platforms import current_platform +# return FusedMoEQuantConfig.make(current_platform.fp8_dtype(), def int8_w8a8_moe_quant_config( w1_scale: torch.Tensor, @@ -618,6 +619,11 @@ def use_deepep_ll_kernels(self): return (self.use_all2all_kernels and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") + @property + def use_mori_kernels(self): + return (self.use_all2all_kernels + and envs.VLLM_ALL2ALL_BACKEND == "mori") + @staticmethod def make(tp_size_: int, dp_size_: int, vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig": @@ -794,6 +800,10 @@ def use_deepep_ht_kernels(self): def use_deepep_ll_kernels(self): return self.moe_parallel_config.use_deepep_ll_kernels + @property + def use_mori_kernels(self): + return self.moe_parallel_config.use_mori_kernels + @property def use_flashinfer_cutlass_kernels(self): """ diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 8de1d14d46b3..658556f0937a 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -41,7 +41,7 @@ from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum from vllm.utils import (cdiv, direct_register_custom_op, has_deep_ep, has_pplx, - round_up) + has_mori, round_up) from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.v1.worker.ubatching import dbo_current_ubatch_id @@ -56,6 +56,8 @@ from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE, DeepEPLLPrepareAndFinalize) + if has_mori(): + from .mori_prepare_finalize import MoriPrepareAndFinalize else: fused_experts = None # type: ignore FusedMoEPermuteExpertsUnpermute = None # type: ignore @@ -207,6 +209,43 @@ def _maybe_make_prepare_finalize( num_dispatchers=all2all_manager.world_size, use_fp8_dispatch=use_fp8_dispatch, ) + elif moe.use_mori_kernels: + use_fp8_dispatch = ( + quant_config is not None + and quant_config.quant_dtype == current_platform.fp8_dtype() + ) + scale_dim = 0 + scale_type_size = 0 + quant_dtype = None + if use_fp8_dispatch: + scale_dim = quant_config.scale_shape( + moe.max_num_tokens, + moe.hidden_dim, + )[-1] + scale_type_size = ( + torch.float32.itemsize + ) # aiter quantization uses float32 scale + quant_dtype = quant_config.quant_dtype + + all_to_all_args = dict( + max_num_tokens=moe.max_num_tokens, + num_local_experts=moe.num_local_experts, + experts_per_token=moe.experts_per_token, + hidden_dim=moe.hidden_dim, + data_type=moe.in_dtype, + quant_dtype=quant_dtype, + scale_dim=scale_dim, + scale_type_size=scale_type_size, + ) + handle = all2all_manager.get_handle(all_to_all_args) + + prepare_finalize = MoriPrepareAndFinalize( + handle, + max_num_tokens=moe.max_num_tokens, + num_local_experts=moe.num_local_experts, + num_dispatchers=all2all_manager.world_size, + use_fp8_dispatch=use_fp8_dispatch, + ) return prepare_finalize @@ -347,7 +386,14 @@ def select_gemm_impl( layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: assert self.moe_quant_config is not None - if (prepare_finalize.activation_format == + if self.moe.use_mori_kernels and is_rocm_aiter_moe_enabled(): + from vllm.model_executor.layers.fused_moe import AiterExperts + logger.debug("AiterExperts for Mori integration %s", self.moe) + return AiterExperts( + max_num_tokens=self.moe.max_num_tokens, + quant_config=self.moe_quant_config, + ) + elif (prepare_finalize.activation_format == FusedMoEActivationFormat.BatchedExperts): logger.debug("BatchedTritonExperts %s", self.moe) return BatchedTritonExperts( @@ -1147,6 +1193,7 @@ def __init__( # Does it really need a batched buffer? if (self.moe_parallel_config.use_pplx_kernels or self.moe_parallel_config.use_deepep_ll_kernels + or self.moe_parallel_config.use_mori_kernels or self.moe_config.use_flashinfer_cutlass_kernels): if vllm_config.parallel_config.enable_dbo: self.batched_hidden_states = torch.zeros( @@ -1215,6 +1262,10 @@ def use_deepep_ht_kernels(self): def use_deepep_ll_kernels(self): return self.moe_parallel_config.use_deepep_ll_kernels + @property + def use_mori_kernels(self): + return self.moe_parallel_config.use_mori_kernels + @property def use_flashinfer_cutlass_kernels(self): return (self.moe_quant_config is not None @@ -1792,7 +1843,7 @@ def must_reduce_shared_expert_outputs(self) -> bool: early. """ return (self.use_pplx_kernels or self.use_deepep_ht_kernels - or self.use_deepep_ll_kernels) + or self.use_deepep_ll_kernels or self.use_mori_kernels) def maybe_all_reduce_tensor_model_parallel( self, final_hidden_states: torch.Tensor): @@ -1800,7 +1851,7 @@ def maybe_all_reduce_tensor_model_parallel( The pplx combine kernel reduces across GPU ranks by default. """ if (self.use_pplx_kernels or self.use_deepep_ht_kernels - or self.use_deepep_ll_kernels): + or self.use_deepep_ll_kernels or self.use_mori_kernels): return final_hidden_states else: return tensor_model_parallel_all_reduce(final_hidden_states) @@ -1995,12 +2046,14 @@ def forward_impl( if (self.moe_parallel_config.use_pplx_kernels or self.moe_parallel_config.use_deepep_ll_kernels + or self.moe_parallel_config.use_mori_kernels or _use_flashinfer_cutlass_kernels): return self.forward_impl_chunked(hidden_states, router_logits) do_naive_dispatch_combine: bool = ( self.dp_size > 1 and not self.moe_parallel_config.use_deepep_ht_kernels + and not self.moe_parallel_config.use_mori_kernels and not self.moe_config.use_flashinfer_cutlass_kernels) # If there are shared experts but we are not using a modular kernel, the diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index b6afc8651e36..6910a3d661c9 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -690,9 +690,14 @@ def _do_fused_experts( workspace13 = buffers.workspace13.get(workspace13_shape, device=a1.device, dtype=workspace_dtype) - workspace2 = buffers.workspace2.get(workspace2_shape, - device=a1.device, - dtype=workspace_dtype) + # aiter does not require intermediate workspace + from vllm.model_executor.layers.fused_moe import AiterExperts + if isinstance(self.fused_experts, AiterExperts): + workspace2 = None + else: + workspace2 = self.workspace2_buffer.get(workspace2_shape, + device=a1.device, + dtype=workspace_dtype) assert fused_out is None or fused_out.shape == fused_out_shape, ( f"fused_out {fused_out.shape} but expected {fused_out_shape}") diff --git a/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py new file mode 100644 index 000000000000..e3c8f5337cfc --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py @@ -0,0 +1,200 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +mori prepare and finalize module for expert parallelism. +Migration from DeepEP to mori for AMD GPU support. +""" + +from typing import Any, Optional + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): + """ + Prepare/Finalize using mori kernels for AMD GPU expert parallelism. + + This class handles the dispatch and combine operations for expert parallelism + using the mori library, which provides optimized All2All communication + primitives for AMD GPUs. + """ + + def __init__( + self, + handle: Any, # mori EpDispatchCombineOp from MoriAll2AllManager + max_num_tokens: int, + num_local_experts: int, + num_dispatchers: int, + use_fp8_dispatch: bool = False, + ): + """ + Initialize MoriPrepareAndFinalize. + + Args: + handle: mori EpDispatchCombineOp instance from All2AllManager + max_num_tokens: Maximum number of tokens per rank + num_local_experts: Number of experts on this rank + num_dispatchers: Number of dispatcher ranks (world size) + use_fp8_dispatch: Whether to use FP8 quantization during dispatch + """ + super().__init__() + assert max_num_tokens > 0 + assert num_local_experts > 0 + + self.handle = handle # mori EpDispatchCombineOp + self.max_num_tokens = max_num_tokens + self.num_local_experts = num_local_experts + self.num_dispatchers_ = num_dispatchers + self.use_fp8_dispatch = use_fp8_dispatch + + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + def max_num_tokens_per_rank(self) -> Optional[int]: + return self.max_num_tokens + + def topk_indices_dtype(self) -> Optional[torch.dtype]: + return torch.int32 + + def num_dispatchers(self) -> int: + return self.num_dispatchers_ + + def prepare( + self, + a1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> tuple[ + torch.Tensor, + Optional[torch.Tensor], + Optional[mk.ExpertTokensMetadata], + Optional[torch.Tensor], + Optional[torch.Tensor], + ]: + """ + Prepare inputs for mori dispatch operation. + Supports pre-dispatch quantization to reduce communication overhead. + + Args: + a1: Input hidden states [num_tokens, hidden_dim] + a1_scale: Input activation scales + topk_weights: Top-k routing weights [num_experts, experts_per_token] + topk_ids: Top-k expert indices [num_experts, experts_per_token] + quant_config: Quantization config + + Returns: + Tuple of (dispatched_x, batched_scales, expert_tokens_meta, dispatch_indices, dispatch_weights) + where dispatched_x is in Standard format (2D tensor) + """ + try: + # Pre-dispatch quantization to reduce communication overhead + dispatch_input = a1 + scales = None + + if self.use_fp8_dispatch: + from aiter import get_hip_quant + from aiter import QuantType + + block_shape = quant_config.block_shape + if block_shape is not None: + assert not apply_router_weight_on_input, ( + "apply_router_weight_on_input is" + " not supported for block scaled moe" + ) + quant_type = QuantType.per_1x128 + else: + quant_type = QuantType.per_Token + + quant_func = get_hip_quant(quant_type) + + dispatch_input, scales = quant_func( + a1, + quant_dtype=quant_config.quant_dtype, + ) + + ( + dispatch_output, + dispatch_weights, + dispatch_scales, + dispatch_indices, + dispatch_recv_num_token, + ) = self.handle.dispatch( + input=dispatch_input, + weights=topk_weights, + scales=scales, + indices=topk_ids, + ) + + expert_tokens_meta = mk.ExpertTokensMetadata( + expert_num_tokens=dispatch_recv_num_token, + expert_num_tokens_cpu=None, + ) + + return ( + dispatch_output, + dispatch_scales, + expert_tokens_meta, + dispatch_indices, + dispatch_weights, + ) + + except Exception as e: + logger.error(f"mori dispatch failed: {e}") + raise RuntimeError(f"mori dispatch failed: {e}") from e + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + extra_finalize_args: Optional[dict] = None, + ) -> None: + """ + Finalize expert outputs using mori combine operation. + + Args: + output: Output tensor to write results [num_original_tokens, hidden_dim] + fused_expert_output: Expert output activations in Standard format (2D tensor) + topk_weights: Original top-k weights + topk_ids: Original top-k indices + """ + assert self.handle is not None + + num_original_tokens = output.size(0) # Original number of tokens + + try: + combined_output, combined_weights = self.handle.combine( + input=fused_expert_output, + weights=topk_weights, + indices=topk_ids, + ) + + output.copy_( + combined_output[:num_original_tokens], non_blocking=True + ) + + except Exception as e: + logger.error(f"mori combine failed: {e}") + raise RuntimeError(f"mori combine failed: {e}") from e + + def __repr__(self) -> str: + return ( + f"MoriPrepareAndFinalize(" + f"max_tokens={self.max_num_tokens}, " + f"num_local_experts={self.num_local_experts}, " + f"num_dispatchers={self.num_dispatchers_})" + ) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 2764af5fc532..c1249ca3dde5 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -188,16 +188,25 @@ def rocm_aiter_fused_moe_impl( w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + expert_num_tokens: Optional[torch.Tensor] = None, + output_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: from aiter import ActivationType, QuantType from aiter.fused_moe import fused_moe + # Check if input is already pre-quantized (from mori dispatch) + input_is_pre_quantized = (a1_scale is not None and + hidden_states.dtype == torch.float8_e4m3fnuz) + dtype = output_dtype if input_is_pre_quantized else None + activation = ActivationType(activation_method) quant_type = QuantType(quant_method) return fused_moe(hidden_states, w1, w2, topk_weight, topk_ids, expert_mask, activation, quant_type, doweight_stage1, w1_scale, - w2_scale, a1_scale, a2_scale) + w2_scale, a1_scale, a2_scale, + num_local_tokens=expert_num_tokens, + dtype=dtype) def rocm_aiter_fused_moe_fake( @@ -214,6 +223,8 @@ def rocm_aiter_fused_moe_fake( w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + expert_num_tokens: Optional[torch.Tensor] = None, + output_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -308,7 +319,10 @@ def rocm_aiter_fused_experts( activation: str = "silu", apply_router_weight_on_input: bool = False, expert_map: Optional[torch.Tensor] = None, + expert_num_tokens: Optional[torch.Tensor] = None, + output_dtype: Optional[torch.dtype] = None, quant_config: Optional[FusedMoEQuantConfig] = None, + a1q_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: if quant_config is None: quant_config = FUSED_MOE_UNQUANTIZED_CONFIG @@ -385,9 +399,11 @@ def rocm_aiter_fused_experts( activation_method=activation_method, w1_scale=quant_config.w1_scale, w2_scale=quant_config.w2_scale, - a1_scale=quant_config.a1_scale, + a1_scale=quant_config.a1_scale if a1q_scale is None else a1q_scale, a2_scale=quant_config.a2_scale, - doweight_stage1=apply_router_weight_on_input) + doweight_stage1=apply_router_weight_on_input, + expert_num_tokens=expert_num_tokens, + output_dtype=output_dtype) def rocm_aiter_topk_softmax(topk_weights: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 3ebb20de9996..2c67f85a023c 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -20,6 +20,8 @@ FusedMoEQuantConfig, fp8_w8a8_moe_quant_config) from vllm.model_executor.layers.fused_moe.layer import ( UnquantizedFusedMoEMethod) +from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + is_rocm_aiter_moe_enabled) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import QuantizationMethods @@ -617,7 +619,7 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, def process_weights_after_loading(self, layer: Module) -> None: # Lazy import to avoid importing triton too early. from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled, shuffle_weights) + shuffle_weights) self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() @@ -827,8 +829,7 @@ def process_weights_after_loading(self, layer: Module) -> None: def maybe_make_prepare_finalize( self) -> Optional[mk.FusedMoEPrepareAndFinalize]: - if (self.rocm_aiter_moe_enabled or self.use_marlin - or self.flashinfer_moe_backend + if (self.use_marlin or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM): return None elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: @@ -845,14 +846,18 @@ def select_gemm_impl( layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: from vllm.model_executor.layers.fused_moe import ( - BatchedTritonOrDeepGemmExperts, TritonOrDeepGemmExperts) - - assert not self.use_marlin and not self.rocm_aiter_moe_enabled, ( - "Marlin and ROCm AITER are not supported with all2all yet.") + AiterExperts, BatchedTritonOrDeepGemmExperts, TritonOrDeepGemmExperts) + assert not self.use_marlin, ( + "Marlin is not supported with all2all yet.") assert self.moe_quant_config is not None - - if (prepare_finalize.activation_format == + if self.moe.use_mori_kernels and is_rocm_aiter_moe_enabled(): + logger.debug("AiterExperts for Mori integration %s", self.moe) + return AiterExperts( + max_num_tokens=self.moe.max_num_tokens, + quant_config=self.moe_quant_config, + ) + elif (prepare_finalize.activation_format == FusedMoEActivationFormat.BatchedExperts): max_num_tokens_per_rank = ( prepare_finalize.max_num_tokens_per_rank()) @@ -1004,8 +1009,21 @@ def apply( # can override fused_experts or cutlass but not rocm or marlin. # topk_weights, topk_ids, zero_expert_result = select_result - - if self.rocm_aiter_moe_enabled: + if self.moe.use_mori_kernels: + common_kwargs = dict( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + return self.fused_experts(**common_kwargs) + elif self.rocm_aiter_moe_enabled: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 rocm_aiter_fused_experts) assert self.fused_experts is None diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 11d6686009b2..53acccc88974 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -3428,6 +3428,12 @@ def has_deep_gemm() -> bool: return _has_module("deep_gemm") +def has_mori() -> bool: + """Whether the optional `mori` package is available.""" + + return _has_module("mori") + + def has_triton_kernels() -> bool: """Whether the optional `triton_kernels` package is available.""" From 77c8a56d0077b6182b86387a23d527d1472116b5 Mon Sep 17 00:00:00 2001 From: HakJu Kim Date: Wed, 1 Oct 2025 21:47:21 +0900 Subject: [PATCH 02/30] Update vllm/model_executor/layers/fused_moe/modular_kernel.py caused by version difference Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: HakJu Kim Signed-off-by: HakJu Kim --- vllm/model_executor/layers/fused_moe/modular_kernel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 6910a3d661c9..90b1ff153c8b 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -695,7 +695,7 @@ def _do_fused_experts( if isinstance(self.fused_experts, AiterExperts): workspace2 = None else: - workspace2 = self.workspace2_buffer.get(workspace2_shape, + workspace2 = buffers.workspace2.get(workspace2_shape, device=a1.device, dtype=workspace_dtype) From ea17a69418bdb711799a7cec6a059a47f84c88b8 Mon Sep 17 00:00:00 2001 From: HakJu Kim Date: Wed, 1 Oct 2025 21:48:24 +0900 Subject: [PATCH 03/30] Update vllm/model_executor/layers/fused_moe/aiter_experts.py accepted suggestion Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: HakJu Kim Signed-off-by: HakJu Kim --- vllm/model_executor/layers/fused_moe/aiter_experts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/aiter_experts.py b/vllm/model_executor/layers/fused_moe/aiter_experts.py index 4665a344da8c..42a1c58182bb 100644 --- a/vllm/model_executor/layers/fused_moe/aiter_experts.py +++ b/vllm/model_executor/layers/fused_moe/aiter_experts.py @@ -111,7 +111,7 @@ def apply( activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, expert_map=expert_map, - expert_num_tokens=expert_tokens_meta.expert_num_tokens, + expert_num_tokens=expert_tokens_meta.expert_num_tokens if expert_tokens_meta is not None else None, output_dtype=output.dtype, quant_config=self.quant_config, a1q_scale=a1q_scale, From a00b38be3e8ade4a371a65c87d4cb419eeef0a1f Mon Sep 17 00:00:00 2001 From: HakJu Kim Date: Thu, 2 Oct 2025 11:08:54 +0900 Subject: [PATCH 04/30] applied pre-commit results and Read the Docs build results Signed-off-by: HakJu Kim --- .../device_communicators/all2all.py | 102 +++++++++++------- .../base_device_communicator.py | 2 + .../layers/fused_moe/aiter_experts.py | 26 ++--- vllm/model_executor/layers/fused_moe/layer.py | 26 ++--- .../layers/fused_moe/mori_prepare_finalize.py | 55 +++++----- .../layers/fused_moe/rocm_aiter_fused_moe.py | 20 +++- .../model_executor/layers/quantization/fp8.py | 7 +- 7 files changed, 141 insertions(+), 97 deletions(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index e9ddb52577dd..7548217dc9ba 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -9,7 +9,7 @@ from vllm.distributed import get_dp_group, get_ep_group from vllm.forward_context import get_forward_context from vllm.logger import init_logger -from vllm.utils import has_deep_ep, has_pplx, has_mori +from vllm.utils import has_deep_ep, has_mori, has_pplx from vllm.utils.flashinfer import has_flashinfer_all2all from .base_device_communicator import All2AllManagerBase, Cache @@ -439,13 +439,15 @@ def cleanup(self): self.mapping = None self.initialized = False + class MoriAll2AllManager(All2AllManagerBase): """ All2All communication based on mori kernels. """ + def __init__(self, cpu_group): assert has_mori( - ), "mori not found. Please follow https://github.com/ROCm/mori/blob/main/README.md#installation to install mori." # noqa + ), "Please install mori from https://github.com/ROCm/mori." super().__init__(cpu_group) self.handle_cache = Cache() @@ -453,10 +455,10 @@ def __init__(self, cpu_group): self._op_handles = {} # Cache for EpDispatchCombineOp instances self._shmem_initialized = False # Delay mori shmem initialization until first use - logger.debug(f"[rank {self.rank}] MoriAll2AllManager created, shmem will be initialized lazily") + logger.debug("[rank %s] MoriAll2AllManager created", self.rank) def _ensure_shmem_initialized(self): - """Ensure mori's shared memory system is initialized (lazy initialization)""" + """Initialize mori's shared memory system lazily""" if self._shmem_initialized: return @@ -473,45 +475,60 @@ def _ensure_shmem_initialized(self): if backend is None: raise RuntimeError("No valid distributed backend found") - logger.debug(f"[rank {self.rank}] PyTorch distributed ready with backend: {backend}") + logger.debug( + "[rank %s] PyTorch distributed ready with backend: %s", + self.rank, backend) - current_group = self.cpu_group if self.cpu_group is not None else dist.group.WORLD + current_group = (self.cpu_group if self.cpu_group is not None else + dist.group.WORLD) # TODO(inhyeok): make group_name more reasonable group_name = "default" try: + import contextlib + import torch._C._distributed_c10d as c10d # Try to unregister first in case it exists - try: + with contextlib.suppress(RuntimeError): c10d._unregister_process_group(group_name) - except: - pass # Register the current process group c10d._register_process_group(group_name, current_group) - logger.debug(f"[rank {self.rank}] Registered process group '{group_name}'") + logger.debug("[rank %s] Registered process group '%s'", + self.rank, group_name) # Initialize mori shmem with the registered group mori.shmem.shmem_torch_process_group_init(group_name) - logger.debug(f"[rank {self.rank}] Torch process group shmem initialization successful") + logger.debug( + "[rank %s] torch process group shmem init success", + self.rank) self._shmem_initialized = True return except Exception as torch_error: - logger.debug(f"[rank {self.rank}] Torch process group shmem init failed: {torch_error}") + logger.debug( + "[rank %s] torch process group shmem init failed: %s", + self.rank, torch_error) self._shmem_initialized = True except Exception as e: - logger.error(f"[rank {self.rank}] mori shmem initialization failed: {e}") + logger.error("[rank %s] mori shmem initialization failed: %s", + self.rank, e) # Don't fail completely - mark as initialized to avoid retry loops self._shmem_initialized = True - logger.warning(f"[rank {self.rank}] Continuing without mori shmem optimization") - - def _make_mori_config(self, max_num_tokens: int, num_local_experts: int, - experts_per_token: int, hidden_dim: int, - scale_dim: int, scale_type_size: int, + logger.warning( + "[rank %s] Continuing without mori shmem optimization", + self.rank) + + def _make_mori_config(self, + max_num_tokens: int, + num_local_experts: int, + experts_per_token: int, + hidden_dim: int, + scale_dim: int, + scale_type_size: int, data_type: torch.dtype = torch.bfloat16, quant_dtype: torch.dtype = None): """Create mori EpDispatchCombineConfig""" @@ -546,9 +563,8 @@ def _make_mori_config(self, max_num_tokens: int, num_local_experts: int, # Determine kernel type based on topology kernel_type=(EpDispatchCombineKernelType.InterNode - if self.internode - else EpDispatchCombineKernelType.IntraNode) - ) + if self.internode else + EpDispatchCombineKernelType.IntraNode)) return config @@ -578,13 +594,16 @@ def get_handle(self, kwargs): scale_type_size = kwargs.get('scale_type_size') # Validate required parameters - if any(param is None for param in [max_num_tokens, num_local_experts, - experts_per_token, hidden_dim]): - raise ValueError("Missing required parameters for mori handle creation") + if any( + param is None for param in + [max_num_tokens, num_local_experts, experts_per_token, hidden_dim + ]): + raise ValueError( + "Missing required parameters for mori handle creation") # Create cache key cache_key = (max_num_tokens, num_local_experts, experts_per_token, - hidden_dim, data_type) + hidden_dim, data_type) # Check cache first if cache_key in self._op_handles: @@ -607,17 +626,22 @@ def get_handle(self, kwargs): # Cache the handle self._op_handles[cache_key] = op - logger.debug(f"[rank {self.dp_rank}] Created mori handle with config: " - f"tokens={max_num_tokens}, experts={num_local_experts}, " - f"topk={experts_per_token}, hidden={hidden_dim}") + logger.debug( + "[rank %s] Created mori handle with config: tokens=%d, experts=%d," + " topk=%d, hidden_dim=%d", self.dp_rank, max_num_tokens, + num_local_experts, experts_per_token, hidden_dim) return op - def dispatch(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): + def dispatch(self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_sequence_parallel: bool = False): raise NotImplementedError - def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + def combine(self, + hidden_states: torch.Tensor, + is_sequence_parallel: bool = False): raise NotImplementedError def destroy(self): @@ -626,17 +650,23 @@ def destroy(self): # Clear operation handle cache self._op_handles.clear() - # Try to finalize mori shared memory if it was successfully initialized + # finalize mori shared memory if it was initialized if self._shmem_initialized: try: import mori.shmem + # Check if shmem is actually active before finalizing mori.shmem.shmem_finalize() - logger.debug(f"[rank {self.dp_rank}] mori shmem finalized") + logger.debug("[rank %s] mori shmem finalized", + self.dp_rank) except Exception as shmem_error: - logger.debug(f"[rank {self.dp_rank}] shmem finalize failed (may not have been active): {shmem_error}") + logger.debug( + "[rank %s] shmem finalize failed " + "(may not have been active): %s", self.dp_rank, + shmem_error) - logger.debug(f"[rank {self.dp_rank}] mori resources cleaned up") + logger.debug("[rank %s] mori resources cleaned up", self.dp_rank) except Exception as e: - logger.warning(f"[rank {self.dp_rank}] Error during mori cleanup: {e}") \ No newline at end of file + logger.warning("[rank %s] Error during mori cleanup: %s", + self.dp_rank, e) diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 413cbd8f2313..cec5a04e557b 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -7,7 +7,9 @@ import torch import torch.distributed as dist from torch.distributed import ProcessGroup + from vllm.logger import init_logger + logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/fused_moe/aiter_experts.py b/vllm/model_executor/layers/fused_moe/aiter_experts.py index 42a1c58182bb..b03f1fff9bf6 100644 --- a/vllm/model_executor/layers/fused_moe/aiter_experts.py +++ b/vllm/model_executor/layers/fused_moe/aiter_experts.py @@ -1,16 +1,19 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ Aiter-based expert processing for Mori integration. """ -from typing import Any, Optional +from typing import Optional import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - rocm_aiter_fused_experts, -) + rocm_aiter_fused_experts) +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceNoOP) class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute): @@ -24,11 +27,9 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, max_num_tokens: int, - quant_config: FusedMoEQuantConfig = None, + quant_config: FusedMoEQuantConfig, ): - super().__init__( - quant_config=quant_config, - ) + super().__init__(quant_config=quant_config, ) self.max_num_tokens = max_num_tokens @property @@ -51,10 +52,6 @@ def supports_expert_map(self) -> bool: def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: """Aiter handles weight and reduce internally.""" - from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceNoOP, - ) - return TopKWeightAndReduceNoOP() def workspace_shapes( @@ -101,6 +98,11 @@ def apply( Process expert computation using Aiter kernels. Works with pre-dispatched tokens from Mori all2all. """ + if expert_tokens_meta is not None: + expert_num_tokens = expert_tokens_meta.expert_num_tokens + else: + expert_num_tokens = None + # Call Aiter fused MoE expert processing result = rocm_aiter_fused_experts( hidden_states=hidden_states, @@ -111,7 +113,7 @@ def apply( activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, expert_map=expert_map, - expert_num_tokens=expert_tokens_meta.expert_num_tokens if expert_tokens_meta is not None else None, + expert_num_tokens=expert_num_tokens, output_dtype=output.dtype, quant_config=self.quant_config, a1q_scale=a1q_scale, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 658556f0937a..b040ab8aad17 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -40,8 +40,8 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum -from vllm.utils import (cdiv, direct_register_custom_op, has_deep_ep, has_pplx, - has_mori, round_up) +from vllm.utils import (cdiv, direct_register_custom_op, has_deep_ep, has_mori, + has_pplx, round_up) from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.v1.worker.ubatching import dbo_current_ubatch_id @@ -75,9 +75,12 @@ def _eplb_map_to_physical_and_record( if is_rocm_aiter_moe_enabled(): from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 - rocm_aiter_grouped_topk as grouped_topk) + rocm_aiter_grouped_topk) + grouped_topk_impl = rocm_aiter_grouped_topk else: from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk + grouped_topk_impl = grouped_topk + if current_platform.is_tpu(): from .moe_pallas import fused_moe as fused_moe_pallas else: @@ -210,21 +213,20 @@ def _maybe_make_prepare_finalize( use_fp8_dispatch=use_fp8_dispatch, ) elif moe.use_mori_kernels: - use_fp8_dispatch = ( - quant_config is not None - and quant_config.quant_dtype == current_platform.fp8_dtype() - ) + use_fp8_dispatch = (quant_config is not None + and quant_config.quant_dtype + == current_platform.fp8_dtype()) scale_dim = 0 scale_type_size = 0 quant_dtype = None if use_fp8_dispatch: + assert quant_config is not None scale_dim = quant_config.scale_shape( moe.max_num_tokens, moe.hidden_dim, )[-1] - scale_type_size = ( - torch.float32.itemsize - ) # aiter quantization uses float32 scale + scale_type_size = (torch.float32.itemsize + ) # aiter quantization uses float32 scale quant_dtype = quant_config.quant_dtype all_to_all_args = dict( @@ -394,7 +396,7 @@ def select_gemm_impl( quant_config=self.moe_quant_config, ) elif (prepare_finalize.activation_format == - FusedMoEActivationFormat.BatchedExperts): + FusedMoEActivationFormat.BatchedExperts): logger.debug("BatchedTritonExperts %s", self.moe) return BatchedTritonExperts( max_num_tokens=self.moe.max_num_tokens, @@ -1760,7 +1762,7 @@ def select_experts( if use_grouped_topk: assert topk_group is not None assert num_expert_group is not None - topk_weights, topk_ids = grouped_topk( + topk_weights, topk_ids = grouped_topk_impl( hidden_states=hidden_states, gating_output=router_logits, topk=top_k, diff --git a/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py index e3c8f5337cfc..f4953f307c0f 100644 --- a/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py @@ -10,8 +10,8 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig logger = init_logger(__name__) @@ -20,9 +20,9 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): """ Prepare/Finalize using mori kernels for AMD GPU expert parallelism. - This class handles the dispatch and combine operations for expert parallelism - using the mori library, which provides optimized All2All communication - primitives for AMD GPUs. + This class handles the dispatch and combine operations for + expert parallelism using the mori library, which provides optimized + All2All communication primitives for AMD GPUs. """ def __init__( @@ -76,11 +76,11 @@ def prepare( apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, ) -> tuple[ - torch.Tensor, - Optional[torch.Tensor], - Optional[mk.ExpertTokensMetadata], - Optional[torch.Tensor], - Optional[torch.Tensor], + torch.Tensor, + Optional[torch.Tensor], + Optional[mk.ExpertTokensMetadata], + Optional[torch.Tensor], + Optional[torch.Tensor], ]: """ Prepare inputs for mori dispatch operation. @@ -94,7 +94,8 @@ def prepare( quant_config: Quantization config Returns: - Tuple of (dispatched_x, batched_scales, expert_tokens_meta, dispatch_indices, dispatch_weights) + Tuple of (dispatched_x, batched_scales, expert_tokens_meta, + dispatch_indices, dispatch_weights) where dispatched_x is in Standard format (2D tensor) """ try: @@ -103,15 +104,13 @@ def prepare( scales = None if self.use_fp8_dispatch: - from aiter import get_hip_quant - from aiter import QuantType + from aiter import QuantType, get_hip_quant block_shape = quant_config.block_shape if block_shape is not None: assert not apply_router_weight_on_input, ( "apply_router_weight_on_input is" - " not supported for block scaled moe" - ) + " not supported for block scaled moe") quant_type = QuantType.per_1x128 else: quant_type = QuantType.per_Token @@ -150,8 +149,8 @@ def prepare( ) except Exception as e: - logger.error(f"mori dispatch failed: {e}") - raise RuntimeError(f"mori dispatch failed: {e}") from e + logger.error("mori dispatch failed: %s", e) + raise RuntimeError("mori dispatch failed: %s", e) from e def finalize( self, @@ -167,8 +166,10 @@ def finalize( Finalize expert outputs using mori combine operation. Args: - output: Output tensor to write results [num_original_tokens, hidden_dim] - fused_expert_output: Expert output activations in Standard format (2D tensor) + output: Output tensor to write results [num_original_tokens, + hidden_dim] + fused_expert_output: Expert output activations in Standard format + (2D tensor) topk_weights: Original top-k weights topk_ids: Original top-k indices """ @@ -183,18 +184,14 @@ def finalize( indices=topk_ids, ) - output.copy_( - combined_output[:num_original_tokens], non_blocking=True - ) + output.copy_(combined_output[:num_original_tokens], + non_blocking=True) except Exception as e: - logger.error(f"mori combine failed: {e}") - raise RuntimeError(f"mori combine failed: {e}") from e + logger.error("mori combine failed: %s", e) + raise RuntimeError("mori combine failed: %s", e) from e def __repr__(self) -> str: - return ( - f"MoriPrepareAndFinalize(" - f"max_tokens={self.max_num_tokens}, " - f"num_local_experts={self.num_local_experts}, " - f"num_dispatchers={self.num_dispatchers_})" - ) + return (f"MoriPrepareAndFinalize(max_tokens={self.max_num_tokens}, " + f"num_local_experts={self.num_local_experts}, " + f"num_dispatchers={self.num_dispatchers_})") diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index c1249ca3dde5..3a081f91866f 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -195,16 +195,26 @@ def rocm_aiter_fused_moe_impl( from aiter.fused_moe import fused_moe # Check if input is already pre-quantized (from mori dispatch) - input_is_pre_quantized = (a1_scale is not None and - hidden_states.dtype == torch.float8_e4m3fnuz) + input_is_pre_quantized = (a1_scale is not None + and hidden_states.dtype == torch.float8_e4m3fnuz) dtype = output_dtype if input_is_pre_quantized else None activation = ActivationType(activation_method) quant_type = QuantType(quant_method) - return fused_moe(hidden_states, w1, w2, topk_weight, topk_ids, expert_mask, - activation, quant_type, doweight_stage1, w1_scale, - w2_scale, a1_scale, a2_scale, + return fused_moe(hidden_states, + w1, + w2, + topk_weight, + topk_ids, + expert_mask, + activation, + quant_type, + doweight_stage1, + w1_scale, + w2_scale, + a1_scale, + a2_scale, num_local_tokens=expert_num_tokens, dtype=dtype) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 2c67f85a023c..c0fefbd8f660 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -846,7 +846,8 @@ def select_gemm_impl( layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: from vllm.model_executor.layers.fused_moe import ( - AiterExperts, BatchedTritonOrDeepGemmExperts, TritonOrDeepGemmExperts) + AiterExperts, BatchedTritonOrDeepGemmExperts, + TritonOrDeepGemmExperts) assert not self.use_marlin, ( "Marlin is not supported with all2all yet.") @@ -858,7 +859,7 @@ def select_gemm_impl( quant_config=self.moe_quant_config, ) elif (prepare_finalize.activation_format == - FusedMoEActivationFormat.BatchedExperts): + FusedMoEActivationFormat.BatchedExperts): max_num_tokens_per_rank = ( prepare_finalize.max_num_tokens_per_rank()) assert max_num_tokens_per_rank is not None @@ -1009,7 +1010,7 @@ def apply( # can override fused_experts or cutlass but not rocm or marlin. # topk_weights, topk_ids, zero_expert_result = select_result - if self.moe.use_mori_kernels: + if self.moe.use_mori_kernels and self.fused_experts: common_kwargs = dict( hidden_states=x, w1=layer.w13_weight, From 07095ff84f81f9d935c905a9c3ceeb556b080786 Mon Sep 17 00:00:00 2001 From: HakJu Kim Date: Thu, 2 Oct 2025 11:21:05 +0900 Subject: [PATCH 05/30] applied few suggestions from code-assistant Signed-off-by: HakJu Kim --- vllm/distributed/device_communicators/all2all.py | 6 ++++-- .../layers/fused_moe/mori_prepare_finalize.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 7548217dc9ba..d2806f10da15 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -510,8 +510,10 @@ def _ensure_shmem_initialized(self): logger.debug( "[rank %s] torch process group shmem init failed: %s", self.rank, torch_error) - - self._shmem_initialized = True + self._shmem_initialized = True + logger.warning( + "[rank %s] Continuing without mori shmem optimization", + self.rank) except Exception as e: logger.error("[rank %s] mori shmem initialization failed: %s", diff --git a/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py index f4953f307c0f..e71e693c2d18 100644 --- a/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py @@ -88,9 +88,9 @@ def prepare( Args: a1: Input hidden states [num_tokens, hidden_dim] - a1_scale: Input activation scales topk_weights: Top-k routing weights [num_experts, experts_per_token] topk_ids: Top-k expert indices [num_experts, experts_per_token] + apply_router_weight_on_input: Whether to apply router weight quant_config: Quantization config Returns: From 26ec16dde6ad0f0bb1d3b055cc0eced0d988cd93 Mon Sep 17 00:00:00 2001 From: HakJu Kim Date: Thu, 2 Oct 2025 11:41:50 +0900 Subject: [PATCH 06/30] applied pre-commit result from github Signed-off-by: HakJu Kim --- vllm/model_executor/layers/fused_moe/layer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index b040ab8aad17..34ddf12ea97f 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -221,10 +221,12 @@ def _maybe_make_prepare_finalize( quant_dtype = None if use_fp8_dispatch: assert quant_config is not None - scale_dim = quant_config.scale_shape( + temp = quant_config.scale_shape( moe.max_num_tokens, moe.hidden_dim, - )[-1] + ) + if temp is not None: + scale_dim = temp[-1] scale_type_size = (torch.float32.itemsize ) # aiter quantization uses float32 scale quant_dtype = quant_config.quant_dtype From 5c997be733bc121e20e951fd89890ab4bd26e205 Mon Sep 17 00:00:00 2001 From: HakJu Kim Date: Thu, 2 Oct 2025 11:52:35 +0900 Subject: [PATCH 07/30] applied some pre-commit results from github Signed-off-by: HakJu Kim --- vllm/model_executor/layers/fused_moe/__init__.py | 2 +- vllm/model_executor/layers/fused_moe/config.py | 2 -- vllm/model_executor/layers/fused_moe/modular_kernel.py | 4 ++-- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index f2be03f61fba..0f595a303945 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -4,6 +4,7 @@ from contextlib import contextmanager from typing import Any, Optional +from vllm.model_executor.layers.fused_moe.aiter_experts import AiterExperts from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) @@ -11,7 +12,6 @@ FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize) from vllm.model_executor.layers.fused_moe.utils import activation_without_mul -from vllm.model_executor.layers.fused_moe.aiter_experts import AiterExperts from vllm.triton_utils import HAS_TRITON _config: Optional[dict[str, Any]] = None diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 3686ebb78dc0..e908022089ef 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -438,8 +438,6 @@ def fp8_w8a8_moe_quant_config( per_act_token_quant=per_act_token_quant, per_out_ch_quant=per_out_ch_quant, block_shape=block_shape) -# from vllm.platforms import current_platform -# return FusedMoEQuantConfig.make(current_platform.fp8_dtype(), def int8_w8a8_moe_quant_config( w1_scale: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 90b1ff153c8b..139e36da493d 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -696,8 +696,8 @@ def _do_fused_experts( workspace2 = None else: workspace2 = buffers.workspace2.get(workspace2_shape, - device=a1.device, - dtype=workspace_dtype) + device=a1.device, + dtype=workspace_dtype) assert fused_out is None or fused_out.shape == fused_out_shape, ( f"fused_out {fused_out.shape} but expected {fused_out_shape}") From 9849fa6f6913c2a1f38821bac2587d82a2c9c556 Mon Sep 17 00:00:00 2001 From: HakJu Kim Date: Thu, 2 Oct 2025 11:59:21 +0900 Subject: [PATCH 08/30] hope this is last for pre-commit Signed-off-by: HakJu Kim --- vllm/model_executor/layers/fused_moe/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index e908022089ef..74f7644fbaaf 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -439,6 +439,7 @@ def fp8_w8a8_moe_quant_config( per_out_ch_quant=per_out_ch_quant, block_shape=block_shape) + def int8_w8a8_moe_quant_config( w1_scale: torch.Tensor, w2_scale: torch.Tensor, From f4fb63af6b2f44aeaa019b1b7c143fb049f1a539 Mon Sep 17 00:00:00 2001 From: HakJu Kim Date: Thu, 2 Oct 2025 18:30:14 +0900 Subject: [PATCH 09/30] removing unrelated change Signed-off-by: HakJu Kim --- vllm/compilation/fix_functionalization.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index 4851429d7720..54403c1f7ca3 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -53,14 +53,12 @@ def __call__(self, graph: torch.fx.Graph): # While functionalized, results at[1] and at[2] are scattered # back into mm_node. After de-functionalization, we can just # use mm_node directly. - mutated_args = {1: 'query', 2: 'key'} for idx, user in self.getitem_users(node).items(): for user_of_getitem in user.users: if is_func(user_of_getitem, torch.ops.aten.slice_scatter.default): user_of_getitem.replace_all_uses_with(mm_node) self._remove(user_of_getitem) - user.replace_all_uses_with(kwargs[mutated_args[idx]]) self._remove(user) self.insert_defunctionalized(graph, node) From 739b489395065b9817dae4d9a3e3dd7c523eb160 Mon Sep 17 00:00:00 2001 From: HakJu Kim Date: Mon, 6 Oct 2025 11:59:13 +0900 Subject: [PATCH 10/30] applied pre-commit results after merging main to this branch Signed-off-by: HakJu Kim --- .../device_communicators/all2all.py | 148 +++++++++++------- .../device_communicators/cuda_communicator.py | 1 + .../layers/fused_moe/aiter_experts.py | 10 +- .../model_executor/layers/fused_moe/config.py | 6 +- vllm/model_executor/layers/fused_moe/layer.py | 38 +++-- .../layers/fused_moe/modular_kernel.py | 1 + .../layers/fused_moe/mori_prepare_finalize.py | 27 ++-- .../layers/fused_moe/rocm_aiter_fused_moe.py | 9 +- .../model_executor/layers/quantization/fp8.py | 13 +- 9 files changed, 161 insertions(+), 92 deletions(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index ae29daba0d71..90e19fe2173f 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -475,8 +475,7 @@ class MoriAll2AllManager(All2AllManagerBase): """ def __init__(self, cpu_group): - assert has_mori( - ), "Please install mori from https://github.com/ROCm/mori." + assert has_mori(), "Please install mori from ROCm/mori github." super().__init__(cpu_group) self.handle_cache = Cache() @@ -506,10 +505,14 @@ def _ensure_shmem_initialized(self): logger.debug( "[rank %s] PyTorch distributed ready with backend: %s", - self.rank, backend) + self.rank, + backend + ) - current_group = (self.cpu_group if self.cpu_group is not None else - dist.group.WORLD) + current_group = ( + self.cpu_group if self.cpu_group is not None + else dist.group.WORLD + ) # TODO(inhyeok): make group_name more reasonable group_name = "default" @@ -524,44 +527,56 @@ def _ensure_shmem_initialized(self): # Register the current process group c10d._register_process_group(group_name, current_group) - logger.debug("[rank %s] Registered process group '%s'", - self.rank, group_name) + logger.debug( + "[rank %s] Registered process group '%s'", + self.rank, + group_name + ) # Initialize mori shmem with the registered group mori.shmem.shmem_torch_process_group_init(group_name) logger.debug( "[rank %s] torch process group shmem init success", - self.rank) + self.rank + ) self._shmem_initialized = True return except Exception as torch_error: logger.debug( "[rank %s] torch process group shmem init failed: %s", - self.rank, torch_error) + self.rank, + torch_error + ) self._shmem_initialized = True logger.warning( "[rank %s] Continuing without mori shmem optimization", - self.rank) + self.rank + ) except Exception as e: - logger.error("[rank %s] mori shmem initialization failed: %s", - self.rank, e) + logger.error( + "[rank %s] mori shmem initialization failed: %s", + self.rank, e + ) # Don't fail completely - mark as initialized to avoid retry loops self._shmem_initialized = True logger.warning( "[rank %s] Continuing without mori shmem optimization", - self.rank) - - def _make_mori_config(self, - max_num_tokens: int, - num_local_experts: int, - experts_per_token: int, - hidden_dim: int, - scale_dim: int, - scale_type_size: int, - data_type: torch.dtype = torch.bfloat16, - quant_dtype: torch.dtype = None): + self.rank + ) + + def _make_mori_config( + self, + max_num_tokens: int, + num_local_experts: int, + experts_per_token: int, + hidden_dim: int, + scale_dim: int, + scale_type_size: int, + data_type: torch.dtype = torch.bfloat16, + quant_dtype: torch.dtype = None, + ): """Create mori EpDispatchCombineConfig""" import mori.ops.dispatch_combine as mori_ops from mori.ops.dispatch_combine import EpDispatchCombineKernelType @@ -582,20 +597,20 @@ def _make_mori_config(self, max_num_inp_token_per_rank=max_num_tokens, num_experts_per_rank=num_local_experts, num_experts_per_token=experts_per_token, - # Performance tuning parameters # warp_num_per_block=8, # block_num=80, max_token_type_size=max_token_type_size, - # Quantization support scale_dim=scale_dim, scale_type_size=scale_type_size, - # Determine kernel type based on topology - kernel_type=(EpDispatchCombineKernelType.InterNode - if self.internode else - EpDispatchCombineKernelType.IntraNode)) + kernel_type=( + EpDispatchCombineKernelType.InterNode + if self.internode else + EpDispatchCombineKernelType.IntraNode + ), + ) return config @@ -616,25 +631,36 @@ def get_handle(self, kwargs): import mori.ops.dispatch_combine as mori_ops # Extract parameters - max_num_tokens = kwargs.get('max_num_tokens') - num_local_experts = kwargs.get('num_local_experts') - experts_per_token = kwargs.get('experts_per_token') - hidden_dim = kwargs.get('hidden_dim') - data_type = kwargs.get('data_type', torch.bfloat16) - scale_dim = kwargs.get('scale_dim') - scale_type_size = kwargs.get('scale_type_size') + max_num_tokens = kwargs.get("max_num_tokens") + num_local_experts = kwargs.get("num_local_experts") + experts_per_token = kwargs.get("experts_per_token") + hidden_dim = kwargs.get("hidden_dim") + data_type = kwargs.get("data_type", torch.bfloat16) + scale_dim = kwargs.get("scale_dim") + scale_type_size = kwargs.get("scale_type_size") # Validate required parameters if any( - param is None for param in - [max_num_tokens, num_local_experts, experts_per_token, hidden_dim - ]): + param is None + for param in [ + max_num_tokens, + num_local_experts, + experts_per_token, + hidden_dim, + ] + ): raise ValueError( - "Missing required parameters for mori handle creation") + "Missing required parameters for mori handle creation" + ) # Create cache key - cache_key = (max_num_tokens, num_local_experts, experts_per_token, - hidden_dim, data_type) + cache_key = ( + max_num_tokens, + num_local_experts, + experts_per_token, + hidden_dim, + data_type, + ) # Check cache first if cache_key in self._op_handles: @@ -659,20 +685,27 @@ def get_handle(self, kwargs): logger.debug( "[rank %s] Created mori handle with config: tokens=%d, experts=%d," - " topk=%d, hidden_dim=%d", self.dp_rank, max_num_tokens, - num_local_experts, experts_per_token, hidden_dim) + " topk=%d, hidden_dim=%d", + self.dp_rank, + max_num_tokens, + num_local_experts, + experts_per_token, + hidden_dim + ) return op def dispatch(self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - is_sequence_parallel: bool = False): + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_sequence_parallel: bool = False, + ): raise NotImplementedError def combine(self, - hidden_states: torch.Tensor, - is_sequence_parallel: bool = False): + hidden_states: torch.Tensor, + is_sequence_parallel: bool = False, + ): raise NotImplementedError def destroy(self): @@ -688,16 +721,21 @@ def destroy(self): # Check if shmem is actually active before finalizing mori.shmem.shmem_finalize() - logger.debug("[rank %s] mori shmem finalized", - self.dp_rank) + logger.debug( + "[rank %s] mori shmem finalized", + self.dp_rank + ) except Exception as shmem_error: logger.debug( "[rank %s] shmem finalize failed " - "(may not have been active): %s", self.dp_rank, - shmem_error) + "(may not have been active): %s", + self.dp_rank, + shmem_error + ) logger.debug("[rank %s] mori resources cleaned up", self.dp_rank) except Exception as e: - logger.warning("[rank %s] Error during mori cleanup: %s", - self.dp_rank, e) + logger.warning( + "[rank %s] Error during mori cleanup: %s", self.dp_rank, e + ) diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 29b36dd9157e..0293d1a15760 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -124,6 +124,7 @@ def __init__( logger.info("Using Flashinfer all2allv manager.") elif all2all_backend == "mori": from .all2all import MoriAll2AllManager + self.all2all_manager = MoriAll2AllManager(self.cpu_group) logger.info("Using Mori all2all manager.") else: diff --git a/vllm/model_executor/layers/fused_moe/aiter_experts.py b/vllm/model_executor/layers/fused_moe/aiter_experts.py index b03f1fff9bf6..0ff6b1229486 100644 --- a/vllm/model_executor/layers/fused_moe/aiter_experts.py +++ b/vllm/model_executor/layers/fused_moe/aiter_experts.py @@ -11,9 +11,11 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - rocm_aiter_fused_experts) + rocm_aiter_fused_experts, +) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceNoOP) + TopKWeightAndReduceNoOP, +) class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute): @@ -29,7 +31,9 @@ def __init__( max_num_tokens: int, quant_config: FusedMoEQuantConfig, ): - super().__init__(quant_config=quant_config, ) + super().__init__( + quant_config=quant_config, + ) self.max_num_tokens = max_num_tokens @property diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 0ebf97a91c05..8d5d8a8a3fd4 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -627,8 +627,10 @@ def use_deepep_ll_kernels(self): @property def use_mori_kernels(self): - return (self.use_all2all_kernels - and envs.VLLM_ALL2ALL_BACKEND == "mori") + return ( + self.use_all2all_kernels + and envs.VLLM_ALL2ALL_BACKEND == "mori" + ) @staticmethod def make( diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 5332b4538c5c..53b680a749c2 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -49,8 +49,14 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum -from vllm.utils import (cdiv, direct_register_custom_op, has_deep_ep, has_mori, - has_pplx, round_up) +from vllm.utils import ( + cdiv, + direct_register_custom_op, + has_deep_ep, + has_mori, + has_pplx, + round_up, +) from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.v1.worker.ubatching import dbo_current_ubatch_id @@ -90,10 +96,13 @@ def _eplb_map_to_physical_and_record( if is_rocm_aiter_moe_enabled(): from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 - rocm_aiter_grouped_topk) + rocm_aiter_grouped_topk, + ) + grouped_topk_impl = rocm_aiter_grouped_topk else: from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk + grouped_topk_impl = grouped_topk if current_platform.is_tpu(): @@ -232,9 +241,10 @@ def _maybe_make_prepare_finalize( use_fp8_dispatch=use_fp8_dispatch, ) elif moe.use_mori_kernels: - use_fp8_dispatch = (quant_config is not None - and quant_config.quant_dtype - == current_platform.fp8_dtype()) + use_fp8_dispatch = ( + quant_config is not None + and quant_config.quant_dtype == current_platform.fp8_dtype() + ) scale_dim = 0 scale_type_size = 0 quant_dtype = None @@ -246,8 +256,9 @@ def _maybe_make_prepare_finalize( ) if temp is not None: scale_dim = temp[-1] - scale_type_size = (torch.float32.itemsize - ) # aiter quantization uses float32 scale + scale_type_size = ( + torch.float32.itemsize + ) # aiter quantization uses float32 scale quant_dtype = quant_config.quant_dtype all_to_all_args = dict( @@ -422,6 +433,7 @@ def select_gemm_impl( if self.moe.use_mori_kernels and is_rocm_aiter_moe_enabled(): from vllm.model_executor.layers.fused_moe import AiterExperts logger.debug("AiterExperts for Mori integration %s", self.moe) + return AiterExperts( max_num_tokens=self.moe.max_num_tokens, quant_config=self.moe_quant_config, @@ -1287,10 +1299,12 @@ def __init__( # TODO(bnell): flashinfer uses non-batched format. # Does it really need a batched buffer? - if (self.moe_parallel_config.use_pplx_kernels - or self.moe_parallel_config.use_deepep_ll_kernels - or self.moe_parallel_config.use_mori_kernels - or self.moe_config.use_flashinfer_cutlass_kernels): + if ( + self.moe_parallel_config.use_pplx_kernels + or self.moe_parallel_config.use_deepep_ll_kernels + or self.moe_parallel_config.use_mori_kernels + or self.moe_config.use_flashinfer_cutlass_kernels + ): if vllm_config.parallel_config.enable_dbo: self.batched_hidden_states = torch.zeros( (2, moe.max_num_tokens, self.hidden_size), diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index aec815f0b9ae..cd039db4f5dc 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -723,6 +723,7 @@ def _do_fused_experts( ) # aiter does not require intermediate workspace from vllm.model_executor.layers.fused_moe import AiterExperts + if isinstance(self.fused_experts, AiterExperts): workspace2 = None else: diff --git a/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py index e71e693c2d18..efe965d46d3a 100644 --- a/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py @@ -76,11 +76,11 @@ def prepare( apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, ) -> tuple[ - torch.Tensor, - Optional[torch.Tensor], - Optional[mk.ExpertTokensMetadata], - Optional[torch.Tensor], - Optional[torch.Tensor], + torch.Tensor, + Optional[torch.Tensor], + Optional[mk.ExpertTokensMetadata], + Optional[torch.Tensor], + Optional[torch.Tensor], ]: """ Prepare inputs for mori dispatch operation. @@ -110,7 +110,8 @@ def prepare( if block_shape is not None: assert not apply_router_weight_on_input, ( "apply_router_weight_on_input is" - " not supported for block scaled moe") + " not supported for block scaled moe" + ) quant_type = QuantType.per_1x128 else: quant_type = QuantType.per_Token @@ -184,14 +185,18 @@ def finalize( indices=topk_ids, ) - output.copy_(combined_output[:num_original_tokens], - non_blocking=True) + output.copy_( + combined_output[:num_original_tokens], + non_blocking=True, + ) except Exception as e: logger.error("mori combine failed: %s", e) raise RuntimeError("mori combine failed: %s", e) from e def __repr__(self) -> str: - return (f"MoriPrepareAndFinalize(max_tokens={self.max_num_tokens}, " - f"num_local_experts={self.num_local_experts}, " - f"num_dispatchers={self.num_dispatchers_})") + return ( + f"MoriPrepareAndFinalize(max_tokens={self.max_num_tokens}, " + f"num_local_experts={self.num_local_experts}, " + f"num_dispatchers={self.num_dispatchers_})" + ) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 306ed04d6e5c..2f16fe78ea72 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -221,8 +221,9 @@ def rocm_aiter_fused_moe_impl( from aiter.fused_moe import fused_moe # Check if input is already pre-quantized (from mori dispatch) - input_is_pre_quantized = (a1_scale is not None - and hidden_states.dtype == torch.float8_e4m3fnuz) + input_is_pre_quantized = ( + a1_scale is not None and hidden_states.dtype == torch.float8_e4m3fnuz + ) dtype = output_dtype if input_is_pre_quantized else None activation = ActivationType(activation_method) @@ -243,7 +244,7 @@ def rocm_aiter_fused_moe_impl( a1_scale, a2_scale, num_local_tokens=expert_num_tokens, - dtype=dtype + dtype=dtype, ) @@ -443,7 +444,7 @@ def rocm_aiter_fused_experts( a2_scale=quant_config.a2_scale, doweight_stage1=apply_router_weight_on_input, expert_num_tokens=expert_num_tokens, - output_dtype=output_dtype + output_dtype=output_dtype, ) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 61c841360987..00193f84bb74 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -974,7 +974,8 @@ def process_weights_after_loading(self, layer: Module) -> None: def maybe_make_prepare_finalize(self) -> Optional[mk.FusedMoEPrepareAndFinalize]: if ( self.use_marlin - or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM): + or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM + ): return None elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( @@ -991,10 +992,12 @@ def select_gemm_impl( layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: from vllm.model_executor.layers.fused_moe import ( - AiterExperts, BatchedTritonOrDeepGemmExperts, - TritonOrDeepGemmExperts) - assert not self.use_marlin, ( - "Marlin is not supported with all2all yet.") + AiterExperts, + BatchedTritonOrDeepGemmExperts, + TritonOrDeepGemmExperts, + ) + + assert not self.use_marlin, "Marlin is not supported with all2all yet." assert self.moe_quant_config is not None From 770676e0a469bfb994d71b50f60a6715138a3456 Mon Sep 17 00:00:00 2001 From: HakJu Kim Date: Mon, 6 Oct 2025 12:17:08 +0900 Subject: [PATCH 11/30] applied additional pre-commit results seems like the rule is changed... Signed-off-by: HakJu Kim --- .../device_communicators/all2all.py | 74 ++++++++----------- .../model_executor/layers/fused_moe/config.py | 5 +- vllm/model_executor/layers/fused_moe/layer.py | 2 +- 3 files changed, 33 insertions(+), 48 deletions(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 90e19fe2173f..e56e98f551ac 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -506,12 +506,13 @@ def _ensure_shmem_initialized(self): logger.debug( "[rank %s] PyTorch distributed ready with backend: %s", self.rank, - backend + backend, ) + # just to make line fit into 80 + world = dist.group.WORLD current_group = ( - self.cpu_group if self.cpu_group is not None - else dist.group.WORLD + self.cpu_group if self.cpu_group is not None else world ) # TODO(inhyeok): make group_name more reasonable @@ -528,16 +529,13 @@ def _ensure_shmem_initialized(self): # Register the current process group c10d._register_process_group(group_name, current_group) logger.debug( - "[rank %s] Registered process group '%s'", - self.rank, - group_name + "[rank %s] Registered proc group %s", self.rank, group_name ) # Initialize mori shmem with the registered group mori.shmem.shmem_torch_process_group_init(group_name) logger.debug( - "[rank %s] torch process group shmem init success", - self.rank + "[rank %s] torch proc group shmem init success", self.rank ) self._shmem_initialized = True return @@ -546,37 +544,32 @@ def _ensure_shmem_initialized(self): logger.debug( "[rank %s] torch process group shmem init failed: %s", self.rank, - torch_error + torch_error, ) self._shmem_initialized = True logger.warning( - "[rank %s] Continuing without mori shmem optimization", - self.rank + "[rank %s] Continue without mori shmem optimize", self.rank ) except Exception as e: - logger.error( - "[rank %s] mori shmem initialization failed: %s", - self.rank, e - ) + logger.error("[rank %s] mori shmem init failed: %s", self.rank, e) # Don't fail completely - mark as initialized to avoid retry loops self._shmem_initialized = True logger.warning( - "[rank %s] Continuing without mori shmem optimization", - self.rank + "[rank %s] Continuing without mori shmem optimize", self.rank ) def _make_mori_config( - self, - max_num_tokens: int, - num_local_experts: int, - experts_per_token: int, - hidden_dim: int, - scale_dim: int, - scale_type_size: int, - data_type: torch.dtype = torch.bfloat16, - quant_dtype: torch.dtype = None, - ): + self, + max_num_tokens: int, + num_local_experts: int, + experts_per_token: int, + hidden_dim: int, + scale_dim: int, + scale_type_size: int, + data_type: torch.dtype = torch.bfloat16, + quant_dtype: torch.dtype = None, + ): """Create mori EpDispatchCombineConfig""" import mori.ops.dispatch_combine as mori_ops from mori.ops.dispatch_combine import EpDispatchCombineKernelType @@ -607,8 +600,8 @@ def _make_mori_config( # Determine kernel type based on topology kernel_type=( EpDispatchCombineKernelType.InterNode - if self.internode else - EpDispatchCombineKernelType.IntraNode + if self.internode + else EpDispatchCombineKernelType.IntraNode ), ) @@ -649,9 +642,7 @@ def get_handle(self, kwargs): hidden_dim, ] ): - raise ValueError( - "Missing required parameters for mori handle creation" - ) + raise ValueError("Require more parameters for mori handle init") # Create cache key cache_key = ( @@ -690,19 +681,21 @@ def get_handle(self, kwargs): max_num_tokens, num_local_experts, experts_per_token, - hidden_dim + hidden_dim, ) return op - def dispatch(self, + def dispatch( + self, hidden_states: torch.Tensor, router_logits: torch.Tensor, is_sequence_parallel: bool = False, ): raise NotImplementedError - def combine(self, + def combine( + self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False, ): @@ -721,21 +714,16 @@ def destroy(self): # Check if shmem is actually active before finalizing mori.shmem.shmem_finalize() - logger.debug( - "[rank %s] mori shmem finalized", - self.dp_rank - ) + logger.debug("[rank %s] mori shmem finalize", self.dp_rank) except Exception as shmem_error: logger.debug( "[rank %s] shmem finalize failed " "(may not have been active): %s", self.dp_rank, - shmem_error + shmem_error, ) logger.debug("[rank %s] mori resources cleaned up", self.dp_rank) except Exception as e: - logger.warning( - "[rank %s] Error during mori cleanup: %s", self.dp_rank, e - ) + logger.warning("[rank %s] mori cleanup fail: %s", self.dp_rank, e) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 8d5d8a8a3fd4..4a104c0b1a7e 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -627,10 +627,7 @@ def use_deepep_ll_kernels(self): @property def use_mori_kernels(self): - return ( - self.use_all2all_kernels - and envs.VLLM_ALL2ALL_BACKEND == "mori" - ) + return self.use_all2all_kernels and envs.VLLM_ALL2ALL_BACKEND == "mori" @staticmethod def make( diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 53b680a749c2..ebbdf2d93dab 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -432,8 +432,8 @@ def select_gemm_impl( assert self.moe_quant_config is not None if self.moe.use_mori_kernels and is_rocm_aiter_moe_enabled(): from vllm.model_executor.layers.fused_moe import AiterExperts - logger.debug("AiterExperts for Mori integration %s", self.moe) + logger.debug("AiterExperts for Mori integration %s", self.moe) return AiterExperts( max_num_tokens=self.moe.max_num_tokens, quant_config=self.moe_quant_config, From 059f29ade5c2dd81117cb112758c11cbce744ab2 Mon Sep 17 00:00:00 2001 From: HakJu Kim Date: Mon, 6 Oct 2025 12:22:41 +0900 Subject: [PATCH 12/30] applied pre-commit results... I fear this will make lines go over 80... Signed-off-by: HakJu Kim --- vllm/distributed/device_communicators/all2all.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index e56e98f551ac..9fc0a229b8c9 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -511,9 +511,7 @@ def _ensure_shmem_initialized(self): # just to make line fit into 80 world = dist.group.WORLD - current_group = ( - self.cpu_group if self.cpu_group is not None else world - ) + current_group = self.cpu_group if self.cpu_group is not None else world # TODO(inhyeok): make group_name more reasonable group_name = "default" @@ -534,9 +532,7 @@ def _ensure_shmem_initialized(self): # Initialize mori shmem with the registered group mori.shmem.shmem_torch_process_group_init(group_name) - logger.debug( - "[rank %s] torch proc group shmem init success", self.rank - ) + logger.debug("[rank %s] torch proc group shmem init success", self.rank) self._shmem_initialized = True return From 7d520230172aa5dd8de3c358764f71e4934cba4e Mon Sep 17 00:00:00 2001 From: HakJu Kim Date: Thu, 9 Oct 2025 12:50:23 +0900 Subject: [PATCH 13/30] applied SageMoore's comments Co-authored-by: Sage Moore Signed-off-by: HakJu Kim --- .../device_communicators/all2all.py | 13 +- .../layers/fused_moe/mori_prepare_finalize.py | 122 ++++++++---------- 2 files changed, 57 insertions(+), 78 deletions(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index ae0c76edcdbd..c931de29d558 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -496,10 +496,6 @@ def _ensure_shmem_initialized(self): import torch.distributed as dist try: - # Wait for PyTorch distributed to be ready - if not dist.is_initialized(): - raise RuntimeError("PyTorch distributed not initialized yet") - # Check if we have a valid backend backend = dist.get_backend() if backend is None: @@ -511,12 +507,11 @@ def _ensure_shmem_initialized(self): backend, ) - # just to make line fit into 80 - world = dist.group.WORLD - current_group = self.cpu_group if self.cpu_group is not None else world + current_group = ( + self.cpu_group if self.cpu_group is not None else dist.group.WORLD + ) + group_name = "mori_shmem_group" - # TODO(inhyeok): make group_name more reasonable - group_name = "default" try: import contextlib diff --git a/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py index efe965d46d3a..43074009a92f 100644 --- a/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py @@ -75,13 +75,7 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[ - torch.Tensor, - Optional[torch.Tensor], - Optional[mk.ExpertTokensMetadata], - Optional[torch.Tensor], - Optional[torch.Tensor], - ]: + ) -> mk.PrepareResultType: """ Prepare inputs for mori dispatch operation. Supports pre-dispatch quantization to reduce communication overhead. @@ -98,60 +92,55 @@ def prepare( dispatch_indices, dispatch_weights) where dispatched_x is in Standard format (2D tensor) """ - try: - # Pre-dispatch quantization to reduce communication overhead - dispatch_input = a1 - scales = None - - if self.use_fp8_dispatch: - from aiter import QuantType, get_hip_quant - - block_shape = quant_config.block_shape - if block_shape is not None: - assert not apply_router_weight_on_input, ( - "apply_router_weight_on_input is" - " not supported for block scaled moe" - ) - quant_type = QuantType.per_1x128 - else: - quant_type = QuantType.per_Token - - quant_func = get_hip_quant(quant_type) - - dispatch_input, scales = quant_func( - a1, - quant_dtype=quant_config.quant_dtype, + # Pre-dispatch quantization to reduce communication overhead + dispatch_input = a1 + scales = None + + if self.use_fp8_dispatch: + from aiter import QuantType, get_hip_quant + + block_shape = quant_config.block_shape + if block_shape is not None: + assert not apply_router_weight_on_input, ( + "apply_router_weight_on_input is" + " not supported for block scaled moe" ) + quant_type = QuantType.per_1x128 + else: + quant_type = QuantType.per_Token - ( - dispatch_output, - dispatch_weights, - dispatch_scales, - dispatch_indices, - dispatch_recv_num_token, - ) = self.handle.dispatch( - input=dispatch_input, - weights=topk_weights, - scales=scales, - indices=topk_ids, - ) + quant_func = get_hip_quant(quant_type) - expert_tokens_meta = mk.ExpertTokensMetadata( - expert_num_tokens=dispatch_recv_num_token, - expert_num_tokens_cpu=None, + dispatch_input, scales = quant_func( + a1, + quant_dtype=quant_config.quant_dtype, ) - return ( - dispatch_output, - dispatch_scales, - expert_tokens_meta, - dispatch_indices, - dispatch_weights, - ) + ( + dispatch_output, + dispatch_weights, + dispatch_scales, + dispatch_indices, + dispatch_recv_num_token, + ) = self.handle.dispatch( + input=dispatch_input, + weights=topk_weights, + scales=scales, + indices=topk_ids, + ) - except Exception as e: - logger.error("mori dispatch failed: %s", e) - raise RuntimeError("mori dispatch failed: %s", e) from e + expert_tokens_meta = mk.ExpertTokensMetadata( + expert_num_tokens=dispatch_recv_num_token, + expert_num_tokens_cpu=None, + ) + + return ( + dispatch_output, + dispatch_scales, + expert_tokens_meta, + dispatch_indices, + dispatch_weights, + ) def finalize( self, @@ -178,21 +167,16 @@ def finalize( num_original_tokens = output.size(0) # Original number of tokens - try: - combined_output, combined_weights = self.handle.combine( - input=fused_expert_output, - weights=topk_weights, - indices=topk_ids, - ) - - output.copy_( - combined_output[:num_original_tokens], - non_blocking=True, - ) + combined_output, combined_weights = self.handle.combine( + input=fused_expert_output, + weights=topk_weights, + indices=topk_ids, + ) - except Exception as e: - logger.error("mori combine failed: %s", e) - raise RuntimeError("mori combine failed: %s", e) from e + output.copy_( + combined_output[:num_original_tokens], + non_blocking=True, + ) def __repr__(self) -> str: return ( From 6d8ef4363e4c31c47e3329934bcc8a74981ff2fd Mon Sep 17 00:00:00 2001 From: HakJu Kim Date: Thu, 9 Oct 2025 13:21:42 +0900 Subject: [PATCH 14/30] following code difference from main and applied two comments from bnellnm Co-authored-by: Bill Nell Signed-off-by: HakJu Kim --- vllm/model_executor/layers/fused_moe/aiter_experts.py | 2 +- vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/aiter_experts.py b/vllm/model_executor/layers/fused_moe/aiter_experts.py index 0ff6b1229486..11cc918302a7 100644 --- a/vllm/model_executor/layers/fused_moe/aiter_experts.py +++ b/vllm/model_executor/layers/fused_moe/aiter_experts.py @@ -74,7 +74,7 @@ def workspace_shapes( Aiter kernels manage memory internally, so minimal workspace is needed. """ # Return minimal shapes since Aiter handles memory internally - workspace2 = () # No intermediate workspace needed + workspace2 = (0, ) # No intermediate workspace needed output_shape = aq.shape workspace13 = output_shape workspace_dtype = a.dtype diff --git a/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py index 43074009a92f..0bcf538204fb 100644 --- a/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py @@ -66,6 +66,9 @@ def topk_indices_dtype(self) -> Optional[torch.dtype]: def num_dispatchers(self) -> int: return self.num_dispatchers_ + def output_is_reduced(self) -> bool: + return True + def prepare( self, a1: torch.Tensor, From eace564a5968bf0302eb2bdedc3f372c569b9290 Mon Sep 17 00:00:00 2001 From: HakJu Kim Date: Thu, 9 Oct 2025 13:41:38 +0900 Subject: [PATCH 15/30] Applied few other comments from bnellnm Co-authored-by: bnellnm <49004751+bnellnm@users.noreply.github.com> Signed-off-by: HakJu Kim --- .../device_communicators/all2all.py | 4 --- vllm/model_executor/layers/fused_moe/layer.py | 18 +++++----- .../model_executor/layers/quantization/fp8.py | 33 +++++-------------- 3 files changed, 18 insertions(+), 37 deletions(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index c931de29d558..134cae7b99d8 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -517,10 +517,6 @@ def _ensure_shmem_initialized(self): import torch._C._distributed_c10d as c10d - # Try to unregister first in case it exists - with contextlib.suppress(RuntimeError): - c10d._unregister_process_group(group_name) - # Register the current process group c10d._register_process_group(group_name, current_group) logger.debug( diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index aeaff54ddd8c..c3b11b69dafd 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -434,15 +434,7 @@ def select_gemm_impl( layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: assert self.moe_quant_config is not None - if self.moe.use_mori_kernels and is_rocm_aiter_moe_enabled(): - from vllm.model_executor.layers.fused_moe import AiterExperts - - logger.debug("AiterExperts for Mori integration %s", self.moe) - return AiterExperts( - max_num_tokens=self.moe.max_num_tokens, - quant_config=self.moe_quant_config, - ) - elif ( + if ( prepare_finalize.activation_format == FusedMoEActivationFormat.BatchedExperts ): @@ -452,6 +444,14 @@ def select_gemm_impl( num_dispatchers=prepare_finalize.num_dispatchers(), quant_config=self.moe_quant_config, ) + elif self.moe.use_mori_kernels and is_rocm_aiter_moe_enabled(): + from vllm.model_executor.layers.fused_moe import AiterExperts + + logger.debug("AiterExperts for Mori integration %s", self.moe) + return AiterExperts( + max_num_tokens=self.moe.max_num_tokens, + quant_config=self.moe_quant_config, + ) else: logger.debug("TritonExperts %s", self.moe) return TritonExperts(self.moe_quant_config) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 00193f84bb74..7c13a37f2673 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1001,13 +1001,7 @@ def select_gemm_impl( assert self.moe_quant_config is not None - if self.moe.use_mori_kernels and is_rocm_aiter_moe_enabled(): - logger.debug("AiterExperts for Mori integration %s", self.moe) - return AiterExperts( - max_num_tokens=self.moe.max_num_tokens, - quant_config=self.moe_quant_config, - ) - elif ( + if ( prepare_finalize.activation_format == FusedMoEActivationFormat.BatchedExperts ): @@ -1027,6 +1021,12 @@ def select_gemm_impl( quant_config=self.moe_quant_config, allow_deep_gemm=self.allow_deep_gemm, ) + elif self.moe.use_mori_kernels and is_rocm_aiter_moe_enabled(): + logger.debug("AiterExperts for Mori integration %s", self.moe) + return AiterExperts( + max_num_tokens=self.moe.max_num_tokens, + quant_config=self.moe_quant_config, + ) elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: experts = select_cutlass_fp8_gemm_impl( self.moe, @@ -1179,26 +1179,11 @@ def apply( # can override fused_experts or cutlass but not rocm or marlin. # topk_weights, topk_ids, zero_expert_result = select_result - if self.moe.use_mori_kernels and self.fused_experts: - common_kwargs = dict( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=False, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - return self.fused_experts(**common_kwargs) - elif self.rocm_aiter_moe_enabled: + if self.rocm_aiter_moe_enabled and self.fused_experts is None: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 rocm_aiter_fused_experts, ) - assert self.fused_experts is None result = rocm_aiter_fused_experts( x, layer.w13_weight, @@ -1237,7 +1222,7 @@ def apply( w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - inplace=True, + inplace=False if self.moe.use_mori_kernels else True, activation=activation, global_num_experts=global_num_experts, apply_router_weight_on_input=apply_router_weight_on_input, From 73093c5d4a07e6d710a2ac07ac9fcfb47f8ce3c5 Mon Sep 17 00:00:00 2001 From: HakJu Kim Date: Thu, 9 Oct 2025 13:59:57 +0900 Subject: [PATCH 16/30] Applied some pre-commit results. Still need to fix AiterExperts workspace_shapes() according to mk refactor Signed-off-by: HakJu Kim --- vllm/distributed/device_communicators/all2all.py | 2 -- vllm/model_executor/layers/fused_moe/aiter_experts.py | 6 +++--- .../layers/fused_moe/mori_prepare_finalize.py | 3 +-- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 134cae7b99d8..00692ea57ccf 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -513,8 +513,6 @@ def _ensure_shmem_initialized(self): group_name = "mori_shmem_group" try: - import contextlib - import torch._C._distributed_c10d as c10d # Register the current process group diff --git a/vllm/model_executor/layers/fused_moe/aiter_experts.py b/vllm/model_executor/layers/fused_moe/aiter_experts.py index 11cc918302a7..076379005099 100644 --- a/vllm/model_executor/layers/fused_moe/aiter_experts.py +++ b/vllm/model_executor/layers/fused_moe/aiter_experts.py @@ -69,16 +69,16 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: """ Aiter kernels manage memory internally, so minimal workspace is needed. """ # Return minimal shapes since Aiter handles memory internally - workspace2 = (0, ) # No intermediate workspace needed + workspace2 = (0,) # No intermediate workspace needed output_shape = aq.shape workspace13 = output_shape workspace_dtype = a.dtype - return (workspace13, workspace2, output_shape, workspace_dtype) + return (workspace13, workspace2, output_shape) def apply( self, diff --git a/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py index 0bcf538204fb..def27c9c8b9b 100644 --- a/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py @@ -105,8 +105,7 @@ def prepare( block_shape = quant_config.block_shape if block_shape is not None: assert not apply_router_weight_on_input, ( - "apply_router_weight_on_input is" - " not supported for block scaled moe" + "apply_router_weight_on_input is not supported for block scaled moe" ) quant_type = QuantType.per_1x128 else: From b9b9a9bb97d088d40b26d9c0ba3cf389e98a89c9 Mon Sep 17 00:00:00 2001 From: ihbang Date: Thu, 9 Oct 2025 06:21:37 +0000 Subject: [PATCH 17/30] refactor workspace_shapes of AiterExperts and handle_cache of MoriAll2AllManager Co-authored-by: bnellnm <49004751+bnellnm@users.noreply.github.com> Signed-off-by: ihbang --- .../device_communicators/all2all.py | 102 ++++++------------ .../layers/fused_moe/aiter_experts.py | 10 +- 2 files changed, 35 insertions(+), 77 deletions(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 00692ea57ccf..738600abedc7 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -482,7 +482,6 @@ def __init__(self, cpu_group): super().__init__(cpu_group) self.handle_cache = Cache() self.config = None - self._op_handles = {} # Cache for EpDispatchCombineOp instances self._shmem_initialized = False # Delay mori shmem initialization until first use logger.debug("[rank %s] MoriAll2AllManager created", self.rank) @@ -555,20 +554,12 @@ def _make_mori_config( scale_dim: int, scale_type_size: int, data_type: torch.dtype = torch.bfloat16, - quant_dtype: torch.dtype = None, + quant_dtype: Optional[torch.dtype] = None, ): """Create mori EpDispatchCombineConfig""" import mori.ops.dispatch_combine as mori_ops from mori.ops.dispatch_combine import EpDispatchCombineKernelType - # Determine data type size - dtype_to_size = { - torch.float32: 4, - torch.bfloat16: 2, - torch.float16: 2, - } - max_token_type_size = dtype_to_size.get(data_type, 2) - config = mori_ops.EpDispatchCombineConfig( data_type=data_type if quant_dtype is None else quant_dtype, rank=self.rank, @@ -580,7 +571,7 @@ def _make_mori_config( # Performance tuning parameters # warp_num_per_block=8, # block_num=80, - max_token_type_size=max_token_type_size, + max_token_type_size=data_type.itemsize, # Quantization support scale_dim=scale_dim, scale_type_size=scale_type_size, @@ -608,70 +599,41 @@ def get_handle(self, kwargs): # Ensure shmem is initialized before creating handles self._ensure_shmem_initialized() - import mori.ops.dispatch_combine as mori_ops - - # Extract parameters - max_num_tokens = kwargs.get("max_num_tokens") - num_local_experts = kwargs.get("num_local_experts") - experts_per_token = kwargs.get("experts_per_token") - hidden_dim = kwargs.get("hidden_dim") - data_type = kwargs.get("data_type", torch.bfloat16) - scale_dim = kwargs.get("scale_dim") - scale_type_size = kwargs.get("scale_type_size") - - # Validate required parameters - if any( - param is None - for param in [ + def create_mori_handle( + max_num_tokens: int, + num_local_experts: int, + experts_per_token: int, + hidden_dim: int, + scale_dim: int, + scale_type_size: int, + data_type: torch.dtype = torch.bfloat16, + quant_dtype: Optional[torch.dtype] = None, + ): + import mori + + config = self._make_mori_config( + max_num_tokens=max_num_tokens, + num_local_experts=num_local_experts, + experts_per_token=experts_per_token, + hidden_dim=hidden_dim, + scale_dim=scale_dim, + scale_type_size=scale_type_size, + data_type=data_type, + quant_dtype=quant_dtype, + ) + op = mori.ops.EpDispatchCombineOp(config) + logger.debug( + "[rank %s] Created mori handle with config: tokens=%d, experts=%d," + " topk=%d, hidden_dim=%d", + self.dp_rank, max_num_tokens, num_local_experts, experts_per_token, hidden_dim, - ] - ): - raise ValueError("Require more parameters for mori handle init") - - # Create cache key - cache_key = ( - max_num_tokens, - num_local_experts, - experts_per_token, - hidden_dim, - data_type, - ) - - # Check cache first - if cache_key in self._op_handles: - return self._op_handles[cache_key] - - # Create new mori configuration and operation - config = self._make_mori_config( - max_num_tokens=max_num_tokens, - num_local_experts=num_local_experts, - experts_per_token=experts_per_token, - hidden_dim=hidden_dim, - data_type=data_type, - scale_dim=scale_dim, - scale_type_size=scale_type_size, - ) - - # Create operation handle - op = mori_ops.EpDispatchCombineOp(config) - - # Cache the handle - self._op_handles[cache_key] = op - - logger.debug( - "[rank %s] Created mori handle with config: tokens=%d, experts=%d," - " topk=%d, hidden_dim=%d", - self.dp_rank, - max_num_tokens, - num_local_experts, - experts_per_token, - hidden_dim, - ) + ) + return op - return op + return self.handle_cache.get_or_create(kwargs, create_mori_handle) def dispatch( self, diff --git a/vllm/model_executor/layers/fused_moe/aiter_experts.py b/vllm/model_executor/layers/fused_moe/aiter_experts.py index 076379005099..0c73fea78ce3 100644 --- a/vllm/model_executor/layers/fused_moe/aiter_experts.py +++ b/vllm/model_executor/layers/fused_moe/aiter_experts.py @@ -60,8 +60,6 @@ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, @@ -73,12 +71,10 @@ def workspace_shapes( """ Aiter kernels manage memory internally, so minimal workspace is needed. """ - # Return minimal shapes since Aiter handles memory internally + workspace1 = (M, K) workspace2 = (0,) # No intermediate workspace needed - output_shape = aq.shape - workspace13 = output_shape - workspace_dtype = a.dtype - return (workspace13, workspace2, output_shape) + output_shape = (M, K) + return (workspace1, workspace2, output_shape) def apply( self, From d3c6ce069e5557e7bbcc06ce1420235e8c655910 Mon Sep 17 00:00:00 2001 From: HakJu Kim Date: Fri, 10 Oct 2025 08:59:50 +0900 Subject: [PATCH 18/30] clean-up handle_cache at destroy() of mori a2a manager Signed-off-by: HakJu Kim --- vllm/distributed/device_communicators/all2all.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 738600abedc7..158823f17429 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -654,7 +654,9 @@ def destroy(self): """Clean up mori resources""" try: # Clear operation handle cache - self._op_handles.clear() + with self.handle_cache._lock: + for _, handle in self.handle_cache._cache.items(): + handle.destroy() # finalize mori shared memory if it was initialized if self._shmem_initialized: From 32482eee535dbd2de1591914eb31005142fc13a2 Mon Sep 17 00:00:00 2001 From: HakJu Kim Date: Fri, 10 Oct 2025 09:08:15 +0900 Subject: [PATCH 19/30] fixed according to SM211 rule Signed-off-by: HakJu Kim --- vllm/model_executor/layers/quantization/fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 7c13a37f2673..3fba5018740b 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1222,7 +1222,7 @@ def apply( w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - inplace=False if self.moe.use_mori_kernels else True, + inplace=not self.moe.use_mori_kernels, activation=activation, global_num_experts=global_num_experts, apply_router_weight_on_input=apply_router_weight_on_input, From eda8c8eac342a3528e1f5a0ad85ae99cd00d32da Mon Sep 17 00:00:00 2001 From: HakJu Kim Date: Fri, 10 Oct 2025 09:46:26 +0900 Subject: [PATCH 20/30] adding mori backend to moe kernel feature doc Signed-off-by: HakJu Kim --- docs/design/moe_kernel_features.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index 0831c5bc790d..ea879d0f96c7 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -41,6 +41,7 @@ th { | flashinfer4 | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferCutlassMoEPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize.FlashInferCutlassMoEPrepareAndFinalize] | | MoEPrepareAndFinalizeNoEP5 | standard | fp8,int8 | G,A,T | N | Y | [`MoEPrepareAndFinalizeNoEP`][vllm.model_executor.layers.fused_moe.prepare_finalize.MoEPrepareAndFinalizeNoEP] | | BatchedPrepareAndFinalize5 | batched | fp8,int8 | G,A,T | N | Y | [`BatchedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedPrepareAndFinalize] | +| MoriPrepareAndFinalize7 | standard | fp88 | G(128),A,T8 |N | Y | [`MoriPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.mori_prepare_finalize.MoriPrepareAndFinalize] | !!! info "Table key" 1. All types: mxfp4, nvfp4, int4, int8, fp8 @@ -49,6 +50,8 @@ th { 4. Controlled by different env vars (`VLLM_FLASHINFER_MOE_BACKEND` "throughput" or "latency") 5. This is a no-op dispatcher that can be used to pair with any modular experts to produce a modular kernel that runs w/o dispatch or combine. These cannot be selected via environment variable. These are generally use for testing or adapting an expert subclass to the `fused_experts` API. 6. This depends on the experts implementation. + 7. Currently, MoRI supports low-latency mode only. + 8. This depends on the experts implementation, currently mori supports aiter. --- @@ -118,3 +121,4 @@ The following table shows "families" of modular kernels that are intended to wor | deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,
`TritonExperts`,
`TritonOrDeepGemmExperts`,
`CutlassExpertsFp8`,
`MarlinExperts` | | deepep_low_latency,
pplx | `DeepEPLLPrepareAndFinalize`,
`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,
`BatchedTritonExperts`,
`BatchedTritonOrDeepGemmExperts`,
`CutlassBatchedExpertsFp8`| | flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` | +| mori | `MoriPrepareAndFinalize` | `AiterExperts` | From ddd3563afab6c46700ef81159d4849c5a59286bf Mon Sep 17 00:00:00 2001 From: HakJu Kim Date: Sun, 12 Oct 2025 00:05:37 +0900 Subject: [PATCH 21/30] applied reviews from bnellnm Co-authored-by: Bill Nell Signed-off-by: HakJu Kim --- vllm/distributed/device_communicators/all2all.py | 8 +++----- .../device_communicators/base_device_communicator.py | 4 ---- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 879bc340ed9e..dfe41888f238 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -511,16 +511,14 @@ def _ensure_shmem_initialized(self): backend, ) - current_group = ( - self.cpu_group if self.cpu_group is not None else dist.group.WORLD - ) + assert self.cpu_group is not None, "No CPU group is given to mori" group_name = "mori_shmem_group" try: import torch._C._distributed_c10d as c10d - # Register the current process group - c10d._register_process_group(group_name, current_group) + # Register the process group + c10d._register_process_group(group_name, self.cpu_group) logger.debug( "[rank %s] Registered proc group %s", self.rank, group_name ) diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 6b7773a5dcbf..c32be0bec55c 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -8,10 +8,6 @@ import torch.distributed as dist from torch.distributed import ProcessGroup -from vllm.logger import init_logger - -logger = init_logger(__name__) - class Cache: def __init__(self): From 1dbff2c76f919df9fec6472905e6567035d78c38 Mon Sep 17 00:00:00 2001 From: HakJu Kim Date: Mon, 13 Oct 2025 09:13:30 +0900 Subject: [PATCH 22/30] new precommit removed Optional Signed-off-by: HakJu Kim --- .../layers/fused_moe/aiter_experts.py | 12 +++---- .../layers/fused_moe/mori_prepare_finalize.py | 10 +++--- .../layers/fused_moe/rocm_aiter_fused_moe.py | 34 +++++++++---------- 3 files changed, 27 insertions(+), 29 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/aiter_experts.py b/vllm/model_executor/layers/fused_moe/aiter_experts.py index 0c73fea78ce3..c615fcb9b550 100644 --- a/vllm/model_executor/layers/fused_moe/aiter_experts.py +++ b/vllm/model_executor/layers/fused_moe/aiter_experts.py @@ -4,8 +4,6 @@ Aiter-based expert processing for Mori integration. """ -from typing import Optional - import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk @@ -66,7 +64,7 @@ def workspace_shapes( topk: int, global_num_experts: int, local_num_experts: int, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + expert_tokens_meta: mk.ExpertTokensMetadata | None, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: """ Aiter kernels manage memory internally, so minimal workspace is needed. @@ -86,12 +84,12 @@ def apply( topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, workspace13: torch.Tensor, workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, ): """ diff --git a/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py index def27c9c8b9b..593dca1c04b6 100644 --- a/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py @@ -5,7 +5,7 @@ Migration from DeepEP to mori for AMD GPU support. """ -from typing import Any, Optional +from typing import Any import torch @@ -57,10 +57,10 @@ def __init__( def activation_format(self) -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.Standard - def max_num_tokens_per_rank(self) -> Optional[int]: + def max_num_tokens_per_rank(self) -> int | None: return self.max_num_tokens - def topk_indices_dtype(self) -> Optional[torch.dtype]: + def topk_indices_dtype(self) -> torch.dtype | None: return torch.int32 def num_dispatchers(self) -> int: @@ -75,7 +75,7 @@ def prepare( topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, - expert_map: Optional[torch.Tensor], + expert_map: torch.Tensor | None, apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, ) -> mk.PrepareResultType: @@ -152,7 +152,7 @@ def finalize( topk_ids: torch.Tensor, apply_router_weight_on_input: bool, weight_and_reduce_impl: mk.TopKWeightAndReduce, - extra_finalize_args: Optional[dict] = None, + extra_finalize_args: dict | None = None, ) -> None: """ Finalize expert outputs using mori combine operation. diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 537fe0bca0e9..a1d73c02415e 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -209,12 +209,12 @@ def rocm_aiter_fused_moe_impl( activation_method: int = ActivationMethod.SILU.value, quant_method: int = QuantMethod.NO.value, doweight_stage1: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - expert_num_tokens: Optional[torch.Tensor] = None, - output_dtype: Optional[torch.dtype] = None, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, + expert_num_tokens: torch.Tensor | None = None, + output_dtype: torch.dtype | None = None, ) -> torch.Tensor: from aiter import ActivationType, QuantType from aiter.fused_moe import fused_moe @@ -257,12 +257,12 @@ def rocm_aiter_fused_moe_fake( activation_method: int = ActivationMethod.SILU.value, quant_method: int = QuantMethod.NO.value, doweight_stage1: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - expert_num_tokens: Optional[torch.Tensor] = None, - output_dtype: Optional[torch.dtype] = None, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, + expert_num_tokens: torch.Tensor | None = None, + output_dtype: torch.dtype | None = None, ) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -353,11 +353,11 @@ def rocm_aiter_fused_experts( topk_ids: torch.Tensor, activation: str = "silu", apply_router_weight_on_input: bool = False, - expert_map: Optional[torch.Tensor] = None, - expert_num_tokens: Optional[torch.Tensor] = None, - output_dtype: Optional[torch.dtype] = None, - quant_config: Optional[FusedMoEQuantConfig] = None, - a1q_scale: Optional[torch.Tensor] = None, + expert_map: torch.Tensor | None = None, + expert_num_tokens: torch.Tensor | None = None, + output_dtype: torch.dtype | None = None, + quant_config: FusedMoEQuantConfig | None = None, + a1q_scale: torch.Tensor | None = None, ) -> torch.Tensor: if quant_config is None: quant_config = FUSED_MOE_UNQUANTIZED_CONFIG From 665a631a840f32b2105bfe9ed18911e2e7334cde Mon Sep 17 00:00:00 2001 From: HakJu Kim Date: Mon, 13 Oct 2025 09:26:37 +0900 Subject: [PATCH 23/30] additional pre-commit results about Optional Signed-off-by: HakJu Kim --- vllm/distributed/device_communicators/all2all.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index d33e55145846..bd62b54276af 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -557,7 +557,7 @@ def _make_mori_config( scale_dim: int, scale_type_size: int, data_type: torch.dtype = torch.bfloat16, - quant_dtype: Optional[torch.dtype] = None, + quant_dtype: torch.dtype | None = None, ): """Create mori EpDispatchCombineConfig""" import mori.ops.dispatch_combine as mori_ops @@ -610,7 +610,7 @@ def create_mori_handle( scale_dim: int, scale_type_size: int, data_type: torch.dtype = torch.bfloat16, - quant_dtype: Optional[torch.dtype] = None, + quant_dtype: torch.dtype | None = None, ): import mori From 096f938727bc65c4b97b1348f04348a77a3af318 Mon Sep 17 00:00:00 2001 From: HakJu Kim Date: Wed, 15 Oct 2025 09:16:55 +0900 Subject: [PATCH 24/30] fixed few whitespaces Signed-off-by: HakJu Kim --- docs/design/moe_kernel_features.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index ea879d0f96c7..9cc075bec0eb 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -41,7 +41,7 @@ th { | flashinfer4 | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferCutlassMoEPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize.FlashInferCutlassMoEPrepareAndFinalize] | | MoEPrepareAndFinalizeNoEP5 | standard | fp8,int8 | G,A,T | N | Y | [`MoEPrepareAndFinalizeNoEP`][vllm.model_executor.layers.fused_moe.prepare_finalize.MoEPrepareAndFinalizeNoEP] | | BatchedPrepareAndFinalize5 | batched | fp8,int8 | G,A,T | N | Y | [`BatchedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedPrepareAndFinalize] | -| MoriPrepareAndFinalize7 | standard | fp88 | G(128),A,T8 |N | Y | [`MoriPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.mori_prepare_finalize.MoriPrepareAndFinalize] | +| MoriPrepareAndFinalize7 | standard | fp88 | G(128),A,T8 | N | Y | [`MoriPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.mori_prepare_finalize.MoriPrepareAndFinalize] | !!! info "Table key" 1. All types: mxfp4, nvfp4, int4, int8, fp8 From 25f8d59d344ba472fad316fc53bead99fd985d4d Mon Sep 17 00:00:00 2001 From: HakJu Kim Date: Wed, 15 Oct 2025 10:39:50 +0900 Subject: [PATCH 25/30] applied pre-commit result this happened because of code conflict(semantically) Signed-off-by: HakJu Kim --- vllm/distributed/device_communicators/cuda_communicator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 4394431fc4b4..af890e0da293 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -116,7 +116,7 @@ def __init__( self.all2all_manager = FlashInferAllToAllManager(self.cpu_group) logger.info("Using Flashinfer all2allv manager.") - elif all2all_backend == "mori": + elif self.all2all_backend == "mori": from .all2all import MoriAll2AllManager self.all2all_manager = MoriAll2AllManager(self.cpu_group) From a4769132a75b112146087a9312dd0193efdd31e2 Mon Sep 17 00:00:00 2001 From: ihbang Date: Mon, 20 Oct 2025 11:56:29 +0900 Subject: [PATCH 26/30] Add json config parsing logic to change mori configs easily Signed-off-by: ihbang --- .../device_communicators/all2all.py | 148 +++++++++++++++++- vllm/envs.py | 4 + vllm/model_executor/layers/fused_moe/layer.py | 1 + .../layers/fused_moe/mori_prepare_finalize.py | 31 ++++ 4 files changed, 180 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index ac6c836dfe3e..59e99da1a6cb 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json +from pathlib import Path from typing import Any import torch @@ -500,6 +502,12 @@ def __init__(self, cpu_group): self.handle_cache = Cache() self.config = None self._shmem_initialized = False + + self.json_config = None + config_path = envs.VLLM_MORI_CONFIG_PATH + if config_path: + self.json_config = self._load_mori_config_from_json(config_path) + # Delay mori shmem initialization until first use logger.debug("[rank %s] MoriAll2AllManager created", self.rank) @@ -560,6 +568,104 @@ def _ensure_shmem_initialized(self): "[rank %s] Continuing without mori shmem optimize", self.rank ) + def _load_mori_config_from_json(self, json_path: str) -> dict | None: + """ + Load mori configuration parameters from JSON file. + + Supports both flat and hierarchical schema: + + Flat schema: + { + "warp_num_per_block": 8, + "block_num": 80, + } + + Hierarchical schema (dispatch/combine specific): + { + "global": { + "warp_num_per_block": 8, + "block_num": 80, + }, + "dispatch": { + "warp_num_per_block": 16, + "block_num": 160 + }, + "combine": { + "warp_num_per_block": 4, + "block_num": 40 + } + } + + Args: + json_path: Path to JSON configuration file + + Returns: + Dictionary of configuration parameters, or None if file doesn't exist + + Raises: + ValueError: If JSON is invalid or contains unsupported parameters + """ + if not json_path: + return None + + json_file = Path(json_path) + if not json_file.exists(): + logger.warning( + "[rank %d] Mori config file not found: %s", self.rank, json_path + ) + return None + + try: + with open(json_file) as f: + config = json.load(f) + + # Valid parameter keys + valid_param_keys = { + "warp_num_per_block", + "block_num", + } + + is_hierarchical = any( + key in config for key in ["global", "dispatch", "combine"] + ) + + if is_hierarchical: + valid_top_keys = {"global", "dispatch", "combine"} + invalid_keys = set(config.keys()) - valid_top_keys + if invalid_keys: + raise ValueError( + f"Invalid top-level keys: {invalid_keys}. " + f"Valid keys: {valid_top_keys}" + ) + + # Validate each section + for section in ["global", "dispatch", "combine"]: + if section in config: + section_config = config[section] + if not isinstance(section_config, dict): + raise ValueError(f"'{section}' must be a dictionary") + + invalid_keys = set(section_config.keys()) - valid_param_keys + if invalid_keys: + raise ValueError( + f"Invalid keys in '{section}': {invalid_keys}. " + f"Valid keys: {valid_param_keys}" + ) + else: + invalid_keys = set(config.keys()) - valid_param_keys + if invalid_keys: + raise ValueError( + f"Invalid config keys: {invalid_keys}. " + f"Valid keys: {valid_param_keys}" + ) + + return config + + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in mori config file {json_path}") from e + except Exception as e: + raise ValueError(f"Error loading mori config from {json_path}") from e + def _make_mori_config( self, max_num_tokens: int, @@ -571,10 +677,41 @@ def _make_mori_config( data_type: torch.dtype = torch.bfloat16, quant_dtype: torch.dtype | None = None, ): - """Create mori EpDispatchCombineConfig""" + """ + Create mori EpDispatchCombineConfig. + + Args: + max_num_tokens: Maximum number of tokens per DP rank + num_local_experts: Number of local experts + experts_per_token: Number of experts per token (topk) + hidden_dim: Hidden dimension size + scale_dim: Scale dimension for quantization + scale_type_size: Scale type size for quantization + data_type: Tensor data type + quant_dtype: Quantization data type (optional) + """ import mori.ops.dispatch_combine as mori_ops from mori.ops.dispatch_combine import EpDispatchCombineKernelType + # Default values (can be overridden by JSON) + warp_num_per_block = 8 + block_num = 80 + + # Override with JSON config if provided + if self.json_config is not None: + is_hierarchical = any( + key in self.json_config for key in ["global", "dispatch", "combine"] + ) + + global_config = self.json_config + if is_hierarchical and "global" in global_config: + global_config = self.json_config["global"] + + warp_num_per_block = global_config.get( + "warp_num_per_block", warp_num_per_block + ) + block_num = global_config.get("block_num", block_num) + config = mori_ops.EpDispatchCombineConfig( data_type=data_type if quant_dtype is None else quant_dtype, rank=self.rank, @@ -583,10 +720,10 @@ def _make_mori_config( max_num_inp_token_per_rank=max_num_tokens, num_experts_per_rank=num_local_experts, num_experts_per_token=experts_per_token, - # Performance tuning parameters - # warp_num_per_block=8, - # block_num=80, max_token_type_size=data_type.itemsize, + # Performance tuning parameters + warp_num_per_block=warp_num_per_block, + block_num=block_num, # Quantization support scale_dim=scale_dim, scale_type_size=scale_type_size, @@ -610,6 +747,9 @@ def get_handle(self, kwargs): - experts_per_token: Number of experts per token (topk) - hidden_dim: Hidden dimension size - data_type: Tensor data type (optional, default bfloat16) + - scale_dim: Scale dimension (optional) + - scale_type_size: Scale type size (optional) + - ubatch_id: Microbatch ID (optional) """ # Ensure shmem is initialized before creating handles self._ensure_shmem_initialized() diff --git a/vllm/envs.py b/vllm/envs.py index ec1cfb5e5ea9..3ddc80d9eda1 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -216,6 +216,7 @@ VLLM_NCCL_INCLUDE_PATH: str | None = None VLLM_USE_FBGEMM: bool = False VLLM_GC_DEBUG: str = "" + VLLM_MORI_CONFIG_PATH: str | None = None def get_default_cache_root(): @@ -1406,6 +1407,9 @@ def get_vllm_port() -> int | None: # - VLLM_GC_DEBUG='{"top_objects":5}': enable GC debugger with # top 5 collected objects "VLLM_GC_DEBUG": lambda: os.getenv("VLLM_GC_DEBUG", ""), + # Path to JSON configuration file for mori all2all parameters + # If set, mori will use parameters from this JSON file instead of defaults + "VLLM_MORI_CONFIG_PATH": lambda: os.getenv("VLLM_MORI_CONFIG_PATH", None), } # --8<-- [end:env-vars-definition] diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 86b95ed145ed..c48cbc641a53 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -279,6 +279,7 @@ def _maybe_make_prepare_finalize( num_local_experts=moe.num_local_experts, num_dispatchers=all2all_manager.world_size, use_fp8_dispatch=use_fp8_dispatch, + json_config=all2all_manager.json_config, ) return prepare_finalize diff --git a/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py index 593dca1c04b6..72aff122b66e 100644 --- a/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py @@ -32,6 +32,7 @@ def __init__( num_local_experts: int, num_dispatchers: int, use_fp8_dispatch: bool = False, + json_config: dict | None = None, ): """ Initialize MoriPrepareAndFinalize. @@ -42,6 +43,7 @@ def __init__( num_local_experts: Number of experts on this rank num_dispatchers: Number of dispatcher ranks (world size) use_fp8_dispatch: Whether to use FP8 quantization during dispatch + json_config: Optional JSON configuration with operation-specific parameters """ super().__init__() assert max_num_tokens > 0 @@ -53,6 +55,33 @@ def __init__( self.num_dispatchers_ = num_dispatchers self.use_fp8_dispatch = use_fp8_dispatch + # Extract dispatch and combine specific parameters from JSON config + self.dispatch_kwargs = {} + self.combine_kwargs = {} + + if json_config: + # Extract dispatch-specific parameters + if "dispatch" in json_config: + dispatch_config = json_config["dispatch"] + + if "block_num" in dispatch_config: + self.dispatch_kwargs["block_num"] = dispatch_config["block_num"] + if "warp_num_per_block" in dispatch_config: + self.dispatch_kwargs["warp_per_block"] = dispatch_config[ + "warp_num_per_block" + ] + + # Extract combine-specific parameters + if "combine" in json_config: + combine_config = json_config["combine"] + + if "block_num" in combine_config: + self.combine_kwargs["block_num"] = combine_config["block_num"] + if "warp_num_per_block" in combine_config: + self.combine_kwargs["warp_per_block"] = combine_config[ + "warp_num_per_block" + ] + @property def activation_format(self) -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.Standard @@ -129,6 +158,7 @@ def prepare( weights=topk_weights, scales=scales, indices=topk_ids, + **self.dispatch_kwargs, # Apply dispatch-specific parameters from JSON ) expert_tokens_meta = mk.ExpertTokensMetadata( @@ -173,6 +203,7 @@ def finalize( input=fused_expert_output, weights=topk_weights, indices=topk_ids, + **self.combine_kwargs, # Apply combine-specific parameters from JSON ) output.copy_( From 4c4306e2f1b001cb85eca067ea3ce42b43ea3751 Mon Sep 17 00:00:00 2001 From: ihbang Date: Mon, 20 Oct 2025 14:58:55 +0900 Subject: [PATCH 27/30] Applied review from HAIAI Signed-off-by: ihbang --- docs/design/moe_kernel_features.md | 2 +- .../device_communicators/all2all.py | 40 ++++++++----------- .../layers/fused_moe/__init__.py | 10 ++++- ...aiter_experts.py => aiter_mori_experts.py} | 7 +++- vllm/model_executor/layers/fused_moe/layer.py | 6 +-- .../layers/fused_moe/mori_prepare_finalize.py | 5 +-- .../layers/fused_moe/rocm_aiter_fused_moe.py | 2 +- .../model_executor/layers/quantization/fp8.py | 6 +-- 8 files changed, 40 insertions(+), 38 deletions(-) rename vllm/model_executor/layers/fused_moe/{aiter_experts.py => aiter_mori_experts.py} (94%) diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index 9cc075bec0eb..2d73aae6ed59 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -121,4 +121,4 @@ The following table shows "families" of modular kernels that are intended to wor | deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,
`TritonExperts`,
`TritonOrDeepGemmExperts`,
`CutlassExpertsFp8`,
`MarlinExperts` | | deepep_low_latency,
pplx | `DeepEPLLPrepareAndFinalize`,
`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,
`BatchedTritonExperts`,
`BatchedTritonOrDeepGemmExperts`,
`CutlassBatchedExpertsFp8`| | flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` | -| mori | `MoriPrepareAndFinalize` | `AiterExperts` | +| mori | `MoriPrepareAndFinalize` | `AiterMoriExperts` | diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 59e99da1a6cb..7d54733389f1 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -1,9 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json +import os from pathlib import Path from typing import Any +import psutil import torch import torch.distributed as dist @@ -516,8 +518,8 @@ def _ensure_shmem_initialized(self): if self._shmem_initialized: return - import mori.shmem import torch.distributed as dist + from mori.shmem import shmem_torch_process_group_init try: # Check if we have a valid backend @@ -532,7 +534,8 @@ def _ensure_shmem_initialized(self): ) assert self.cpu_group is not None, "No CPU group is given to mori" - group_name = "mori_shmem_group" + ppid = psutil.Process(os.getpid()).ppid() + group_name = f"mori_shmem_group_{ppid}" try: import torch._C._distributed_c10d as c10d @@ -544,29 +547,18 @@ def _ensure_shmem_initialized(self): ) # Initialize mori shmem with the registered group - mori.shmem.shmem_torch_process_group_init(group_name) + shmem_torch_process_group_init(group_name) logger.debug("[rank %s] torch proc group shmem init success", self.rank) self._shmem_initialized = True return except Exception as torch_error: - logger.debug( - "[rank %s] torch process group shmem init failed: %s", - self.rank, - torch_error, - ) - self._shmem_initialized = True - logger.warning( - "[rank %s] Continue without mori shmem optimize", self.rank - ) + raise RuntimeError( + "torch process group initialization failed" + ) from torch_error except Exception as e: - logger.error("[rank %s] mori shmem init failed: %s", self.rank, e) - # Don't fail completely - mark as initialized to avoid retry loops - self._shmem_initialized = True - logger.warning( - "[rank %s] Continuing without mori shmem optimize", self.rank - ) + raise RuntimeError("mori shmem initialization failed") from e def _load_mori_config_from_json(self, json_path: str) -> dict | None: """ @@ -690,7 +682,7 @@ def _make_mori_config( data_type: Tensor data type quant_dtype: Quantization data type (optional) """ - import mori.ops.dispatch_combine as mori_ops + from mori.ops import EpDispatchCombineConfig from mori.ops.dispatch_combine import EpDispatchCombineKernelType # Default values (can be overridden by JSON) @@ -712,7 +704,7 @@ def _make_mori_config( ) block_num = global_config.get("block_num", block_num) - config = mori_ops.EpDispatchCombineConfig( + config = EpDispatchCombineConfig( data_type=data_type if quant_dtype is None else quant_dtype, rank=self.rank, world_size=self.world_size, @@ -764,7 +756,7 @@ def create_mori_handle( data_type: torch.dtype = torch.bfloat16, quant_dtype: torch.dtype | None = None, ): - import mori + from mori.ops import EpDispatchCombineOp config = self._make_mori_config( max_num_tokens=max_num_tokens, @@ -776,7 +768,7 @@ def create_mori_handle( data_type=data_type, quant_dtype=quant_dtype, ) - op = mori.ops.EpDispatchCombineOp(config) + op = EpDispatchCombineOp(config) logger.debug( "[rank %s] Created mori handle with config: tokens=%d, experts=%d," " topk=%d, hidden_dim=%d", @@ -816,10 +808,10 @@ def destroy(self): # finalize mori shared memory if it was initialized if self._shmem_initialized: try: - import mori.shmem + from mori.shmem import shmem_finalize # Check if shmem is actually active before finalizing - mori.shmem.shmem_finalize() + shmem_finalize() logger.debug("[rank %s] mori shmem finalize", self.dp_rank) except Exception as shmem_error: logger.debug( diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 402c9ea91fe5..cb16cb39ce6f 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -4,7 +4,6 @@ from contextlib import contextmanager from typing import Any -from vllm.model_executor.layers.fused_moe.aiter_experts import AiterExperts from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, @@ -18,6 +17,7 @@ ) from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.fused_moe.utils import activation_without_mul +from vllm.platforms import current_platform from vllm.triton_utils import HAS_TRITON _config: dict[str, Any] | None = None @@ -94,7 +94,6 @@ def get_config() -> dict[str, Any] | None: "BatchedDeepGemmExperts", "TritonOrDeepGemmExperts", "BatchedTritonOrDeepGemmExperts", - "AiterExperts", ] else: # Some model classes directly use the custom ops. Add placeholders @@ -104,3 +103,10 @@ def _raise_exception(method: str): fused_topk = lambda *args, **kwargs: _raise_exception("fused_topk") fused_experts = lambda *args, **kwargs: _raise_exception("fused_experts") + +if current_platform.is_rocm(): + from vllm.model_executor.layers.fused_moe.aiter_mori_experts import AiterMoriExperts + + __all__ += [ + "AiterMoriExperts", + ] diff --git a/vllm/model_executor/layers/fused_moe/aiter_experts.py b/vllm/model_executor/layers/fused_moe/aiter_mori_experts.py similarity index 94% rename from vllm/model_executor/layers/fused_moe/aiter_experts.py rename to vllm/model_executor/layers/fused_moe/aiter_mori_experts.py index c615fcb9b550..d9bbd6ac3ed5 100644 --- a/vllm/model_executor/layers/fused_moe/aiter_experts.py +++ b/vllm/model_executor/layers/fused_moe/aiter_mori_experts.py @@ -16,7 +16,7 @@ ) -class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute): +class AiterMoriExperts(mk.FusedMoEPermuteExpertsUnpermute): """ Aiter-based expert processing that works with Mori dispatch/combine. @@ -29,6 +29,11 @@ def __init__( max_num_tokens: int, quant_config: FusedMoEQuantConfig, ): + from vllm.platforms.rocm import on_mi3xx + + if not on_mi3xx(): + raise RuntimeError("AiterMoriExperts should be used on AMD mi3xx GPUs") + super().__init__( quant_config=quant_config, ) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index c48cbc641a53..6db5a22234bc 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -446,10 +446,10 @@ def select_gemm_impl( quant_config=self.moe_quant_config, ) elif self.moe.use_mori_kernels and is_rocm_aiter_moe_enabled(): - from vllm.model_executor.layers.fused_moe import AiterExperts + from vllm.model_executor.layers.fused_moe import AiterMoriExperts - logger.debug("AiterExperts for Mori integration %s", self.moe) - return AiterExperts( + logger.debug("AiterMoriExperts for Mori integration %s", self.moe) + return AiterMoriExperts( max_num_tokens=self.moe.max_num_tokens, quant_config=self.moe_quant_config, ) diff --git a/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py index 72aff122b66e..f16821152b23 100644 --- a/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py @@ -131,14 +131,13 @@ def prepare( if self.use_fp8_dispatch: from aiter import QuantType, get_hip_quant - block_shape = quant_config.block_shape - if block_shape is not None: + if quant_config.block_shape is not None: assert not apply_router_weight_on_input, ( "apply_router_weight_on_input is not supported for block scaled moe" ) quant_type = QuantType.per_1x128 else: - quant_type = QuantType.per_Token + quant_type = QuantType.per_Tensor quant_func = get_hip_quant(quant_type) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index a1d73c02415e..1451d834c786 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -221,7 +221,7 @@ def rocm_aiter_fused_moe_impl( # Check if input is already pre-quantized (from mori dispatch) input_is_pre_quantized = ( - a1_scale is not None and hidden_states.dtype == torch.float8_e4m3fnuz + a1_scale is not None and hidden_states.dtype == current_platform.fp8_dtype() ) dtype = output_dtype if input_is_pre_quantized else None diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 262a565dd8e0..3112930fe043 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -996,7 +996,7 @@ def select_gemm_impl( layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: from vllm.model_executor.layers.fused_moe import ( - AiterExperts, + AiterMoriExperts, BatchedTritonOrDeepGemmExperts, TritonOrDeepGemmExperts, ) @@ -1026,8 +1026,8 @@ def select_gemm_impl( allow_deep_gemm=self.allow_deep_gemm, ) elif self.moe.use_mori_kernels and is_rocm_aiter_moe_enabled(): - logger.debug("AiterExperts for Mori integration %s", self.moe) - return AiterExperts( + logger.debug("AiterMoriExperts for Mori integration %s", self.moe) + return AiterMoriExperts( max_num_tokens=self.moe.max_num_tokens, quant_config=self.moe_quant_config, ) From 3adc79ee6dc9e0c0df5c31dca281eba55971b344 Mon Sep 17 00:00:00 2001 From: ihbang Date: Tue, 21 Oct 2025 15:26:34 +0900 Subject: [PATCH 28/30] add quant_dtype check on _make_mori_config Signed-off-by: ihbang --- vllm/distributed/device_communicators/all2all.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 7d54733389f1..a67ce29327cb 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -685,6 +685,10 @@ def _make_mori_config( from mori.ops import EpDispatchCombineConfig from mori.ops.dispatch_combine import EpDispatchCombineKernelType + from vllm.platforms import current_platform + + assert quant_dtype is None or quant_dtype == current_platform.fp8_dtype() + # Default values (can be overridden by JSON) warp_num_per_block = 8 block_num = 80 From 756898a78f07c46faa10e7ba096bd60b661d1aa5 Mon Sep 17 00:00:00 2001 From: HakJu Kim Date: Mon, 27 Oct 2025 11:12:40 +0900 Subject: [PATCH 29/30] moved has_mori to import_utils.py Signed-off-by: HakJu Kim --- vllm/utils/import_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/utils/import_utils.py b/vllm/utils/import_utils.py index 65f588b52e5e..f7b8094a2949 100644 --- a/vllm/utils/import_utils.py +++ b/vllm/utils/import_utils.py @@ -346,6 +346,9 @@ def has_deep_ep() -> bool: """Whether the optional `deep_ep` package is available.""" return _has_module("deep_ep") +def has_mori() -> bool: + """Whether the optional `mori` package is available.""" + return _has_module("mori") def has_deep_gemm() -> bool: """Whether the optional `deep_gemm` package is available.""" From e9b8624c5bea160673d6305d22e1bd97fa94f9f5 Mon Sep 17 00:00:00 2001 From: HakJu Kim Date: Mon, 27 Oct 2025 11:31:15 +0900 Subject: [PATCH 30/30] fixed few more things (conflict and pre-commit) Signed-off-by: HakJu Kim --- vllm/distributed/device_communicators/all2all.py | 2 +- vllm/model_executor/layers/fused_moe/layer.py | 6 +----- vllm/utils/import_utils.py | 2 ++ 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 6563799d9a43..de22d10bcefb 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -14,7 +14,7 @@ from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.utils.flashinfer import has_flashinfer_all2all -from vllm.utils.import_utils import has_deep_ep, has_pplx, has_mori +from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx from .base_device_communicator import All2AllManagerBase, Cache diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index e46dc8b18725..af062b8c2ed9 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -56,7 +56,7 @@ from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe -from vllm.utils.import_utils import has_deep_ep, has_pplx, has_mori +from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx from vllm.utils.math_utils import cdiv, round_up from vllm.utils.torch_utils import current_stream, direct_register_custom_op from vllm.v1.worker.ubatching import dbo_current_ubatch_id @@ -99,13 +99,9 @@ def _eplb_map_to_physical_and_record( from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 rocm_aiter_grouped_topk as grouped_topk_aiter, ) - - grouped_topk_impl = rocm_aiter_grouped_topk else: from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk - grouped_topk_impl = grouped_topk - if current_platform.is_tpu(): from .moe_pallas import fused_moe as fused_moe_pallas else: diff --git a/vllm/utils/import_utils.py b/vllm/utils/import_utils.py index f7b8094a2949..5c94ca0b300f 100644 --- a/vllm/utils/import_utils.py +++ b/vllm/utils/import_utils.py @@ -346,10 +346,12 @@ def has_deep_ep() -> bool: """Whether the optional `deep_ep` package is available.""" return _has_module("deep_ep") + def has_mori() -> bool: """Whether the optional `mori` package is available.""" return _has_module("mori") + def has_deep_gemm() -> bool: """Whether the optional `deep_gemm` package is available.""" return _has_module("deep_gemm")