Skip to content

Conversation

@dongmin-ra
Copy link
Contributor

@dongmin-ra dongmin-ra commented Oct 23, 2025

This is a work by a team from Moreh Inc.

Motivation

Currently, Mori EP unit test (i.e. test_dispatch_combine.py) does not sufficiently cover real cases.

Problem 1. Each test case performed dispatch/combine only once.

  • In practice, depending on the model, dispatch/combine is executed multiple times corresponding to the number of layers.
    • e.g. In DeepSeek-R1, dispatch/combine is performed 58 times per step.
  • When dispatch/combine is executed multiple times, there may be cases where the results are incorrect, but currently this is not being checked.

Problem 2. There is synchronization across devices before checking the results.

  • In practice, for performance reasons, explicit synchronization is not performed after dispatch and combine.
  • Most issues with the output values occur due to device-to-device timing differences.
    • Since the unit test is synchronizing across devices after dispatch/combine, it cannot detect the timing issue.

Technical Details

Updated the unit tests to more closely resemble real use cases:

  • Introduced the num_reps parameter to allow each test case to perform dispatch/combine multiple times.
  • Removed explicitly synchronize across devices after dispatch/combine.
    • Instead, results are stored in a list during execution.
    • verification is performed after all dispatch/combine operations are complete.

Note

Currently, there is an issue in Mori EP where results are incorrect if num_reps is greater than 1. I will create a PR to fix this.

Fixed an issue where the test would hang if an error occurred on some ranks:

  • Previously, the parent process waited for results from all processes, but if an error occurred on some ranks, other ranks continued executing dispatch/combine and never sent a test result to the result queue.
    • As a result, the parent process would wait indefinitely for responses that never arrive, causing the test to hang.
  • Changed the behavior so that if an error occurs in any process, the error is raised immediately without waiting for the result queue.

Test Plan

pytest -s -v tests/python/ops/test_dispatch_combine.py

Test Result

=================================================================================================================================================== test session starts ====================================================================================================================================================
platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0 -- /usr/bin/python
cachedir: .pytest_cache
rootdir: /app/mori
plugins: assume-2.4.3, anyio-4.10.0, asyncio-1.1.0
asyncio: mode=Mode.STRICT, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function
collected 48 items

test_dispatch_combine.py::test_dispatch_combine[1-8-32-1-1-0-7168-data_type0-8] Multiprocessing start method set to spawn
SKIPPED (skip fp8 with scale_dim == 0)
test_dispatch_combine.py::test_dispatch_combine[1-8-32-1-1-0-7168-data_type1-8] rank 0 RDMA devices: mlx5_0, mlx5_1, mlx5_2, mlx5_3, mlx5_4
rank 7 rankInNode 7 select device [0] mlx5_0
rank 2 rankInNode 2 select device [3] mlx5_3
rank 4 rankInNode 4 select device [1] mlx5_1
rank 3 rankInNode 3 select device [4] mlx5_4
rank 0 rankInNode 0 select device [0] mlx5_0
rank 1 rankInNode 1 select device [2] mlx5_2
rank 6 rankInNode 6 select device [3] mlx5_3
rank 5 rankInNode 5 select device [3] mlx5_3
Running MORI dispatch/combine (#tokens=8): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 32.81it/s]
Checking Result: 1it [00:00, 20.22it/s]
PASSED
test_dispatch_combine.py::test_dispatch_combine[1-8-32-1-1-0-4096-data_type0-8] SKIPPED (skip fp8 with scale_dim == 0)
Running MORI dispatch/combine (#tokens=8): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 3024.01it/s]
Checking Result: 1it [00:00, 1384.72it/s]
PASSED
Running MORI dispatch/combine (#tokens=8): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 94.81it/s]
Checking Result: 1it [00:00, 1506.03it/s]
PASSED
test_dispatch_combine.py::test_dispatch_combine[1-8-32-1-1-32-7168-data_type1-8] SKIPPED (skip fp8 with scale_dim == 0)
Running MORI dispatch/combine (#tokens=8): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 2435.72it/s]
Checking Result: 1it [00:00, 1615.68it/s]
PASSED
... (skip) ...
test_dispatch_combine.py::test_dispatch_combine[1-8-32-2048-4-32-7168-data_type1-8] SKIPPED (skip fp8 with scale_dim == 0)
Running MORI dispatch/combine (#tokens=9202): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 13.28it/s]
Checking Result: 1it [00:00, 18.28it/s]
PASSED
test_dispatch_combine.py::test_dispatch_combine[1-8-32-2048-4-32-4096-data_type1-8] SKIPPED (skip fp8 with scale_dim == 0)

============================================================================================================================================= 24 passed, 24 skipped in 16.69s ==============================================================================================================================================

Submission Checklist

@jhchouuu
Copy link
Collaborator

jhchouuu commented Oct 24, 2025

@dongmin-ra Could I be granted more access to the moreh-dev repo?
I replicated your issue during the test, and use your pr to fix this issue.
but I found that your modification cause benchmark init issue,
python3 tests/python/ops/bench_dispatch_combine.py.
Then I have fixed it, and I found combine result mismatch during warmup step, could you please help to check this benchmark status after cherry-pick this commit?

@dongmin-ra
Copy link
Contributor Author

@jhchouuu Could I be granted more access to the moreh-dev repo?

Okay. I granted you write access to moreh-dev/mori repo.

Then I have fixed it, and I found combine result mismatch during warmup step, could you please help to check this benchmark status after cherry-pick this commit?

Sure, I'll check.

@dongmin-ra
Copy link
Contributor Author

@jhchouuu I fixed the errors on bench_dispatch_combine.py and bench_dispatch_combine_tune.py scripts. Please take a look when you get a chance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants