Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions case-templates/allto_allv_grouped_mat_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,26 +174,26 @@ def shape_decl(var_name, dims):
lines.append(" .OutputShapes(<LB>" + ", ".join(output_shapes_parts) + "<RB>)")
# Node Attrs
lines.append(" .NodeAttrs(<LB>" + ", ".join([
"<LB>\\\"group\\\", ge::AnyValue::CreateFrom<std::string>(group)<RB>",
f"<LB>\\\"ep_world_size\\\", ge::AnyValue::CreateFrom<int64_t>({ep_world_size})<RB>",
"<LB>\\\"send_counts\\\", ge::AnyValue::CreateFrom<vector<int64_t>>(send_counts)<RB>",
"<LB>\\\"recv_counts\\\", ge::AnyValue::CreateFrom<vector<int64_t>>(recv_counts)<RB>",
f"<LB>\\\"trans_gmm_weight\\\", ge::AnyValue::CreateFrom<bool>({'true' if trans_gmm_weight else 'false'})<RB>",
f"<LB>\\\"trans_mm_weight\\\", ge::AnyValue::CreateFrom<bool>({'true' if trans_mm_weight else 'false'})<RB>",
f"<LB>\\\"permute_out_flag\\\", ge::AnyValue::CreateFrom<bool>({'true' if permute_out_flag else 'false'})<RB>",
]) + "><RB>)")
"<LB>\"group\", ge::AnyValue::CreateFrom<std::string>(group)<RB>",
f"<LB>\"ep_world_size\", ge::AnyValue::CreateFrom<int64_t>({ep_world_size})<RB>",
"<LB>\"send_counts\", ge::AnyValue::CreateFrom<vector<int64_t>>(send_counts)<RB>",
"<LB>\"recv_counts\", ge::AnyValue::CreateFrom<vector<int64_t>>(recv_counts)<RB>",
f"<LB>\"trans_gmm_weight\", ge::AnyValue::CreateFrom<bool>({'true' if trans_gmm_weight else 'false'})<RB>",
f"<LB>\"trans_mm_weight\", ge::AnyValue::CreateFrom<bool>({'true' if trans_mm_weight else 'false'})<RB>",
f"<LB>\"permute_out_flag\", ge::AnyValue::CreateFrom<bool>({'true' if permute_out_flag else 'false'})<RB>",
]) + "<RB>)")
lines.append(" .CompileInfo(&compile_info)")
lines.append(" .PlatformInfo(reinterpret_cast<char*>(&platform_info))")
# 输入/输出类型映射:0 FP16, 1 FP16, 2 INT64, 3 INT64, 4 FP16, 5 FP16;输出均 FP16
lines.append(f" .NodeInputTd(0, {dt_fp16}, ge::FORMAT_ND, ge::FORMAT_ND)")
lines.append(f" .NodeInputTd(1, {dt_fp16}, ge::FORMAT_ND, ge::FORMAT_ND)")
lines.append(" .NodeInputTd(2, ge::DT_INT64, ge::FORMAT_ND, ge::FORMAT_ND)")
lines.append(" .NodeInputTd(3, ge::DT_INT64, ge::FORMAT_ND, ge::FORMAT_ND)")
lines.append(f" .NodeInputTd(4, {dt_fp16}, ge::FORMAT_ND, ge::FORMAT_ND)")
lines.append(f" .NodeInputTd(5, {dt_fp16}, ge::FORMAT_ND, ge::FORMAT_ND)")
lines.append(" .NodeOutputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND)")
lines.append(" .NodeOutputTd(1, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND)")
lines.append(" .NodeOutputTd(2, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND)")
# lines.append(f" .NodeInputTd(0, {dt_fp16}, ge::FORMAT_ND, ge::FORMAT_ND)")
# lines.append(f" .NodeInputTd(1, {dt_fp16}, ge::FORMAT_ND, ge::FORMAT_ND)")
# lines.append(" .NodeInputTd(2, ge::DT_INT64, ge::FORMAT_ND, ge::FORMAT_ND)")
# lines.append(" .NodeInputTd(3, ge::DT_INT64, ge::FORMAT_ND, ge::FORMAT_ND)")
# lines.append(f" .NodeInputTd(4, {dt_fp16}, ge::FORMAT_ND, ge::FORMAT_ND)")
# lines.append(f" .NodeInputTd(5, {dt_fp16}, ge::FORMAT_ND, ge::FORMAT_ND)")
# lines.append(" .NodeOutputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND)")
# lines.append(" .NodeOutputTd(1, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND)")
# lines.append(" .NodeOutputTd(2, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND)")
lines.append(" .TilingData(param.get())")
lines.append(" .Workspace(ws_size)")
lines.append(" .Build();")
Expand Down
6 changes: 4 additions & 2 deletions case-templates/moe_distribute_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def render_test_case(op_name, spec, idx, helpers=None):
ep_rank_id = getattr(spec, "ep_rank_id", 0)
tp_rank_id = getattr(spec, "tp_rank_id", 0)
expert_shard_type = getattr(spec, "expert_shard_type", 0)
shared_expert_num = getattr(spec, "shared_expert_num", 1)
shared_expert_num = getattr(spec, "shared_expert_num", 1) or 1
shared_expert_rank_num = getattr(spec, "shared_expert_rank_num", 1)
moe_expert_num = getattr(spec, "moe_expert_num", 8)
quant_mode = getattr(spec, "quant_mode", 0)
Expand All @@ -67,6 +67,7 @@ def render_test_case(op_name, spec, idx, helpers=None):
# 输出形状:按 UT 规则/经验推导,允许通过 spec 覆盖
# 0: expand_x_output -> 使用 out 形状
expand_x_out = getattr(spec, "expand_x_out", (out[0], out[1]))
scales_shape = getattr(spec, "scales_shape", (out[0], out[1]))
# 1: dynamic_scales_output -> 一维,长度与 expand_x_out 的第一维一致
dynamic_scales_len = getattr(spec, "dynamic_scales_len", expand_x_out[0])
# 2: expand_idx_output -> 一维,等于 expert_ids 的元素数(B * topk)
Expand Down Expand Up @@ -139,6 +140,7 @@ def render_test_case(op_name, spec, idx, helpers=None):
lines.append(" // 4. Define input/output shapes (dims 与 storage_dims 对齐)")
lines.append(f" gert::StorageShape expand_x_shape = <LB><LB>{x1[0]}, {x1[1]}<RB>, <LB>{x1[0]}, {x1[1]}<RB><RB>;")
lines.append(f" gert::StorageShape expert_ids_shape = <LB><LB>{x2[0]}, {x2[1]}<RB>, <LB>{x2[0]}, {x2[1]}<RB><RB>;")
lines.append(f" gert::StorageShape scales_shape = <LB><LB>{scales_shape[0]}, {scales_shape[1]}<RB>, <LB>{scales_shape[0]}, {scales_shape[1]}<RB><RB>;")
lines.append(f" gert::StorageShape expand_x_output_shape = <LB><LB>{expand_x_out[0]}, {expand_x_out[1]}<RB>, <LB>{expand_x_out[0]}, {expand_x_out[1]}<RB><RB>;")
lines.append(f" gert::StorageShape dynamic_scales_output_shape = <LB><LB>{dynamic_scales_len}<RB>, <LB>{dynamic_scales_len}<RB><RB>;")
lines.append(f" gert::StorageShape expand_idx_output_shape = <LB><LB>{expand_idx_len}<RB>, <LB>{expand_idx_len}<RB><RB>;")
Expand Down Expand Up @@ -169,7 +171,7 @@ def render_test_case(op_name, spec, idx, helpers=None):
f"<LB>\"quant_mode\", ge::AnyValue::CreateFrom<int64_t>({quant_mode})<RB>",
f"<LB>\"global_bs\", ge::AnyValue::CreateFrom<int64_t>({global_bs})<RB>",
f"<LB>\"expert_token_nums_type\", ge::AnyValue::CreateFrom<int64_t>({expert_token_nums_type})<RB>",
]) + "><RB>")
]) + "<RB>)")
lines.append(" .CompileInfo(&compile_info)")
lines.append(" .PlatformInfo(reinterpret_cast<char*>(&platform_info))")
lines.append(f" .NodeInputTd(0, ge::{{dt_in}}, ge::FORMAT_ND, ge::FORMAT_ND)")
Expand Down
8 changes: 7 additions & 1 deletion convert_ut_from_xlsx.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def strip_all_testf_blocks(ut_content: str) -> str:
i = 0
while i < len(lines):
line = lines[i]
print(line)
if TEST_F_PATTERN.match(line):
# 跳过到匹配的闭合花括号 '}'
brace = 0
Expand All @@ -83,6 +84,7 @@ def strip_all_testf_blocks(ut_content: str) -> str:
if lb != -1:
brace += 1
rest = rest[lb + 1 :]
i += 1
break
i += 1
if i >= len(lines):
Expand All @@ -96,6 +98,8 @@ def strip_all_testf_blocks(ut_content: str) -> str:
brace += 1
elif ch == '}':
brace -= 1
print(brace)
print(lines[i])
i += 1
# 已跳过一个 TEST_F 块
continue
Expand Down Expand Up @@ -275,6 +279,7 @@ class CaseSpec:
x2_shape: Optional[Tuple[int, int]] = None
gather_output_shape: Optional[Tuple[int, int]] = None
output_shape: Optional[Tuple[int, int]] = None
scales_shape: Optional[Tuple[int, int]] = None
# bias 长度(可选,若提供则覆盖默认 out_n)
bias_len: Optional[int] = None
# 额外字段(为各算子模板提供注入能力)
Expand Down Expand Up @@ -314,6 +319,7 @@ def row_to_case(row: Dict[str, Any], idx: int) -> CaseSpec:
x2_shape = parse_shape(row.get("x2_shape") or row.get("expert_ids_shape")) or None
go_shape = parse_shape(row.get("gather_output_shape") or row.get("gather_out_shape")) or None
out_shape = parse_shape(row.get("output_shape")) or None
scales_shape = parse_shape(row.get("scales_shape")) or (8, 7168)

# 如果提供了 input_tensor_shape(形如 [[M,K],[K,N],[N]]),优先使用
inputs = parse_shape_list(row.get("input_tensor_shape"))
Expand Down Expand Up @@ -395,6 +401,7 @@ def row_to_case(row: Dict[str, Any], idx: int) -> CaseSpec:
x2_shape=x2_shape,
gather_output_shape=go_shape,
output_shape=out_shape,
scales_shape=scales_shape,
bias_len=bias_len,
ep_world_size=ep_world_size,
ep_rank_id=ep_rank_id,
Expand Down Expand Up @@ -798,4 +805,3 @@ def main():
if __name__ == "__main__":
raise SystemExit(main())