Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
468560f
add InterNodeNormal
isytwu Sep 12, 2025
0c4d390
can compile
isytwu Sep 19, 2025
2fd3ac9
fix dispatch hang
isytwu Sep 21, 2025
3527f53
[WIP] fix dispatch data mismatch
isytwu Sep 22, 2025
59f7f2e
fix dispatch data mismatch
isytwu Sep 23, 2025
67527be
remove unuse printf
isytwu Sep 23, 2025
77beba8
remove unuse printf
isytwu Sep 23, 2025
924ceb0
[WIP] fix run 2 times hang
isytwu Sep 23, 2025
7d8dd19
RDMA head/tail change from tokenNum to tokenStep
isytwu Sep 24, 2025
76f2925
fix data mismatch and assert fail
isytwu Sep 24, 2025
cefcb8a
add spins and fix OOM
isytwu Sep 24, 2025
bab85ca
fix typo that lead to hang
isytwu Sep 24, 2025
2f8b79d
fix OOM
isytwu Sep 24, 2025
3e21db0
add config numGPUsPerNode and maxRDMAStepTokens
isytwu Sep 24, 2025
cf01d2a
add config numGPUsPerNode and maxRDMAStepTokens
isytwu Sep 24, 2025
bd52af9
add config numGPUsPerNode and maxRDMAStepTokens
isytwu Sep 24, 2025
f56e0c8
refine buffer/data ready loop
isytwu Sep 27, 2025
df06921
refine warpCopy offset
isytwu Sep 28, 2025
58e0938
enlarge staging buffer and remove output memset
isytwu Sep 28, 2025
d1512cc
try to fix max_rdma_step_tokens=1 cqe error(maybe env issue) and refi…
isytwu Sep 29, 2025
ac58eb8
add WarpBroadcast
isytwu Sep 29, 2025
caab959
remove unused code and formatting
isytwu Sep 29, 2025
c1d493e
optimize perf to 120 (80 blocks 64 rdma_tokens)
isytwu Sep 30, 2025
9759663
add max_rdma_step_tokens/maxP2PStepTokens
isytwu Oct 10, 2025
6313c8d
fix stepP2pTokens typo
isytwu Oct 10, 2025
afb6d4f
modify the shmem quiet API
jhchouuu Oct 10, 2025
bd00acf
refine ShmemQuietThread and support multi-qp
isytwu Oct 11, 2025
33e7986
test_dispatch_combine_internode.py bench add data check
isytwu Oct 11, 2025
f27dd96
refine kernel DEBUG log
isytwu Oct 11, 2025
64b9498
add KERNEL_ASSERT
isytwu Oct 13, 2025
0ed11f7
combine code ready
isytwu Oct 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 158 additions & 80 deletions examples/ops/dispatch_combine/test_dispatch_combine_internode.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# SOFTWARE.
import mori
import os
import sys

import torch
import torch.distributed as dist
Expand All @@ -38,18 +39,26 @@ def __init__(
data_type=dtype,
rank=self.rank,
world_size=self.world_size,
num_gpus_per_node=gpu_per_node,
hidden_dim=7168,
scale_dim=32,
scale_type_size=4,
max_token_type_size=4,
max_num_inp_token_per_rank=max_tokens,
num_experts_per_rank=16,
# num_experts_per_rank=256 // world_size,
num_experts_per_token=8,
warp_num_per_block=16,
block_num=64,
max_token_type_size=2,
kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode,
block_num=80,
max_rdma_step_tokens=64,
max_p2p_step_tokens=8,
max_channel_staging_tokens=128,
# kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode,
kernel_type=mori.ops.EpDispatchCombineKernelType.InterNodeNormal,
)
self.log_file = open(f"dispatch_rank_{self.rank}.log", "w")
# sys.stdout = open(f"dispatch_rank_{self.rank}.log", "w")
self.log = self.log_file.write

def setup(self):
local_rank = self.rank % self.gpu_per_node
Expand Down Expand Up @@ -78,6 +87,7 @@ def setup(self):
self.rng.manual_seed(123)

def cleanup(self):
self.log_file.close()
mori.shmem.shmem_finalize()
dist.destroy_process_group()

Expand Down Expand Up @@ -112,7 +122,7 @@ def _allgather_with_token_num_padding(self, input, max_token_num):
dist.all_gather(output, padded_input)
return output

def gen_test_data(self, use_max_token_num=False):
def gen_test_data(self, round=1, use_max_token_num=False):
# gen num_tokens
if use_max_token_num:
num_token = torch.tensor(
Expand All @@ -126,6 +136,7 @@ def gen_test_data(self, use_max_token_num=False):
generator=self.rng,
device=self.device,
)
print(f"rank {self.rank} num_token={num_token[self.rank]}")

# gen indices
all_rank_indices = []
Expand Down Expand Up @@ -170,8 +181,17 @@ def gen_test_data(self, use_max_token_num=False):
)
for r in range(self.world_size)
]
# all_rank_weights = [
# torch.full(
# (num_token[r], self.config.num_experts_per_token),
# fill_value=(r + 1) * 0.1 + round,
# dtype=torch.float32,
# device=self.device,
# )
# for r in range(self.world_size)
# ]

# gen weights
# gen scales
all_rank_scales = [
torch.rand(
num_token[r],
Expand All @@ -197,6 +217,35 @@ def gen_test_data(self, use_max_token_num=False):
).to(self.config.data_type)
)

# for r in range(self.world_size):
# all_rank_input.append(
# torch.full(
# (num_token[r], self.config.hidden_dim),
# fill_value=float((r + 1) * 10000* + round),
# dtype=torch.float32,
# device=self.device,
# ).to(self.config.data_type)
# )

# for r in range(self.world_size):
# # base = float(r * self.config.max_num_inp_token_per_rank)
# base = float(r * 0.1 + 1)
# # base = float(0)
# # 生成 token 索引 [0, 1, ..., num_token[r]-1]
# token_offset = torch.arange(num_token[r], dtype=torch.float32, device=self.device)
# # 广播到 (num_token[r], hidden_dim)
# fill_tensor = (base + token_offset[:, None]).expand(-1, self.config.hidden_dim)
# all_rank_input.append(fill_tensor.to(self.config.data_type).contiguous())

indices = all_rank_indices[self.rank].cpu().numpy()
for i, row in enumerate(indices):
val = all_rank_input[self.rank][i].to(torch.float32).cpu().numpy()
val_str = ",".join(f"{x:.6f}" for x in val[:10])
self.log(
f"rank {self.rank} round {round} token {i}: {','.join(str(x) for x in row)} val: {val_str} \n"
)
self.log_file.flush()

return (
num_token,
all_rank_indices,
Expand Down Expand Up @@ -225,9 +274,11 @@ def run_test_once(self, op, test_data, error_round, round):
all_rank_weights[self.rank],
# None,
all_rank_scales[self.rank],
# None,
all_rank_indices[self.rank],
)
torch.cuda.synchronize()
print(f"rank {self.rank} kernel finished")
dist.barrier()

src_token_pos = op.get_dispatch_src_token_pos().tolist()
Expand All @@ -236,94 +287,118 @@ def run_test_once(self, op, test_data, error_round, round):
for i, src_token_id in enumerate(src_token_pos):
src_pe = src_token_id // max_num_token_to_send_per_rank
src_tok_id = src_token_id % max_num_token_to_send_per_rank
is_pass = torch.equal(
token_match = torch.equal(
dispatch_output[i], all_rank_input[src_pe][src_tok_id]
)
if not is_pass:
if not token_match:
print(
f"rank {self.rank} token {i} assert {is_pass} expected { all_rank_input[src_pe][src_tok_id]} got {dispatch_output[i]}"
f"round {round} rank {self.rank} token {i} src_token_id {src_token_id} src_pe {src_pe} src_tok_id {src_tok_id} assert {token_match} expect_shape {all_rank_input[src_pe][src_tok_id].shape} got shape {dispatch_output[i].shape} expected { all_rank_input[src_pe][src_tok_id]} got {dispatch_output[i]}"
)
# assert False
# assert token_match
error_round.add(round)
if dispatch_weights is not None:
assert torch.equal(
weight_match = torch.equal(
dispatch_weights[i], all_rank_weights[src_pe][src_tok_id]
)
assert torch.equal(
if not weight_match:
print(
f"round {round} rank {self.rank} token {i} src_token_id {src_token_id} src_pe {src_pe} src_tok_id {src_tok_id} expected weight { all_rank_weights[src_pe][src_tok_id]} got weight {dispatch_weights[i]}"
)
error_round.add(round)
# assert weight_match

indices_match = torch.equal(
dispatch_indices[i], all_rank_indices[src_pe][src_tok_id]
)
if not indices_match:
print(
f"round {round} rank {self.rank} token {i} src_token_id {src_token_id} src_pe {src_pe} src_tok_id {src_tok_id} expected indices { all_rank_indices[src_pe][src_tok_id]} got indices {dispatch_indices[i]}"
)
error_round.add(round)
# assert indices_match

# TODO: test output scales
scales_match = torch.equal(
dispatch_scales[i], all_rank_scales[src_pe][src_tok_id]
)
if not scales_match:
print(
f"round {round} rank {self.rank} token {i} src_token_id {src_token_id} src_pe {src_pe} src_tok_id {src_tok_id} expected indices { all_rank_scales[src_pe][src_tok_id]} got indices {dispatch_scales[i]}"
)
error_round.add(round)
# assert scales_match

if self.rank % self.gpu_per_node == 0:
print(f"Node {self.rank // self.gpu_per_node} Dispatch Pass")

dist.barrier()
# dist.barrier()

combine_output, combine_output_weight = op.combine(
dispatch_output,
dispatch_weights,
all_rank_indices[self.rank],
)
torch.cuda.synchronize()

for i in range(all_rank_num_token[self.rank]):
pes = [
(idx // self.config.num_experts_per_rank)
for idx in all_rank_indices[self.rank][i].cpu().tolist()
]
unique_pes = len(set(pes))

got, expected = combine_output[i], (
all_rank_input[self.rank][i].to(torch.float32) * unique_pes
).to(self.config.data_type)

ok = torch.allclose(got.float(), expected.float(), atol=1e-2, rtol=1e-2)
if not ok:
print(self.rank, "got: ", got)
print(self.rank, "expected: ", expected)
print(self.rank, "delta:", got - expected)
assert False
error_round.add(round)

if dispatch_weights is not None:
got_weight, expected_weight = (
combine_output_weight[i],
all_rank_weights[self.rank][i] * unique_pes,
)
weight_match = torch.allclose(
got_weight, expected_weight, atol=1e-5, rtol=1e-5
)
if not weight_match and self.config.rank == 0:
print(f"Weight mismatch for token {i}:")
print(
f" indices[{i}]: {all_rank_indices[self.rank][i].cpu().tolist()}"
)
print(f" pes: {pes}")
print(f" unique_pes: {unique_pes}")
print(f" got_weight: {got_weight}")
print(
f" expected_weight (weights[{i}] * {unique_pes}): {expected_weight}"
)
print(f" original weights[{i}]: {all_rank_weights[self.rank][i]}")
print(f" diff: {torch.abs(got_weight - expected_weight)}")
print(
f" max_diff: {torch.abs(got_weight - expected_weight).max()}"
)
assert weight_match, f"Weight assertion failed for token {i}"

if self.rank % self.gpu_per_node == 0:
print(f"Node {self.rank // self.gpu_per_node} Combine Pass")
# combine_output, combine_output_weight = op.combine(
# dispatch_output,
# dispatch_weights,
# all_rank_indices[self.rank],
# )
# torch.cuda.synchronize()

# for i in range(all_rank_num_token[self.rank]):
# pes = [
# (idx // self.config.num_experts_per_rank)
# for idx in all_rank_indices[self.rank][i].cpu().tolist()
# ]
# unique_pes = len(set(pes))

# got, expected = combine_output[i], (
# all_rank_input[self.rank][i].to(torch.float32) * unique_pes
# ).to(self.config.data_type)

# ok = torch.allclose(got.float(), expected.float(), atol=1e-2, rtol=1e-2)
# if not ok:
# print(self.rank, "got: ", got)
# print(self.rank, "expected: ", expected)
# print(self.rank, "delta:", got - expected)
# assert False
# error_round.add(round)

# if dispatch_weights is not None:
# got_weight, expected_weight = (
# combine_output_weight[i],
# all_rank_weights[self.rank][i] * unique_pes,
# )
# weight_match = torch.allclose(
# got_weight, expected_weight, atol=1e-5, rtol=1e-5
# )
# if not weight_match and self.config.rank == 0:
# print(f"Weight mismatch for token {i}:")
# print(
# f" indices[{i}]: {all_rank_indices[self.rank][i].cpu().tolist()}"
# )
# print(f" pes: {pes}")
# print(f" unique_pes: {unique_pes}")
# print(f" got_weight: {got_weight}")
# print(
# f" expected_weight (weights[{i}] * {unique_pes}): {expected_weight}"
# )
# print(f" original weights[{i}]: {all_rank_weights[self.rank][i]}")
# print(f" diff: {torch.abs(got_weight - expected_weight)}")
# print(
# f" max_diff: {torch.abs(got_weight - expected_weight).max()}"
# )
# assert weight_match, f"Weight assertion failed for token {i}"

# if self.rank % self.gpu_per_node == 0:
# print(f"Node {self.rank // self.gpu_per_node} Combine Pass")

def test_dispatch_combine(self):
op = mori.ops.EpDispatchCombineOp(self.config)
error_round = set()
for i in range(500):
for i in range(10):
if self.rank == 0:
print(f"Round {i} begin")
test_data = self.gen_test_data()
test_data = self.gen_test_data(i)
if self.rank == 0:
print(f"Round {i} gen test_data done")
self.run_test_once(op, test_data, error_round, i)
assert len(error_round) == 0
print(
"rank: ",
self.rank,
Expand Down Expand Up @@ -406,7 +481,7 @@ def run_bench_once(self, op, test_data):

def bench_dispatch_combine(self):
op = mori.ops.EpDispatchCombineOp(self.config)
test_data = self.gen_test_data(use_max_token_num=True)
test_data = self.gen_test_data(1, use_max_token_num=True)

disp_duration_us_list = []
disp_bandwidth_GB_list = []
Expand All @@ -425,9 +500,9 @@ def bench_dispatch_combine(self):
if self.rank == 0:
print(f"WarmUp Round {i} begin")
self.run_test_once(op, test_data, error_round, i)
assert (
len(error_round) == 0
), f"Warmup failed with errors in rounds: {error_round}"
assert (
len(error_round) == 0
), f"Warmup failed with errors in rounds: {error_round}"

for i in range(50):
if self.rank == 0:
Expand Down Expand Up @@ -464,13 +539,13 @@ def bench_dispatch_combine(self):
f"avg {sum(disp_bandwidth_GB_list[i]) / self.config.world_size:.2f} GB/s"
)

for i in range(len(comb_duration_us_list)):
print(
f"Round {i} combine duration {comb_duration_us_list[i]} "
f"bandwidth {comb_bandwidth_GB_list[i]} "
f"avg {sum(comb_duration_us_list[i]) / self.config.world_size:.2f} µs "
f"avg {sum(comb_bandwidth_GB_list[i]) / self.config.world_size:.2f} GB/s"
)
# for i in range(len(comb_duration_us_list)):
# print(
# f"Round {i} combine duration {comb_duration_us_list[i]} "
# f"bandwidth {comb_bandwidth_GB_list[i]} "
# f"avg {sum(comb_duration_us_list[i]) / self.config.world_size:.2f} µs "
# f"avg {sum(comb_bandwidth_GB_list[i]) / self.config.world_size:.2f} GB/s"
# )

disp_bandwidth_GB_list = disp_bandwidth_GB_list[0:]
avg_disp_bw_per_round = [
Expand Down Expand Up @@ -526,7 +601,9 @@ def test_dispatch_combine(
gpu_per_node,
world_size,
max_tokens,
torch.bfloat16, # torch.float8_e4m3fnuz
torch.bfloat16,
# torch.float8_e4m3fnuz,
# torch.float32,
)
test_case.setup()
if is_bench:
Expand Down Expand Up @@ -562,3 +639,4 @@ def test_dispatch_combine(
nprocs=gpu_per_node,
join=True,
)

Loading