Skip to content

Conversation

@whitememory
Copy link

@whitememory whitememory commented Oct 1, 2025

Purpose

This PR is to integrate mori(https://github.com/ROCm/mori) kernel as all2all backend of vLLM to boost up EP performance at rocm. This is a work by a team from Moreh Inc.
Note that only low latency mode is available on mori. (internal static buffer size)
to activate mori -> install mori using from https://github.com/ROCm/mori and add export VLLM_ALL2ALL_BACKEND="mori" at server booting script

Test Plan

server

#!/bin/bash
export VLLM_ROCM_USE_AITER=1
export VLLM_RANDOMIZE_DP_DUMMY_INPUTS=1
export VLLM_ALL2ALL_BACKEND="mori"
export VLLM_MOE_DP_CHUNK_SIZE=1024
vllm serve "/app/model/DeepSeek-R1"
--served-model-name "deepseek-ai/DeepSeek-R1"
--port 8001
--trust-remote-code
--tensor-parallel-size 1
--data_parallel_size 8
--no-enable-prefix-caching
--enable-expert-parallel
--max-model-len 8192
--max-num-seqs 2048
--max_num_batched_token 32768
--gpu-memory-utilization 0.92
--block-size 1 --enforce-eager

curl test

curl http://localhost:8001/v1/chat/completions -H "Content-Type:
application/json" -d '{
"model": "deepseek-ai/DeepSeek-R1",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Who won the world series in 2020?"}
],
"max_tokens": 200
}'

benchmark script

#!/bin/bash
model_name="deepseek-ai/DeepSeek-R1"
CON=2000

curl -X POST http://localhost:8001/reset_prefix_cache
date
vllm bench serve
--backend vllm
--model $model_name
--metric-percentiles "25,50,75"
--percentile-metrics "itl,tps,ttft,e2el"
--port 8001
--num-prompts $CON
--max-concurrency $CON
--ignore-eos
--dataset-name random
--random-input-len 2000
--random-output-len 100

Test Result

  1. Introduction
    Since mori supports only low-latency, mori backend must use chunked forward (repeated moe)
Prefill performance can be worse than a2a backends which do not use chunked forward. (We set VLLM_MOE_DP_CHUNK_SIZE=1024, but test ISL was 2000 per each request)
    Thus, performance is measured using ITL metric.
    Cuda graph capture seems to have a problem at main 1405f0c, so we gave --enforce-eager to server argument.

  2. Performance result

mori

---------------Inter-token Latency----------------
Mean ITL (ms): 1146.22
Median ITL (ms): 189.88
P25 ITL (ms): 175.00
P50 ITL (ms): 189.88
P75 ITL (ms): 199.09

default(allgather_reducescatter)

---------------Inter-token Latency----------------
Mean ITL (ms): 1087.56
Median ITL (ms): 193.03
P25 ITL (ms): 175.06
P50 ITL (ms): 193.03
P75 ITL (ms): 203.93

naive

---------------Inter-token Latency----------------
Mean ITL (ms): 1198.89
Median ITL (ms): 259.74
P25 ITL (ms): 244.77
P50 ITL (ms): 259.74
P75 ITL (ms): 274.06

  1. Environment

test machine : mi300x
test model : DeepSeek-R1
base docker image : rocm/vllm-dev:official_0909_rc3_20250906
mori version : main 0a6f908
vllm version : 0.11.0rc1 and main 1405f0c

  1. Etc tests
    4-1. test at 0.11.0rc1
    curl test passed at all a2a backends tested on rocm. (naive, allgather_reducescatter, mori)
    cuda graph capture passed at all a2a backends tested on rocm. (allgather_reducescatter, mori) - naive backend is supposed to fail
    with this fix from mori cuda graph replay is also fine.
    4-2. test at main 1405f0c -> main at commit 1405f0c seems to have problem.
    curl test DID NOT PASS using all a2a backends tested on rocm. (naive, allgather_reducescatter, mori)
    cuda graph capture DID NOT PASS using all a2a backends tested on rocm. (naive, allgather_reducescatter, mori) - naive backend is supposed to fail
    4-2-1.
    Result of curl test seems like, at main 1405f0c, server responded to “You are a helpful assistant.” not "Who won the world series in 2020?”.
    server at 0.11.0rc1 responded to "Who won the world series in 2020?” correctly.

  2. Future works
    5-1. Even at 0.11.0rc1, cuda graph replay seems to have a problem. cuda graph capture is done, but it hangs at replay. We will investigate on this.
    5-2 integrate mori with DBO
    5-3. when mori HT mode releases, apply the mode to enhance prefill performance


Essential Elements of an Effective PR Description Checklist
  • [O] The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • [O] The test plan, such as providing test command.
  • [O] The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@github-actions
Copy link

github-actions bot commented Oct 1, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify bot added the rocm Related to AMD ROCm label Oct 1, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request integrates the mori all2all backend to enhance performance on ROCm platforms. The changes are extensive, touching communication, model execution layers, and configuration. My review focuses on the correctness and robustness of the new integration. I've identified a few critical issues related to state management, incorrect caching, and potential runtime errors that should be addressed to ensure the stability and correctness of this new feature.

@whitememory whitememory changed the title vllm mori integration [Hardware][AMD} mori integration Oct 1, 2025
@whitememory whitememory changed the title [Hardware][AMD} mori integration [Hardware][AMD] mori integration Oct 1, 2025
@whitememory whitememory changed the title [Hardware][AMD] mori integration [Hardware][AMD] mori all2all backend integration Oct 1, 2025
@whitememory
Copy link
Author

At my environment, the pre-commit hanged...

I will apply pre-commit result, and respond to any reviews or code-assists tomorrow.

whitememory and others added 5 commits October 2, 2025 11:25
Note that only low latency mode is available on mori (https://github.com/ROCm/mori)

Co-authored-by: Inhyeok Bang <[email protected]>
Co-authored-by: Dongmin Ra <[email protected]>
Co-authored-by: Jimin Park <[email protected]>
Co-authored-by: Geonwoo Choi <[email protected]>
Signed-off-by: HakJu Kim <[email protected]>
caused by version difference

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: HakJu Kim <[email protected]>
Signed-off-by: HakJu Kim <[email protected]>
accepted suggestion

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: HakJu Kim <[email protected]>
Signed-off-by: HakJu Kim <[email protected]>
@whitememory whitememory changed the title [Hardware][AMD] mori all2all backend integration [Hardware][AMD][Kernel] mori all2all backend integration Oct 2, 2025
@whitememory
Copy link
Author

@mgoin
Please review for this PR. (or forward to code owner)

@HAIAI HAIAI self-requested a review October 13, 2025 18:26
@mergify
Copy link

mergify bot commented Oct 13, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @whitememory.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 13, 2025
@mergify mergify bot removed the needs-rebase label Oct 13, 2025
| deepep_low_latency | batched | fp8 | G(128),A,T<sup>3</sup> | Y | Y | [`DeepEPHTPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize.DeepEPHTPrepareAndFinalize] |
| flashinfer_all2allv | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferAllToAllMoEPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize.FlashInferAllToAllMoEPrepareAndFinalize] |
| flashinfer<sup>4</sup> | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferCutlassMoEPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize.FlashInferCutlassMoEPrepareAndFinalize] |
| flashinfer<sup>4</sup> | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferCutlassMoEPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize.FlashInferCutlassMoEPrepareAndFinalize] |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this line (duplicate of prior line)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not part of this PR, and it's not related.
May I fix it in this PR?
If so, I will remove it.

5. This is a no-op dispatcher that can be used to pair with any modular experts to produce a modular kernel that runs w/o dispatch or combine. These cannot be selected via environment variable. These are generally use for testing or adapting an expert subclass to the `fused_experts` API.
6. This depends on the experts implementation.
7. Currently, MoRI supports low-latency mode only.
8. This depends on the experts implementation, currently mori supports aiter.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this to explain or a direct answer to fp8?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we integrated mori on aiter moe only.
We found it natural to follow quant type and quant format of rocm aiter moe, which is in this doc at about line 103.

)

# Initialize mori shmem with the registered group
mori.shmem.shmem_torch_process_group_init(group_name)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 TP4 instances supported?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point!
we missed that possibility.
if this does not work for 2 different instances at same node, we will think how to give unique group_name.

Copy link

@ihbang ihbang Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed it so that multiple instances could use their own unique group name including their PPID (Parent Process ID)
It works well because all ranks in one instance shared the same PPID

except Exception as e:
logger.error("[rank %s] mori shmem init failed: %s", self.rank, e)
# Don't fail completely - mark as initialized to avoid retry loops
self._shmem_initialized = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we handle this differently instead of True

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ihbang would you take a look into this?
I was worried about this handling also.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed it to raise an Exception when shmem initialization is failed

data_type: torch.dtype = torch.bfloat16,
quant_dtype: torch.dtype | None = None,
):
import mori
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do specific import

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed (some other import are also fixed)

"""Create mori EpDispatchCombineConfig"""
import mori.ops.dispatch_combine as mori_ops
from mori.ops.dispatch_combine import EpDispatchCombineKernelType

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding check for data_type and quant_dtype before proceed?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quant_dtype check is added.
data_type check is not included because it seems that mori doesn't have specific dtype restriction

from contextlib import contextmanager
from typing import Any

from vllm.model_executor.layers.fused_moe.aiter_experts import AiterExperts
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add rocm check

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rocm check is added

@@ -0,0 +1,121 @@
# SPDX-License-Identifier: Apache-2.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename this file to aiter_mori_experts.py

Copy link

@ihbang ihbang Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this file is now renamed to aiter_mori_experts.py
(and AiterExperts class is also renamed to AiterMoriExperts)

elif moe.use_mori_kernels:
use_fp8_dispatch = (
quant_config is not None
and quant_config.quant_dtype == current_platform.fp8_dtype()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this works for gfx950 (OCP fp8), not gfx942

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we only tested mori all2all with gfx942(MI300X) but it works fine with fp8 fnuz

quant_type = QuantType.per_1x128
else:
quant_type = QuantType.per_Token

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add more .per_Tensor for example

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed it to use per_Tensor instead of per_Token (same with rocm_aiter_fused_experts)


# Check if input is already pre-quantized (from mori dispatch)
input_is_pre_quantized = (
a1_scale is not None and hidden_states.dtype == torch.float8_e4m3fnuz
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make this work for both gfx942 and gfx950, fp8 type

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed it not to use current_platform.fp8_dtype() instead of torch.float8_e4m3fnuz

def has_mori() -> bool:
"""Whether the optional `mori` package is available."""

return _has_module("mori")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider to add platform check

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed it to check current_platform.is_rocm() also

whitememory and others added 3 commits October 15, 2025 09:17
Signed-off-by: HakJu Kim <[email protected]>
this happened because of code conflict(semantically)

Signed-off-by: HakJu Kim <[email protected]>
@whitememory
Copy link
Author

whitememory commented Oct 15, 2025

NOTE : we made a progress in cuda graph problem. (capture ok, but replay hang)
-> if some rank runs cuda graph, and other ranks do not run cuda graph (like dummy run in some versions of vllm). mori hangs (mori depends on a cpu variable, which is updated in eager mode, but not updated in cuda graph mode)
We believe it is mori's bug which can occur at some versions of vllm.
We will test more and my colleague will register PR to mori repo.
-> we found similar fix at mori branch, and it's merged. we are "fixing that fix" in this PR.

Test Result is updated.

@mergify
Copy link

mergify bot commented Oct 17, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @whitememory.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 17, 2025
@ihbang
Copy link

ihbang commented Oct 20, 2025

I added "json_config" to EpDispatchCombineConfig and dispatch/combine methods of EpDispatchCombineOp to change its configs easily.
You can manage warp_num_per_block and block_num values with json file like below.
(VLLM_MORI_CONFIG_PATH should be set to the path of your json file)

{
  "global": {
    "warp_num_per_block": 8,
    "block_num": 80
  },
  "dispatch": {
    "warp_num_per_block": 16,
    "block_num": 80
  }
}

@whitememory whitememory requested a review from HAIAI October 21, 2025 08:58
@HAIAI
Copy link
Collaborator

HAIAI commented Oct 24, 2025

@whitememory may you fix the conflict? Thanks.

@whitememory
Copy link
Author

@whitememory may you fix the conflict? Thanks.

@HAIAI
we fixed the conflict, thank you

@whitememory
Copy link
Author

@HAIAI
Is it possible to continue reviewing before more conflicts come in?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation rocm Related to AMD ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants