Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 76 additions & 17 deletions golden/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ class RunConfig:
atol: Absolute tolerance for golden comparison.
compile_only: If ``True``, stop after code generation without
executing on device or validating against golden.
validate_ir: If ``True``, validate the generated IR against golden
before runtime execution.
validate_ir_only: If ``True``, validate the generated IR against
golden and skip runtime execution. Requires ``validate_ir=True``.
compile: Kwargs forwarded to :func:`pypto.ir.compile` (e.g.
``backend_type``, ``dump_passes``, ``output_dir``, ``strategy``,
``profiling``).
Expand All @@ -43,6 +47,8 @@ class RunConfig:
rtol: float = 1e-5
atol: float = 1e-5
compile_only: bool = False
validate_ir: bool = False
validate_ir_only: bool = False
compile: dict[str, Any] = field(default_factory=dict)
runtime: dict[str, Any] = field(default_factory=dict)

Expand Down Expand Up @@ -169,6 +175,28 @@ def __exit__(self_, *_exc):
def _fail(error: str) -> RunResult:
return RunResult(passed=False, error=error, execution_time=time.time() - start)

def _prepare_golden_outputs(input_values: dict[str, torch.Tensor] | None) -> dict[str, torch.Tensor]:
"""Load or compute golden outputs for validation stages."""
with _stage("compute golden"):
if data_dir is not None:
print(f"[RUN] cache hit: {data_dir / 'out'}", flush=True)
output_names = [s.name for s in tensor_specs if s.is_output]
return _load_tensors(data_dir, "out", output_names)

if input_values is None:
raise RuntimeError("input snapshot is unavailable for golden computation")

scratch: dict[str, torch.Tensor] = {}
for spec in tensor_specs:
if spec.is_output and spec.init_value is None:
scratch[spec.name] = torch.zeros(spec.shape, dtype=spec.dtype)
else:
scratch[spec.name] = input_values[spec.name].clone()
golden_fn(scratch)
golden_outputs = {spec.name: scratch[spec.name] for spec in tensor_specs if spec.is_output}
_save_tensors(Path(work_dir) / "data" / "out", golden_outputs)
return golden_outputs

# Compile
if runtime_dir is not None:
if config.compile_only:
Expand All @@ -187,12 +215,15 @@ def _fail(error: str) -> RunResult:

if config.compile_only:
total = time.time() - start
if config.validate_ir:
print("[RUN] compile_only is set, skipping IR validation as well.")
print(f"[RUN] PASS ({total:.2f}s)", flush=True)
return RunResult(passed=True, execution_time=total)
Comment on lines 216 to 221
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

compile_only skip message ignores validate_ir_only.

The notice only fires when config.validate_ir is true, but a user who sets validate_ir_only=True alongside compile_only=True will see no message and may not realize their IR-only intent was dropped. Broaden the condition (and/or reject the combination explicitly).

-            if config.validate_ir:
+            if config.validate_ir or config.validate_ir_only:
                 print("[RUN] compile_only is set, skipping IR validation as well.")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@golden/runner.py` around lines 216 - 221, The message shown when compile_only
is set should also account for validate_ir_only: update the condition around the
informational print to check either config.validate_ir or
config.validate_ir_only (or explicitly detect the invalid combination
config.compile_only && config.validate_ir_only and raise/print a warning) so
users who set validate_ir_only=True alongside compile_only=True see the correct
notice; modify the block around config.compile_only (where RunResult is
returned) to either broaden the if to (config.validate_ir or
config.validate_ir_only) or add an explicit rejection/clarifying message for
config.compile_only && config.validate_ir_only.


work_dir = compiled.output_dir

# Generate Inputs
input_snapshot: dict[str, torch.Tensor] | None = None
with _stage("generate inputs"):
if data_dir is not None:
missing = [
Expand All @@ -219,34 +250,62 @@ def _fail(error: str) -> RunResult:
}
_save_tensors(work_dir / "data" / "in", input_snapshot)

has_golden_source = golden_fn is not None or golden_data is not None
golden_outputs: dict[str, torch.Tensor] | None = None

if config.validate_ir:
passes_dir = Path(work_dir) / "passes_dump"
if not has_golden_source:
print(
"[RUN] validate ir skipped: no golden_fn or golden_data is provided.",
flush=True,
)
elif not passes_dir.is_dir():
print(
f"[RUN] validate ir skipped: passes_dump directory does not exist: {passes_dir}",
flush=True,
)
else:
try:
golden_outputs = _prepare_golden_outputs(input_snapshot)
except Exception as e:
return _fail(str(e))

with _stage("validate ir"):
try:
from pypto.debug.torch_codegen import validate_pass_ir_codegen_results

validate_pass_ir_codegen_results(
str(passes_dir),
{k: v.clone() for k, v in tensors.items()},
golden_outputs,
rtol=config.rtol,
atol=config.atol
)
except Exception as e:
return _fail(f"IR validation failed: {e}")
if config.validate_ir_only:
total = time.time() - start
print(f"[RUN] PASS ({total:.2f}s, runtime skipped due to validate_ir_only)", flush=True)
return RunResult(passed=True, execution_time=total)

# Runtime
with _stage("runtime"):
ordered = [tensors[spec.name] for spec in tensor_specs]
execute_compiled(work_dir, ordered, **config.runtime)

if golden_fn is None and golden_data is None:
if not has_golden_source:
total = time.time() - start
print(f"[RUN] PASS ({total:.2f}s, validation skipped: no golden_fn or golden_data)", flush=True)
return RunResult(passed=True, execution_time=total)

device_outputs = {spec.name: tensors[spec.name] for spec in tensor_specs if spec.is_output}

# Compute Golden (or load from cache)
with _stage("compute golden"):
if data_dir is not None:
print(f"[RUN] cache hit: {data_dir / 'out'}", flush=True)
output_names = [s.name for s in tensor_specs if s.is_output]
golden_outputs = _load_tensors(data_dir, "out", output_names)
else:
scratch: dict[str, torch.Tensor] = {}
for spec in tensor_specs:
if spec.is_output and spec.init_value is None:
scratch[spec.name] = torch.zeros(spec.shape, dtype=spec.dtype)
else:
scratch[spec.name] = input_snapshot[spec.name].clone()
golden_fn(scratch)
golden_outputs = {spec.name: scratch[spec.name] for spec in tensor_specs if spec.is_output}
_save_tensors(work_dir / "data" / "out", golden_outputs)
if golden_outputs is None:
try:
golden_outputs = _prepare_golden_outputs(input_snapshot)
except Exception as e:
return _fail(str(e))

# Validate
with _stage("validate"):
Expand Down
Loading