Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
cd167ad
vllm mori integration from moreh team
whitememory Oct 1, 2025
77c8a56
Update vllm/model_executor/layers/fused_moe/modular_kernel.py
whitememory Oct 1, 2025
ea17a69
Update vllm/model_executor/layers/fused_moe/aiter_experts.py
whitememory Oct 1, 2025
a00b38b
applied pre-commit results and Read the Docs build results
whitememory Oct 2, 2025
07095ff
applied few suggestions from code-assistant
whitememory Oct 2, 2025
26ec16d
applied pre-commit result from github
whitememory Oct 2, 2025
5c997be
applied some pre-commit results from github
whitememory Oct 2, 2025
9849fa6
hope this is last for pre-commit
whitememory Oct 2, 2025
f4fb63a
removing unrelated change
whitememory Oct 2, 2025
c4bbc1a
Merge branch 'main' into mori_moreh
whitememory Oct 3, 2025
6dabf63
Merge branch 'main' into mori_moreh
whitememory Oct 3, 2025
19e7d40
Merge branch 'main' into mori_moreh
whitememory Oct 3, 2025
a1de125
Merge branch 'main' into mori_moreh
whitememory Oct 4, 2025
d2f65f6
Merge branch 'main' into mori_moreh
whitememory Oct 4, 2025
fb01286
Merge branch 'main' into mori_moreh
whitememory Oct 6, 2025
739b489
applied pre-commit results after merging main to
whitememory Oct 6, 2025
770676e
applied additional pre-commit results
whitememory Oct 6, 2025
ff08bc8
Merge branch 'main' into mori_moreh
whitememory Oct 6, 2025
059f29a
applied pre-commit results...
whitememory Oct 6, 2025
221edae
Merge branch 'main' into mori_moreh
whitememory Oct 7, 2025
fc55d72
Merge branch 'main' into mori_moreh
whitememory Oct 7, 2025
7d52023
applied SageMoore's comments
whitememory Oct 9, 2025
e357840
Merge branch 'main' into mori_moreh
whitememory Oct 9, 2025
6d8ef43
following code difference from main
whitememory Oct 9, 2025
eace564
Applied few other comments from bnellnm
whitememory Oct 9, 2025
73093c5
Applied some pre-commit results.
whitememory Oct 9, 2025
b9b9a9b
refactor workspace_shapes of AiterExperts and handle_cache of MoriAll…
ihbang Oct 9, 2025
d3c6ce0
clean-up handle_cache at destroy() of mori a2a manager
whitememory Oct 9, 2025
32482ee
fixed according to SM211 rule
whitememory Oct 10, 2025
73a17d5
Merge branch 'main' into mori_moreh
whitememory Oct 10, 2025
eda8c8e
adding mori backend to moe kernel feature doc
whitememory Oct 10, 2025
797d819
Merge branch 'main' into mori_moreh
whitememory Oct 10, 2025
ddd3563
applied reviews from bnellnm
whitememory Oct 11, 2025
272205e
Merge branch 'main' into mori_moreh
whitememory Oct 11, 2025
4abf225
Merge branch 'main' into mori_moreh
whitememory Oct 12, 2025
1dbff2c
new precommit removed Optional
whitememory Oct 13, 2025
665a631
additional pre-commit results about Optional
whitememory Oct 13, 2025
753a506
Merge branch 'main' into mori_moreh
whitememory Oct 13, 2025
096f938
fixed few whitespaces
whitememory Oct 15, 2025
8badbd6
Merge branch 'main' into mori_moreh
whitememory Oct 15, 2025
25f8d59
applied pre-commit result
whitememory Oct 15, 2025
a476913
Add json config parsing logic to change mori configs easily
ihbang Oct 20, 2025
4c4306e
Applied review from HAIAI
ihbang Oct 20, 2025
3adc79e
add quant_dtype check on _make_mori_config
ihbang Oct 21, 2025
d4a1529
Merge branch 'main' into mori_moreh
whitememory Oct 27, 2025
756898a
moved has_mori to import_utils.py
whitememory Oct 27, 2025
e9b8624
fixed few more things (conflict and pre-commit)
whitememory Oct 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/design/moe_kernel_features.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ th {
| flashinfer<sup>4</sup> | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferCutlassMoEPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize.FlashInferCutlassMoEPrepareAndFinalize] |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this line (duplicate of prior line)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not part of this PR, and it's not related.
May I fix it in this PR?
If so, I will remove it.

| MoEPrepareAndFinalizeNoEP<sup>5</sup> | standard | fp8,int8 | G,A,T | N | Y | [`MoEPrepareAndFinalizeNoEP`][vllm.model_executor.layers.fused_moe.prepare_finalize.MoEPrepareAndFinalizeNoEP] |
| BatchedPrepareAndFinalize<sup>5</sup> | batched | fp8,int8 | G,A,T | N | Y | [`BatchedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedPrepareAndFinalize] |
| MoriPrepareAndFinalize<sup>7</sup> | standard | fp8<sup>8</sup> | G(128),A,T<sup>8</sup> | N | Y | [`MoriPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.mori_prepare_finalize.MoriPrepareAndFinalize] |

!!! info "Table key"
1. All types: mxfp4, nvfp4, int4, int8, fp8
Expand All @@ -49,6 +50,8 @@ th {
4. Controlled by different env vars (`VLLM_FLASHINFER_MOE_BACKEND` "throughput" or "latency")
5. This is a no-op dispatcher that can be used to pair with any modular experts to produce a modular kernel that runs w/o dispatch or combine. These cannot be selected via environment variable. These are generally use for testing or adapting an expert subclass to the `fused_experts` API.
6. This depends on the experts implementation.
7. Currently, MoRI supports low-latency mode only.
8. This depends on the experts implementation, currently mori supports aiter.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this to explain or a direct answer to fp8?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we integrated mori on aiter moe only.
We found it natural to follow quant type and quant format of rocm aiter moe, which is in this doc at about line 103.


---

Expand Down Expand Up @@ -117,4 +120,5 @@ The following table shows "families" of modular kernels that are intended to wor
|----------------------------------|------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------|
| deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,</br>`TritonExperts`,</br>`TritonOrDeepGemmExperts`,</br>`CutlassExpertsFp8`, </br>`MarlinExperts` |
| deepep_low_latency,</br>pplx | `DeepEPLLPrepareAndFinalize`,</br>`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`BatchedTritonOrDeepGemmExperts`,</br>`CutlassBatchedExpertsFp8`,</br>`BatchedMarlinExperts`|
| flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` |
| flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` |
| mori | `MoriPrepareAndFinalize` | `AiterMoriExperts` |
345 changes: 344 additions & 1 deletion vllm/distributed/device_communicators/all2all.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import os
from pathlib import Path
from typing import Any

import psutil
import torch
import torch.distributed as dist

Expand All @@ -10,7 +14,7 @@
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.utils.flashinfer import has_flashinfer_all2all
from vllm.utils.import_utils import has_deep_ep, has_pplx
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx

from .base_device_communicator import All2AllManagerBase, Cache

Expand Down Expand Up @@ -488,3 +492,342 @@ def cleanup(self):
self.prepare_workspace_tensor = None
self.mapping = None
self.initialized = False


class MoriAll2AllManager(All2AllManagerBase):
"""
All2All communication based on mori kernels.
"""

def __init__(self, cpu_group):
assert has_mori(), "Please install mori from ROCm/mori github."

super().__init__(cpu_group)
self.handle_cache = Cache()
self.config = None
self._shmem_initialized = False

self.json_config = None
config_path = envs.VLLM_MORI_CONFIG_PATH
if config_path:
self.json_config = self._load_mori_config_from_json(config_path)

# Delay mori shmem initialization until first use
logger.debug("[rank %s] MoriAll2AllManager created", self.rank)

def _ensure_shmem_initialized(self):
"""Initialize mori's shared memory system lazily"""
if self._shmem_initialized:
return

import torch.distributed as dist
from mori.shmem import shmem_torch_process_group_init

try:
# Check if we have a valid backend
backend = dist.get_backend()
if backend is None:
raise RuntimeError("No valid distributed backend found")

logger.debug(
"[rank %s] PyTorch distributed ready with backend: %s",
self.rank,
backend,
)

assert self.cpu_group is not None, "No CPU group is given to mori"
ppid = psutil.Process(os.getpid()).ppid()
group_name = f"mori_shmem_group_{ppid}"

try:
import torch._C._distributed_c10d as c10d

# Register the process group
c10d._register_process_group(group_name, self.cpu_group)
logger.debug(
"[rank %s] Registered proc group %s", self.rank, group_name
)

# Initialize mori shmem with the registered group
shmem_torch_process_group_init(group_name)
logger.debug("[rank %s] torch proc group shmem init success", self.rank)
self._shmem_initialized = True
return

except Exception as torch_error:
raise RuntimeError(
"torch process group initialization failed"
) from torch_error

except Exception as e:
raise RuntimeError("mori shmem initialization failed") from e

def _load_mori_config_from_json(self, json_path: str) -> dict | None:
"""
Load mori configuration parameters from JSON file.

Supports both flat and hierarchical schema:

Flat schema:
{
"warp_num_per_block": 8,
"block_num": 80,
}

Hierarchical schema (dispatch/combine specific):
{
"global": {
"warp_num_per_block": 8,
"block_num": 80,
},
"dispatch": {
"warp_num_per_block": 16,
"block_num": 160
},
"combine": {
"warp_num_per_block": 4,
"block_num": 40
}
}

Args:
json_path: Path to JSON configuration file

Returns:
Dictionary of configuration parameters, or None if file doesn't exist

Raises:
ValueError: If JSON is invalid or contains unsupported parameters
"""
if not json_path:
return None

json_file = Path(json_path)
if not json_file.exists():
logger.warning(
"[rank %d] Mori config file not found: %s", self.rank, json_path
)
return None

try:
with open(json_file) as f:
config = json.load(f)

# Valid parameter keys
valid_param_keys = {
"warp_num_per_block",
"block_num",
}

is_hierarchical = any(
key in config for key in ["global", "dispatch", "combine"]
)

if is_hierarchical:
valid_top_keys = {"global", "dispatch", "combine"}
invalid_keys = set(config.keys()) - valid_top_keys
if invalid_keys:
raise ValueError(
f"Invalid top-level keys: {invalid_keys}. "
f"Valid keys: {valid_top_keys}"
)

# Validate each section
for section in ["global", "dispatch", "combine"]:
if section in config:
section_config = config[section]
if not isinstance(section_config, dict):
raise ValueError(f"'{section}' must be a dictionary")

invalid_keys = set(section_config.keys()) - valid_param_keys
if invalid_keys:
raise ValueError(
f"Invalid keys in '{section}': {invalid_keys}. "
f"Valid keys: {valid_param_keys}"
)
else:
invalid_keys = set(config.keys()) - valid_param_keys
if invalid_keys:
raise ValueError(
f"Invalid config keys: {invalid_keys}. "
f"Valid keys: {valid_param_keys}"
)

return config

except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in mori config file {json_path}") from e
except Exception as e:
raise ValueError(f"Error loading mori config from {json_path}") from e

def _make_mori_config(
self,
max_num_tokens: int,
num_local_experts: int,
experts_per_token: int,
hidden_dim: int,
scale_dim: int,
scale_type_size: int,
data_type: torch.dtype = torch.bfloat16,
quant_dtype: torch.dtype | None = None,
):
"""
Create mori EpDispatchCombineConfig.

Args:
max_num_tokens: Maximum number of tokens per DP rank
num_local_experts: Number of local experts
experts_per_token: Number of experts per token (topk)
hidden_dim: Hidden dimension size
scale_dim: Scale dimension for quantization
scale_type_size: Scale type size for quantization
data_type: Tensor data type
quant_dtype: Quantization data type (optional)
"""
from mori.ops import EpDispatchCombineConfig
from mori.ops.dispatch_combine import EpDispatchCombineKernelType

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding check for data_type and quant_dtype before proceed?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quant_dtype check is added.
data_type check is not included because it seems that mori doesn't have specific dtype restriction

from vllm.platforms import current_platform

assert quant_dtype is None or quant_dtype == current_platform.fp8_dtype()

# Default values (can be overridden by JSON)
warp_num_per_block = 8
block_num = 80

# Override with JSON config if provided
if self.json_config is not None:
is_hierarchical = any(
key in self.json_config for key in ["global", "dispatch", "combine"]
)

global_config = self.json_config
if is_hierarchical and "global" in global_config:
global_config = self.json_config["global"]

warp_num_per_block = global_config.get(
"warp_num_per_block", warp_num_per_block
)
block_num = global_config.get("block_num", block_num)

config = EpDispatchCombineConfig(
data_type=data_type if quant_dtype is None else quant_dtype,
rank=self.rank,
world_size=self.world_size,
hidden_dim=hidden_dim,
max_num_inp_token_per_rank=max_num_tokens,
num_experts_per_rank=num_local_experts,
num_experts_per_token=experts_per_token,
max_token_type_size=data_type.itemsize,
# Performance tuning parameters
warp_num_per_block=warp_num_per_block,
block_num=block_num,
# Quantization support
scale_dim=scale_dim,
scale_type_size=scale_type_size,
# Determine kernel type based on topology
kernel_type=(
EpDispatchCombineKernelType.InterNode
if self.internode
else EpDispatchCombineKernelType.IntraNode
),
)

return config

def get_handle(self, kwargs):
"""
Get or create mori operation handle.
Args:
kwargs: Dictionary with keys:
- max_num_tokens: Maximum tokens per DP rank
- num_local_experts: Number of local experts
- experts_per_token: Number of experts per token (topk)
- hidden_dim: Hidden dimension size
- data_type: Tensor data type (optional, default bfloat16)
- scale_dim: Scale dimension (optional)
- scale_type_size: Scale type size (optional)
- ubatch_id: Microbatch ID (optional)
"""
# Ensure shmem is initialized before creating handles
self._ensure_shmem_initialized()

def create_mori_handle(
max_num_tokens: int,
num_local_experts: int,
experts_per_token: int,
hidden_dim: int,
scale_dim: int,
scale_type_size: int,
data_type: torch.dtype = torch.bfloat16,
quant_dtype: torch.dtype | None = None,
):
from mori.ops import EpDispatchCombineOp

config = self._make_mori_config(
max_num_tokens=max_num_tokens,
num_local_experts=num_local_experts,
experts_per_token=experts_per_token,
hidden_dim=hidden_dim,
scale_dim=scale_dim,
scale_type_size=scale_type_size,
data_type=data_type,
quant_dtype=quant_dtype,
)
op = EpDispatchCombineOp(config)
logger.debug(
"[rank %s] Created mori handle with config: tokens=%d, experts=%d,"
" topk=%d, hidden_dim=%d",
self.dp_rank,
max_num_tokens,
num_local_experts,
experts_per_token,
hidden_dim,
)
return op

return self.handle_cache.get_or_create(kwargs, create_mori_handle)

def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
):
raise NotImplementedError

def combine(
self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False,
):
raise NotImplementedError

def destroy(self):
"""Clean up mori resources"""
try:
# Clear operation handle cache
with self.handle_cache._lock:
for _, handle in self.handle_cache._cache.items():
handle.destroy()

# finalize mori shared memory if it was initialized
if self._shmem_initialized:
try:
from mori.shmem import shmem_finalize

# Check if shmem is actually active before finalizing
shmem_finalize()
logger.debug("[rank %s] mori shmem finalize", self.dp_rank)
except Exception as shmem_error:
logger.debug(
"[rank %s] shmem finalize failed "
"(may not have been active): %s",
self.dp_rank,
shmem_error,
)

logger.debug("[rank %s] mori resources cleaned up", self.dp_rank)

except Exception as e:
logger.warning("[rank %s] mori cleanup fail: %s", self.dp_rank, e)
Loading