diff --git a/case-templates/allto_allv_grouped_mat_mul.py b/case-templates/allto_allv_grouped_mat_mul.py index 7906243..e53f095 100644 --- a/case-templates/allto_allv_grouped_mat_mul.py +++ b/case-templates/allto_allv_grouped_mat_mul.py @@ -174,26 +174,26 @@ def shape_decl(var_name, dims): lines.append(" .OutputShapes(" + ", ".join(output_shapes_parts) + ")") # Node Attrs lines.append(" .NodeAttrs(" + ", ".join([ - "\\\"group\\\", ge::AnyValue::CreateFrom(group)", - f"\\\"ep_world_size\\\", ge::AnyValue::CreateFrom({ep_world_size})", - "\\\"send_counts\\\", ge::AnyValue::CreateFrom>(send_counts)", - "\\\"recv_counts\\\", ge::AnyValue::CreateFrom>(recv_counts)", - f"\\\"trans_gmm_weight\\\", ge::AnyValue::CreateFrom({'true' if trans_gmm_weight else 'false'})", - f"\\\"trans_mm_weight\\\", ge::AnyValue::CreateFrom({'true' if trans_mm_weight else 'false'})", - f"\\\"permute_out_flag\\\", ge::AnyValue::CreateFrom({'true' if permute_out_flag else 'false'})", - ]) + ">)") + "\"group\", ge::AnyValue::CreateFrom(group)", + f"\"ep_world_size\", ge::AnyValue::CreateFrom({ep_world_size})", + "\"send_counts\", ge::AnyValue::CreateFrom>(send_counts)", + "\"recv_counts\", ge::AnyValue::CreateFrom>(recv_counts)", + f"\"trans_gmm_weight\", ge::AnyValue::CreateFrom({'true' if trans_gmm_weight else 'false'})", + f"\"trans_mm_weight\", ge::AnyValue::CreateFrom({'true' if trans_mm_weight else 'false'})", + f"\"permute_out_flag\", ge::AnyValue::CreateFrom({'true' if permute_out_flag else 'false'})", + ]) + ")") lines.append(" .CompileInfo(&compile_info)") lines.append(" .PlatformInfo(reinterpret_cast(&platform_info))") # 输入/输出类型映射:0 FP16, 1 FP16, 2 INT64, 3 INT64, 4 FP16, 5 FP16;输出均 FP16 - lines.append(f" .NodeInputTd(0, {dt_fp16}, ge::FORMAT_ND, ge::FORMAT_ND)") - lines.append(f" .NodeInputTd(1, {dt_fp16}, ge::FORMAT_ND, ge::FORMAT_ND)") - lines.append(" .NodeInputTd(2, ge::DT_INT64, ge::FORMAT_ND, ge::FORMAT_ND)") - lines.append(" .NodeInputTd(3, ge::DT_INT64, ge::FORMAT_ND, ge::FORMAT_ND)") - lines.append(f" .NodeInputTd(4, {dt_fp16}, ge::FORMAT_ND, ge::FORMAT_ND)") - lines.append(f" .NodeInputTd(5, {dt_fp16}, ge::FORMAT_ND, ge::FORMAT_ND)") - lines.append(" .NodeOutputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND)") - lines.append(" .NodeOutputTd(1, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND)") - lines.append(" .NodeOutputTd(2, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND)") + # lines.append(f" .NodeInputTd(0, {dt_fp16}, ge::FORMAT_ND, ge::FORMAT_ND)") + # lines.append(f" .NodeInputTd(1, {dt_fp16}, ge::FORMAT_ND, ge::FORMAT_ND)") + # lines.append(" .NodeInputTd(2, ge::DT_INT64, ge::FORMAT_ND, ge::FORMAT_ND)") + # lines.append(" .NodeInputTd(3, ge::DT_INT64, ge::FORMAT_ND, ge::FORMAT_ND)") + # lines.append(f" .NodeInputTd(4, {dt_fp16}, ge::FORMAT_ND, ge::FORMAT_ND)") + # lines.append(f" .NodeInputTd(5, {dt_fp16}, ge::FORMAT_ND, ge::FORMAT_ND)") + # lines.append(" .NodeOutputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND)") + # lines.append(" .NodeOutputTd(1, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND)") + # lines.append(" .NodeOutputTd(2, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND)") lines.append(" .TilingData(param.get())") lines.append(" .Workspace(ws_size)") lines.append(" .Build();") diff --git a/case-templates/moe_distribute_dispatch.py b/case-templates/moe_distribute_dispatch.py index 69ddeac..d707750 100644 --- a/case-templates/moe_distribute_dispatch.py +++ b/case-templates/moe_distribute_dispatch.py @@ -57,7 +57,7 @@ def render_test_case(op_name, spec, idx, helpers=None): ep_rank_id = getattr(spec, "ep_rank_id", 0) tp_rank_id = getattr(spec, "tp_rank_id", 0) expert_shard_type = getattr(spec, "expert_shard_type", 0) - shared_expert_num = getattr(spec, "shared_expert_num", 1) + shared_expert_num = getattr(spec, "shared_expert_num", 1) or 1 shared_expert_rank_num = getattr(spec, "shared_expert_rank_num", 1) moe_expert_num = getattr(spec, "moe_expert_num", 8) quant_mode = getattr(spec, "quant_mode", 0) @@ -67,6 +67,7 @@ def render_test_case(op_name, spec, idx, helpers=None): # 输出形状:按 UT 规则/经验推导,允许通过 spec 覆盖 # 0: expand_x_output -> 使用 out 形状 expand_x_out = getattr(spec, "expand_x_out", (out[0], out[1])) + scales_shape = getattr(spec, "scales_shape", (out[0], out[1])) # 1: dynamic_scales_output -> 一维,长度与 expand_x_out 的第一维一致 dynamic_scales_len = getattr(spec, "dynamic_scales_len", expand_x_out[0]) # 2: expand_idx_output -> 一维,等于 expert_ids 的元素数(B * topk) @@ -139,6 +140,7 @@ def render_test_case(op_name, spec, idx, helpers=None): lines.append(" // 4. Define input/output shapes (dims 与 storage_dims 对齐)") lines.append(f" gert::StorageShape expand_x_shape = {x1[0]}, {x1[1]}, {x1[0]}, {x1[1]};") lines.append(f" gert::StorageShape expert_ids_shape = {x2[0]}, {x2[1]}, {x2[0]}, {x2[1]};") + lines.append(f" gert::StorageShape scales_shape = {scales_shape[0]}, {scales_shape[1]}, {scales_shape[0]}, {scales_shape[1]};") lines.append(f" gert::StorageShape expand_x_output_shape = {expand_x_out[0]}, {expand_x_out[1]}, {expand_x_out[0]}, {expand_x_out[1]};") lines.append(f" gert::StorageShape dynamic_scales_output_shape = {dynamic_scales_len}, {dynamic_scales_len};") lines.append(f" gert::StorageShape expand_idx_output_shape = {expand_idx_len}, {expand_idx_len};") @@ -169,7 +171,7 @@ def render_test_case(op_name, spec, idx, helpers=None): f"\"quant_mode\", ge::AnyValue::CreateFrom({quant_mode})", f"\"global_bs\", ge::AnyValue::CreateFrom({global_bs})", f"\"expert_token_nums_type\", ge::AnyValue::CreateFrom({expert_token_nums_type})", - ]) + ">") + ]) + ")") lines.append(" .CompileInfo(&compile_info)") lines.append(" .PlatformInfo(reinterpret_cast(&platform_info))") lines.append(f" .NodeInputTd(0, ge::{{dt_in}}, ge::FORMAT_ND, ge::FORMAT_ND)") diff --git a/convert_ut_from_xlsx.py b/convert_ut_from_xlsx.py index c265777..46bb007 100644 --- a/convert_ut_from_xlsx.py +++ b/convert_ut_from_xlsx.py @@ -71,6 +71,7 @@ def strip_all_testf_blocks(ut_content: str) -> str: i = 0 while i < len(lines): line = lines[i] + print(line) if TEST_F_PATTERN.match(line): # 跳过到匹配的闭合花括号 '}' brace = 0 @@ -83,6 +84,7 @@ def strip_all_testf_blocks(ut_content: str) -> str: if lb != -1: brace += 1 rest = rest[lb + 1 :] + i += 1 break i += 1 if i >= len(lines): @@ -96,6 +98,8 @@ def strip_all_testf_blocks(ut_content: str) -> str: brace += 1 elif ch == '}': brace -= 1 + print(brace) + print(lines[i]) i += 1 # 已跳过一个 TEST_F 块 continue @@ -275,6 +279,7 @@ class CaseSpec: x2_shape: Optional[Tuple[int, int]] = None gather_output_shape: Optional[Tuple[int, int]] = None output_shape: Optional[Tuple[int, int]] = None + scales_shape: Optional[Tuple[int, int]] = None # bias 长度(可选,若提供则覆盖默认 out_n) bias_len: Optional[int] = None # 额外字段(为各算子模板提供注入能力) @@ -314,6 +319,7 @@ def row_to_case(row: Dict[str, Any], idx: int) -> CaseSpec: x2_shape = parse_shape(row.get("x2_shape") or row.get("expert_ids_shape")) or None go_shape = parse_shape(row.get("gather_output_shape") or row.get("gather_out_shape")) or None out_shape = parse_shape(row.get("output_shape")) or None + scales_shape = parse_shape(row.get("scales_shape")) or (8, 7168) # 如果提供了 input_tensor_shape(形如 [[M,K],[K,N],[N]]),优先使用 inputs = parse_shape_list(row.get("input_tensor_shape")) @@ -395,6 +401,7 @@ def row_to_case(row: Dict[str, Any], idx: int) -> CaseSpec: x2_shape=x2_shape, gather_output_shape=go_shape, output_shape=out_shape, + scales_shape=scales_shape, bias_len=bias_len, ep_world_size=ep_world_size, ep_rank_id=ep_rank_id, @@ -798,4 +805,3 @@ def main(): if __name__ == "__main__": raise SystemExit(main()) -