diff --git a/golden/runner.py b/golden/runner.py index 2135db5..d1b28b2 100644 --- a/golden/runner.py +++ b/golden/runner.py @@ -33,6 +33,10 @@ class RunConfig: atol: Absolute tolerance for golden comparison. compile_only: If ``True``, stop after code generation without executing on device or validating against golden. + validate_ir: If ``True``, validate the generated IR against golden + before runtime execution. + validate_ir_only: If ``True``, validate the generated IR against + golden and skip runtime execution. Requires ``validate_ir=True``. compile: Kwargs forwarded to :func:`pypto.ir.compile` (e.g. ``backend_type``, ``dump_passes``, ``output_dir``, ``strategy``, ``profiling``). @@ -43,6 +47,8 @@ class RunConfig: rtol: float = 1e-5 atol: float = 1e-5 compile_only: bool = False + validate_ir: bool = False + validate_ir_only: bool = False compile: dict[str, Any] = field(default_factory=dict) runtime: dict[str, Any] = field(default_factory=dict) @@ -169,6 +175,28 @@ def __exit__(self_, *_exc): def _fail(error: str) -> RunResult: return RunResult(passed=False, error=error, execution_time=time.time() - start) + def _prepare_golden_outputs(input_values: dict[str, torch.Tensor] | None) -> dict[str, torch.Tensor]: + """Load or compute golden outputs for validation stages.""" + with _stage("compute golden"): + if data_dir is not None: + print(f"[RUN] cache hit: {data_dir / 'out'}", flush=True) + output_names = [s.name for s in tensor_specs if s.is_output] + return _load_tensors(data_dir, "out", output_names) + + if input_values is None: + raise RuntimeError("input snapshot is unavailable for golden computation") + + scratch: dict[str, torch.Tensor] = {} + for spec in tensor_specs: + if spec.is_output and spec.init_value is None: + scratch[spec.name] = torch.zeros(spec.shape, dtype=spec.dtype) + else: + scratch[spec.name] = input_values[spec.name].clone() + golden_fn(scratch) + golden_outputs = {spec.name: scratch[spec.name] for spec in tensor_specs if spec.is_output} + _save_tensors(Path(work_dir) / "data" / "out", golden_outputs) + return golden_outputs + # Compile if runtime_dir is not None: if config.compile_only: @@ -187,12 +215,15 @@ def _fail(error: str) -> RunResult: if config.compile_only: total = time.time() - start + if config.validate_ir: + print("[RUN] compile_only is set, skipping IR validation as well.") print(f"[RUN] PASS ({total:.2f}s)", flush=True) return RunResult(passed=True, execution_time=total) work_dir = compiled.output_dir # Generate Inputs + input_snapshot: dict[str, torch.Tensor] | None = None with _stage("generate inputs"): if data_dir is not None: missing = [ @@ -219,34 +250,62 @@ def _fail(error: str) -> RunResult: } _save_tensors(work_dir / "data" / "in", input_snapshot) + has_golden_source = golden_fn is not None or golden_data is not None + golden_outputs: dict[str, torch.Tensor] | None = None + + if config.validate_ir: + passes_dir = Path(work_dir) / "passes_dump" + if not has_golden_source: + print( + "[RUN] validate ir skipped: no golden_fn or golden_data is provided.", + flush=True, + ) + elif not passes_dir.is_dir(): + print( + f"[RUN] validate ir skipped: passes_dump directory does not exist: {passes_dir}", + flush=True, + ) + else: + try: + golden_outputs = _prepare_golden_outputs(input_snapshot) + except Exception as e: + return _fail(str(e)) + + with _stage("validate ir"): + try: + from pypto.debug.torch_codegen import validate_pass_ir_codegen_results + + validate_pass_ir_codegen_results( + str(passes_dir), + {k: v.clone() for k, v in tensors.items()}, + golden_outputs, + rtol=config.rtol, + atol=config.atol + ) + except Exception as e: + return _fail(f"IR validation failed: {e}") + if config.validate_ir_only: + total = time.time() - start + print(f"[RUN] PASS ({total:.2f}s, runtime skipped due to validate_ir_only)", flush=True) + return RunResult(passed=True, execution_time=total) + # Runtime with _stage("runtime"): ordered = [tensors[spec.name] for spec in tensor_specs] execute_compiled(work_dir, ordered, **config.runtime) - if golden_fn is None and golden_data is None: + if not has_golden_source: total = time.time() - start print(f"[RUN] PASS ({total:.2f}s, validation skipped: no golden_fn or golden_data)", flush=True) return RunResult(passed=True, execution_time=total) device_outputs = {spec.name: tensors[spec.name] for spec in tensor_specs if spec.is_output} - # Compute Golden (or load from cache) - with _stage("compute golden"): - if data_dir is not None: - print(f"[RUN] cache hit: {data_dir / 'out'}", flush=True) - output_names = [s.name for s in tensor_specs if s.is_output] - golden_outputs = _load_tensors(data_dir, "out", output_names) - else: - scratch: dict[str, torch.Tensor] = {} - for spec in tensor_specs: - if spec.is_output and spec.init_value is None: - scratch[spec.name] = torch.zeros(spec.shape, dtype=spec.dtype) - else: - scratch[spec.name] = input_snapshot[spec.name].clone() - golden_fn(scratch) - golden_outputs = {spec.name: scratch[spec.name] for spec in tensor_specs if spec.is_output} - _save_tensors(work_dir / "data" / "out", golden_outputs) + if golden_outputs is None: + try: + golden_outputs = _prepare_golden_outputs(input_snapshot) + except Exception as e: + return _fail(str(e)) # Validate with _stage("validate"):