Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/minisweagent/config/geak.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ run:
kill_buffer_s: 360 # forced os._exit() this long after opt_deadline
full:
total_s: 7200 # 2 hours total wall-clock
preprocess_soft_cap_s: 900
preprocess_soft_cap_s: 2400 # 40 min: translation + multi-round harness-gen + baseline
preprocess_hard_cap_fraction: 0.5 # -> 3600 s ceiling
finalize_grace_s: 300
kill_buffer_s: 360
Expand Down
10 changes: 10 additions & 0 deletions src/minisweagent/run/mini.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,16 @@ def _hard_kill_handler() -> None:
test_command = preprocess_ctx["test_command"]
if preprocess_ctx.get("repo_root") and repo is None:
repo = Path(preprocess_ctx["repo_root"])
elif preprocess_ctx.get("repo_root"):
# A PyTorch->FlyDSL translation retargets the optimization root to the
# per-run ``_opt_repo`` (where the translated kernel + staged reference
# live). Honor that even when ``--repo`` was passed, otherwise
# optimization/preflight root at the source repo (which has NO
# translated kernel) and the harness fails to import it.
_pp_root = Path(preprocess_ctx["repo_root"])
if _pp_root.name == "_opt_repo" and (repo is None or Path(repo).resolve() != _pp_root.resolve()):
logger.info("Using per-run _opt_repo as optimization root (translation run): %s", _pp_root)
repo = _pp_root

# Resolve max_rounds via the documented precedence chain:
# CLI --max-rounds (if any future flag added) > config (mode preset) >
Expand Down
62 changes: 54 additions & 8 deletions src/minisweagent/run/preprocess_v3/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

import logging
import os
import re
import shlex
import statistics
import subprocess
Expand Down Expand Up @@ -77,16 +78,42 @@

#: Short timeout for the correctness gate that runs before baseline collection.
#: Goal: fail fast on a broken kernel rather than spending minutes running the
#: full benchmark loop. Override via ``GEAK_BENCH_TIMEOUT`` (or the legacy
#: ``GEAK_CORRECTNESS_GATE_TIMEOUT``).
#: full benchmark loop. ``GEAK_CORRECTNESS_GATE_TIMEOUT`` takes precedence; for
#: compiled kernels whose first ``--correctness`` run also builds the extension,
#: the larger ``GEAK_BENCH_TIMEOUT`` is honored as a fallback. Default 120s.
_CORRECTNESS_GATE_TIMEOUT_S = int(
os.environ.get(
"GEAK_BENCH_TIMEOUT",
os.environ.get("GEAK_CORRECTNESS_GATE_TIMEOUT", "120"),
"GEAK_CORRECTNESS_GATE_TIMEOUT",
os.environ.get("GEAK_BENCH_TIMEOUT", "120"),
)
)


#: Exception names in harness stderr/stdout that mean the harness could not
#: resolve (import/open) the kernel-under-test — a broken-harness failure that
#: yields an empty baseline. Surfaced precisely so a no-latency baseline reads
#: as "kernel not found at <path>" instead of a silent "produced no latency".
_KERNEL_RESOLUTION_MARKERS = ("FileNotFoundError", "ModuleNotFoundError", "ImportError")


def detect_kernel_resolution_failure(raw_outputs: list[dict[str, Any]]) -> str | None:
"""Return the first kernel-resolution error line from harness output, or ``None``.

Scans each run's stderr/stdout for an import / file-not-found error (the
signature of a harness pointing at a non-existent kernel path) and returns
that line verbatim — e.g. ``FileNotFoundError: [Errno 2] No such file or
directory: '<path>'`` — so callers can report exactly which path failed to
resolve rather than a generic "no latency" message.
"""
for out in raw_outputs:
blob = f"{out.get('stderr') or ''}\n{out.get('stdout') or ''}"
for marker in _KERNEL_RESOLUTION_MARKERS:
idx = blob.find(marker)
if idx != -1:
return blob[idx:].splitlines()[0].strip()
return None


@dataclass(frozen=True)
class BaselineMetrics:
"""Wall-clock benchmark statistics for a harness run.
Expand Down Expand Up @@ -327,6 +354,7 @@ def collect_baseline_metrics(
repeats: int = 5,
work_dir: Path | None = None,
gpu_id: int = 0,
skip_correctness_gate: bool = False,
) -> BaselineMetrics:
"""Run the harness ``repeats`` times in ``--benchmark`` mode.

Expand All @@ -349,6 +377,20 @@ def collect_baseline_metrics(
gpu_id:
``HIP_VISIBLE_DEVICES`` value for each invocation.
Defaults to GPU 0 to match the legacy default.
skip_correctness_gate:
When ``True``, skip the up-front ``--correctness`` gate and go
straight to the benchmark loop. Use this when correctness has
already been validated upstream on an authoritative harness —
notably after a successful PyTorch→FlyDSL translation, which runs
its own correctness + performance-regression check. The gate
re-checks correctness on the (stricter) harness-generator harness
and trips on *any* non-zero exit (timeout / env / multi-shape
miss), not just real numeric mismatches, so re-gating an
already-validated kernel discards good candidates. The global
``GEAK_SKIP_CORRECTNESS_GATE=1`` env var still forces a skip
regardless of this flag; this parameter scopes the skip to a
single call (e.g. translation runs) without disabling the gate
for user-supplied harnesses.

Returns:
A :class:`BaselineMetrics` summarising the run.
Expand Down Expand Up @@ -378,10 +420,13 @@ def collect_baseline_metrics(

# Correctness gate: a quick ``--correctness`` invocation up front so that a
# broken kernel fails in ~5-30 s rather than after a full benchmark + profile
# cycle (~5+ min). Mirrors the legacy harness validation shape; can be
# disabled via ``GEAK_SKIP_CORRECTNESS_GATE=1`` when you explicitly want
# baseline numbers from a correctness-failing kernel.
if not os.environ.get("GEAK_SKIP_CORRECTNESS_GATE"):
# cycle (~5+ min). Mirrors the legacy harness validation shape. Skipped when:
# * ``skip_correctness_gate=True`` — correctness already validated upstream
# (e.g. translation, which runs its own correctness + perf-regression
# gate). Scoped to this call, so user-supplied harnesses still gate.
# * ``GEAK_SKIP_CORRECTNESS_GATE=1`` — global override for when you
# explicitly want baseline numbers from a correctness-failing kernel.
if not skip_correctness_gate and not os.environ.get("GEAK_SKIP_CORRECTNESS_GATE"):
gate = _run_benchmark_once(
harness_path,
work_dir=work_dir,
Expand Down Expand Up @@ -689,4 +734,5 @@ def collect_profile(
"ProfileResult",
"collect_baseline_metrics",
"collect_profile",
"detect_kernel_resolution_failure",
]
Loading