diff --git a/.agents/evals/minimal_skill_trigger_eval.jsonl b/.agents/evals/minimal_skill_trigger_eval.jsonl index 484ed5f20..dcc502cb5 100644 --- a/.agents/evals/minimal_skill_trigger_eval.jsonl +++ b/.agents/evals/minimal_skill_trigger_eval.jsonl @@ -15,3 +15,5 @@ {"id":"review-02","prompt":"提交前帮我跑一遍格式和 lint 检查,看看 CI 风险","expected_skill":"tilelang-review-skill"} {"id":"github-01","prompt":"我准备把分支提 PR 到 npuir,给我完整 commit push PR 流程","expected_skill":"tilelang-github-operations"} {"id":"github-02","prompt":"帮我同步 upstream/npuir,rebase 后再发起 PR","expected_skill":"tilelang-github-operations"} +{"id":"precision-01","prompt":"算子输出有 nan,并且和 Torch 的结果有细微差异,如何看精度报告","expected_skill":"tilelang-precision-debug-skill"} +{"id":"precision-02","prompt":"请用 assert_close 给跑一下测试,输出 ascii diff map 帮我看看误差分布","expected_skill":"tilelang-precision-debug-skill"} diff --git a/.agents/skills/tilelang-debug-helper/SKILL.md b/.agents/skills/tilelang-debug-helper/SKILL.md index 46f319cc8..ec83a5fab 100644 --- a/.agents/skills/tilelang-debug-helper/SKILL.md +++ b/.agents/skills/tilelang-debug-helper/SKILL.md @@ -35,3 +35,4 @@ Before answering, follow AGENTS.md section "Docs Auto Routing Rules (Mandatory)" - tilelang-mlir-skill - tilelang-error-fixer +- tilelang-precision-debug-skill diff --git a/.agents/skills/tilelang-precision-debug-skill/SKILL.md b/.agents/skills/tilelang-precision-debug-skill/SKILL.md new file mode 100644 index 000000000..2782973df --- /dev/null +++ b/.agents/skills/tilelang-precision-debug-skill/SKILL.md @@ -0,0 +1,56 @@ +--- +name: tilelang-precision-debug-skill +description: TileLang 算子精度比对与误差分析技能。当用户提及“精度错误”、“结果不一致”、“比对失败”、“ASCII diff map”或需要深入定位数值差异时必须使用本技能。 +--- + +# TileLang Precision Debug Skill + +## Mandatory routing rule + +Before answering, follow AGENTS.md section "Docs Auto Routing Rules (Mandatory)". + +## Trigger Guidance + +- When users face accuracy issues, mismatch with reference implementation (e.g., PyTorch), or precision errors. +- When users mention `assert_close` failure, Top-10 errors, relative error too large, or need to see the ASCII diff map. +- When a `.precision_debug` directory is mentioned. + +## Instructions SOP + +1. **Replace Comparison API:** + Ensure `assert_close` is imported from `testcommon` (standard in `testing/npuir`) or `tilelang.utils.prec_assert_close`. + +2. **Activate Debug Reporting:** + The detailed reporting is **disabled by default** to avoid clutter. You can activate it in two ways: + - **Surgical (recommended):** Add `@pytest.mark.precision_debug` to the test function. + ```python + import pytest + from testcommon import assert_close + + @pytest.mark.precision_debug + def test_my_op(): + ... + assert_close(actual, expected, dtype=dtype) + ``` + - **Global:** Set the environment variable `TL_PREC_DEBUG=1` before running `pytest`. + ```bash + TL_PREC_DEBUG=1 pytest testing/npuir/test_xxx.py + ``` + +2. **Run and Collect:** + Execute the test script. If there is a mismatch, the tool will automatically generate a `.precision_debug/_/` directory containing `report.txt`, `diff_map.txt`, and the serialized tensors. + +3. **Analyze the Precision Debug Report:** + - **`report.txt`:** Check the mismatch ratio, `Max abs/rel diff`, and the `Top-10 largest differences`. Determine if the error is widespread (logic error) or isolated (boundary, overflow). + - **`diff_map.txt` (ASCII Map):** Identify spatial distribution patterns: + - **Blocky distribution:** Incorrect tiling or block_M/N parameters. + - **Periodic stripes:** Incorrect memory layout, strides, or vectorization broadcast issues. + - **Edge/Boundary spikes:** Padding, masking, or loop boundary conditions not handled correctly. + +4. **Root Cause Localization:** + Use the pattern analysis to trace back to Load (memory alignment), Compute (accumulation precision, type casting), or Store (write-back boundary) stages in the TVM IR or MLIR. + +## Related skills + +- tilelang-debug-helper +- tilelang-error-fixer diff --git a/.gitignore b/.gitignore index 727beb0a8..0c82eee5d 100644 --- a/.gitignore +++ b/.gitignore @@ -96,3 +96,9 @@ tilelang/jit/adapter/cython/.cycache # unit test output files unittest/npuir/output/ + +# precision debug tool output +.precision_debug/ + +# agents +.worktree/ diff --git a/AGENTS.md b/AGENTS.md index 972c28cad..3e62c9083 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -55,6 +55,7 @@ Use the matching skill whenever the user asks for: - compile/runtime error analysis on npuir branch - code review, lint/format checks, PR readiness - commit/push/rebase/upstream sync, PR or issue workflow +- accuracy issues, assert_close failures, ASCII diff map, or precision errors Developer-mode MixCV trigger rule: - If one kernel contains Cube-side T.gemm and Vector-side at least one v-prefix op (such as T.vadd/T.vmul/T.vexp/T.vcast/T.vbrc), treat it as MixCV and use tilelang-mixcv-skill. diff --git "a/docs/Tilelang\347\256\227\345\255\220\350\260\203\350\257\225\346\214\207\345\215\227.md" "b/docs/Tilelang\347\256\227\345\255\220\350\260\203\350\257\225\346\214\207\345\215\227.md" index 3a5181293..91001b619 100644 --- "a/docs/Tilelang\347\256\227\345\255\220\350\260\203\350\257\225\346\214\207\345\215\227.md" +++ "b/docs/Tilelang\347\256\227\345\255\220\350\260\203\350\257\225\346\214\207\345\215\227.md" @@ -15,9 +15,29 @@ | 问题类型 | 具体表现 | 推荐调试方法 | | ------------ | ------------ |------------ | -| **1.精度问题** | 算子成功编译并生成.o文件,但NPU运行结果和标杆参考(如torch)存在差异 | `T.print`打印调试 | -| **2.编译失败**| 算子编译失败,未生成预期的TVM IR和MLIR | 编译调试 | -| **3.运行时失败** | 算子编译成功,但未生成.o文件,进程终止 | 运行时调试 | +| **1.精度问题** | 算子成功编译并生成.o文件,但NPU运行结果和标杆参考(如torch)存在差异 | `T.print`打印调试 | +| **2.编译失败**| 算子编译失败,未生成预期的TVM IR和MLIR | 编译调试 | +| **3.运行时失败** | 算子编译成功,但未生成.o文件,进程终止 | 运行时调试 | + +### 1.2 精度问题快速定位(precision debug) + +对于 `assert_close` 失败、结果不一致或需要观察误差分布的场景,推荐优先使用 +`testing/npuir/testcommon.py` 中的 `assert_close` 包装器,或直接使用 +`tilelang.utils.prec_assert_close`。 + +开启方式有两种: + +- 在单个 pytest 用例上增加 `@pytest.mark.precision_debug` +- 运行测试前设置环境变量 `TL_PREC_DEBUG=1` + +当比对失败时,工具会在 `.precision_debug/_/` 下生成: + +- `report.txt`:误差统计、Top-10 差异点、误差分布 +- `diff_map.txt`:ASCII diff map,便于观察块状、条纹、边界等模式 +- `actual.pt` / `expected.pt`:失败时的张量快照 + +建议优先在最小复现用例上使用该工具,先判断误差是全局扩散、局部边界异常,还是 +特定布局/广播模式下的周期性偏差,再回到 TVM IR 或 MLIR 阶段做进一步定位。 # 2 Tilelang-AscendNPUIR 编译流程概览 @@ -408,4 +428,4 @@ _compile_option_list = [ - **运行时失败**:xx。 - **性能优化**:通过控制jit_npu中的编译选项来配置流水优化等功能。 -掌握这些调试技巧后,您就可以在TileLang中高效地调试NPU算子,充分利用昇腾硬件的计算能力。 \ No newline at end of file +掌握这些调试技巧后,您就可以在TileLang中高效地调试NPU算子,充分利用昇腾硬件的计算能力。 diff --git a/testing/npuir/README.md b/testing/npuir/README.md index 97bd7a61a..3ce00c066 100644 --- a/testing/npuir/README.md +++ b/testing/npuir/README.md @@ -6,24 +6,27 @@ single source of truth. ## Design -The test model is split into three layers: +The test model is split into four layers: - Marker metadata: `op` and `mode` +- Optional debug marker: `precision_debug` - Runtime controls: `--op`, `--mode`, `--npu-device` - Test matrix: `dtype`, shapes, and other case dimensions via `@pytest.mark.parametrize(...)` That split is intentional: -- `op` and `mode` are the only custom markers. They define what a test is. +- `op` and `mode` define what a test is. +- `precision_debug` is an opt-in debug switch for mismatch reporting. - CLI options only choose which tests to run and which NPU device to use. - Data-oriented coverage stays inside the test matrix instead of growing more CLI flags. ## Marker Rules -Only two custom markers are valid: +Three custom markers are valid: - `@pytest.mark.op("")` - `@pytest.mark.mode("")` +- `@pytest.mark.precision_debug` Use file-level `pytestmark` when a whole file shares the same metadata: @@ -55,6 +58,9 @@ def test_copy_release(): The closest marker wins. In practice that means a test-level marker overrides a file-level marker of the same kind. +Use `precision_debug` only on focused repro tests. It is not part of the test +identity and should not be used as a broad file-level marker. + ## Runtime Rules `--npu-device` is the only supported device selector in this pytest layer. The @@ -69,6 +75,10 @@ warning so the remapping is visible in the test output. `with ascend_mode(...)`. The pytest runtime reads the closest `mode` marker and applies `ascend_mode(mode)` automatically around the test body. +`precision_debug` is also marker-driven. When present, pytest sets +`TL_PREC_DEBUG=1` around that test so `testcommon.assert_close(...)` emits a +report directory under `.precision_debug/` on mismatch. + ## Test Matrix Rules Use `@pytest.mark.parametrize(...)` for case dimensions such as: @@ -104,12 +114,13 @@ def test_copy_shape(M, N, block_M, block_N, in_dtype, out_dtype): ## Contributor Rules -- Do not add custom markers beyond `op` and `mode`. +- Do not add custom markers beyond `op`, `mode`, and `precision_debug`. - Do not add `@pytest.mark.dtype(...)`. - Do not add folder-category markers such as `@pytest.mark.memory`. - Do not call `torch.npu.set_device(...)` inside tests. - Prefer file-level `pytestmark` for shared `op` / `mode`. - Use test-level markers only when a file intentionally needs overrides. +- Use `precision_debug` sparingly for mismatch-focused repros. - Keep compile and execution work inside test functions, not at module import time. The directory name remains the category signal for humans. For example, @@ -136,6 +147,9 @@ pytest testing/npuir --mode=Developer # combined selection pytest testing/npuir --op=copy --mode=Developer --npu-device=0 + +# focused precision report for one repro +TL_PREC_DEBUG=1 pytest testing/npuir/memory_ops/test_copy_shape_dev.py ``` `--op` and `--mode` accept comma-separated values. @@ -146,6 +160,7 @@ pytest testing/npuir --op=copy --mode=Developer --npu-device=0 - `--mode` matches the closest `@pytest.mark.mode(...)` - `--npu-device` sets the session device before tests execute - out-of-range `--npu-device` values are remapped with modulo and reported as warnings +- `@pytest.mark.precision_debug` enables detailed mismatch reports for that test Tests without a matching marker are excluded when that selector is provided. @@ -169,3 +184,26 @@ CASES = [ def test_copy_shape_dev(M, N, block_M, block_N, dtype): ... ``` + +## Precision Debug + +Use `from testcommon import assert_close` for NPUIR tests. On mismatch, the +precision debug tool can emit: + +- `report.txt` +- `diff_map.txt` +- serialized `actual.pt` and `expected.pt` + +Enable it in one of two ways: + +```python +import pytest + +@pytest.mark.precision_debug +def test_copy_debug(): + ... +``` + +```bash +TL_PREC_DEBUG=1 pytest testing/npuir/test_xxx.py +``` diff --git a/testing/npuir/conftest.py b/testing/npuir/conftest.py index 511fdd502..8afbcf606 100644 --- a/testing/npuir/conftest.py +++ b/testing/npuir/conftest.py @@ -21,6 +21,13 @@ def _get_npu_device_id(config: pytest.Config) -> tuple[int, Optional[str]]: return resolve_npu_device_id(config.getoption("--npu-device")) +def pytest_configure(config): + config.addinivalue_line( + "markers", + "precision_debug: enable tilelang advanced precision debugging for this test", + ) + + def pytest_addoption(parser): parser.addoption( "--op", @@ -116,3 +123,12 @@ def _apply_mode_marker(request: pytest.FixtureRequest): with ascend_mode(mode): yield + + +@pytest.fixture(autouse=True) +def _apply_precision_debug_marker( + request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch +): + if request.node.get_closest_marker("precision_debug") is not None: + monkeypatch.setenv("TL_PREC_DEBUG", "1") + yield diff --git a/testing/npuir/pytest.ini b/testing/npuir/pytest.ini index 3be6d8e4f..c448210fa 100644 --- a/testing/npuir/pytest.ini +++ b/testing/npuir/pytest.ini @@ -4,3 +4,4 @@ addopts = -ra --strict-markers markers = op(name): Real operation identity, used by --op. mode(name): Runtime Ascend mode, used by --mode and applied automatically during test execution. + precision_debug: Enable detailed precision mismatch reporting for a focused test. diff --git a/testing/npuir/test_precision_debug_smoke.py b/testing/npuir/test_precision_debug_smoke.py new file mode 100644 index 000000000..2ae3edaac --- /dev/null +++ b/testing/npuir/test_precision_debug_smoke.py @@ -0,0 +1,79 @@ +import pathlib + +import pytest +import torch + +from testcommon import assert_close +from tilelang.utils import prec_assert_close + + +pytestmark = [ + pytest.mark.op("precision_debug"), + pytest.mark.mode("Developer"), +] + + +def _single_run_dir(root: pathlib.Path) -> pathlib.Path: + runs = [path for path in root.iterdir() if path.is_dir()] + assert len(runs) == 1 + return runs[0] + + +def test_prec_assert_close_exported_from_tilelang_utils(): + from tilelang.utils.precision_debug import prec_assert_close as direct_import + + assert prec_assert_close is direct_import + + +def test_assert_close_without_debug_does_not_write_reports( + tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch +): + monkeypatch.chdir(tmp_path) + actual = torch.tensor([0.0, 1.0], dtype=torch.float32) + expected = torch.tensor([0.0, 0.0], dtype=torch.float32) + + with pytest.raises(AssertionError): + assert_close(actual, expected, dtype="float32") + + assert not (tmp_path / ".precision_debug").exists() + + +@pytest.mark.precision_debug +def test_precision_debug_marker_enables_report_output( + tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch +): + monkeypatch.chdir(tmp_path) + actual = torch.tensor([0.0, 1.0], dtype=torch.float32) + expected = torch.tensor([0.0, 0.0], dtype=torch.float32) + + with pytest.raises(AssertionError): + assert_close(actual, expected, dtype="float32") + + run_dir = _single_run_dir(tmp_path / ".precision_debug") + assert (run_dir / "report.txt").is_file() + assert (run_dir / "diff_map.txt").is_file() + assert (run_dir / "actual.pt").is_file() + assert (run_dir / "expected.pt").is_file() + + +def test_prec_assert_close_respects_equal_nan_in_report( + tmp_path: pathlib.Path, +): + output_dir = tmp_path / "precision_debug_manual" + actual = torch.tensor([float("nan")], dtype=torch.float32) + expected = torch.tensor([float("nan")], dtype=torch.float32) + + with pytest.raises(AssertionError): + prec_assert_close( + actual, + expected, + output_dir=str(output_dir), + save_tensors=False, + print_map=False, + equal_nan=False, + ) + + report_text = (_single_run_dir(output_dir) / "report.txt").read_text( + encoding="utf-8" + ) + assert "Mismatched: 1 (100.0000%)" in report_text diff --git a/testing/npuir/testcommon.py b/testing/npuir/testcommon.py index bd0c6d9a6..65ac30630 100644 --- a/testing/npuir/testcommon.py +++ b/testing/npuir/testcommon.py @@ -5,6 +5,7 @@ import torch import tilelang +from tilelang.utils.precision_debug import DEFAULT_TOLERANCE, prec_assert_close TorchDTypeLike = Union[str, torch.dtype] @@ -20,17 +21,6 @@ "bool": torch.bool, } -DEFAULT_TOLERANCE = { - "float16": (1e-3, 1e-3), - "bfloat16": (2e-2, 2e-2), - "float32": (1e-4, 1e-4), - "int8": (0.0, 0.0), - "int16": (0.0, 0.0), - "int32": (0.0, 0.0), - "int64": (0.0, 0.0), - "bool": (0.0, 0.0), -} - def resolve_dtype(dtype: TorchDTypeLike) -> torch.dtype: if isinstance(dtype, torch.dtype): @@ -134,17 +124,29 @@ def assert_close( atol: Optional[float] = None, equal_nan: bool = True, ) -> None: - if dtype is None: - dtype = actual.dtype - name = dtype_name(dtype) - default_rtol, default_atol = DEFAULT_TOLERANCE[name] - torch.testing.assert_close( - actual, - expected, - rtol=default_rtol if rtol is None else rtol, - atol=default_atol if atol is None else atol, - equal_nan=equal_nan, - ) + enabled = os.environ.get("TL_PREC_DEBUG", "0") == "1" + + if enabled: + prec_assert_close( + actual, + expected, + dtype=dtype, + rtol=rtol, + atol=atol, + equal_nan=equal_nan, + ) + else: + if dtype is None: + dtype = actual.dtype + name = dtype_name(dtype) + default_rtol, default_atol = DEFAULT_TOLERANCE[name] + torch.testing.assert_close( + actual, + expected, + rtol=default_rtol if rtol is None else rtol, + atol=default_atol if atol is None else atol, + equal_nan=equal_nan, + ) def build_dtype_param_combos( diff --git a/tilelang/testing/__init__.py b/tilelang/testing/__init__.py index bbd934bed..7bfd8bf41 100644 --- a/tilelang/testing/__init__.py +++ b/tilelang/testing/__init__.py @@ -11,6 +11,7 @@ from tvm.testing.utils import _compose from tilelang.utils.tensor import torch_assert_close as torch_assert_close +from tilelang.utils.precision_debug import prec_assert_close as prec_assert_close # pytest.main() wrapper to allow running single test file @@ -45,7 +46,7 @@ def requires_cuda_compute_version(major_version, minor_version=0, mode="ge"): minor_version: int The minor version of the (major,minor) version tuple. - + mode: str The mode of the comparison. diff --git a/tilelang/utils/__init__.py b/tilelang/utils/__init__.py index e524f0c4e..d4f69b52c 100644 --- a/tilelang/utils/__init__.py +++ b/tilelang/utils/__init__.py @@ -26,6 +26,7 @@ get_npu_launcher_header, # noqa: F401 safe_copy, # noqa: F401 ) +from .precision_debug import prec_assert_close as prec_assert_close from .npu_arch import ( get_ascend_device_name, # noqa: F401 supports_native_bf16, # noqa: F401 diff --git a/tilelang/utils/precision_debug.py b/tilelang/utils/precision_debug.py new file mode 100644 index 000000000..66e3e7599 --- /dev/null +++ b/tilelang/utils/precision_debug.py @@ -0,0 +1,609 @@ +# Copyright (c) Tile-AI Corporation. +# Licensed under the MIT License. +""" +Precision debug tool: drop-in replacement for torch.testing.assert_close. +On mismatch, writes a text report, saves tensors, and prints an ASCII diff map. +""" + +from __future__ import annotations + +import inspect +import os +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from typing import Any, Optional, Union + +import numpy as np +import torch + +# --------------------------------------------------------------------------- +# Character-level diff mapping constants +# --------------------------------------------------------------------------- +_LEVEL_CHARS = ["·", "-", "~", "*", "#", "@"] +_LEVEL_CHARS_NP = np.array(_LEVEL_CHARS) +_LEVEL_THRESHOLDS = np.array([1.0, 3.0, 10.0, 50.0, 200.0]) +_PASS_CHARS = frozenset({"·", "|", " "}) +_PARALLEL_THRESHOLD = 512 # use thread pool when row count exceeds this + +# --------------------------------------------------------------------------- +# Default tolerances based on dtype. +# --------------------------------------------------------------------------- +DEFAULT_TOLERANCE = { + "float16": (1e-3, 1e-3), + "bfloat16": (2e-2, 2e-2), + "float32": (1e-4, 1e-4), + "int8": (0.0, 0.0), + "int16": (0.0, 0.0), + "int32": (0.0, 0.0), + "int64": (0.0, 0.0), + "uint8": (0.0, 0.0), + "bool": (0.0, 0.0), +} + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _prod(shape: tuple[int, ...]) -> int: + p = 1 + for s in shape: + p *= int(s) + return p + + +def _get_caller_info() -> tuple[str, str]: + """Return (caller_filename_no_ext, caller_func_name) for the assert_close caller.""" + frame = inspect.currentframe() + if frame is None: + return "unknown", "unknown" + for _ in range(4): + frame = frame.f_back + if frame is None: + return "unknown", "unknown" + filename = os.path.splitext(os.path.basename(frame.f_code.co_filename))[0] + func = frame.f_code.co_name or "unknown" + return filename, func + + +def _equalize_for_compare( + actual: torch.Tensor, expected: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + """Move to CPU and promote to common dtype for comparison.""" + a, b = actual.detach(), expected.detach() + if a.device.type != "cpu": + a = a.cpu() + if b.device.type != "cpu": + b = b.cpu() + if a.dtype != b.dtype: + dtype = torch.promote_types( + a.dtype if a.dtype.is_floating_point else torch.float32, + b.dtype if b.dtype.is_floating_point else torch.float32, + ) + a, b = a.to(dtype), b.to(dtype) + return a, b + + +# --------------------------------------------------------------------------- +# Statistics +# --------------------------------------------------------------------------- + + +def _compute_stats( + actual: torch.Tensor, + expected: torch.Tensor, + rtol: float, + atol: float, + *, + equal_nan: bool, +) -> dict[str, Any]: + """Compute mismatch statistics between two tensors.""" + a, b = _equalize_for_compare(actual, expected) + diff_abs = (a - b).abs() + safe_b = b.abs().clamp(min=1e-12) + diff_rel = (diff_abs / safe_b).nan_to_num_(nan=0.0, posinf=0.0, neginf=0.0) + mismatched = ~torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + + total = a.numel() + num_mismatched = int(mismatched.sum().item()) + + flat_abs = diff_abs.reshape(-1) + flat_rel = diff_rel.reshape(-1) + flat_mm = mismatched.reshape(-1) + + max_abs = float(flat_abs.max().item()) + max_rel = float(flat_rel.max().item()) + mean_abs = float(flat_abs.mean().item()) + median_abs = float(flat_abs.median().item()) + + max_abs_idx = list(np.unravel_index(int(flat_abs.argmax().item()), a.shape)) + max_rel_idx = list(np.unravel_index(int(flat_rel.argmax().item()), a.shape)) + + # Top-10 by absolute diff + if num_mismatched > 0: + mm_idx = flat_mm.nonzero(as_tuple=True)[0] + vals = flat_abs[mm_idx] + top10_flat = mm_idx[vals.argsort(descending=True)[:10]] + else: + top10_flat = flat_abs.argsort(descending=True)[:10] + + top10_list = [] + for fi in top10_flat.tolist(): + idx = list(np.unravel_index(fi, a.shape)) + av, bv = float(a.reshape(-1)[fi].item()), float(b.reshape(-1)[fi].item()) + top10_list.append( + { + "index": idx, + "actual": av, + "expected": bv, + "abs_diff": abs(av - bv), + "rel_diff": abs(av - bv) / (abs(bv) + 1e-12), + } + ) + + hist_edges = np.linspace(0, max_abs if max_abs > 0 else 1, 11) + hist_vals, _ = np.histogram(flat_abs.numpy(), bins=hist_edges) + hist_max = int(max(hist_vals.max(), 1)) + + return { + "shape": tuple(a.shape), + "dtype": str(a.dtype), + "device": str(actual.device), + "rtol": rtol, + "atol": atol, + "total": total, + "num_mismatched": num_mismatched, + "max_abs": max_abs, + "max_abs_idx": max_abs_idx, + "max_rel": max_rel, + "max_rel_idx": max_rel_idx, + "mean_abs": mean_abs, + "median_abs": median_abs, + "top10": top10_list, + "hist_edges": hist_edges, + "hist_vals": hist_vals, + "hist_max": hist_max, + "diff_abs": diff_abs, + "actual_cpu": a, + "expected_cpu": b, + } + + +# --------------------------------------------------------------------------- +# Text report (saved to file) +# --------------------------------------------------------------------------- + + +def _write_report(out_dir: str, stats: dict[str, Any], msg: str) -> None: + lines = [ + "Precision Debug Report", + "=" * 60, + f"shape: {stats['shape']}", + f"dtype: {stats['dtype']}", + f"device: {stats['device']}", + f"rtol: {stats['rtol']}, atol: {stats['atol']}", + "", + "Summary", + "-" * 40, + f"Total elements: {stats['total']}", + f"Mismatched: {stats['num_mismatched']}" + f" ({100.0 * stats['num_mismatched'] / max(1, stats['total']):.4f}%)", + "", + "Difference statistics", + "-" * 40, + f"Max abs diff: {stats['max_abs']} at index {stats['max_abs_idx']}", + f"Max rel diff: {stats['max_rel']} at index {stats['max_rel_idx']}", + f"Mean abs diff: {stats['mean_abs']}", + f"Median abs diff: {stats['median_abs']}", + "", + "Top-10 largest absolute differences", + "-" * 40, + ] + for i, row in enumerate(stats["top10"], 1): + lines.append( + f" {i}. idx={row['index']} actual={row['actual']} expected={row['expected']}" + f" abs={row['abs_diff']} rel={row['rel_diff']}" + ) + lines.extend(["", "Difference distribution (abs diff bins)", "-" * 40]) + edges, vals, hmax = stats["hist_edges"], stats["hist_vals"], stats["hist_max"] + for i in range(len(edges) - 1): + bar = "*" * (int(40 * vals[i] / hmax) if hmax > 0 else 0) + lines.append(f" [{edges[i]:.2e}, {edges[i + 1]:.2e}): {bar} ({vals[i]})") + if msg: + lines.extend(["", "Message", "-" * 40, msg]) + with open(os.path.join(out_dir, "report.txt"), "w", encoding="utf-8") as f: + f.write("\n".join(lines)) + + +# --------------------------------------------------------------------------- +# Vectorised diff-level computation & string building +# --------------------------------------------------------------------------- + + +def _arr_to_levels(arr: np.ndarray, atol: float) -> np.ndarray: + """Map abs-diff values to integer levels 0-5 (vectorised).""" + tol = max(atol, 1e-12) + ratios = arr / tol + levels = np.zeros(arr.shape, dtype=np.intp) + for i in range(len(_LEVEL_THRESHOLDS)): + levels[ratios >= _LEVEL_THRESHOLDS[i]] = i + 1 + return levels + + +def _levels_rows_to_strings( + levels: np.ndarray, + has_sep: bool, + outer: int, + group_size: int, +) -> list[str]: + """Convert rows of integer levels (2-D) to character strings. + + If *has_sep*, inserts ``|`` between every *group_size* characters. + """ + results: list[str] = [] + for row in levels: + if has_sep and outer > 1: + parts: list[str] = [] + for g in range(outer): + s = g * group_size + parts.append("".join(_LEVEL_CHARS_NP[row[s : s + group_size]])) + results.append("|".join(parts)) + else: + results.append("".join(_LEVEL_CHARS_NP[row])) + return results + + +def _parallel_build_strings( + levels: np.ndarray, + has_sep: bool, + outer: int, + group_size: int, +) -> list[str]: + """Thread-parallel version of :func:`_levels_rows_to_strings`.""" + total = len(levels) + n_workers = min(4, os.cpu_count() or 2) + chunk = max(1, (total + n_workers - 1) // n_workers) + with ThreadPoolExecutor(max_workers=n_workers) as pool: + futs = [ + pool.submit( + _levels_rows_to_strings, + levels[s : min(s + chunk, total)], + has_sep, + outer, + group_size, + ) + for s in range(0, total, chunk) + ] + out: list[str] = [] + for f in futs: + out.extend(f.result()) + return out + + +# --------------------------------------------------------------------------- +# Smart line-layout helpers +# --------------------------------------------------------------------------- + + +def _format_row_label( + idx: tuple[int, ...], + leading_shape: tuple[int, ...], + merged_shape: tuple[int, ...], +) -> str: + """Build a slice-notation label for one row. + + Example: shape=(2,64,8,128), split=3, idx=(1,32,3) → ``[1,32,3,:128]`` + """ + parts: list[str] = [] + for dim_i, val in enumerate(idx): + w = max(1, len(str(leading_shape[dim_i] - 1))) + parts.append(str(val).rjust(w)) + for d in merged_shape: + parts.append(f":{d}") + return "[" + ",".join(parts) + "]" + + +def _compute_label_width( + leading_shape: tuple[int, ...], + merged_shape: tuple[int, ...], +) -> int: + """Fixed character width of the widest possible row label.""" + parts: list[str] = [] + for d in leading_shape: + parts.append(str(max(0, d - 1))) + for d in merged_shape: + parts.append(f":{d}") + return len("[" + ",".join(parts) + "]") + + +def _choose_line_split(shape: tuple[int, ...], term_width: int = 200) -> int: + """Decide how to split *shape* into row-index dims and merged-per-row dims. + + Returns *split* such that ``shape[:split]`` iterates rows and + ``shape[split:]`` is flattened into one line. + """ + if not shape: + return 0 + ndim = len(shape) + min_row_chars = min(40, term_width // 4) + + # Pass 1: merge as many tail dims as fit without downsampling + best = ndim - 1 + for s in range(ndim - 1, -1, -1): + merged = shape[s:] + elems = _prod(merged) + lw = _compute_label_width(shape[:s], merged) + 3 # " " prefix + space + outer = merged[0] if len(merged) > 1 else 1 + n_seps = (outer - 1) if len(merged) > 1 else 0 + if elems + n_seps + lw <= term_width: + best = s + else: + break + + # Pass 2: if rows are still too short, allow moderate downsampling (≤20×) + if _prod(shape[best:]) < min_row_chars and best > 0: + for s in range(best - 1, -1, -1): + merged = shape[s:] + elems = _prod(merged) + lw = _compute_label_width(shape[:s], merged) + 3 + avail = max(1, term_width - lw) + if elems / avail <= 20: + best = s + else: + break + + return best + + +# --------------------------------------------------------------------------- +# Batch row-content builder (vectorised + optional threading) +# --------------------------------------------------------------------------- + + +def _build_all_row_contents( + diff_abs: torch.Tensor, + total_rows: int, + merged_shape: tuple[int, ...], + atol: float, + max_content: int, +) -> list[str]: + """Build the ASCII diff string for every row at once.""" + inner = _prod(merged_shape) if merged_shape else 1 + if inner == 0 or total_rows == 0: + return [""] * total_rows + + # --- reshape to 2-D and move to numpy once --- + flat = diff_abs.reshape(total_rows, inner).numpy() + + has_sep = len(merged_shape) > 1 + outer = int(merged_shape[0]) if has_sep else 1 + group_size = inner // outer if outer > 0 else inner + + # --- determine whether block-max downsampling is needed --- + if has_sep: + sep_count = outer - 1 + usable = max(1, max_content - sep_count) + per_group = max(1, usable // max(1, outer)) + need_ds = group_size > per_group + else: + per_group = max_content + need_ds = inner > max_content + + # --- downsample if necessary --- + if need_ds: + if has_sep: + grouped = flat.reshape(total_rows, outer, group_size) + blk = max(1, -(-group_size // per_group)) + cols = -(-group_size // blk) + ds = np.zeros((total_rows, outer, cols), dtype=flat.dtype) + for b in range(cols): + s, e = b * blk, min((b + 1) * blk, group_size) + ds[:, :, b] = np.nanmax(grouped[:, :, s:e], axis=2) + levels = _arr_to_levels(ds.reshape(total_rows, outer * cols), atol) + final_gs = cols + else: + blk = max(1, -(-inner // max_content)) + cols = -(-inner // blk) + ds = np.zeros((total_rows, cols), dtype=flat.dtype) + for b in range(cols): + s, e = b * blk, min((b + 1) * blk, inner) + ds[:, b] = np.nanmax(flat[:, s:e], axis=1) + levels = _arr_to_levels(ds, atol) + final_gs = cols + outer = 1 + has_sep = False + else: + levels = _arr_to_levels(flat, atol) + final_gs = group_size + + # --- build strings (parallel for large row counts) --- + if total_rows > _PARALLEL_THRESHOLD: + return _parallel_build_strings(levels, has_sep, outer, final_gs) + return _levels_rows_to_strings(levels, has_sep, outer, final_gs) + + +# --------------------------------------------------------------------------- +# ASCII diff map (console + file) +# --------------------------------------------------------------------------- + + +def _print_diff_map( + stats: dict[str, Any], + atol: float, + run_dir: str, + term_width: int = 200, +) -> None: + """Print an ASCII diff map to console and save to *diff_map.txt*.""" + diff_abs: torch.Tensor = stats["diff_abs"] + shape = tuple(diff_abs.shape) + ndim = len(shape) + + out: list[str] = [] + out.append( + f" Legend: · pass - 1~3x ~ 3~10x * 10~50x # 50~200x @ >200x" + f" (atol={atol})" + ) + + if ndim == 0: + tol = max(atol, 1e-12) + lv = int(sum(1 for t in _LEVEL_THRESHOLDS if float(diff_abs.item()) / tol >= t)) + out.append(f" [] {_LEVEL_CHARS[min(lv, 5)]}") + else: + split = _choose_line_split(shape, term_width=term_width) + leading_shape = shape[:split] + merged_shape = shape[split:] + + # --- layout description --- + dim_lead = ", ".join(f"d{i}={shape[i]}" for i in range(split)) + dim_tail = ", ".join(f"d{i}={shape[i]}" for i in range(split, ndim)) + if split == 0: + out.append(f" Shape: {shape} (all dims merged into one row)") + else: + out.append(f" Shape: {shape} Layout: [{dim_lead}] x ({dim_tail})") + + total_rows = _prod(leading_shape) if leading_shape else 1 + label_w = _compute_label_width(leading_shape, merged_shape) + max_content = max(10, term_width - label_w - 3) + + # --- vectorised row content --- + contents = _build_all_row_contents( + diff_abs, + total_rows, + merged_shape, + atol, + max_content, + ) + + # --- build labels --- + labels: list[str] = [] + for fi in range(total_rows): + if leading_shape: + idx = tuple(int(i) for i in np.unravel_index(fi, leading_shape)) + else: + idx = () + labels.append(_format_row_label(idx, leading_shape, merged_shape)) + + # --- run-length encode by content (compress identical rows) --- + # group = (first_flat, count, content, first_label, last_label, has_error) + groups: list[tuple[int, int, str, str, str, bool]] = [] + for i in range(total_rows): + c = contents[i] + err = any(ch not in _PASS_CHARS for ch in c) + if groups and groups[-1][2] == c: + g = groups[-1] + groups[-1] = (g[0], g[1] + 1, g[2], g[3], labels[i], g[5] or err) + else: + groups.append((i, 1, c, labels[i], labels[i], err)) + + # --- emit lines --- + for _, count, content, first_lbl, last_lbl, has_err in groups: + fl = first_lbl.ljust(label_w) + if count == 1: + out.append(f" {fl} {content}") + elif count == 2: + ll = last_lbl.ljust(label_w) + out.append(f" {fl} {content}") + out.append(f" {ll} {content}") + else: + out.append(f" {fl} {content}") + tag = "same pattern" if has_err else "all pass" + out.append(f" ... ({count - 1} more rows, {tag}, to {last_lbl}) ...") + + # --- write to file --- + with open(os.path.join(run_dir, "diff_map.txt"), "w", encoding="utf-8") as f: + f.write("\n".join(out)) + + # --- print to console --- + print("\n".join(out)) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def prec_assert_close( + actual: torch.Tensor, + expected: torch.Tensor, + *, + dtype: Optional[Union[str, torch.dtype]] = None, + rtol: Optional[float] = None, + atol: Optional[float] = None, + output_dir: str = ".precision_debug", + save_tensors: bool = True, + print_map: bool = True, + term_width: int = 200, + msg: str = "", + equal_nan: bool = True, +) -> None: + """Drop-in replacement for ``torch.testing.assert_close``. + + On mismatch the function: + 1. Writes a text report (``report.txt``). + 2. Saves actual / expected tensors (``actual.pt``, ``expected.pt``). + 3. Prints an ASCII diff map to the console (and ``diff_map.txt``). + 4. Raises ``AssertionError``. + """ + if actual.shape != expected.shape: + raise AssertionError( + f"Shape mismatch: actual {actual.shape} vs expected {expected.shape}. " + + msg + ) + + # Handle dtype-based default tolerances (Compatibility with testing/npuir/testcommon.py) + if rtol is None or atol is None: + target_dtype = dtype if dtype is not None else actual.dtype + # Convert torch.dtype to string name + if isinstance(target_dtype, torch.dtype): + dname = str(target_dtype).split(".")[-1] + else: + dname = str(target_dtype) + + def_rtol, def_atol = DEFAULT_TOLERANCE.get(dname, (1e-5, 1e-5)) + if rtol is None: + rtol = def_rtol + if atol is None: + atol = def_atol + + a, b = _equalize_for_compare(actual, expected) + if torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan).all(): + return + + stats = _compute_stats(actual, expected, rtol, atol, equal_nan=equal_nan) + filename, func = _get_caller_info() + safe = "".join( + c if c.isalnum() or c in "._-" else "_" for c in f"{filename}_{func}" + ) + ts = datetime.now().strftime("%Y%m%d_%H%M%S") + run_dir = os.path.join(output_dir, f"{safe}_{ts}") + os.makedirs(run_dir, exist_ok=True) + + _write_report(run_dir, stats, msg) + + if save_tensors: + torch.save(actual.cpu(), os.path.join(run_dir, "actual.pt")) + torch.save(expected.cpu(), os.path.join(run_dir, "expected.pt")) + + if print_map: + try: + _print_diff_map(stats, atol, run_dir, term_width=term_width) + except Exception as e: + with open(os.path.join(run_dir, "report.txt"), "a", encoding="utf-8") as f: + f.write(f"\n\nASCII diff map failed: {e}\n") + + print( + f"[PrecisionDebug] FAILED: shape={stats['shape']} dtype={stats['dtype']}\n" + f" Mismatched: {stats['num_mismatched']}/{stats['total']}" + f" ({100.0 * stats['num_mismatched'] / max(1, stats['total']):.2f}%)\n" + f" Max abs diff: {stats['max_abs']} at index {stats['max_abs_idx']}\n" + f" Max rel diff: {stats['max_rel']} at index {stats['max_rel_idx']}\n" + f" Mean abs diff: {stats['mean_abs']}\n" + f" Output saved to: {run_dir}" + ) + + raise AssertionError( + f"Tensors not close (rtol={rtol}, atol={atol}). " + f"Mismatched: {stats['num_mismatched']}/{stats['total']}. " + f"See {run_dir}. " + msg + )