Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@ def _lookup_cost(
return None


def _cost_or_task(cost_row: dict[str, Any] | None, key: str, task_cost: float, default: float = 0.0) -> float:
def _cost_or_task(cost_row: dict[str, Any] | None, key: str, default: float = 0.0) -> float:
if not cost_row:
return task_cost if key in {"cost", "used_cost"} else default
return default

value = cost_row.get(key)
if value is None:
return task_cost if key in {"cost", "used_cost"} else default
return default

return float(value)

Expand Down Expand Up @@ -311,21 +311,17 @@ def build_report_data(jsonl_dir: Path, include_failed_runs: bool = False) -> dic
}

cost_row = _lookup_cost(costs_index, run_id=run_id, process=process, process_short=process_short, hash_short=hash_short)
task_cost = float(t.get("cost") or 0.0)

if cost_row:
run_cost_acc[run_group_key]["cost"] += _cost_or_task(cost_row, "cost", task_cost)
run_cost_acc[run_group_key]["used_cost"] += _cost_or_task(cost_row, "used_cost", task_cost)
run_cost_acc[run_group_key]["unused_cost"] += _cost_or_task(cost_row, "unused_cost", task_cost, default=0.0)
else:
run_cost_acc[run_group_key]["cost"] += task_cost
run_cost_acc[run_group_key]["used_cost"] += task_cost
run_cost_acc[run_group_key]["cost"] += _cost_or_task(cost_row, "cost")
run_cost_acc[run_group_key]["used_cost"] += _cost_or_task(cost_row, "used_cost")
run_cost_acc[run_group_key]["unused_cost"] += _cost_or_task(cost_row, "unused_cost")

if has_cost_rows:
overview_key = (group, process_short)
cost_group_acc[overview_key]["total_cost"] += _cost_or_task(cost_row, "cost", task_cost)
cost_group_acc[overview_key]["used_cost"] += _cost_or_task(cost_row, "used_cost", task_cost)
cost_group_acc[overview_key]["unused_cost"] += _cost_or_task(cost_row, "unused_cost", task_cost, default=0.0)
cost_group_acc[overview_key]["total_cost"] += _cost_or_task(cost_row, "cost")
cost_group_acc[overview_key]["used_cost"] += _cost_or_task(cost_row, "used_cost")
cost_group_acc[overview_key]["unused_cost"] += _cost_or_task(cost_row, "unused_cost")
cost_group_acc[overview_key]["n_tasks"] += 1

status = t.get("status")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ def test_build_report_data_has_all_sections(tmp_path, make_run, flat_task, write
}


def test_run_costs_without_cur_uses_task_cost(tmp_path, make_run, flat_task, write_run_json):
def test_run_costs_without_cur_are_zero(tmp_path, make_run, flat_task, write_run_json):
data_dir = tmp_path / "data"
jsonl_dir = tmp_path / "jsonl_bundle"
write_run_json(data_dir, [make_run(tasks=[flat_task(cost=4.2)])])
normalize_jsonl(data_dir, jsonl_dir)

data = build_report_data(jsonl_dir)
assert data["run_costs"][0]["cost"] == 4.2
assert data["run_costs"][0]["cost"] == 0.0
assert data["run_costs"][0]["used_cost"] is None


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def extract_tasks(runs: list[dict[str, Any]]) -> list[dict[str, Any]]:
"peak_rss": task.get("peakRss", 0),
"read_bytes": task.get("readBytes", 0),
"write_bytes": task.get("writeBytes", 0),
"cost": task.get("cost"),
"cost": None,
"executor": task.get("executor", ""),
"machine_type": task.get("machineType", ""),
"cloud_zone": task.get("cloudZone", ""),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_cached_count_extracted(make_run, flat_task):
def test_nested_tasks_unwrapped(make_run, nested_task):
run = make_run(tasks=[nested_task(cost=2.0), nested_task(cost=3.0)])
rows = extract_tasks([run])
assert sum(r["cost"] for r in rows) == 5.0
assert all(r["cost"] is None for r in rows)


def test_failed_tasks_filtered(make_run, flat_task):
Expand Down
Loading