|
9 | 9 | from refua_campaign.orchestrator import _extract_first_json_object, _extract_json_plan |
10 | 10 | from refua_campaign.prompts import planner_suffix |
11 | 11 |
|
| 12 | +_TOOL_STAGE_INDEX: dict[str, int] = { |
| 13 | + "web_search": 0, |
| 14 | + "web_fetch": 0, |
| 15 | + "refua_data_list": 0, |
| 16 | + "refua_data_fetch": 0, |
| 17 | + "refua_data_materialize": 0, |
| 18 | + "refua_data_query": 0, |
| 19 | + "refua_validate_spec": 1, |
| 20 | + "refua_fold": 2, |
| 21 | + "refua_affinity": 2, |
| 22 | + "refua_antibody_design": 2, |
| 23 | + "refua_protein_properties": 2, |
| 24 | + "refua_admet_profile": 3, |
| 25 | + "refua_clinical_simulator": 4, |
| 26 | + "refua_job": 5, |
| 27 | +} |
| 28 | +_EVIDENCE_TOOLS: frozenset[str] = frozenset( |
| 29 | + { |
| 30 | + "web_search", |
| 31 | + "web_fetch", |
| 32 | + "refua_data_list", |
| 33 | + "refua_data_fetch", |
| 34 | + "refua_data_materialize", |
| 35 | + "refua_data_query", |
| 36 | + } |
| 37 | +) |
| 38 | +_HYPOTHESIS_TOOLS: frozenset[str] = frozenset( |
| 39 | + { |
| 40 | + "refua_fold", |
| 41 | + "refua_affinity", |
| 42 | + "refua_antibody_design", |
| 43 | + "refua_admet_profile", |
| 44 | + "refua_clinical_simulator", |
| 45 | + } |
| 46 | +) |
| 47 | + |
12 | 48 |
|
13 | 49 | @dataclass(frozen=True) |
14 | 50 | class PlanPolicy: |
15 | 51 | max_calls: int = 10 |
16 | 52 | require_validate_first: bool = True |
| 53 | + enforce_stage_progression: bool = False |
| 54 | + require_evidence_before_hypothesis: bool = False |
17 | 55 |
|
18 | 56 |
|
19 | 57 | @dataclass(frozen=True) |
@@ -242,7 +280,8 @@ def _critic_once( |
242 | 280 | instructions=( |
243 | 281 | "Return JSON only with shape " |
244 | 282 | '{"approved":bool,"issues":[...],"suggested_fixes":[...]}. ' |
245 | | - "Reject plans that are vague, unsafe, or non-executable." |
| 283 | + "Reject plans that are vague, unsafe, non-executable, skip staged " |
| 284 | + "validation, or make claims without evidence-linked calls." |
246 | 285 | ), |
247 | 286 | **self._request_kwargs(phase="critic-loop", objective=objective), |
248 | 287 | ) |
@@ -320,13 +359,100 @@ def evaluate_plan_policy( |
320 | 359 | "First call is not refua_validate_spec; high-cost calls may fail later." |
321 | 360 | ) |
322 | 361 |
|
| 362 | + ordered_tools = _ordered_plan_tools(calls) |
| 363 | + if policy.require_evidence_before_hypothesis: |
| 364 | + first_hypothesis_index = _first_tool_index(ordered_tools, _HYPOTHESIS_TOOLS) |
| 365 | + if first_hypothesis_index is not None: |
| 366 | + evidence_before = any( |
| 367 | + tool in _EVIDENCE_TOOLS for tool in ordered_tools[:first_hypothesis_index] |
| 368 | + ) |
| 369 | + if not evidence_before: |
| 370 | + errors.append( |
| 371 | + "Policy requires evidence collection before hypothesis-heavy calls " |
| 372 | + "(design/admet/clinical)." |
| 373 | + ) |
| 374 | + |
| 375 | + if policy.enforce_stage_progression: |
| 376 | + errors.extend(_stage_progression_errors(ordered_tools)) |
| 377 | + warnings.extend(_stage_progression_warnings(ordered_tools)) |
| 378 | + |
323 | 379 | return PolicyCheck( |
324 | 380 | approved=(len(errors) == 0), |
325 | 381 | errors=tuple(errors), |
326 | 382 | warnings=tuple(warnings), |
327 | 383 | ) |
328 | 384 |
|
329 | 385 |
|
| 386 | +def _ordered_plan_tools(calls: list[Any]) -> list[str]: |
| 387 | + tools: list[str] = [] |
| 388 | + for entry in calls: |
| 389 | + if not isinstance(entry, dict): |
| 390 | + continue |
| 391 | + tool = entry.get("tool") |
| 392 | + if isinstance(tool, str) and tool.strip(): |
| 393 | + tools.append(tool.strip()) |
| 394 | + return tools |
| 395 | + |
| 396 | + |
| 397 | +def _first_tool_index(tools: list[str], match: frozenset[str]) -> int | None: |
| 398 | + for idx, tool in enumerate(tools): |
| 399 | + if tool in match: |
| 400 | + return idx |
| 401 | + return None |
| 402 | + |
| 403 | + |
| 404 | +def _stage_progression_errors(tools: list[str]) -> list[str]: |
| 405 | + if not tools: |
| 406 | + return [] |
| 407 | + |
| 408 | + errors: list[str] = [] |
| 409 | + highest_seen = -1 |
| 410 | + seen_validation = False |
| 411 | + seen_design = False |
| 412 | + |
| 413 | + for idx, tool in enumerate(tools, start=1): |
| 414 | + stage = _TOOL_STAGE_INDEX.get(tool) |
| 415 | + if stage is None: |
| 416 | + continue |
| 417 | + |
| 418 | + if stage > highest_seen + 1: |
| 419 | + errors.append( |
| 420 | + f"Call #{idx} ({tool}) jumps pipeline stages; add missing intermediate " |
| 421 | + "stage calls first." |
| 422 | + ) |
| 423 | + highest_seen = max(highest_seen, stage) |
| 424 | + |
| 425 | + if stage == 1: |
| 426 | + seen_validation = True |
| 427 | + if stage == 2: |
| 428 | + seen_design = True |
| 429 | + |
| 430 | + if stage >= 2 and not seen_validation: |
| 431 | + errors.append( |
| 432 | + f"Call #{idx} ({tool}) requires prior refua_validate_spec validation." |
| 433 | + ) |
| 434 | + if stage >= 4 and not seen_design: |
| 435 | + errors.append( |
| 436 | + f"Call #{idx} ({tool}) requires prior design/affinity stage calls." |
| 437 | + ) |
| 438 | + |
| 439 | + return errors |
| 440 | + |
| 441 | + |
| 442 | +def _stage_progression_warnings(tools: list[str]) -> list[str]: |
| 443 | + warnings: list[str] = [] |
| 444 | + if any(tool in {"refua_fold", "refua_affinity", "refua_antibody_design"} for tool in tools): |
| 445 | + if "refua_admet_profile" not in tools: |
| 446 | + warnings.append( |
| 447 | + "Design/affinity calls present without refua_admet_profile; safety triage may be incomplete." |
| 448 | + ) |
| 449 | + if "refua_clinical_simulator" in tools and "refua_admet_profile" not in tools: |
| 450 | + warnings.append( |
| 451 | + "Clinical simulation is present without ADMET profiling evidence." |
| 452 | + ) |
| 453 | + return warnings |
| 454 | + |
| 455 | + |
330 | 456 | def _parse_critic_json(text: str) -> dict[str, Any]: |
331 | 457 | stripped = text.strip() |
332 | 458 | try: |
|
0 commit comments