Skip to content

Commit 4cc9e30

Browse files
committed
feat: Enhance portfolio ranking and budget allocation for disease programs
- Added expected_value, allocation_fraction, recommended_budget, and decision fields to RankedDisease. - Updated rank_disease_programs to calculate expected value and allocate budgets based on total_budget. - Introduced _with_decision helper function to encapsulate decision-making logic. - Improved JSON serialization in RankedDisease.to_json method. - Enhanced prompts with additional guidance for evidence roles and staged progression. - Introduced regulatory_bridge for building regulatory bundles and verifying evidence. - Added translational_handoff for generating structured handoff documents for promising cures. - Implemented parallel execution for safe tools in refua_mcp_adapter. - Added tests for campaign state persistence, evidence quality summarization, and translational handoff. - Updated CLI tests to include new flags for regulatory bundles and state file management. - Enhanced policy checks for enforcing evidence collection before hypothesis generation.
1 parent f609276 commit 4cc9e30

19 files changed

Lines changed: 2089 additions & 44 deletions

src/refua_campaign/autonomy.py

Lines changed: 127 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,49 @@
99
from refua_campaign.orchestrator import _extract_first_json_object, _extract_json_plan
1010
from refua_campaign.prompts import planner_suffix
1111

12+
_TOOL_STAGE_INDEX: dict[str, int] = {
13+
"web_search": 0,
14+
"web_fetch": 0,
15+
"refua_data_list": 0,
16+
"refua_data_fetch": 0,
17+
"refua_data_materialize": 0,
18+
"refua_data_query": 0,
19+
"refua_validate_spec": 1,
20+
"refua_fold": 2,
21+
"refua_affinity": 2,
22+
"refua_antibody_design": 2,
23+
"refua_protein_properties": 2,
24+
"refua_admet_profile": 3,
25+
"refua_clinical_simulator": 4,
26+
"refua_job": 5,
27+
}
28+
_EVIDENCE_TOOLS: frozenset[str] = frozenset(
29+
{
30+
"web_search",
31+
"web_fetch",
32+
"refua_data_list",
33+
"refua_data_fetch",
34+
"refua_data_materialize",
35+
"refua_data_query",
36+
}
37+
)
38+
_HYPOTHESIS_TOOLS: frozenset[str] = frozenset(
39+
{
40+
"refua_fold",
41+
"refua_affinity",
42+
"refua_antibody_design",
43+
"refua_admet_profile",
44+
"refua_clinical_simulator",
45+
}
46+
)
47+
1248

1349
@dataclass(frozen=True)
1450
class PlanPolicy:
1551
max_calls: int = 10
1652
require_validate_first: bool = True
53+
enforce_stage_progression: bool = False
54+
require_evidence_before_hypothesis: bool = False
1755

1856

1957
@dataclass(frozen=True)
@@ -242,7 +280,8 @@ def _critic_once(
242280
instructions=(
243281
"Return JSON only with shape "
244282
'{"approved":bool,"issues":[...],"suggested_fixes":[...]}. '
245-
"Reject plans that are vague, unsafe, or non-executable."
283+
"Reject plans that are vague, unsafe, non-executable, skip staged "
284+
"validation, or make claims without evidence-linked calls."
246285
),
247286
**self._request_kwargs(phase="critic-loop", objective=objective),
248287
)
@@ -320,13 +359,100 @@ def evaluate_plan_policy(
320359
"First call is not refua_validate_spec; high-cost calls may fail later."
321360
)
322361

362+
ordered_tools = _ordered_plan_tools(calls)
363+
if policy.require_evidence_before_hypothesis:
364+
first_hypothesis_index = _first_tool_index(ordered_tools, _HYPOTHESIS_TOOLS)
365+
if first_hypothesis_index is not None:
366+
evidence_before = any(
367+
tool in _EVIDENCE_TOOLS for tool in ordered_tools[:first_hypothesis_index]
368+
)
369+
if not evidence_before:
370+
errors.append(
371+
"Policy requires evidence collection before hypothesis-heavy calls "
372+
"(design/admet/clinical)."
373+
)
374+
375+
if policy.enforce_stage_progression:
376+
errors.extend(_stage_progression_errors(ordered_tools))
377+
warnings.extend(_stage_progression_warnings(ordered_tools))
378+
323379
return PolicyCheck(
324380
approved=(len(errors) == 0),
325381
errors=tuple(errors),
326382
warnings=tuple(warnings),
327383
)
328384

329385

386+
def _ordered_plan_tools(calls: list[Any]) -> list[str]:
387+
tools: list[str] = []
388+
for entry in calls:
389+
if not isinstance(entry, dict):
390+
continue
391+
tool = entry.get("tool")
392+
if isinstance(tool, str) and tool.strip():
393+
tools.append(tool.strip())
394+
return tools
395+
396+
397+
def _first_tool_index(tools: list[str], match: frozenset[str]) -> int | None:
398+
for idx, tool in enumerate(tools):
399+
if tool in match:
400+
return idx
401+
return None
402+
403+
404+
def _stage_progression_errors(tools: list[str]) -> list[str]:
405+
if not tools:
406+
return []
407+
408+
errors: list[str] = []
409+
highest_seen = -1
410+
seen_validation = False
411+
seen_design = False
412+
413+
for idx, tool in enumerate(tools, start=1):
414+
stage = _TOOL_STAGE_INDEX.get(tool)
415+
if stage is None:
416+
continue
417+
418+
if stage > highest_seen + 1:
419+
errors.append(
420+
f"Call #{idx} ({tool}) jumps pipeline stages; add missing intermediate "
421+
"stage calls first."
422+
)
423+
highest_seen = max(highest_seen, stage)
424+
425+
if stage == 1:
426+
seen_validation = True
427+
if stage == 2:
428+
seen_design = True
429+
430+
if stage >= 2 and not seen_validation:
431+
errors.append(
432+
f"Call #{idx} ({tool}) requires prior refua_validate_spec validation."
433+
)
434+
if stage >= 4 and not seen_design:
435+
errors.append(
436+
f"Call #{idx} ({tool}) requires prior design/affinity stage calls."
437+
)
438+
439+
return errors
440+
441+
442+
def _stage_progression_warnings(tools: list[str]) -> list[str]:
443+
warnings: list[str] = []
444+
if any(tool in {"refua_fold", "refua_affinity", "refua_antibody_design"} for tool in tools):
445+
if "refua_admet_profile" not in tools:
446+
warnings.append(
447+
"Design/affinity calls present without refua_admet_profile; safety triage may be incomplete."
448+
)
449+
if "refua_clinical_simulator" in tools and "refua_admet_profile" not in tools:
450+
warnings.append(
451+
"Clinical simulation is present without ADMET profiling evidence."
452+
)
453+
return warnings
454+
455+
330456
def _parse_critic_json(text: str) -> dict[str, Any]:
331457
stripped = text.strip()
332458
try:

0 commit comments

Comments
 (0)