-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
[Hardware][AMD][Kernel] mori all2all backend integration #26013
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 32 commits
cd167ad
77c8a56
ea17a69
a00b38b
07095ff
26ec16d
5c997be
9849fa6
f4fb63a
c4bbc1a
6dabf63
19e7d40
a1de125
d2f65f6
fb01286
739b489
770676e
ff08bc8
059f29a
221edae
fc55d72
7d52023
e357840
6d8ef43
eace564
73093c5
b9b9a9b
d3c6ce0
32482ee
73a17d5
eda8c8e
797d819
ddd3563
272205e
4abf225
1dbff2c
665a631
753a506
096f938
8badbd6
25f8d59
a476913
4c4306e
3adc79e
d4a1529
756898a
e9b8624
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -41,6 +41,7 @@ th { | |
| | flashinfer<sup>4</sup> | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferCutlassMoEPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize.FlashInferCutlassMoEPrepareAndFinalize] | | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please remove this line (duplicate of prior line) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not part of this PR, and it's not related. |
||
| | MoEPrepareAndFinalizeNoEP<sup>5</sup> | standard | fp8,int8 | G,A,T | N | Y | [`MoEPrepareAndFinalizeNoEP`][vllm.model_executor.layers.fused_moe.prepare_finalize.MoEPrepareAndFinalizeNoEP] | | ||
| | BatchedPrepareAndFinalize<sup>5</sup> | batched | fp8,int8 | G,A,T | N | Y | [`BatchedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedPrepareAndFinalize] | | ||
| | MoriPrepareAndFinalize<sup>7</sup> | standard | fp8<sup>8</sup> | G(128),A,T<sup>8</sup> |N | Y | [`MoriPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.mori_prepare_finalize.MoriPrepareAndFinalize] | | ||
|
|
||
| !!! info "Table key" | ||
| 1. All types: mxfp4, nvfp4, int4, int8, fp8 | ||
|
|
@@ -49,6 +50,8 @@ th { | |
| 4. Controlled by different env vars (`VLLM_FLASHINFER_MOE_BACKEND` "throughput" or "latency") | ||
| 5. This is a no-op dispatcher that can be used to pair with any modular experts to produce a modular kernel that runs w/o dispatch or combine. These cannot be selected via environment variable. These are generally use for testing or adapting an expert subclass to the `fused_experts` API. | ||
| 6. This depends on the experts implementation. | ||
| 7. Currently, MoRI supports low-latency mode only. | ||
| 8. This depends on the experts implementation, currently mori supports aiter. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this to explain or a direct answer to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, we integrated mori on aiter moe only. |
||
|
|
||
| --- | ||
|
|
||
|
|
@@ -118,3 +121,4 @@ The following table shows "families" of modular kernels that are intended to wor | |
| | deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,</br>`TritonExperts`,</br>`TritonOrDeepGemmExperts`,</br>`CutlassExpertsFp8`, </br>`MarlinExperts` | | ||
| | deepep_low_latency,</br>pplx | `DeepEPLLPrepareAndFinalize`,</br>`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`BatchedTritonOrDeepGemmExperts`,</br>`CutlassBatchedExpertsFp8`| | ||
| | flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` | | ||
| | mori | `MoriPrepareAndFinalize` | `AiterExperts` | | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,7 +9,7 @@ | |
| from vllm.distributed import get_dp_group, get_ep_group | ||
| from vllm.forward_context import get_forward_context | ||
| from vllm.logger import init_logger | ||
| from vllm.utils import has_deep_ep, has_pplx | ||
| from vllm.utils import has_deep_ep, has_mori, has_pplx | ||
| from vllm.utils.flashinfer import has_flashinfer_all2all | ||
|
|
||
| from .base_device_communicator import All2AllManagerBase, Cache | ||
|
|
@@ -474,3 +474,212 @@ def cleanup(self): | |
| self.prepare_workspace_tensor = None | ||
| self.mapping = None | ||
| self.initialized = False | ||
|
|
||
|
|
||
| class MoriAll2AllManager(All2AllManagerBase): | ||
| """ | ||
| All2All communication based on mori kernels. | ||
| """ | ||
|
|
||
| def __init__(self, cpu_group): | ||
| assert has_mori(), "Please install mori from ROCm/mori github." | ||
|
|
||
| super().__init__(cpu_group) | ||
| self.handle_cache = Cache() | ||
| self.config = None | ||
| self._shmem_initialized = False | ||
| # Delay mori shmem initialization until first use | ||
| logger.debug("[rank %s] MoriAll2AllManager created", self.rank) | ||
|
|
||
| def _ensure_shmem_initialized(self): | ||
| """Initialize mori's shared memory system lazily""" | ||
| if self._shmem_initialized: | ||
| return | ||
|
|
||
| import mori.shmem | ||
| import torch.distributed as dist | ||
|
|
||
| try: | ||
| # Check if we have a valid backend | ||
| backend = dist.get_backend() | ||
| if backend is None: | ||
| raise RuntimeError("No valid distributed backend found") | ||
|
|
||
| logger.debug( | ||
| "[rank %s] PyTorch distributed ready with backend: %s", | ||
| self.rank, | ||
| backend, | ||
| ) | ||
|
|
||
| current_group = ( | ||
| self.cpu_group if self.cpu_group is not None else dist.group.WORLD | ||
whitememory marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ) | ||
| group_name = "mori_shmem_group" | ||
|
|
||
| try: | ||
| import torch._C._distributed_c10d as c10d | ||
|
|
||
| # Register the current process group | ||
| c10d._register_process_group(group_name, current_group) | ||
| logger.debug( | ||
| "[rank %s] Registered proc group %s", self.rank, group_name | ||
| ) | ||
|
|
||
| # Initialize mori shmem with the registered group | ||
| mori.shmem.shmem_torch_process_group_init(group_name) | ||
|
||
| logger.debug("[rank %s] torch proc group shmem init success", self.rank) | ||
| self._shmem_initialized = True | ||
| return | ||
|
|
||
| except Exception as torch_error: | ||
| logger.debug( | ||
| "[rank %s] torch process group shmem init failed: %s", | ||
| self.rank, | ||
| torch_error, | ||
| ) | ||
| self._shmem_initialized = True | ||
| logger.warning( | ||
| "[rank %s] Continue without mori shmem optimize", self.rank | ||
| ) | ||
|
|
||
| except Exception as e: | ||
| logger.error("[rank %s] mori shmem init failed: %s", self.rank, e) | ||
| # Don't fail completely - mark as initialized to avoid retry loops | ||
| self._shmem_initialized = True | ||
|
||
| logger.warning( | ||
| "[rank %s] Continuing without mori shmem optimize", self.rank | ||
| ) | ||
|
|
||
| def _make_mori_config( | ||
| self, | ||
| max_num_tokens: int, | ||
| num_local_experts: int, | ||
| experts_per_token: int, | ||
| hidden_dim: int, | ||
| scale_dim: int, | ||
| scale_type_size: int, | ||
| data_type: torch.dtype = torch.bfloat16, | ||
| quant_dtype: Optional[torch.dtype] = None, | ||
| ): | ||
| """Create mori EpDispatchCombineConfig""" | ||
| import mori.ops.dispatch_combine as mori_ops | ||
| from mori.ops.dispatch_combine import EpDispatchCombineKernelType | ||
|
|
||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. adding check for data_type and quant_dtype before proceed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. quant_dtype check is added. |
||
| config = mori_ops.EpDispatchCombineConfig( | ||
| data_type=data_type if quant_dtype is None else quant_dtype, | ||
| rank=self.rank, | ||
| world_size=self.world_size, | ||
| hidden_dim=hidden_dim, | ||
| max_num_inp_token_per_rank=max_num_tokens, | ||
| num_experts_per_rank=num_local_experts, | ||
| num_experts_per_token=experts_per_token, | ||
| # Performance tuning parameters | ||
| # warp_num_per_block=8, | ||
| # block_num=80, | ||
| max_token_type_size=data_type.itemsize, | ||
| # Quantization support | ||
| scale_dim=scale_dim, | ||
| scale_type_size=scale_type_size, | ||
| # Determine kernel type based on topology | ||
| kernel_type=( | ||
| EpDispatchCombineKernelType.InterNode | ||
| if self.internode | ||
| else EpDispatchCombineKernelType.IntraNode | ||
| ), | ||
| ) | ||
|
|
||
| return config | ||
|
|
||
| def get_handle(self, kwargs): | ||
| """ | ||
| Get or create mori operation handle. | ||
| Args: | ||
| kwargs: Dictionary with keys: | ||
| - max_num_tokens: Maximum tokens per DP rank | ||
| - num_local_experts: Number of local experts | ||
| - experts_per_token: Number of experts per token (topk) | ||
| - hidden_dim: Hidden dimension size | ||
| - data_type: Tensor data type (optional, default bfloat16) | ||
| """ | ||
| # Ensure shmem is initialized before creating handles | ||
| self._ensure_shmem_initialized() | ||
|
|
||
| def create_mori_handle( | ||
| max_num_tokens: int, | ||
| num_local_experts: int, | ||
| experts_per_token: int, | ||
| hidden_dim: int, | ||
| scale_dim: int, | ||
| scale_type_size: int, | ||
| data_type: torch.dtype = torch.bfloat16, | ||
| quant_dtype: Optional[torch.dtype] = None, | ||
| ): | ||
| import mori | ||
|
||
|
|
||
| config = self._make_mori_config( | ||
| max_num_tokens=max_num_tokens, | ||
| num_local_experts=num_local_experts, | ||
| experts_per_token=experts_per_token, | ||
| hidden_dim=hidden_dim, | ||
| scale_dim=scale_dim, | ||
| scale_type_size=scale_type_size, | ||
| data_type=data_type, | ||
| quant_dtype=quant_dtype, | ||
| ) | ||
| op = mori.ops.EpDispatchCombineOp(config) | ||
| logger.debug( | ||
| "[rank %s] Created mori handle with config: tokens=%d, experts=%d," | ||
| " topk=%d, hidden_dim=%d", | ||
| self.dp_rank, | ||
| max_num_tokens, | ||
| num_local_experts, | ||
| experts_per_token, | ||
| hidden_dim, | ||
| ) | ||
| return op | ||
|
|
||
| return self.handle_cache.get_or_create(kwargs, create_mori_handle) | ||
|
|
||
| def dispatch( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| router_logits: torch.Tensor, | ||
| is_sequence_parallel: bool = False, | ||
| ): | ||
| raise NotImplementedError | ||
|
|
||
| def combine( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| is_sequence_parallel: bool = False, | ||
| ): | ||
| raise NotImplementedError | ||
|
|
||
| def destroy(self): | ||
| """Clean up mori resources""" | ||
| try: | ||
| # Clear operation handle cache | ||
| with self.handle_cache._lock: | ||
| for _, handle in self.handle_cache._cache.items(): | ||
| handle.destroy() | ||
|
|
||
| # finalize mori shared memory if it was initialized | ||
| if self._shmem_initialized: | ||
| try: | ||
| import mori.shmem | ||
|
|
||
| # Check if shmem is actually active before finalizing | ||
| mori.shmem.shmem_finalize() | ||
| logger.debug("[rank %s] mori shmem finalize", self.dp_rank) | ||
| except Exception as shmem_error: | ||
| logger.debug( | ||
| "[rank %s] shmem finalize failed " | ||
| "(may not have been active): %s", | ||
| self.dp_rank, | ||
| shmem_error, | ||
| ) | ||
|
|
||
| logger.debug("[rank %s] mori resources cleaned up", self.dp_rank) | ||
|
|
||
| except Exception as e: | ||
| logger.warning("[rank %s] mori cleanup fail: %s", self.dp_rank, e) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
| from contextlib import contextmanager | ||
| from typing import Any, Optional | ||
|
|
||
| from vllm.model_executor.layers.fused_moe.aiter_experts import AiterExperts | ||
|
||
| from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig | ||
| from vllm.model_executor.layers.fused_moe.layer import ( | ||
| FusedMoE, | ||
|
|
@@ -94,6 +95,7 @@ def get_config() -> Optional[dict[str, Any]]: | |
| "BatchedDeepGemmExperts", | ||
| "TritonOrDeepGemmExperts", | ||
| "BatchedTritonOrDeepGemmExperts", | ||
| "AiterExperts", | ||
| ] | ||
| else: | ||
| # Some model classes directly use the custom ops. Add placeholders | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.