Skip to content
Open
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
cd167ad
vllm mori integration from moreh team
whitememory Oct 1, 2025
77c8a56
Update vllm/model_executor/layers/fused_moe/modular_kernel.py
whitememory Oct 1, 2025
ea17a69
Update vllm/model_executor/layers/fused_moe/aiter_experts.py
whitememory Oct 1, 2025
a00b38b
applied pre-commit results and Read the Docs build results
whitememory Oct 2, 2025
07095ff
applied few suggestions from code-assistant
whitememory Oct 2, 2025
26ec16d
applied pre-commit result from github
whitememory Oct 2, 2025
5c997be
applied some pre-commit results from github
whitememory Oct 2, 2025
9849fa6
hope this is last for pre-commit
whitememory Oct 2, 2025
f4fb63a
removing unrelated change
whitememory Oct 2, 2025
c4bbc1a
Merge branch 'main' into mori_moreh
whitememory Oct 3, 2025
6dabf63
Merge branch 'main' into mori_moreh
whitememory Oct 3, 2025
19e7d40
Merge branch 'main' into mori_moreh
whitememory Oct 3, 2025
a1de125
Merge branch 'main' into mori_moreh
whitememory Oct 4, 2025
d2f65f6
Merge branch 'main' into mori_moreh
whitememory Oct 4, 2025
fb01286
Merge branch 'main' into mori_moreh
whitememory Oct 6, 2025
739b489
applied pre-commit results after merging main to
whitememory Oct 6, 2025
770676e
applied additional pre-commit results
whitememory Oct 6, 2025
ff08bc8
Merge branch 'main' into mori_moreh
whitememory Oct 6, 2025
059f29a
applied pre-commit results...
whitememory Oct 6, 2025
221edae
Merge branch 'main' into mori_moreh
whitememory Oct 7, 2025
fc55d72
Merge branch 'main' into mori_moreh
whitememory Oct 7, 2025
7d52023
applied SageMoore's comments
whitememory Oct 9, 2025
e357840
Merge branch 'main' into mori_moreh
whitememory Oct 9, 2025
6d8ef43
following code difference from main
whitememory Oct 9, 2025
eace564
Applied few other comments from bnellnm
whitememory Oct 9, 2025
73093c5
Applied some pre-commit results.
whitememory Oct 9, 2025
b9b9a9b
refactor workspace_shapes of AiterExperts and handle_cache of MoriAll…
ihbang Oct 9, 2025
d3c6ce0
clean-up handle_cache at destroy() of mori a2a manager
whitememory Oct 9, 2025
32482ee
fixed according to SM211 rule
whitememory Oct 10, 2025
73a17d5
Merge branch 'main' into mori_moreh
whitememory Oct 10, 2025
eda8c8e
adding mori backend to moe kernel feature doc
whitememory Oct 10, 2025
797d819
Merge branch 'main' into mori_moreh
whitememory Oct 10, 2025
ddd3563
applied reviews from bnellnm
whitememory Oct 11, 2025
272205e
Merge branch 'main' into mori_moreh
whitememory Oct 11, 2025
4abf225
Merge branch 'main' into mori_moreh
whitememory Oct 12, 2025
1dbff2c
new precommit removed Optional
whitememory Oct 13, 2025
665a631
additional pre-commit results about Optional
whitememory Oct 13, 2025
753a506
Merge branch 'main' into mori_moreh
whitememory Oct 13, 2025
096f938
fixed few whitespaces
whitememory Oct 15, 2025
8badbd6
Merge branch 'main' into mori_moreh
whitememory Oct 15, 2025
25f8d59
applied pre-commit result
whitememory Oct 15, 2025
a476913
Add json config parsing logic to change mori configs easily
ihbang Oct 20, 2025
4c4306e
Applied review from HAIAI
ihbang Oct 20, 2025
3adc79e
add quant_dtype check on _make_mori_config
ihbang Oct 21, 2025
d4a1529
Merge branch 'main' into mori_moreh
whitememory Oct 27, 2025
756898a
moved has_mori to import_utils.py
whitememory Oct 27, 2025
e9b8624
fixed few more things (conflict and pre-commit)
whitememory Oct 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/design/moe_kernel_features.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ th {
| flashinfer<sup>4</sup> | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferCutlassMoEPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize.FlashInferCutlassMoEPrepareAndFinalize] |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this line (duplicate of prior line)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not part of this PR, and it's not related.
May I fix it in this PR?
If so, I will remove it.

| MoEPrepareAndFinalizeNoEP<sup>5</sup> | standard | fp8,int8 | G,A,T | N | Y | [`MoEPrepareAndFinalizeNoEP`][vllm.model_executor.layers.fused_moe.prepare_finalize.MoEPrepareAndFinalizeNoEP] |
| BatchedPrepareAndFinalize<sup>5</sup> | batched | fp8,int8 | G,A,T | N | Y | [`BatchedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedPrepareAndFinalize] |
| MoriPrepareAndFinalize<sup>7</sup> | standard | fp8<sup>8</sup> | G(128),A,T<sup>8</sup> |N | Y | [`MoriPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.mori_prepare_finalize.MoriPrepareAndFinalize] |

!!! info "Table key"
1. All types: mxfp4, nvfp4, int4, int8, fp8
Expand All @@ -49,6 +50,8 @@ th {
4. Controlled by different env vars (`VLLM_FLASHINFER_MOE_BACKEND` "throughput" or "latency")
5. This is a no-op dispatcher that can be used to pair with any modular experts to produce a modular kernel that runs w/o dispatch or combine. These cannot be selected via environment variable. These are generally use for testing or adapting an expert subclass to the `fused_experts` API.
6. This depends on the experts implementation.
7. Currently, MoRI supports low-latency mode only.
8. This depends on the experts implementation, currently mori supports aiter.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this to explain or a direct answer to fp8?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we integrated mori on aiter moe only.
We found it natural to follow quant type and quant format of rocm aiter moe, which is in this doc at about line 103.


---

Expand Down Expand Up @@ -118,3 +121,4 @@ The following table shows "families" of modular kernels that are intended to wor
| deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,</br>`TritonExperts`,</br>`TritonOrDeepGemmExperts`,</br>`CutlassExpertsFp8`, </br>`MarlinExperts` |
| deepep_low_latency,</br>pplx | `DeepEPLLPrepareAndFinalize`,</br>`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`BatchedTritonOrDeepGemmExperts`,</br>`CutlassBatchedExpertsFp8`|
| flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` |
| mori | `MoriPrepareAndFinalize` | `AiterExperts` |
211 changes: 210 additions & 1 deletion vllm/distributed/device_communicators/all2all.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from vllm.distributed import get_dp_group, get_ep_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.utils import has_deep_ep, has_pplx
from vllm.utils import has_deep_ep, has_mori, has_pplx
from vllm.utils.flashinfer import has_flashinfer_all2all

from .base_device_communicator import All2AllManagerBase, Cache
Expand Down Expand Up @@ -474,3 +474,212 @@ def cleanup(self):
self.prepare_workspace_tensor = None
self.mapping = None
self.initialized = False


class MoriAll2AllManager(All2AllManagerBase):
"""
All2All communication based on mori kernels.
"""

def __init__(self, cpu_group):
assert has_mori(), "Please install mori from ROCm/mori github."

super().__init__(cpu_group)
self.handle_cache = Cache()
self.config = None
self._shmem_initialized = False
# Delay mori shmem initialization until first use
logger.debug("[rank %s] MoriAll2AllManager created", self.rank)

def _ensure_shmem_initialized(self):
"""Initialize mori's shared memory system lazily"""
if self._shmem_initialized:
return

import mori.shmem
import torch.distributed as dist

try:
# Check if we have a valid backend
backend = dist.get_backend()
if backend is None:
raise RuntimeError("No valid distributed backend found")

logger.debug(
"[rank %s] PyTorch distributed ready with backend: %s",
self.rank,
backend,
)

current_group = (
self.cpu_group if self.cpu_group is not None else dist.group.WORLD
)
group_name = "mori_shmem_group"

try:
import torch._C._distributed_c10d as c10d

# Register the current process group
c10d._register_process_group(group_name, current_group)
logger.debug(
"[rank %s] Registered proc group %s", self.rank, group_name
)

# Initialize mori shmem with the registered group
mori.shmem.shmem_torch_process_group_init(group_name)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 TP4 instances supported?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point!
we missed that possibility.
if this does not work for 2 different instances at same node, we will think how to give unique group_name.

Copy link

@ihbang ihbang Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed it so that multiple instances could use their own unique group name including their PPID (Parent Process ID)
It works well because all ranks in one instance shared the same PPID

logger.debug("[rank %s] torch proc group shmem init success", self.rank)
self._shmem_initialized = True
return

except Exception as torch_error:
logger.debug(
"[rank %s] torch process group shmem init failed: %s",
self.rank,
torch_error,
)
self._shmem_initialized = True
logger.warning(
"[rank %s] Continue without mori shmem optimize", self.rank
)

except Exception as e:
logger.error("[rank %s] mori shmem init failed: %s", self.rank, e)
# Don't fail completely - mark as initialized to avoid retry loops
self._shmem_initialized = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we handle this differently instead of True

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ihbang would you take a look into this?
I was worried about this handling also.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed it to raise an Exception when shmem initialization is failed

logger.warning(
"[rank %s] Continuing without mori shmem optimize", self.rank
)

def _make_mori_config(
self,
max_num_tokens: int,
num_local_experts: int,
experts_per_token: int,
hidden_dim: int,
scale_dim: int,
scale_type_size: int,
data_type: torch.dtype = torch.bfloat16,
quant_dtype: Optional[torch.dtype] = None,
):
"""Create mori EpDispatchCombineConfig"""
import mori.ops.dispatch_combine as mori_ops
from mori.ops.dispatch_combine import EpDispatchCombineKernelType

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding check for data_type and quant_dtype before proceed?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quant_dtype check is added.
data_type check is not included because it seems that mori doesn't have specific dtype restriction

config = mori_ops.EpDispatchCombineConfig(
data_type=data_type if quant_dtype is None else quant_dtype,
rank=self.rank,
world_size=self.world_size,
hidden_dim=hidden_dim,
max_num_inp_token_per_rank=max_num_tokens,
num_experts_per_rank=num_local_experts,
num_experts_per_token=experts_per_token,
# Performance tuning parameters
# warp_num_per_block=8,
# block_num=80,
max_token_type_size=data_type.itemsize,
# Quantization support
scale_dim=scale_dim,
scale_type_size=scale_type_size,
# Determine kernel type based on topology
kernel_type=(
EpDispatchCombineKernelType.InterNode
if self.internode
else EpDispatchCombineKernelType.IntraNode
),
)

return config

def get_handle(self, kwargs):
"""
Get or create mori operation handle.
Args:
kwargs: Dictionary with keys:
- max_num_tokens: Maximum tokens per DP rank
- num_local_experts: Number of local experts
- experts_per_token: Number of experts per token (topk)
- hidden_dim: Hidden dimension size
- data_type: Tensor data type (optional, default bfloat16)
"""
# Ensure shmem is initialized before creating handles
self._ensure_shmem_initialized()

def create_mori_handle(
max_num_tokens: int,
num_local_experts: int,
experts_per_token: int,
hidden_dim: int,
scale_dim: int,
scale_type_size: int,
data_type: torch.dtype = torch.bfloat16,
quant_dtype: Optional[torch.dtype] = None,
):
import mori
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do specific import

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed (some other import are also fixed)


config = self._make_mori_config(
max_num_tokens=max_num_tokens,
num_local_experts=num_local_experts,
experts_per_token=experts_per_token,
hidden_dim=hidden_dim,
scale_dim=scale_dim,
scale_type_size=scale_type_size,
data_type=data_type,
quant_dtype=quant_dtype,
)
op = mori.ops.EpDispatchCombineOp(config)
logger.debug(
"[rank %s] Created mori handle with config: tokens=%d, experts=%d,"
" topk=%d, hidden_dim=%d",
self.dp_rank,
max_num_tokens,
num_local_experts,
experts_per_token,
hidden_dim,
)
return op

return self.handle_cache.get_or_create(kwargs, create_mori_handle)

def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
):
raise NotImplementedError

def combine(
self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False,
):
raise NotImplementedError

def destroy(self):
"""Clean up mori resources"""
try:
# Clear operation handle cache
with self.handle_cache._lock:
for _, handle in self.handle_cache._cache.items():
handle.destroy()

# finalize mori shared memory if it was initialized
if self._shmem_initialized:
try:
import mori.shmem

# Check if shmem is actually active before finalizing
mori.shmem.shmem_finalize()
logger.debug("[rank %s] mori shmem finalize", self.dp_rank)
except Exception as shmem_error:
logger.debug(
"[rank %s] shmem finalize failed "
"(may not have been active): %s",
self.dp_rank,
shmem_error,
)

logger.debug("[rank %s] mori resources cleaned up", self.dp_rank)

except Exception as e:
logger.warning("[rank %s] mori cleanup fail: %s", self.dp_rank, e)
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
import torch.distributed as dist
from torch.distributed import ProcessGroup

from vllm.logger import init_logger

logger = init_logger(__name__)


class Cache:
def __init__(self):
Expand Down
5 changes: 5 additions & 0 deletions vllm/distributed/device_communicators/cuda_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ def __init__(

self.all2all_manager = FlashInferAllToAllManager(self.cpu_group)
logger.info("Using Flashinfer all2allv manager.")
elif all2all_backend == "mori":
from .all2all import MoriAll2AllManager

self.all2all_manager = MoriAll2AllManager(self.cpu_group)
logger.info("Using Mori all2all manager.")
else:
raise ValueError(f"Unknown all2all backend: {all2all_backend}")

Expand Down
3 changes: 3 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@
VLLM_ALL2ALL_BACKEND: Literal[
"naive",
"pplx",
"mori",
"deepep_high_throughput",
"deepep_low_latency",
"allgather_reducescatter",
Expand Down Expand Up @@ -1141,6 +1142,7 @@ def get_vllm_port() -> Optional[int]:
# - "allgather_reducescatter": all2all implementation based on allgather and
# reducescatter
# - "pplx": use pplx kernels
# - "mori": use mori kernels (currently, only low-latency is supported)
# - "deepep_high_throughput", use deepep high-throughput kernels
# - "deepep_low_latency", use deepep low-latency kernels
# - "flashinfer_all2allv", use flashinfer alltoallv kernels for mnnvl
Expand All @@ -1150,6 +1152,7 @@ def get_vllm_port() -> Optional[int]:
[
"naive",
"pplx",
"mori",
"deepep_high_throughput",
"deepep_low_latency",
"allgather_reducescatter",
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/fused_moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from contextlib import contextmanager
from typing import Any, Optional

from vllm.model_executor.layers.fused_moe.aiter_experts import AiterExperts
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add rocm check

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rocm check is added

from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
Expand Down Expand Up @@ -94,6 +95,7 @@ def get_config() -> Optional[dict[str, Any]]:
"BatchedDeepGemmExperts",
"TritonOrDeepGemmExperts",
"BatchedTritonOrDeepGemmExperts",
"AiterExperts",
]
else:
# Some model classes directly use the custom ops. Add placeholders
Expand Down
Loading