Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,23 @@ agent:
step_limit: 200
use_skills: true
tool_profile: translation
# Latency benchmarking (median over N timed passes; no env vars). bench_iters
# omitted -> inherits the optimization default (DEFAULT_EVAL_BENCHMARK_ITERATIONS).
bench_warmup: 10
bench_iters: 30
reference_mode: compile_fallback # PyTorch at its best; eager | compile also valid

model:
model_class: amd_llm
model_name: claude-opus-4.6
model_name: claude-opus-4.8
api_key: null
model_kwargs:
temperature: 0.0
max_tokens: 16000
# Cost accounting rates (USD per million tokens) used to populate
# translation_cost_usd in translation_result.json. Defaults below are public
# Claude Opus rates; override per model/gateway as needed.
cost_per_mtok_input: 15.0
cost_per_mtok_output: 75.0
cost_per_mtok_cache_write: 18.75
cost_per_mtok_cache_read: 1.5
248 changes: 200 additions & 48 deletions src/minisweagent/run/preprocess/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,74 @@ def _parse_timing_from_harness_output(
)


# Default LLM pricing (USD per million tokens), Claude Opus public rates.
# Overridable per key via the model: section of the translation YAML
# (cost_per_mtok_input / _output / _cache_write / _cache_read).
_DEFAULT_COST_RATES_PER_MTOK = {
"input": 15.0,
"output": 75.0,
"cache_write": 18.75,
"cache_read": 1.50,
}


def _aggregate_trajectory_tokens(output_dir: Path) -> dict[str, int]:
"""Sum token usage across all round trajectories under *output_dir*.

Reads ``round_*/traj.json`` (JSON or concatenated JSONL) written by the
translation agent and accumulates Anthropic-style usage fields. Returns
zeros when no trajectory is found.
"""
agg = {"calls": 0, "input": 0, "output": 0, "cache_write": 0, "cache_read": 0}
decoder = json.JSONDecoder()

def _walk(obj):
if isinstance(obj, dict):
if "output_tokens" in obj:
agg["calls"] += 1
agg["input"] += int(obj.get("input_tokens") or 0)
agg["output"] += int(obj.get("output_tokens") or 0)
agg["cache_write"] += int(obj.get("cache_creation_input_tokens") or 0)
agg["cache_read"] += int(obj.get("cache_read_input_tokens") or 0)
for value in obj.values():
_walk(value)
elif isinstance(obj, list):
for value in obj:
_walk(value)

for traj in sorted(output_dir.glob("round_*/traj.json")):
try:
text = traj.read_text()
except OSError:
continue
idx, length = 0, len(text)
while idx < length:
while idx < length and text[idx] in " \t\r\n":
idx += 1
if idx >= length:
break
try:
obj, idx = decoder.raw_decode(text, idx)
except ValueError:
break
_walk(obj)
return agg


def _estimate_cost_usd(tokens: dict, rates_per_mtok: dict) -> float:
"""Estimate USD cost from a token breakdown and per-million-token rates."""
return round(
(
tokens.get("input", 0) * rates_per_mtok["input"]
+ tokens.get("output", 0) * rates_per_mtok["output"]
+ tokens.get("cache_write", 0) * rates_per_mtok["cache_write"]
+ tokens.get("cache_read", 0) * rates_per_mtok["cache_read"]
)
/ 1e6,
4,
)


def run_translation(
kernel_path: Path,
output_dir: Path,
Expand Down Expand Up @@ -144,6 +212,10 @@ def _print(msg: str) -> None:
"translation_rounds_used": 0,
"translation_pytorch_latency_ms": None,
"translation_flydsl_latency_ms": None,
"translation_speedup": None,
"translation_cost_usd": None,
"translation_tokens": None,
"translation_model_calls": None,
"translation_errors": [],
}

Expand All @@ -170,6 +242,22 @@ def _print(msg: str) -> None:
_print(f" [red]{msg}[/red]" if console else f" ERROR: {msg}")
return result

# -- Benchmark / reference settings (from the translation YAML, no env vars) --
# bench_iters defaults to the shared optimization constant so the two stages
# cannot drift; the generated harness itself reads no environment.
try:
from minisweagent.run.preprocess.harness_utils import (
DEFAULT_EVAL_BENCHMARK_ITERATIONS as _DEFAULT_BENCH_ITERS,
)
except Exception:
_DEFAULT_BENCH_ITERS = 30
# pop (not get): these are translation-harness settings, not agent fields,
# so they must not be splatted into TranslationAgentConfig(**agent_config).
bench_warmup = int(agent_config_dict.pop("bench_warmup", 10))
bench_iters = int(agent_config_dict.pop("bench_iters", _DEFAULT_BENCH_ITERS))
reference_mode = str(agent_config_dict.pop("reference_mode", "compile_fallback")).strip().lower()
_print(f" Latency bench: warmup={bench_warmup} iters={bench_iters} (median), reference_mode={reference_mode}")

# -- Resolve model --
# Precedence: explicit model object > explicit model_name > YAML config > factory default
_model = model
Expand Down Expand Up @@ -219,6 +307,9 @@ def _print(msg: str) -> None:
model=_model,
repo_root=repo_root,
output_dir=output_dir,
bench_warmup=bench_warmup,
bench_iters=bench_iters,
reference_mode=reference_mode,
)
except Exception as exc:
msg = f"Failed to create translation harness: {exc}"
Expand Down Expand Up @@ -309,13 +400,27 @@ def _print(msg: str) -> None:
)
assert isinstance(harness_result, dict)

# Always persist the PyTorch reference latency, even when the candidate
# is incorrect or the harness errors out. The harness prints the
# reference latency before running/comparing the candidate, so it is
# available in stdout regardless of correctness. (Candidate latency and
# speedup are only meaningful for a CORRECT candidate, so those are
# parsed in the success branch below.)
import re

_ref_only = re.search(
r"PyTorch reference latency:\s*([\d.]+)\s*ms",
harness_result.get("stdout", ""),
)
if _ref_only:
result["translation_pytorch_latency_ms"] = float(_ref_only.group(1))

if harness_result["success"]:
_print(f" Round {round_num}: CORRECT")
result["translation_success"] = True
result["translation_kernel_path"] = str(candidate_path)

# Parse timing from the validation run's stdout — the harness
# prints latencies and speedup when the candidate is tested.
# Parse full timing (reference + candidate + speedup) from stdout.
_parse_timing_from_harness_output(
harness_result.get("stdout", ""),
result,
Expand Down Expand Up @@ -427,6 +532,31 @@ def _print(msg: str) -> None:
if result["translation_success"]:
_print(f" Translation successful in {result['translation_rounds_used']} rounds ({elapsed:.1f}s)")

# -- Cost accounting (token-based estimate from the round trajectories) --
# Persisted regardless of success so failed/partial runs still record spend.
try:
rates = dict(_DEFAULT_COST_RATES_PER_MTOK)
for _key, _cfg_key in (
("input", "cost_per_mtok_input"),
("output", "cost_per_mtok_output"),
("cache_write", "cost_per_mtok_cache_write"),
("cache_read", "cost_per_mtok_cache_read"),
):
if model_config.get(_cfg_key) is not None:
rates[_key] = float(model_config[_cfg_key])
tokens = _aggregate_trajectory_tokens(output_dir)
result["translation_tokens"] = tokens
result["translation_model_calls"] = tokens["calls"] or getattr(_model, "n_calls", None)
result["translation_cost_usd"] = _estimate_cost_usd(tokens, rates)
result["translation_cost_rates_per_mtok"] = rates
_print(
f" Cost: ${result['translation_cost_usd']:.2f} "
f"({tokens['calls']} calls, in={tokens['input']} out={tokens['output']} "
f"cache_r={tokens['cache_read']} cache_w={tokens['cache_write']})"
)
except Exception as exc:
_print(f" Warning: cost accounting failed: {exc}")

# Write result metadata
(output_dir / "translation_result.json").write_text(json.dumps(result, indent=2, default=str))

Expand Down Expand Up @@ -743,6 +873,9 @@ def _create_translation_harness(
model,
repo_root: Path,
output_dir: Path,
bench_warmup: int = 10,
bench_iters: int = 30,
reference_mode: str = "compile_fallback",
) -> Path:
"""Create a comparison harness for translation validation.

Expand All @@ -755,6 +888,9 @@ def _create_translation_harness(
kernel_path=kernel_path,
candidate_path=candidate_path,
candidate_flag=pair.harness_candidate_flag,
bench_warmup=bench_warmup,
bench_iters=bench_iters,
reference_mode=reference_mode,
)
harness_path.write_text(harness_code)
logger.info("Created translation harness: %s", harness_path)
Expand All @@ -766,6 +902,9 @@ def _generate_minimal_translation_harness(
kernel_path: Path,
candidate_path: Path,
candidate_flag: str,
bench_warmup: int = 10,
bench_iters: int = 30,
reference_mode: str = "compile_fallback",
) -> str:
"""Generate a minimal Python harness that validates translation correctness.

Expand Down Expand Up @@ -817,27 +956,65 @@ def _is_native_pattern(module):
and not hasattr(module, "Model"))


# -- Benchmark settings (baked in from the translation YAML; no env reads) --
_BENCH_WARMUP = {bench_warmup}
_BENCH_ITERS = {bench_iters}
_REFERENCE_MODE = "{reference_mode}"


def _bench_median_ms(run_fn, warmup=_BENCH_WARMUP, iters=_BENCH_ITERS):
"""Median latency (ms) over ``iters`` timed calls after ``warmup`` warmups.

Uses CUDA events per iteration (no Triton). Returns (last_output, median_ms).
"""
out = None
with torch.no_grad():
for _ in range(warmup):
run_fn()
torch.cuda.synchronize()
samples = []
for _ in range(iters):
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
out = run_fn()
e.record()
torch.cuda.synchronize()
samples.append(s.elapsed_time(e))
samples.sort()
return out, samples[len(samples) // 2]


def _make_reference_callable(model, inputs):
"""Return (callable, mode_label) for the PyTorch reference, honoring _REFERENCE_MODE.

eager -> raw eager forward.
compile -> torch.compile, errors surface.
compile_fallback -> torch.compile, fall back to eager on any failure (PyTorch at its best).
"""
eager_fn = lambda: model(*inputs)
if _REFERENCE_MODE == "eager":
return eager_fn, "eager"
try:
cmodel = torch.compile(model)
with torch.no_grad():
cmodel(*inputs) # probe: triggers compilation outside the timed loop
return (lambda: cmodel(*inputs)), "compile"
except Exception as exc:
if _REFERENCE_MODE == "compile":
raise
print(f"Reference mode: compile failed ({{type(exc).__name__}}: {{exc}}); falling back to eager")
return eager_fn, "eager (compile fallback)"


def _run_native(module, inputs):
"""Run a native-pattern module (build_model + forward)."""
get_init_inputs = getattr(module, "get_init_inputs", None)
init_inputs = get_init_inputs() if get_init_inputs else []
state = module.build_model(*init_inputs)

# Warmup
with torch.no_grad():
for _ in range(3):
module.forward(state, *inputs)
torch.cuda.synchronize()

# Timed run
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
with torch.no_grad():
start.record()
output = module.forward(state, *inputs)
end.record()
torch.cuda.synchronize()
latency_ms = start.elapsed_time(end)
run_fn = lambda: module.forward(state, *inputs)
output, latency_ms = _bench_median_ms(run_fn)
return output, latency_ms


Expand All @@ -858,21 +1035,9 @@ def run_reference():
model = model.half()
inputs = [x.cuda().half() if isinstance(x, torch.Tensor) else x for x in inputs]

# Warmup
with torch.no_grad():
for _ in range(3):
model(*inputs)
torch.cuda.synchronize()

# Timed run
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
with torch.no_grad():
start.record()
ref_output = model(*inputs)
end.record()
torch.cuda.synchronize()
latency_ms = start.elapsed_time(end)
run_fn, _ref_mode = _make_reference_callable(model, inputs)
print(f"Reference mode: {{_ref_mode}}")
ref_output, latency_ms = _bench_median_ms(run_fn)

return model, inputs, ref_output, latency_ms

Expand All @@ -892,21 +1057,8 @@ def run_candidate(candidate_path: str, ref_inputs):

inputs = ref_inputs

# Warmup
with torch.no_grad():
for _ in range(3):
model(*inputs)
torch.cuda.synchronize()

# Timed run
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
with torch.no_grad():
start.record()
cand_output = model(*inputs)
end.record()
torch.cuda.synchronize()
latency_ms = start.elapsed_time(end)
run_fn = lambda: model(*inputs)
cand_output, latency_ms = _bench_median_ms(run_fn)

return cand_output, latency_ms

Expand Down Expand Up @@ -967,7 +1119,7 @@ def main():
print("CORRECTNESS: PASS")

speedup = ref_latency / cand_latency if cand_latency > 0 else float("inf")
print(f"Speedup: {{speedup:.2f}}x (ref={{ref_latency:.3f}}ms, cand={{cand_latency:.3f}}ms)")
print(f"Speedup: {{speedup:.2f}}x (ref={{ref_latency:.3f}}ms, cand={{cand_latency:.3f}}ms, median of {bench_iters})")

if speedup < 0.5:
print("WARNING: FlyDSL candidate is significantly slower than PyTorch reference")
Expand Down