Skip to content

Commit e50cca0

Browse files
[UT][GroupedGemm] Add grouped gemm pytest mini scope (#47)
* add mini scope Signed-off-by: Ma, Liangliang <[email protected]> * fix grouped param in pytest confest Signed-off-by: Ma, Liangliang <[email protected]> --------- Signed-off-by: Ma, Liangliang <[email protected]>
1 parent 6c6be64 commit e50cca0

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

tests/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ def pytest_generate_tests(metafunc):
2020
return
2121

2222
for param_name, values in profile.items():
23-
if param_name in metafunc.fixturenames:
23+
split_names = [name.strip() for name in param_name.split(",")]
24+
if all(name in metafunc.fixturenames for name in split_names):
2425
new_markers = []
2526
for mark in metafunc.definition.own_markers:
2627
if mark.name == "parametrize" and mark.args[0] != param_name:

tests/fused_moe/test_fused_moe.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,16 @@ def random_partition(size_a: int, target: int):
2828
return result
2929

3030

31+
MINI_PYTEST_PARAMS = {
32+
"default": {
33+
"m,n,k": [(1, 256, 128), (4, 512, 256), (16, 512, 512)],
34+
"e": [16, 8],
35+
"topk": [1, 2],
36+
"dtype": [torch.bfloat16]
37+
}
38+
}
39+
40+
3141
@pytest.mark.parametrize("m,n,k", FUSED_MOE_MNK_FACTORS)
3242
@pytest.mark.parametrize("e", NUM_EXPERTS)
3343
@pytest.mark.parametrize("topk", TOP_KS)
@@ -36,7 +46,6 @@ def test_grouped_gemm(m, n, k, e, topk, dtype):
3646
seed_everything(7)
3747
num_experts = e
3848
token_per_group = random_partition(e, m * topk)
39-
print(token_per_group)
4049
# input
4150
input_A = torch.randn((sum(token_per_group), k),
4251
dtype=dtype,

0 commit comments

Comments
 (0)