Skip to content

Conversation

@dongmin-ra
Copy link
Contributor

@dongmin-ra dongmin-ra commented Oct 23, 2025

This is a work by a team from Moreh Inc.

Motivation

Fixed the problem where the output occasionally becomes incorrect.

Technical Details

Issue

When running the intranode EP, the output occasionally becomes incorrect.

Cause

This issue occurs because both the dispatch and combine phases share the same output buffers,
shmemOutTokMemObj and shmemOutWeightsMemObj.
Since these buffers are reused across phases, they can be overwritten due to differences in execution timing between ranks.

In intranode EP, each dispatch/combine kernel behaviors are as follows:

  • Dispatch (send phase): Sends data to the remote device’s shmemOutTokMemObj and shmemOutWeightsMemObj.
  • Combine (send phase): Writes the combine output and output weights into its ownshmemOutTokMemObj and shmemOutWeightsMemObj.
    The Core Issue

Because the same buffers are reused for every dispatch/combine call, the combine result should be copied to the user buffer.

  • Within a single rank, the sequence dispatch → memcpy → combine → memcpy is strictly ordered, so it’s safe.
  • However, across ranks, there’s no guaranteed order between dispatch and combine phases except for the cross barrier used inside combine.

While one rank is performing the combine recv phase, another rank might already be executing the dispatch send phase.

  • If there’s a delay between the combine call and the memcpy that follows it, another rank’s dispatch may overwrite the shared buffer in that gap.
  • This situation makes output incorrect.
image

Fix

Split shmemOutTokMemObj and shmemOutWeightsMemObj into seperate objects for dispatch and combine

Test Plan

  1. Merge moreh-dev:enhance_mori_ep_test and moreh-dev:fix_ep_result_error branches
  1. Modify tests/python/ops/test_dispatch_combine.py as below:
diff --git tests/python/ops/test_dispatch_combine.py tests/python/ops/test_dispatch_combine.py
index c515434..d9a4dd7 100644
--- tests/python/ops/test_dispatch_combine.py
+++ tests/python/ops/test_dispatch_combine.py
@@ -454,7 +454,7 @@ def _test_dispatch_combine(
 @pytest.mark.parametrize("max_num_inp_token_per_rank", (1, 128, 2048))
 @pytest.mark.parametrize("num_experts_per_rank", (32,))
 @pytest.mark.parametrize("num_experts_per_token", (8,))
[email protected]("num_reps", (1,))
[email protected]("num_reps", (8,))
 def test_dispatch_combine(
     torch_dist_process_manager,
     world_size,
  1. Execute the test script
pytest -s -v tests/python/ops/test_dispatch_combine.py

Test Result

Before modification: the test fails

E       ValueError: Caught ValueError in rank 0.
E       Original Traceback (most recent call last):
E         File "/app/mori/tests/python/utils.py", line 171, in _worker
E           result = func(rank, *args)
E                    ^^^^^^^^^^^^^^^^^
E         File "/app/mori/tests/python/ops/test_dispatch_combine.py", line 445, in _test_dispatch_combine
E           test_case.run_test(op, test_data)
E         File "/app/mori/tests/python/ops/test_dispatch_combine.py", line 375, in run_test
E           self.check_result(test_data, mori_result, i)
E         File "/app/mori/tests/python/ops/test_dispatch_combine.py", line 400, in check_result
E           self.check_combine_result(test_data, combine_output, combine_weights, round)
E         File "/app/mori/tests/python/ops/test_dispatch_combine.py", line 269, in check_combine_result
E           raise ValueError(error_msg)
E       ValueError: 0-th Combine result mismatch for token 0:
E         indices[0]: [104, 49, 208, 235, 114, 18, 28, 55]
E         got: tensor([ 9.3259e-15, -5.9292e-20, -1.2821e-22,  ...,  3.1641e-01,
E                3.7969e+00, -8.4375e-01], device='cuda:0', dtype=torch.bfloat16)
E         expected : tensor([ 1.9922,  0.9961, -1.2500,  ...,  0.3164,  3.7969, -0.8438],
E              device='cuda:0', dtype=torch.bfloat16)

After modification: all test passed

Submission Checklist

@jhchouuu
Copy link
Collaborator

Thanks to moreh team! We will check these PRs as soon as possible.

@jhchouuu
Copy link
Collaborator

LGTM, the same problem as the same input buffer between dispatch and combine, thanks

@jhchouuu jhchouuu merged commit 8dd2f97 into ROCm:main Oct 27, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants