diff --git a/core/framework/config.py b/core/framework/config.py index 236f9be69a..5ad69e551b 100644 --- a/core/framework/config.py +++ b/core/framework/config.py @@ -13,6 +13,7 @@ from typing import Any from framework.graph.edge import DEFAULT_MAX_TOKENS +from framework.llm.codex_backend import CODEX_API_BASE, build_codex_litellm_kwargs # --------------------------------------------------------------------------- # Low-level config file access @@ -125,7 +126,6 @@ def get_worker_api_key() -> str | None: return token except ImportError: pass - api_key_env_var = worker_llm.get("api_key_env_var") if api_key_env_var: return os.environ.get(api_key_env_var) @@ -141,7 +141,7 @@ def get_worker_api_base() -> str | None: return get_api_base() if worker_llm.get("use_codex_subscription"): - return "https://chatgpt.com/backend-api/codex" + return CODEX_API_BASE if worker_llm.get("use_kimi_code_subscription"): return "https://api.kimi.com/coding" if worker_llm.get("use_antigravity_subscription"): @@ -169,23 +169,14 @@ def get_worker_llm_extra_kwargs() -> dict[str, Any]: if worker_llm.get("use_codex_subscription"): api_key = get_worker_api_key() if api_key: - headers: dict[str, str] = { - "Authorization": f"Bearer {api_key}", - "User-Agent": "CodexBar", - } + account_id = None try: from framework.runner.runner import get_codex_account_id account_id = get_codex_account_id() - if account_id: - headers["ChatGPT-Account-Id"] = account_id except ImportError: pass - return { - "extra_headers": headers, - "store": False, - "allowed_openai_params": ["store"], - } + return build_codex_litellm_kwargs(api_key, account_id=account_id) return {} @@ -274,7 +265,6 @@ def get_api_key() -> str | None: return token except ImportError: pass - # Standard env-var path (covers ZAI Code and all API-key providers) api_key_env_var = llm.get("api_key_env_var") if api_key_env_var: @@ -380,7 +370,7 @@ def get_api_base() -> str | None: llm = get_hive_config().get("llm", {}) if llm.get("use_codex_subscription"): # Codex subscription routes through the ChatGPT backend, not api.openai.com. - return "https://chatgpt.com/backend-api/codex" + return CODEX_API_BASE if llm.get("use_kimi_code_subscription"): # Kimi Code uses an Anthropic-compatible endpoint (no /v1 suffix). return "https://api.kimi.com/coding" @@ -415,23 +405,14 @@ def get_llm_extra_kwargs() -> dict[str, Any]: if llm.get("use_codex_subscription"): api_key = get_api_key() if api_key: - headers: dict[str, str] = { - "Authorization": f"Bearer {api_key}", - "User-Agent": "CodexBar", - } + account_id = None try: from framework.runner.runner import get_codex_account_id account_id = get_codex_account_id() - if account_id: - headers["ChatGPT-Account-Id"] = account_id except ImportError: pass - return { - "extra_headers": headers, - "store": False, - "allowed_openai_params": ["store"], - } + return build_codex_litellm_kwargs(api_key, account_id=account_id) return {} diff --git a/core/framework/graph/conversation.py b/core/framework/graph/conversation.py index b1bbba400b..785617cb85 100644 --- a/core/framework/graph/conversation.py +++ b/core/framework/graph/conversation.py @@ -351,13 +351,15 @@ def __init__( def system_prompt(self) -> str: return self._system_prompt - def update_system_prompt(self, new_prompt: str) -> None: + def update_system_prompt(self, new_prompt: str, output_keys: list[str] | None = None) -> None: """Update the system prompt. Used in continuous conversation mode at phase transitions to swap Layer 3 (focus) while preserving the conversation history. """ self._system_prompt = new_prompt + if output_keys is not None: + self._output_keys = output_keys self._meta_persisted = False # re-persist with new prompt def set_current_phase(self, phase_id: str) -> None: @@ -771,7 +773,7 @@ async def compact( delete_before = recent_messages[0].seq if recent_messages else self._next_seq await self._store.delete_parts_before(delete_before) await self._store.write_part(summary_msg.seq, summary_msg.to_storage_dict()) - await self._store.write_cursor({"next_seq": self._next_seq}) + await self._write_cursor_update({"next_seq": self._next_seq}) self._messages = [summary_msg] + recent_messages self._last_api_input_tokens = None # reset; next LLM call will recalibrate @@ -975,7 +977,7 @@ async def compact_preserving_structure( # Write kept structural messages (they may have been modified) for msg in kept_structural: await self._store.write_part(msg.seq, msg.to_storage_dict()) - await self._store.write_cursor({"next_seq": self._next_seq}) + await self._write_cursor_update({"next_seq": self._next_seq}) # Reassemble: reference + kept structural (in original order) + recent self._messages = [ref_msg] + kept_structural + recent_messages @@ -1012,7 +1014,7 @@ async def clear(self) -> None: """Remove all messages, keep system prompt, preserve ``_next_seq``.""" if self._store: await self._store.delete_parts_before(self._next_seq) - await self._store.write_cursor({"next_seq": self._next_seq}) + await self._write_cursor_update({"next_seq": self._next_seq}) self._messages.clear() self._last_api_input_tokens = None @@ -1047,6 +1049,14 @@ def export_summary(self) -> str: # --- Persistence internals --------------------------------------------- + async def _write_cursor_update(self, data: dict[str, Any]) -> None: + """Merge cursor updates instead of clobbering existing crash-recovery state.""" + if self._store is None: + return + cursor = await self._store.read_cursor() or {} + cursor.update(data) + await self._store.write_cursor(cursor) + async def _persist(self, message: Message) -> None: """Write-through a single message. No-op when store is None.""" if self._store is None: @@ -1054,7 +1064,7 @@ async def _persist(self, message: Message) -> None: if not self._meta_persisted: await self._persist_meta() await self._store.write_part(message.seq, message.to_storage_dict()) - await self._store.write_cursor({"next_seq": self._next_seq}) + await self._write_cursor_update({"next_seq": self._next_seq}) async def _persist_meta(self) -> None: """Lazily write conversation metadata to the store (called once).""" diff --git a/core/framework/graph/event_loop_node.py b/core/framework/graph/event_loop_node.py index 9aec856776..07f0de58a6 100644 --- a/core/framework/graph/event_loop_node.py +++ b/core/framework/graph/event_loop_node.py @@ -12,85 +12,19 @@ from __future__ import annotations import asyncio +import contextvars import json import logging -import os import re import time from collections.abc import Awaitable, Callable +from dataclasses import dataclass, field from datetime import UTC, datetime from pathlib import Path -from typing import Any +from typing import Any, Literal, Protocol, runtime_checkable from framework.graph.conversation import ConversationStore, NodeConversation -from framework.graph.event_loop import types as event_loop_types -from framework.graph.event_loop.compaction import ( - build_emergency_summary, - build_llm_compaction_prompt, - compact, - format_messages_for_summary, - llm_compact, -) -from framework.graph.event_loop.cursor_persistence import ( - RestoredState, - check_pause, - drain_injection_queue, - drain_trigger_queue, - restore, - write_cursor, -) -from framework.graph.event_loop.event_publishing import ( - generate_action_plan, - log_skip_judge, - publish_context_usage, - publish_iteration, - publish_judge_verdict, - publish_llm_turn_complete, - publish_loop_completed, - publish_loop_started, - publish_output_key_set, - publish_stalled, - publish_text_delta, - publish_tool_completed, - publish_tool_started, - run_hooks, -) -from framework.graph.event_loop.judge_pipeline import ( - SubagentJudge as SharedSubagentJudge, - judge_turn, -) -from framework.graph.event_loop.stall_detector import ( - fingerprint_tool_calls, - is_stalled, - is_tool_doom_loop, - ngram_similarity, -) -from framework.graph.event_loop.subagent_executor import execute_subagent -from framework.graph.event_loop.synthetic_tools import ( - build_ask_user_multiple_tool, - build_ask_user_tool, - build_delegate_tool, - build_escalate_tool, - build_report_to_parent_tool, - build_set_output_tool, - handle_set_output, -) -from framework.graph.event_loop.tool_result_handler import ( - build_json_preview, - execute_tool, - extract_json_metadata, - is_transient_error, - record_learning, - restore_spill_counter, - truncate_tool_result, -) -from framework.graph.event_loop.types import ( - JudgeProtocol, - JudgeVerdict, - TriggerEvent, -) from framework.graph.node import NodeContext, NodeProtocol, NodeResult -from framework.llm.capabilities import supports_image_tool_results from framework.llm.provider import Tool, ToolResult, ToolUse from framework.llm.stream_events import ( FinishEvent, @@ -104,46 +38,18 @@ logger = logging.getLogger(__name__) -async def _describe_images_as_text(image_content: list[dict[str, Any]]) -> str | None: - """Describe images using the best available vision model.""" - import litellm +@dataclass +class TriggerEvent: + """A framework-level trigger signal (timer tick or webhook hit). - blocks: list[dict[str, Any]] = [ - { - "type": "text", - "text": ( - "Describe the following image(s) concisely but with enough detail " - "that a text-only AI assistant can understand the content and context." - ), - } - ] - blocks.extend(image_content) - - candidates: list[str] = [] - if os.environ.get("OPENAI_API_KEY"): - candidates.append("gpt-4o-mini") - if os.environ.get("ANTHROPIC_API_KEY"): - candidates.append("claude-3-haiku-20240307") - if os.environ.get("GOOGLE_API_KEY") or os.environ.get("GEMINI_API_KEY"): - candidates.append("gemini/gemini-1.5-flash") - - for model in candidates: - try: - response = await litellm.acompletion( - model=model, - messages=[{"role": "user", "content": blocks}], - max_tokens=512, - ) - description = (response.choices[0].message.content or "").strip() - if description: - count = len(image_content) - label = "image" if count == 1 else f"{count} images" - return f"[{label} attached — description: {description}]" - except Exception as exc: - logger.debug("Vision fallback model '%s' failed: %s", model, exc) - continue + Triggers are queued separately from user messages / external events + and drained atomically so the LLM sees all pending triggers at once. + """ - return None + trigger_type: str # "timer" | "webhook" + source_id: str # entry point ID or webhook route ID + payload: dict[str, Any] = field(default_factory=dict) + timestamp: float = field(default_factory=time.time) # Pattern for detecting context-window-exceeded errors across LLM providers. @@ -164,6 +70,77 @@ def _is_context_too_large_error(exc: BaseException) -> bool: return bool(_CONTEXT_TOO_LARGE_RE.search(str(exc))) +_ASSISTANT_INPUT_PATTERNS = ( + re.compile(r"\?\s*$"), + re.compile(r"\b(let me know|tell me|share|choose|pick|approve|confirm|which|what)\b", re.I), + re.compile(r"\bready to proceed\b", re.I), +) + +_QUESTION_WIDGET_FALLBACK = "Choose an option below." + + +def _assistant_text_requires_input(text: str) -> bool: + """Heuristic: only count auto-blocked assistant text that clearly invites a reply.""" + stripped = text.strip() + if not stripped: + return False + return any(pattern.search(stripped) for pattern in _ASSISTANT_INPUT_PATTERNS) + + +def _split_prompt_for_question_widget( + prompt: str, + *, + stream_id: str, + has_structured_choices: bool, +) -> tuple[str | None, str]: + """Split long queen prompts into visible output plus a short widget question. + + Queen often asks follow-up questions after a worker completes by sending the + full result plus the next-action question inside a single ask_user prompt. + Other models tend to improvise around that, but Codex follows it literally, + which leaves the result hidden inside the QuestionWidget. Surface long + result-bearing prompts as normal output first, then keep the widget focused + on the actual choice prompt. + """ + + stripped = prompt.strip() + if stream_id != "queen" or not has_structured_choices or not stripped: + return None, stripped + + if len(stripped) <= 180 and "\n" not in stripped: + return None, stripped + + paragraphs = [part.strip() for part in re.split(r"\n\s*\n", stripped) if part.strip()] + if len(paragraphs) >= 2: + tail = paragraphs[-1] + body = "\n\n".join(paragraphs[:-1]).strip() + if body and (_assistant_text_requires_input(tail) or tail.endswith("?")): + return body, tail + + lines = [line.strip() for line in stripped.splitlines() if line.strip()] + if len(lines) >= 2: + tail = lines[-1] + body = "\n".join(lines[:-1]).strip() + if body and (_assistant_text_requires_input(tail) or tail.endswith("?")): + return body, tail + + last_q = stripped.rfind("?") + if last_q != -1 and last_q >= len(stripped) - 220: + split_idx = max( + stripped.rfind("\n", 0, last_q), + stripped.rfind(". ", 0, last_q), + stripped.rfind("! ", 0, last_q), + ) + if split_idx != -1: + start = split_idx + (2 if stripped[split_idx : split_idx + 2] in {". ", "! "} else 1) + body = stripped[:start].strip() + tail = stripped[start:].strip() + if body and tail: + return body, tail + + return stripped, _QUESTION_WIDGET_FALLBACK + + # --------------------------------------------------------------------------- # Escalation receiver (temporary routing target for subagent → user input) # --------------------------------------------------------------------------- @@ -190,7 +167,7 @@ async def inject_event( content: str, *, is_client_input: bool = False, - image_content: list[dict] | None = None, + image_content: list[dict[str, Any]] | None = None, ) -> None: """Called by ExecutionStream.inject_input() when the user responds.""" self._response = content @@ -213,12 +190,268 @@ class TurnCancelled(Exception): pass -# Re-export shared event-loop types from the legacy parent module. -SubagentJudge = SharedSubagentJudge -LoopConfig = event_loop_types.LoopConfig -HookContext = event_loop_types.HookContext -HookResult = event_loop_types.HookResult -OutputAccumulator = event_loop_types.OutputAccumulator +@dataclass +class JudgeVerdict: + """Result of judge evaluation for the event loop.""" + + action: Literal["ACCEPT", "RETRY", "ESCALATE"] + # None = no evaluation happened (skip_judge, tool-continue); not logged. + # "" = evaluated but no feedback; logged with default text. + # "..." = evaluated with feedback; logged as-is. + feedback: str | None = None + + +@runtime_checkable +class JudgeProtocol(Protocol): + """Protocol for event-loop judges. + + Implementations evaluate the current state of the event loop and + decide whether to accept the output, retry with feedback, or escalate. + """ + + async def evaluate(self, context: dict[str, Any]) -> JudgeVerdict: ... + + +class SubagentJudge: + """Judge for subagent execution. + + Accepts immediately when all required output keys are filled, + regardless of whether real tool calls were also made in the same turn. + On RETRY, reminds the subagent of its specific task with progressive + urgency based on remaining iterations. + """ + + def __init__(self, task: str, max_iterations: int = 10): + self._task = task + self._max_iterations = max_iterations + + async def evaluate(self, context: dict[str, Any]) -> JudgeVerdict: + missing = context.get("missing_keys", []) + if not missing: + return JudgeVerdict(action="ACCEPT", feedback="") + + iteration = context.get("iteration", 0) + remaining = self._max_iterations - iteration - 1 + + if remaining <= 3: + urgency = ( + f"URGENT: Only {remaining} iterations left. " + f"Stop all other work and call set_output NOW for: {missing}" + ) + elif remaining <= self._max_iterations // 2: + urgency = ( + f"WARNING: {remaining} iterations remaining. " + f"You must call set_output for: {missing}" + ) + else: + urgency = f"Missing output keys: {missing}. Use set_output to provide them." + + return JudgeVerdict(action="RETRY", feedback=f"Your task: {self._task}\n{urgency}") + + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + + +@dataclass +class LoopConfig: + """Configuration for the event loop.""" + + max_iterations: int = 50 + max_tool_calls_per_turn: int = 30 + judge_every_n_turns: int = 1 + stall_detection_threshold: int = 3 + stall_similarity_threshold: float = 0.85 + max_context_tokens: int = 32_000 + store_prefix: str = "" + + # Overflow margin for max_tool_calls_per_turn. Tool calls are only + # discarded when the count exceeds max_tool_calls_per_turn * (1 + margin). + # Default 0.5 means 50% wiggle room (e.g. limit=10 → hard cutoff at 15). + tool_call_overflow_margin: float = 0.5 + + # --- Tool result context management --- + # When a tool result exceeds this character count, it is truncated in the + # conversation context. If *spillover_dir* is set the full result is + # written to a file and the truncated message includes the filename so + # the agent can retrieve it with load_data(). If *spillover_dir* is + # ``None`` the result is simply truncated with an explanatory note. + max_tool_result_chars: int = 30_000 + spillover_dir: str | None = None # Path string; created on first use + + # --- set_output value spilling --- + # When a set_output value exceeds this character count it is auto-saved + # to a file in *spillover_dir* and the stored value is replaced with a + # lightweight file reference. This keeps shared memory / adapt.md / + # transition markers small and forces the next node to load the full + # data from the file. Set to 0 to disable. + max_output_value_chars: int = 2_000 + + # --- Stream retry (transient error recovery within EventLoopNode) --- + # When _run_single_turn() raises a transient error (network, rate limit, + # server error), retry up to this many times with exponential backoff + # before re-raising. Set to 0 to disable. + max_stream_retries: int = 3 + stream_retry_backoff_base: float = 2.0 + stream_retry_max_delay: float = 60.0 # cap per-retry sleep + + # --- Tool doom loop detection --- + # Detect when the LLM calls the same tool(s) with identical args for + # N consecutive turns. For client-facing nodes, blocks for user input. + # For non-client-facing nodes, injects a warning into the conversation. + tool_doom_loop_threshold: int = 3 + + # --- Client-facing auto-block grace period --- + # When a client-facing node produces text-only turns (no tools, no + # set_output), the judge is skipped for this many consecutive auto-block + # turns. After the grace period, the judge runs to apply RETRY pressure + # on models stuck in a clarification loop. Explicit ask_user() calls + # always skip the judge regardless of this setting. + cf_grace_turns: int = 1 + tool_doom_loop_enabled: bool = True + + # --- Per-tool-call timeout --- + # Maximum seconds a single tool call may take before being killed. + # Prevents hung MCP servers (especially browser/GCU tools) from + # blocking the entire event loop indefinitely. 0 = no timeout. + tool_call_timeout_seconds: float = 60.0 + + # --- Subagent delegation timeout --- + # Maximum seconds a delegate_to_sub_agent call may run before being + # killed. Subagents run a full event-loop so they naturally take + # longer than a single tool call — default is 10 minutes. 0 = no timeout. + subagent_timeout_seconds: float = 600.0 + + # --- Lifecycle hooks --- + # Hooks are async callables keyed by event name. Supported events: + # "session_start" — fires once after the first user message is added, + # before the first LLM turn. trigger = initial message. + # "external_message" — fires when inject_notification() delivers a message. + # trigger = injected message text. + # Each hook receives a HookContext and may return a HookResult to patch + # the system prompt and/or inject a follow-up user message. + hooks: dict[str, list] = None # dict[str, list[HookFn]] (None → no hooks) + + def __post_init__(self) -> None: + if self.hooks is None: + object.__setattr__(self, "hooks", {}) + + +# --------------------------------------------------------------------------- +# Hook types +# --------------------------------------------------------------------------- + + +@dataclass +class HookContext: + """Context passed to every lifecycle hook.""" + + event: str # event name, e.g. "session_start" + trigger: str | None # message that triggered the hook, if any + system_prompt: str # current system prompt at hook invocation time + + +@dataclass +class HookResult: + """What a hook may return to modify node state.""" + + system_prompt: str | None = None # replace current system prompt + inject: str | None = None # inject an additional user message + + +# --------------------------------------------------------------------------- +# Output accumulator with write-through persistence +# --------------------------------------------------------------------------- + + +@dataclass +class OutputAccumulator: + """Accumulates output key-value pairs with optional write-through persistence. + + Values are stored in memory and optionally written through to a + ConversationStore's cursor data for crash recovery. + + When *spillover_dir* and *max_value_chars* are set, large values are + automatically saved to files and replaced with lightweight file + references. This guarantees auto-spill fires on **every** ``set()`` + call regardless of code path (resume, checkpoint restore, etc.). + """ + + values: dict[str, Any] = field(default_factory=dict) + store: ConversationStore | None = None + spillover_dir: str | None = None + max_value_chars: int = 0 # 0 = disabled + + async def set(self, key: str, value: Any) -> None: + """Set a key-value pair, auto-spilling large values to files. + + When the serialised value exceeds *max_value_chars*, the data is + saved to ``/output_.`` and *value* is + replaced with a compact file-reference string. + """ + value = self._auto_spill(key, value) + self.values[key] = value + if self.store: + cursor = await self.store.read_cursor() or {} + outputs = cursor.get("outputs", {}) + outputs[key] = value + cursor["outputs"] = outputs + await self.store.write_cursor(cursor) + + def _auto_spill(self, key: str, value: Any) -> Any: + """Save large values to a file and return a reference string.""" + if self.max_value_chars <= 0 or not self.spillover_dir: + return value + + val_str = json.dumps(value, ensure_ascii=False) if not isinstance(value, str) else value + if len(val_str) <= self.max_value_chars: + return value + + spill_path = Path(self.spillover_dir) + spill_path.mkdir(parents=True, exist_ok=True) + ext = ".json" if isinstance(value, (dict, list)) else ".txt" + filename = f"output_{key}{ext}" + write_content = ( + json.dumps(value, indent=2, ensure_ascii=False) + if isinstance(value, (dict, list)) + else str(value) + ) + (spill_path / filename).write_text(write_content, encoding="utf-8") + file_size = (spill_path / filename).stat().st_size + logger.info( + "set_output value auto-spilled: key=%s, %d chars → %s (%d bytes)", + key, + len(val_str), + filename, + file_size, + ) + return ( + f"[Saved to '{filename}' ({file_size:,} bytes). " + f"Use load_data(filename='{filename}') " + f"to access full data.]" + ) + + def get(self, key: str) -> Any | None: + """Get a value by key, or None if not present.""" + return self.values.get(key) + + def to_dict(self) -> dict[str, Any]: + """Return a copy of all accumulated values.""" + return dict(self.values) + + def has_all_keys(self, required: list[str]) -> bool: + """Check if all required keys have been set (non-None).""" + return all(key in self.values and self.values[key] is not None for key in required) + + @classmethod + async def restore(cls, store: ConversationStore) -> OutputAccumulator: + """Restore an OutputAccumulator from a store's cursor data.""" + cursor = await store.read_cursor() + values = {} + if cursor and "outputs" in cursor: + values = cursor["outputs"] + return cls(values=values, store=store) # --------------------------------------------------------------------------- @@ -323,6 +556,11 @@ async def execute(self, ctx: NodeContext) -> NodeResult: # Client-facing auto-block grace: consecutive text-only turns without # any real tool call or set_output. Resets on progress. _cf_text_only_streak = 0 + # Auto-complete should only fire on straightforward worker turns. + # After a transient retry or queen-guidance escalation, fall back to + # the normal post-tool/judge flow so recovery paths stay observable. + _saw_transient_stream_retry = False + _received_queen_guidance = False # 1. Guard: LLM required if ctx.llm is None: @@ -370,6 +608,8 @@ async def execute(self, ctx: NodeContext) -> NodeResult: start_iteration = 0 _restored_recent_responses: list[str] = [] _restored_tool_fingerprints: list[list[tuple[str, str]]] = [] + _restored_output_fingerprints: list[list[tuple[str, str]]] = [] + _restored_no_progress_turns = 0 else: # Try crash-recovery restore from store, then fall back to fresh. restored = await self._restore(ctx) @@ -379,29 +619,15 @@ async def execute(self, ctx: NodeContext) -> NodeResult: start_iteration = restored.start_iteration _restored_recent_responses = restored.recent_responses _restored_tool_fingerprints = restored.recent_tool_fingerprints + _restored_output_fingerprints = restored.recent_output_fingerprints + _restored_no_progress_turns = restored.no_progress_turns - # Refresh the system prompt with full composition including - # execution preamble and node-type preamble. The stored - # prompt may be stale after code changes or when runtime- - # injected context (e.g. worker identity) has changed. - from framework.graph.prompt_composer import ( - EXECUTION_SCOPE_PREAMBLE, - compose_system_prompt, - ) - - _exec_preamble = None - if ( - not ctx.is_subagent_mode - and ctx.node_spec.node_type in ("event_loop", "gcu") - and ctx.node_spec.output_keys - ): - _exec_preamble = EXECUTION_SCOPE_PREAMBLE - - _node_type_preamble = None - if ctx.node_spec.node_type == "gcu": - from framework.graph.gcu import GCU_BROWSER_SYSTEM_PROMPT - - _node_type_preamble = GCU_BROWSER_SYSTEM_PROMPT + # Refresh the system prompt with full 3-layer composition. + # The stored prompt may be stale after code changes or when + # runtime-injected context (e.g. worker identity) has changed. + # On resume, we rebuild identity + narrative + focus so the LLM + # understands the session history, not just the node directive. + from framework.graph.prompt_composer import compose_system_prompt _current_prompt = compose_system_prompt( identity_prompt=ctx.identity_prompt or None, @@ -410,8 +636,6 @@ async def execute(self, ctx: NodeContext) -> NodeResult: accounts_prompt=ctx.accounts_prompt or None, skills_catalog_prompt=ctx.skills_catalog_prompt or None, protocols_prompt=ctx.protocols_prompt or None, - execution_preamble=_exec_preamble, - node_type_preamble=_node_type_preamble, ) if conversation.system_prompt != _current_prompt: conversation.update_system_prompt(_current_prompt) @@ -419,6 +643,8 @@ async def execute(self, ctx: NodeContext) -> NodeResult: else: _restored_recent_responses = [] _restored_tool_fingerprints = [] + _restored_output_fingerprints = [] + _restored_no_progress_turns = 0 # Fresh conversation: either isolated mode or first node in continuous mode. from framework.graph.prompt_composer import ( @@ -596,6 +822,8 @@ async def execute(self, ctx: NodeContext) -> NodeResult: # 5. Stall / doom loop detection state (restored from cursor if resuming) recent_responses: list[str] = _restored_recent_responses recent_tool_fingerprints: list[list[tuple[str, str]]] = _restored_tool_fingerprints + recent_output_fingerprints: list[list[tuple[str, str]]] = _restored_output_fingerprints + no_progress_turns: int = _restored_no_progress_turns _consecutive_empty_turns: int = 0 # 6. Main loop @@ -631,7 +859,7 @@ async def execute(self, ctx: NodeContext) -> NodeResult: ) # 6b. Drain injection queue - await self._drain_injection_queue(conversation, ctx) + await self._drain_injection_queue(conversation) # 6b1. Drain trigger queue (framework-level signals) await self._drain_trigger_queue(conversation) @@ -713,7 +941,12 @@ async def execute(self, ctx: NodeContext) -> NodeResult: request_messages, reported_to_parent, ) = await self._run_single_turn( - ctx, conversation, tools, iteration, accumulator + ctx, + conversation, + tools, + iteration, + accumulator, + allow_tool_turn_auto_complete=(not _saw_transient_stream_retry), ) logger.info( "[%s] iter=%d: LLM done — text=%d chars, real_tools=%d, " @@ -807,6 +1040,7 @@ async def execute(self, ctx: NodeContext) -> NodeResult: "split into multiple calls.]" ) + _saw_transient_stream_retry = True await asyncio.sleep(delay) continue # retry same iteration @@ -1040,6 +1274,70 @@ async def execute(self, ctx: NodeContext) -> NodeResult: else: _consecutive_empty_turns = 0 + # 6e''''. Straightforward worker turns can auto-complete as soon + # as required outputs are set. Recovery/escalation paths should + # still flow through the normal post-tool continuation or judge. + if ( + outputs_set + and not ctx.node_spec.client_facing + and not user_input_requested + and not queen_input_requested + and (self._can_auto_complete_outputs(ctx) or _received_queen_guidance) + and not reported_to_parent + and not _saw_transient_stream_retry + and not any( + tc.get("is_error") and tc.get("tool_name") != "set_output" + for tc in logged_tool_calls + ) + and accumulator is not None + and self._outputs_ready_for_auto_complete(ctx, accumulator) + ): + logger.info( + "[%s] iter=%d: required outputs satisfied via set_output — auto-completing", + node_id, + iteration, + ) + await self._publish_loop_completed(stream_id, node_id, iteration + 1, execution_id) + latency_ms = int((time.time() - start_time) * 1000) + _accept_count += 1 + if ctx.runtime_logger: + iter_latency_ms = int((time.time() - iter_start) * 1000) + ctx.runtime_logger.log_step( + node_id=node_id, + node_type="event_loop", + step_index=iteration, + verdict="ACCEPT", + verdict_feedback=("Required outputs satisfied; auto-completed node."), + tool_calls=logged_tool_calls, + llm_text=assistant_text, + input_tokens=turn_tokens.get("input", 0), + output_tokens=turn_tokens.get("output", 0), + latency_ms=iter_latency_ms, + ) + ctx.runtime_logger.log_node_complete( + node_id=node_id, + node_name=ctx.node_spec.name, + node_type="event_loop", + success=True, + total_steps=iteration + 1, + tokens_used=total_input_tokens + total_output_tokens, + input_tokens=total_input_tokens, + output_tokens=total_output_tokens, + latency_ms=latency_ms, + exit_status="success", + accept_count=_accept_count, + retry_count=_retry_count, + escalate_count=_escalate_count, + continue_count=_continue_count, + ) + return NodeResult( + success=True, + output=accumulator.to_dict(), + tokens_used=total_input_tokens + total_output_tokens, + latency_ms=latency_ms, + conversation=conversation if _is_continuous else None, + ) + # 6f. Stall detection recent_responses.append(assistant_text) if len(recent_responses) > self._config.stall_detection_threshold: @@ -1158,10 +1456,144 @@ async def execute(self, ctx: NodeContext) -> NodeResult: else: await conversation.add_user_message(warning_msg) recent_tool_fingerprints.clear() + recent_output_fingerprints.clear() else: # Text-only turn breaks the doom loop chain recent_tool_fingerprints.clear() + # 6f''. Repeated identical set_output-only turns + # If a worker keeps writing the same outputs without any real tool + # work in between, it is usually stuck in a broken node prompt/flow. + set_output_calls = [ + tc for tc in logged_tool_calls if tc.get("tool_name") == "set_output" + ] + if set_output_calls and not mcp_tool_calls and stream_id not in ("queen", "judge"): + output_fps = self._fingerprint_set_output_calls(set_output_calls) + recent_output_fingerprints.append(output_fps) + threshold = self._config.tool_doom_loop_threshold + if len(recent_output_fingerprints) > threshold: + recent_output_fingerprints.pop(0) + is_output_doom, output_desc = self._is_output_doom_loop(recent_output_fingerprints) + if is_output_doom: + logger.warning("[%s] %s", node_id, output_desc) + if self._event_bus: + await self._event_bus.emit_tool_doom_loop( + stream_id=stream_id, + node_id=node_id, + description=output_desc, + execution_id=execution_id, + ) + latency_ms = int((time.time() - start_time) * 1000) + _continue_count += 1 + if ctx.runtime_logger: + iter_latency_ms = int((time.time() - iter_start) * 1000) + ctx.runtime_logger.log_step( + node_id=node_id, + node_type="event_loop", + step_index=iteration, + verdict="CONTINUE", + verdict_feedback=output_desc, + tool_calls=logged_tool_calls, + llm_text=assistant_text, + input_tokens=turn_tokens.get("input", 0), + output_tokens=turn_tokens.get("output", 0), + latency_ms=iter_latency_ms, + ) + ctx.runtime_logger.log_node_complete( + node_id=node_id, + node_name=ctx.node_spec.name, + node_type="event_loop", + success=False, + error=output_desc, + total_steps=iteration + 1, + tokens_used=total_input_tokens + total_output_tokens, + input_tokens=total_input_tokens, + output_tokens=total_output_tokens, + latency_ms=latency_ms, + exit_status="output_doom_loop", + accept_count=_accept_count, + retry_count=_retry_count, + escalate_count=_escalate_count, + continue_count=_continue_count, + ) + return NodeResult( + success=False, + error=output_desc, + output=accumulator.to_dict(), + tokens_used=total_input_tokens + total_output_tokens, + latency_ms=latency_ms, + conversation=conversation if _is_continuous else None, + ) + else: + recent_output_fingerprints.clear() + + # 6f'''. Semantic no-progress churn detection + missing = self._get_missing_output_keys( + accumulator, + ctx.node_spec.output_keys, + ctx.node_spec.nullable_output_keys, + ) + if self._is_no_progress_turn( + missing=missing, + outputs_set=outputs_set, + real_tool_results=real_tool_results, + user_input_requested=user_input_requested, + queen_input_requested=queen_input_requested, + reported_to_parent=reported_to_parent, + ): + no_progress_turns += 1 + else: + no_progress_turns = 0 + + if no_progress_turns >= self._config.stall_detection_threshold: + desc = ( + f"No-progress loop detected: {no_progress_turns} consecutive turns " + f"with missing required outputs {missing} and no tool/input/escalation progress" + ) + logger.warning("[%s] %s", node_id, desc) + await self._publish_stalled(stream_id, node_id, execution_id) + latency_ms = int((time.time() - start_time) * 1000) + _continue_count += 1 + if ctx.runtime_logger: + iter_latency_ms = int((time.time() - iter_start) * 1000) + ctx.runtime_logger.log_step( + node_id=node_id, + node_type="event_loop", + step_index=iteration, + verdict="CONTINUE", + verdict_feedback=desc, + tool_calls=logged_tool_calls, + llm_text=assistant_text, + input_tokens=turn_tokens.get("input", 0), + output_tokens=turn_tokens.get("output", 0), + latency_ms=iter_latency_ms, + ) + ctx.runtime_logger.log_node_complete( + node_id=node_id, + node_name=ctx.node_spec.name, + node_type="event_loop", + success=False, + error=desc, + total_steps=iteration + 1, + tokens_used=total_input_tokens + total_output_tokens, + input_tokens=total_input_tokens, + output_tokens=total_output_tokens, + latency_ms=latency_ms, + exit_status="no_progress_loop", + accept_count=_accept_count, + retry_count=_retry_count, + escalate_count=_escalate_count, + continue_count=_continue_count, + ) + return NodeResult( + success=False, + error=desc, + output=accumulator.to_dict(), + tokens_used=total_input_tokens + total_output_tokens, + latency_ms=latency_ms, + conversation=conversation if _is_continuous else None, + ) + # 6g. Write cursor checkpoint (includes stall/doom state for resume) await self._write_cursor( ctx, @@ -1170,6 +1602,8 @@ async def execute(self, ctx: NodeContext) -> NodeResult: iteration, recent_responses=recent_responses, recent_tool_fingerprints=recent_tool_fingerprints, + recent_output_fingerprints=recent_output_fingerprints, + no_progress_turns=no_progress_turns, ) # 6h'. Client-facing input blocking @@ -1309,6 +1743,11 @@ async def execute(self, ctx: NodeContext) -> NodeResult: prompt=_cf_prompt, options=ask_user_options, questions=multi_qs, + auto_blocked=_cf_auto, + assistant_text_present=bool(_cf_auto and assistant_text.strip()), + assistant_text_requires_input=bool( + _cf_auto and _assistant_text_requires_input(assistant_text) + ), ) # Emit deferred tool_call_completed for ask_user / ask_user_multiple deferred = getattr(self, "_deferred_tool_complete", None) @@ -1369,6 +1808,7 @@ async def execute(self, ctx: NodeContext) -> NodeResult: ) recent_responses.clear() + no_progress_turns = 0 # -- Judge-skip decision after client-facing blocking -- # @@ -1511,7 +1951,9 @@ async def execute(self, ctx: NodeContext) -> NodeResult: ) recent_responses.clear() + no_progress_turns = 0 _cf_text_only_streak = 0 + _received_queen_guidance = True _continue_count += 1 self._log_skip_judge( ctx, @@ -1779,7 +2221,8 @@ async def inject_event( human user (e.g. /chat endpoint), False for external events (e.g. worker question forwarded by the frontend). Controls message formatting in _drain_injection_queue, not wake behavior. - image_content: Optional list of OpenAI-style image blocks to attach. + image_content: Optional multimodal payload blocks that should be + attached to the injected user message. """ await self._injection_queue.put((content, is_client_input, image_content)) self._input_ready.set() @@ -1822,6 +2265,9 @@ async def _await_user_input( options: list[str] | None = None, questions: list[dict] | None = None, emit_client_request: bool = True, + auto_blocked: bool = False, + assistant_text_present: bool = False, + assistant_text_requires_input: bool = False, ) -> bool: """Block until user input arrives or shutdown is signaled. @@ -1862,6 +2308,9 @@ async def _await_user_input( execution_id=ctx.execution_id or "", options=options, questions=questions, + auto_blocked=auto_blocked, + assistant_text_present=assistant_text_present, + assistant_text_requires_input=assistant_text_requires_input, ) self._awaiting_input = True @@ -1882,6 +2331,8 @@ async def _run_single_turn( tools: list[Tool], iteration: int, accumulator: OutputAccumulator, + *, + allow_tool_turn_auto_complete: bool, ) -> tuple[ str, list[dict], @@ -2141,26 +2592,25 @@ async def _do_stream( # --- Framework-level set_output handling --- _tc_start = time.time() _tc_ts = datetime.now(UTC).isoformat() - result = self._handle_set_output(tc.tool_input, ctx.node_spec.output_keys) + value = self._normalize_set_output_value(tc.tool_input.get("value", "")) + key = tc.tool_input.get("key", "") + result = self._handle_set_output( + tc.tool_input, + ctx.node_spec.output_keys, + missing_keys=self._get_missing_output_keys( + accumulator, + ctx.node_spec.output_keys, + ctx.node_spec.nullable_output_keys, + ), + current_value=accumulator.get(key), + normalized_value=value, + ) result = ToolResult( tool_use_id=tc.tool_use_id, content=result.content, is_error=result.is_error, ) if not result.is_error: - value = tc.tool_input.get("value", "") - # Parse JSON strings into native types so downstream - # consumers get lists/dicts instead of serialised JSON, - # and the hallucination validator skips non-string values. - if isinstance(value, str): - try: - parsed = json.loads(value) - if isinstance(parsed, (list, dict, bool, int, float)): - value = parsed - except (json.JSONDecodeError, TypeError): - pass - key = tc.tool_input.get("key", "") - # Auto-spill happens inside accumulator.set() # — it fires on every code path (fresh, resume, # restore) and prevents overwrite regression. @@ -2234,16 +2684,26 @@ async def _do_stream( user_input_requested = True - # Free-form ask_user (no options): stream the question - # text as a chat message so the user can see it. When - # options are present the QuestionWidget shows the - # question, but without options nothing renders it. - if ask_user_options is None and ask_user_prompt and ctx.node_spec.client_facing: + display_prompt: str | None = None + if ask_user_prompt and ctx.node_spec.client_facing: + if ask_user_options is None: + display_prompt = ask_user_prompt + else: + display_prompt, ask_user_prompt = _split_prompt_for_question_widget( + ask_user_prompt, + stream_id=stream_id, + has_structured_choices=True, + ) + + # Free-form ask_user already needs a visible chat + # message; for queen prompts with choices, display_prompt + # lets us surface long worker results before the widget. + if display_prompt: await self._publish_text_delta( stream_id, node_id, - content=ask_user_prompt, - snapshot=ask_user_prompt, + content=display_prompt, + snapshot=display_prompt, ctx=ctx, execution_id=execution_id, iteration=iteration, @@ -2358,27 +2818,6 @@ async def _do_stream( results_by_id[tc.tool_use_id] = result elif tc.tool_name == "delegate_to_sub_agent": - # Guard: in continuous mode the LLM may see delegate - # calls from a previous node's conversation history and - # attempt to re-use the tool on a node that doesn't own - # it. Only accept if the tool was actually offered. - if not any(t.name == "delegate_to_sub_agent" for t in tools): - logger.warning( - "[%s] LLM called delegate_to_sub_agent but tool " - "was not offered to this node — rejecting", - node_id, - ) - result = ToolResult( - tool_use_id=tc.tool_use_id, - content=( - "ERROR: delegate_to_sub_agent is not available " - "on this node. This tool belongs to a different " - "node in the workflow." - ), - is_error=True, - ) - results_by_id[tc.tool_use_id] = result - continue # --- Framework-level subagent delegation --- # Queue for parallel execution in Phase 2 logger.info( @@ -2624,20 +3063,10 @@ async def _timed_subagent( real_tool_results.append(tool_entry) logged_tool_calls.append(tool_entry) - image_content = result.image_content - if image_content and ctx.llm and not supports_image_tool_results(ctx.llm.model): - logger.info( - "Stripping image_content from tool result; " - "model '%s' does not support images in tool results", - ctx.llm.model, - ) - image_content = None - await conversation.add_tool_result( tool_use_id=tc.tool_use_id, content=result.content, is_error=result.is_error, - image_content=image_content, is_skill_content=result.is_skill_content, ) if ( @@ -2767,6 +3196,28 @@ async def _timed_subagent( reported_to_parent, ) + if ( + outputs_set_this_turn + and accumulator is not None + and allow_tool_turn_auto_complete + and self._can_auto_complete_outputs(ctx) + and self._outputs_ready_for_auto_complete(ctx, accumulator) + ): + return ( + final_text, + real_tool_results, + outputs_set_this_turn, + token_counts, + logged_tool_calls, + user_input_requested, + ask_user_prompt, + ask_user_options, + queen_input_requested, + final_system_prompt, + final_messages, + reported_to_parent, + ) + # Tool calls processed -- loop back to stream with updated conversation inner_turn += 1 @@ -2777,38 +3228,385 @@ async def _timed_subagent( # ------------------------------------------------------------------- def _build_ask_user_tool(self) -> Tool: - """Build the synthetic ask_user tool. Delegates to synthetic_tools module.""" - return build_ask_user_tool() + """Build the synthetic ask_user tool for explicit user-input requests. + + Client-facing nodes call ask_user() when they need to pause and wait + for user input. Text-only turns WITHOUT ask_user flow through without + blocking, allowing progress updates and summaries to stream freely. + """ + return Tool( + name="ask_user", + description=( + "You MUST call this tool whenever you need the user's response. " + "Always call it after greeting the user, asking a question, or " + "requesting approval. Do NOT call it for status updates or " + "summaries that don't require a response. " + "Always include 2-3 predefined options. The UI automatically " + "appends an 'Other' free-text input after your options, so NEVER " + "include catch-all options like 'Custom idea', 'Something else', " + "'Other', or 'None of the above' — the UI handles that. " + "When the question primarily needs a typed answer but you must " + "include options, make one option signal that typing is expected " + "(e.g. 'I\\'ll type my response'). This helps users discover the " + "free-text input. " + "The ONLY exception: omit options when the question demands a " + "free-form answer the user must type out (e.g. 'Describe your " + "agent idea', 'Paste the error message'). " + 'Example: {"question": "What would you like to do?", "options": ' + '["Build a new agent", "Modify existing agent", "Run tests"]} ' + "Free-form example: " + '{"question": "Describe the agent you want to build."}' + ), + parameters={ + "type": "object", + "properties": { + "question": { + "type": "string", + "description": "The question or prompt shown to the user.", + }, + "options": { + "type": "array", + "items": {"type": "string"}, + "description": ( + "2-3 specific predefined choices. Include in most cases. " + 'Example: ["Option A", "Option B", "Option C"]. ' + "The UI always appends an 'Other' free-text input, so " + "do NOT include catch-alls like 'Custom idea' or 'Other'. " + "Omit ONLY when the user must type a free-form answer." + ), + "minItems": 2, + "maxItems": 3, + }, + }, + "required": ["question"], + }, + ) def _build_ask_user_multiple_tool(self) -> Tool: - """Build the synthetic ask_user_multiple tool. Delegates to synthetic_tools module.""" - return build_ask_user_multiple_tool() + """Build the synthetic ask_user_multiple tool for batched questions. + + Queen-only tool that presents multiple questions at once so the user + can answer them all in a single interaction rather than one at a time. + """ + return Tool( + name="ask_user_multiple", + description=( + "Ask the user multiple questions at once. Use this instead of " + "ask_user when you have 2 or more questions to ask in the same " + "turn — it lets the user answer everything in one go rather than " + "going back and forth. Each question can have its own predefined " + "options (2-3 choices) or be free-form. The UI renders all " + "questions together with a single Submit button. " + "ALWAYS prefer this over ask_user when you have multiple things " + "to clarify. " + "IMPORTANT: Do NOT repeat the questions in your text response — " + "the widget renders them. Keep your text to a brief intro only. " + 'Example: {"questions": [' + ' {"id": "scope", "prompt": "What scope?", "options": ["Full", "Partial"]},' + ' {"id": "format", "prompt": "Output format?", "options": ["PDF", "CSV", "JSON"]},' + ' {"id": "details", "prompt": "Any special requirements?"}' + "]}" + ), + parameters={ + "type": "object", + "properties": { + "questions": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": ( + "Short identifier for this question (used in the response)." + ), + }, + "prompt": { + "type": "string", + "description": "The question text shown to the user.", + }, + "options": { + "type": "array", + "items": {"type": "string"}, + "description": ( + "2-3 predefined choices. The UI appends an " + "'Other' free-text input automatically. " + "Omit only when the user must type a free-form answer." + ), + "minItems": 2, + "maxItems": 3, + }, + }, + "required": ["id", "prompt"], + }, + "minItems": 2, + "maxItems": 8, + "description": "List of questions to present to the user.", + }, + }, + "required": ["questions"], + }, + ) def _build_set_output_tool(self, output_keys: list[str] | None) -> Tool | None: - """Build the synthetic set_output tool. Delegates to synthetic_tools module.""" - return build_set_output_tool(output_keys) + """Build the synthetic set_output tool for explicit output declaration.""" + if not output_keys: + return None + return Tool( + name="set_output", + description=( + "Set an output value for this node. Call once per output key. " + "Use this for brief notes, counts, status, and file references — " + "NOT for large data payloads. When a tool result was saved to a " + "data file, pass the filename as the value " + "(e.g. 'google_sheets_get_values_1.txt') so the next phase can " + "load the full data. Values exceeding ~2000 characters are " + "auto-saved to data files. " + f"Valid keys: {output_keys}" + ), + parameters={ + "type": "object", + "properties": { + "key": { + "type": "string", + "description": f"Output key. Must be one of: {output_keys}", + "enum": output_keys, + }, + "value": { + "type": "string", + "description": ( + "The output value — a brief note, count, status, " + "or data filename reference." + ), + }, + }, + "required": ["key", "value"], + }, + ) def _build_escalate_tool(self) -> Tool: - """Build the synthetic escalate tool. Delegates to synthetic_tools module.""" - return build_escalate_tool() + """Build the synthetic escalate tool for worker -> queen handoff.""" + return Tool( + name="escalate", + description=( + "Escalate to the queen when requesting user input, " + "blocked by errors, missing " + "credentials, or ambiguous constraints that require supervisor " + "guidance. Include a concise reason and optional context. " + "The node will pause until the queen injects guidance." + ), + parameters={ + "type": "object", + "properties": { + "reason": { + "type": "string", + "description": ( + "Short reason for escalation (e.g. 'Tool repeatedly failing')." + ), + }, + "context": { + "type": "string", + "description": "Optional diagnostic details for the queen.", + }, + }, + "required": ["reason"], + }, + ) def _build_delegate_tool( self, sub_agents: list[str], node_registry: dict[str, Any] ) -> Tool | None: - """Build the synthetic delegate_to_sub_agent tool. Delegates to synthetic_tools module.""" - return build_delegate_tool(sub_agents, node_registry) + """Build the synthetic delegate_to_sub_agent tool for subagent invocation. + + Args: + sub_agents: List of node IDs that can be invoked as subagents. + node_registry: Map of node_id -> NodeSpec for looking up subagent descriptions. + + Returns: + Tool definition if sub_agents is non-empty, None otherwise. + """ + if not sub_agents: + return None + + agent_descriptions = [] + for agent_id in sub_agents: + spec = node_registry.get(agent_id) + if spec: + desc = getattr(spec, "description", "(no description)") + agent_descriptions.append(f"- {agent_id}: {desc}") + else: + agent_descriptions.append(f"- {agent_id}: (not found in registry)") + + return Tool( + name="delegate_to_sub_agent", + description=( + "Delegate a task to a specialized sub-agent. The sub-agent runs " + "autonomously with read-only access to current memory and returns " + "its result. Use this to parallelize work or leverage specialized capabilities.\n\n" + "Available sub-agents:\n" + "\n".join(agent_descriptions) + ), + parameters={ + "type": "object", + "properties": { + "agent_id": { + "type": "string", + "description": f"The sub-agent to invoke. Must be one of: {sub_agents}", + "enum": sub_agents, + }, + "task": { + "type": "string", + "description": ( + "The task description for the sub-agent to execute. " + "Be specific about what you want the sub-agent to do and " + "what information to return." + ), + }, + }, + "required": ["agent_id", "task"], + }, + ) def _build_report_to_parent_tool(self) -> Tool: - """Build the synthetic report_to_parent tool. Delegates to synthetic_tools module.""" - return build_report_to_parent_tool() + """Build the synthetic report_to_parent tool for sub-agent progress reports. + + Sub-agents call this to send one-way progress updates, partial findings, + or status reports to the parent node (and external observers via event bus) + without blocking execution. + + When ``wait_for_response`` is True, the sub-agent blocks until the parent + relays the user's response — used for escalation (e.g. login pages, CAPTCHAs). + + When ``mark_complete`` is True, the sub-agent terminates immediately after + sending the report — no need to call set_output for each output key. + """ + return Tool( + name="report_to_parent", + description=( + "Send a report to the parent agent. By default this is fire-and-forget: " + "the parent receives the report but does not respond. " + "Set wait_for_response=true to BLOCK until the user replies — use this " + "when you need human intervention (e.g. login pages, CAPTCHAs, " + "authentication walls). The user's response is returned as the tool result. " + "Set mark_complete=true to finish your task and terminate immediately " + "after sending the report — use this when your findings are in the " + "message/data fields and you don't need to call set_output." + ), + parameters={ + "type": "object", + "properties": { + "message": { + "type": "string", + "description": "A human-readable status or progress message.", + }, + "data": { + "type": "object", + "description": "Optional structured data to include with the report.", + }, + "wait_for_response": { + "type": "boolean", + "description": ( + "If true, block execution until the user responds. " + "Use for escalation scenarios requiring human intervention." + ), + "default": False, + }, + "mark_complete": { + "type": "boolean", + "description": ( + "If true, terminate the sub-agent immediately after sending " + "this report. The report message and data are delivered to the " + "parent as the final result. No set_output calls are needed." + ), + "default": False, + }, + }, + "required": ["message"], + }, + ) def _handle_set_output( self, tool_input: dict[str, Any], output_keys: list[str] | None, + *, + missing_keys: list[str] | None = None, + current_value: Any = None, + normalized_value: Any = None, ) -> ToolResult: - """Handle set_output tool call. Delegates to synthetic_tools module.""" - return handle_set_output(tool_input, output_keys) + """Handle set_output tool call. Returns ToolResult (sync).""" + key = tool_input.get("key", "") + value = tool_input.get("value", "") + valid_keys = output_keys or [] + + # Recover from truncated JSON (max_tokens hit mid-argument). + # The _raw key is set by litellm when json.loads fails. + if not key and "_raw" in tool_input: + import re + + raw = tool_input["_raw"] + key_match = re.search(r'"key"\s*:\s*"(\w+)"', raw) + if key_match: + key = key_match.group(1) + val_match = re.search(r'"value"\s*:\s*"', raw) + if val_match: + start = val_match.end() + value = raw[start:].rstrip() + for suffix in ('"}\n', '"}', '"'): + if value.endswith(suffix): + value = value[: -len(suffix)] + break + if key: + logger.warning( + "Recovered set_output args from truncated JSON: key=%s, value_len=%d", + key, + len(value), + ) + # Re-inject so the caller sees proper key/value + tool_input["key"] = key + tool_input["value"] = value + + if key not in valid_keys: + return ToolResult( + tool_use_id="", + content=f"Invalid output key '{key}'. Valid keys: {valid_keys}", + is_error=True, + ) + + candidate = ( + normalized_value + if normalized_value is not None + else self._normalize_set_output_value(value) + ) + if current_value is not None and current_value == candidate: + remaining = [k for k in (missing_keys or []) if k != key] + suffix = f" Remaining required outputs: {remaining}." if remaining else "" + return ToolResult( + tool_use_id="", + content=( + f"Output '{key}' is already set to the same value. " + "Do not repeat identical set_output calls; set remaining outputs or finish." + f"{suffix}" + ), + is_error=True, + ) + + return ToolResult( + tool_use_id="", + content=f"Output '{key}' set successfully.", + is_error=False, + ) + + @staticmethod + def _normalize_set_output_value(value: Any) -> Any: + """Parse JSON-ish scalar/list/dict strings into native values when safe.""" + if isinstance(value, str): + try: + parsed = json.loads(value) + if isinstance(parsed, (list, dict, bool, int, float)): + return parsed + except (json.JSONDecodeError, TypeError): + return value + return value # ------------------------------------------------------------------- # Judge evaluation @@ -2823,20 +3621,120 @@ async def _judge_turn( tool_results: list[dict], iteration: int, ) -> JudgeVerdict: - """Evaluate the current state. Delegates to judge_pipeline module.""" - return await judge_turn( - mark_complete_flag=self._mark_complete_flag, - judge=self._judge, - ctx=ctx, - conversation=conversation, - accumulator=accumulator, - assistant_text=assistant_text, - tool_results=tool_results, - iteration=iteration, - get_missing_output_keys_fn=self._get_missing_output_keys, - max_context_tokens=self._config.max_context_tokens, + """Evaluate the current state using judge or implicit logic. + + Evaluation levels (in order): + 0. Short-circuits: mark_complete, skip_judge, tool-continue. + 1. Custom judge (JudgeProtocol) — full authority when set. + 2. Implicit judge — output-key check + optional conversation-aware + quality gate (when ``success_criteria`` is defined). + + Returns a JudgeVerdict. ``feedback=None`` means no real evaluation + happened (skip_judge, tool-continue); the caller must not inject a + feedback message. Any non-None feedback (including ``""``) means a + real evaluation occurred and will be logged into the conversation. + """ + + # --- Level 0: short-circuits (no evaluation) ----------------------- + + if self._mark_complete_flag: + return JudgeVerdict(action="ACCEPT") + + if ctx.node_spec.skip_judge: + return JudgeVerdict(action="RETRY") # feedback=None → not logged + + # --- Level 1: custom judge ----------------------------------------- + + if self._judge is not None: + context = { + "assistant_text": assistant_text, + "tool_calls": tool_results, + "output_accumulator": accumulator.to_dict(), + "accumulator": accumulator, + "iteration": iteration, + "conversation_summary": conversation.export_summary(), + "output_keys": ctx.node_spec.output_keys, + "missing_keys": self._get_missing_output_keys( + accumulator, ctx.node_spec.output_keys, ctx.node_spec.nullable_output_keys + ), + } + verdict = await self._judge.evaluate(context) + # Ensure evaluated RETRY always carries feedback for logging. + if verdict.action == "RETRY" and not verdict.feedback: + return JudgeVerdict(action="RETRY", feedback="Custom judge returned RETRY.") + return verdict + + # --- Level 2: implicit judge --------------------------------------- + + # Real tool calls were made — let the agent keep working. + if tool_results: + return JudgeVerdict(action="RETRY") # feedback=None → not logged + + missing = self._get_missing_output_keys( + accumulator, ctx.node_spec.output_keys, ctx.node_spec.nullable_output_keys ) + if missing: + return JudgeVerdict( + action="RETRY", + feedback=( + f"Task incomplete. Required outputs not yet produced: {missing}. " + f"Follow your system prompt instructions to complete the work." + ), + ) + + # All output keys present — run safety checks before accepting. + + output_keys = ctx.node_spec.output_keys or [] + nullable_keys = set(ctx.node_spec.nullable_output_keys or []) + + # All-nullable with nothing set → node produced nothing useful. + all_nullable = output_keys and nullable_keys >= set(output_keys) + none_set = not any(accumulator.get(k) is not None for k in output_keys) + if all_nullable and none_set: + return JudgeVerdict( + action="RETRY", + feedback=( + f"No output keys have been set yet. " + f"Use set_output to set at least one of: {output_keys}" + ), + ) + + # Client-facing with no output keys → continuous interaction node. + # Inject tool-use pressure instead of auto-accepting. + if not output_keys and ctx.node_spec.client_facing: + return JudgeVerdict( + action="RETRY", + feedback=( + "STOP describing what you will do. " + "You have FULL access to all tools — file creation, " + "shell commands, MCP tools — and you CAN call them " + "directly in your response. Respond ONLY with tool " + "calls, no prose. Execute the task now." + ), + ) + + # Level 2b: conversation-aware quality check (if success_criteria set) + if ctx.node_spec.success_criteria and ctx.llm: + from framework.graph.conversation_judge import evaluate_phase_completion + + verdict = await evaluate_phase_completion( + llm=ctx.llm, + conversation=conversation, + phase_name=ctx.node_spec.name, + phase_description=ctx.node_spec.description, + success_criteria=ctx.node_spec.success_criteria, + accumulator_state=accumulator.to_dict(), + max_context_tokens=self._config.max_context_tokens, + ) + if verdict.action != "ACCEPT": + return JudgeVerdict( + action=verdict.action, + feedback=verdict.feedback or "Phase criteria not met.", + ) + + return JudgeVerdict(action="ACCEPT", feedback="") + # ------------------------------------------------------------------- # Helpers # ------------------------------------------------------------------- @@ -2890,40 +3788,295 @@ def _get_missing_output_keys( skip = set(nullable_keys) if nullable_keys else set() return [k for k in output_keys if k not in skip and accumulator.get(k) is None] + def _outputs_ready_for_auto_complete( + self, + ctx: NodeContext, + accumulator: OutputAccumulator, + ) -> bool: + """Whether a worker node has produced all meaningful required outputs.""" + output_keys = ctx.node_spec.output_keys or [] + if not output_keys: + return False + + missing = self._get_missing_output_keys( + accumulator, + output_keys, + ctx.node_spec.nullable_output_keys, + ) + if missing: + return False + + nullable_keys = set(ctx.node_spec.nullable_output_keys or []) + all_nullable = nullable_keys >= set(output_keys) + none_set = not any(accumulator.get(k) is not None for k in output_keys) + return not (all_nullable and none_set) + + def _can_auto_complete_outputs(self, ctx: NodeContext) -> bool: + """Only simple nodes should bypass the normal judge path. + + Nodes with custom judges or success criteria still need the normal + evaluation path so they can enforce retry/accept semantics. + """ + return ( + self._judge is None + and not ctx.node_spec.sub_agents + and not ctx.node_spec.success_criteria + and not self._mark_complete_flag + ) + @staticmethod def _ngram_similarity(s1: str, s2: str, n: int = 2) -> float: - """Jaccard similarity of n-gram sets. Delegates to stall_detector module.""" - return ngram_similarity(s1, s2, n) + """Jaccard similarity of n-gram sets. + + Returns 0.0-1.0, where 1.0 is exact match. + Fast: O(len(s) + len(s2)) using set operations. + """ + + def _ngrams(s: str) -> set[str]: + return {s[i : i + n] for i in range(len(s) - n + 1) if s.strip()} + + if not s1 or not s2: + return 0.0 + + ngrams1, ngrams2 = _ngrams(s1.lower()), _ngrams(s2.lower()) + if not ngrams1 or not ngrams2: + return 0.0 + + intersection = len(ngrams1 & ngrams2) + union = len(ngrams1 | ngrams2) + return intersection / union if union else 0.0 def _is_stalled(self, recent_responses: list[str]) -> bool: - """Detect stall using n-gram similarity. Delegates to stall_detector module.""" - return is_stalled( - recent_responses, - self._config.stall_detection_threshold, - self._config.stall_similarity_threshold, - ) + """Detect stall using n-gram similarity. + + Detects when ALL N consecutive responses are mutually similar + (>= threshold). A single dissimilar response resets the signal. + This catches phrases like "I'm still stuck" vs "I'm stuck" + without false-positives on "attempt 1" vs "attempt 2". + """ + if len(recent_responses) < self._config.stall_detection_threshold: + return False + if not recent_responses[0]: + return False + + threshold = self._config.stall_similarity_threshold + # Every consecutive pair must be similar + for i in range(1, len(recent_responses)): + if self._ngram_similarity(recent_responses[i], recent_responses[i - 1]) < threshold: + return False + return True @staticmethod def _is_transient_error(exc: BaseException) -> bool: - """Classify whether an exception is transient. Delegates to tool_result_handler module.""" - return is_transient_error(exc) + """Classify whether an exception is transient (retryable) vs permanent. + + Transient: network errors, rate limits, server errors, timeouts. + Permanent: auth errors, bad requests, context window exceeded. + """ + try: + from litellm.exceptions import ( + APIConnectionError, + BadGatewayError, + InternalServerError, + RateLimitError, + ServiceUnavailableError, + ) + + transient_types: tuple[type[BaseException], ...] = ( + RateLimitError, + APIConnectionError, + InternalServerError, + BadGatewayError, + ServiceUnavailableError, + TimeoutError, + ConnectionError, + OSError, + ) + except ImportError: + transient_types = (TimeoutError, ConnectionError, OSError) + + if isinstance(exc, transient_types): + return True + + # RuntimeError from StreamErrorEvent with "Stream error:" prefix + if isinstance(exc, RuntimeError): + error_str = str(exc).lower() + transient_keywords = [ + "rate limit", + "429", + "timeout", + "connection", + "internal server", + "502", + "503", + "504", + "service unavailable", + "bad gateway", + "overloaded", + "failed to parse tool call", + ] + return any(kw in error_str for kw in transient_keywords) + + return False @staticmethod def _fingerprint_tool_calls( tool_results: list[dict], ) -> list[tuple[str, str]]: - """Create deterministic fingerprints. Delegates to stall_detector module.""" - return fingerprint_tool_calls(tool_results) + """Create deterministic fingerprints for a turn's tool calls. + + Each fingerprint is (tool_name, canonical_args_json). Order-sensitive + so [search("a"), fetch("b")] != [fetch("b"), search("a")]. + """ + fingerprints = [] + for tr in tool_results: + name = tr.get("tool_name", "") + args = tr.get("tool_input", {}) + try: + canonical = json.dumps(args, sort_keys=True, default=str) + except (TypeError, ValueError): + canonical = str(args) + fingerprints.append((name, canonical)) + return fingerprints + + @staticmethod + def _fingerprint_set_output_calls( + tool_results: list[dict], + ) -> list[tuple[str, str]]: + """Create deterministic fingerprints for set_output-only turns.""" + fingerprints: list[tuple[str, str]] = [] + for tr in tool_results: + args = tr.get("tool_input", {}) + key = str(args.get("key", "")) + value = args.get("value") + try: + canonical = json.dumps(value, sort_keys=True, default=str) + except (TypeError, ValueError): + canonical = str(value) + fingerprints.append((key, canonical)) + return fingerprints def _is_tool_doom_loop( self, recent_tool_fingerprints: list[list[tuple[str, str]]], ) -> tuple[bool, str]: - """Detect doom loop. Delegates to stall_detector module.""" - return is_tool_doom_loop( - recent_tool_fingerprints=recent_tool_fingerprints, - threshold=self._config.tool_doom_loop_threshold, - enabled=self._config.tool_doom_loop_enabled, + """Detect doom loop via exact fingerprint match. + + Detects when N consecutive turns invoke the same tools with + identical (canonicalized) arguments. Different arguments mean + different work, so only exact matches count. + + Returns (is_doom_loop, description). + """ + if not self._config.tool_doom_loop_enabled: + return False, "" + threshold = self._config.tool_doom_loop_threshold + if len(recent_tool_fingerprints) < threshold: + return False, "" + first = recent_tool_fingerprints[0] + if not first: + return False, "" + + # All turns in the window must match the first exactly + if all(fp == first for fp in recent_tool_fingerprints[1:]): + tool_names = [name for name, _ in first] + desc = ( + f"Doom loop detected: {len(recent_tool_fingerprints)} " + f"identical consecutive tool calls ({', '.join(tool_names)})" + ) + return True, desc + return False, "" + + def _is_output_doom_loop( + self, + recent_output_fingerprints: list[list[tuple[str, str]]], + ) -> tuple[bool, str]: + """Detect repeated identical set_output-only turns.""" + threshold = self._config.tool_doom_loop_threshold + if len(recent_output_fingerprints) < threshold: + return False, "" + first = recent_output_fingerprints[0] + if not first: + return False, "" + if all(fp == first for fp in recent_output_fingerprints[1:]): + keys = [key for key, _ in first] + desc = ( + f"Output doom loop detected: {len(recent_output_fingerprints)} " + f"identical consecutive set_output turn(s) ({', '.join(keys)})" + ) + return True, desc + if self._is_meta_reset_output_loop(recent_output_fingerprints): + key = first[0][0] + desc = ( + f"Output doom loop detected: {len(recent_output_fingerprints)} " + f"consecutive '{key}' resets asking for a fresh payload instead of " + "progressing the node" + ) + return True, desc + return False, "" + + @staticmethod + def _is_no_progress_turn( + *, + missing: list[str], + outputs_set: list[str], + real_tool_results: list[dict], + user_input_requested: bool, + queen_input_requested: bool, + reported_to_parent: bool, + ) -> bool: + """Whether the turn made no contract-level progress. + + This is intentionally state-based instead of text-based. If required + outputs are still missing and the model did not call a real tool, + set an output, ask for input, escalate, or report to a parent, then + the turn made no meaningful runtime progress regardless of how the + prose is worded. + """ + if not missing: + return False + if outputs_set: + return False + if real_tool_results: + return False + if user_input_requested or queen_input_requested or reported_to_parent: + return False + return True + + @staticmethod + def _looks_like_meta_reset_output(canonical_value: str) -> bool: + text = canonical_value.lower() + if not text: + return False + + markers = ( + "awaiting fresh", + "fresh request payload", + "phase transition payload", + "phase transition instructions", + "fresh phase transition", + "new event acknowledged", + "context reset", + "structured inputs to proceed", + ) + return sum(1 for marker in markers if marker in text) >= 2 + + def _is_meta_reset_output_loop( + self, + recent_output_fingerprints: list[list[tuple[str, str]]], + ) -> bool: + threshold = self._config.tool_doom_loop_threshold + if len(recent_output_fingerprints) < threshold: + return False + if any(len(turn) != 1 for turn in recent_output_fingerprints): + return False + + first_key = recent_output_fingerprints[0][0][0] + if not all(turn[0][0] == first_key for turn in recent_output_fingerprints[1:]): + return False + + return all( + self._looks_like_meta_reset_output(turn[0][1]) for turn in recent_output_fingerprints ) async def _execute_tool(self, tc: ToolCallEvent) -> ToolResult: @@ -2935,12 +4088,74 @@ async def _execute_tool(self, tc: ToolCallEvent) -> ToolResult: sync executors (MCP STDIO tools that block on ``future.result()``) don't freeze the event loop. """ - return await execute_tool( - tool_executor=self._tool_executor, - tc=tc, - timeout=self._config.tool_call_timeout_seconds, - skill_dirs=getattr(self, "_skill_dirs", []), - ) + if self._tool_executor is None: + return ToolResult( + tool_use_id=tc.tool_use_id, + content=f"No tool executor configured for '{tc.tool_name}'", + is_error=True, + ) + + # AS-9: Intercept file-read tools for skill directories — bypass session sandbox + _SKILL_READ_TOOLS = {"view_file", "load_data", "read_file"} + skill_dirs = getattr(self, "_skill_dirs", []) + if tc.tool_name in _SKILL_READ_TOOLS and skill_dirs: + _path = tc.tool_input.get("path", "") + if _path: + import os + from pathlib import Path as _Path + + _resolved = os.path.realpath(os.path.abspath(_path)) + if any(_resolved.startswith(os.path.realpath(d)) for d in skill_dirs): + try: + _content = _Path(_resolved).read_text(encoding="utf-8") + _is_skill_md = _resolved.endswith("SKILL.md") + return ToolResult( + tool_use_id=tc.tool_use_id, + content=_content, + is_skill_content=_is_skill_md, # AS-10: protect SKILL.md reads + ) + except Exception as _exc: + return ToolResult( + tool_use_id=tc.tool_use_id, + content=f"Could not read skill resource '{_path}': {_exc}", + is_error=True, + ) + + tool_use = ToolUse(id=tc.tool_use_id, name=tc.tool_name, input=tc.tool_input) + timeout = self._config.tool_call_timeout_seconds + + async def _run() -> ToolResult: + # Offload the executor call to a thread. Sync MCP executors + # block on future.result() — running in a thread keeps the + # event loop free so asyncio.wait_for can fire the timeout. + loop = asyncio.get_running_loop() + ctx = contextvars.copy_context() + result = await loop.run_in_executor( + None, + lambda: ctx.run(self._tool_executor, tool_use), + ) + # Async executors return a coroutine — await it on the loop + if asyncio.iscoroutine(result) or asyncio.isfuture(result): + result = await result + return result + + try: + if timeout > 0: + result = await asyncio.wait_for(_run(), timeout=timeout) + else: + result = await _run() + except TimeoutError: + logger.warning("Tool '%s' timed out after %.0fs", tc.tool_name, timeout) + return ToolResult( + tool_use_id=tc.tool_use_id, + content=( + f"Tool '{tc.tool_name}' timed out after {timeout:.0f}s. " + "The operation took too long and was cancelled. " + "Try a simpler request or a different approach." + ), + is_error=True, + ) + return result def _record_learning(self, key: str, value: Any) -> None: """Append a set_output value to adapt.md as a learning entry. @@ -2950,11 +4165,39 @@ def _record_learning(self, key: str, value: Any) -> None: adapt.md is injected into the system prompt, these persist through any compaction. """ - return record_learning( - key=key, - value=value, - spillover_dir=self._config.spillover_dir, - ) + if not self._config.spillover_dir: + return + try: + adapt_path = Path(self._config.spillover_dir) / "adapt.md" + adapt_path.parent.mkdir(parents=True, exist_ok=True) + content = adapt_path.read_text(encoding="utf-8") if adapt_path.exists() else "" + + if "## Outputs" not in content: + content += "\n\n## Outputs\n" + + # Truncate long values for memory (full value is in shared memory) + v_str = str(value) + if len(v_str) > 500: + v_str = v_str[:500] + "…" + + entry = f"- {key}: {v_str}\n" + + # Replace existing entry for same key (update, not duplicate) + lines = content.splitlines(keepends=True) + replaced = False + for i, line in enumerate(lines): + if line.startswith(f"- {key}:"): + lines[i] = entry + replaced = True + break + if replaced: + content = "".join(lines) + else: + content += entry + + adapt_path.write_text(content, encoding="utf-8") + except Exception as e: + logger.warning("Failed to record learning for key=%s: %s", key, e) def _next_spill_filename(self, tool_name: str) -> str: """Return a short, monotonic filename for a tool result spill.""" @@ -2965,9 +4208,22 @@ def _next_spill_filename(self, tool_name: str) -> str: def _restore_spill_counter(self) -> None: """Scan spillover_dir for existing spill files and restore the counter.""" - self._spill_counter = restore_spill_counter( - spillover_dir=self._config.spillover_dir, - ) + spill_dir = self._config.spillover_dir + if not spill_dir: + return + spill_path = Path(spill_dir) + if not spill_path.is_dir(): + return + max_n = 0 + for f in spill_path.iterdir(): + if not f.is_file(): + continue + m = re.search(r"_(\d+)\.txt$", f.name) + if m: + max_n = max(max_n, int(m.group(1))) + if max_n > self._spill_counter: + self._spill_counter = max_n + logger.info("Restored spill counter to %d from existing files", max_n) # ------------------------------------------------------------------ # JSON metadata / smart preview helpers for truncation @@ -2982,9 +4238,53 @@ def _extract_json_metadata(parsed: Any, *, _depth: int = 0, _max_depth: int = 3) Returns an empty string for simple scalars. """ - return extract_json_metadata( - parsed=parsed, - ) + if _depth >= _max_depth: + if isinstance(parsed, dict): + return f"dict with {len(parsed)} keys" + if isinstance(parsed, list): + return f"list of {len(parsed)} items" + return type(parsed).__name__ + + if isinstance(parsed, dict): + if not parsed: + return "empty dict" + lines: list[str] = [] + indent = " " * (_depth + 1) + for key, value in list(parsed.items())[:20]: + if isinstance(value, list): + line = f'{indent}"{key}": list of {len(value)} items' + if value: + first = value[0] + if isinstance(first, dict): + sample_keys = list(first.keys())[:10] + line += f" (each item: dict with keys {sample_keys})" + elif isinstance(first, list): + line += f" (each item: list of {len(first)} elements)" + lines.append(line) + elif isinstance(value, dict): + child = EventLoopNode._extract_json_metadata( + value, _depth=_depth + 1, _max_depth=_max_depth + ) + lines.append(f'{indent}"{key}": {child}') + else: + lines.append(f'{indent}"{key}": {type(value).__name__}') + if len(parsed) > 20: + lines.append(f"{indent}... and {len(parsed) - 20} more keys") + return "\n".join(lines) + + if isinstance(parsed, list): + if not parsed: + return "empty list" + desc = f"list of {len(parsed)} items" + first = parsed[0] + if isinstance(first, dict): + sample_keys = list(first.keys())[:10] + desc += f" (each item: dict with keys {sample_keys})" + elif isinstance(first, list): + desc += f" (each item: list of {len(first)} elements)" + return desc + + return "" @staticmethod def _build_json_preview(parsed: Any, *, max_chars: int = 5000) -> str | None: @@ -2995,10 +4295,54 @@ def _build_json_preview(parsed: Any, *, max_chars: int = 5000) -> str | None: Returns ``None`` if no truncation was needed (no large arrays). """ - return build_json_preview( - parsed=parsed, - max_chars=max_chars, - ) + _LARGE_ARRAY_THRESHOLD = 10 + + def _truncate_arrays(obj: Any) -> tuple[Any, bool]: + """Return (truncated_copy, was_truncated).""" + if isinstance(obj, list) and len(obj) > _LARGE_ARRAY_THRESHOLD: + n = len(obj) + head = obj[:3] + tail = obj[-1:] + marker = f"... ({n - 4} more items omitted, {n} total) ..." + return head + [marker] + tail, True + if isinstance(obj, dict): + changed = False + out: dict[str, Any] = {} + for k, v in obj.items(): + new_v, did = _truncate_arrays(v) + out[k] = new_v + changed = changed or did + return (out, True) if changed else (obj, False) + return obj, False + + preview_obj, was_truncated = _truncate_arrays(parsed) + if not was_truncated: + return None # No large arrays — caller should use raw slicing + + try: + result = json.dumps(preview_obj, indent=2, ensure_ascii=False) + except (TypeError, ValueError): + return None + + if len(result) > max_chars: + # Even 3+1 items too big — try just 1 item + def _minimal_arrays(obj: Any) -> Any: + if isinstance(obj, list) and len(obj) > _LARGE_ARRAY_THRESHOLD: + n = len(obj) + return obj[:1] + [f"... ({n - 1} more items omitted, {n} total) ..."] + if isinstance(obj, dict): + return {k: _minimal_arrays(v) for k, v in obj.items()} + return obj + + preview_obj = _minimal_arrays(parsed) + try: + result = json.dumps(preview_obj, indent=2, ensure_ascii=False) + except (TypeError, ValueError): + return None + if len(result) > max_chars: + result = result[:max_chars] + "…" + + return result def _truncate_tool_result( self, @@ -3017,13 +4361,175 @@ def _truncate_tool_result( - Errors: pass through unchanged - load_data results: truncate with pagination hint (no re-spill) """ - return truncate_tool_result( - result=result, - tool_name=tool_name, - max_tool_result_chars=self._config.max_tool_result_chars, - spillover_dir=self._config.spillover_dir, - next_spill_filename_fn=self._next_spill_filename, - ) + limit = self._config.max_tool_result_chars + + # Errors always pass through unchanged + if result.is_error: + return result + + # load_data reads FROM spilled files — never re-spill (circular). + # Just truncate with a pagination hint if the result is too large. + if tool_name == "load_data": + if limit <= 0 or len(result.content) <= limit: + return result # Small load_data result — pass through as-is + # Large load_data result — truncate with smart preview + PREVIEW_CAP = min(5000, max(limit - 500, limit // 2)) + + metadata_str = "" + smart_preview: str | None = None + try: + parsed_ld = json.loads(result.content) + metadata_str = self._extract_json_metadata(parsed_ld) + smart_preview = self._build_json_preview(parsed_ld, max_chars=PREVIEW_CAP) + except (json.JSONDecodeError, TypeError, ValueError): + pass + + if smart_preview is not None: + preview_block = smart_preview + else: + preview_block = result.content[:PREVIEW_CAP] + "…" + + header = ( + f"[{tool_name} result: {len(result.content):,} chars — " + f"too large for context. Use offset_bytes/limit_bytes " + f"parameters to read smaller chunks.]" + ) + if metadata_str: + header += f"\n\nData structure:\n{metadata_str}" + header += ( + "\n\nWARNING: This is an INCOMPLETE preview. " + "Do NOT draw conclusions or counts from it." + ) + + truncated = f"{header}\n\nPreview (small sample only):\n{preview_block}" + logger.info( + "%s result truncated: %d → %d chars (use offset/limit to paginate)", + tool_name, + len(result.content), + len(truncated), + ) + return ToolResult( + tool_use_id=result.tool_use_id, + content=truncated, + is_error=False, + ) + + spill_dir = self._config.spillover_dir + if spill_dir: + spill_path = Path(spill_dir) + spill_path.mkdir(parents=True, exist_ok=True) + filename = self._next_spill_filename(tool_name) + + # Pretty-print JSON content so load_data's line-based + # pagination works correctly. + write_content = result.content + parsed_json: Any = None # track for metadata extraction + try: + parsed_json = json.loads(result.content) + write_content = json.dumps(parsed_json, indent=2, ensure_ascii=False) + except (json.JSONDecodeError, TypeError, ValueError): + pass # Not JSON — write as-is + + (spill_path / filename).write_text(write_content, encoding="utf-8") + + if limit > 0 and len(result.content) > limit: + # Large result: build a small, metadata-rich preview so the + # LLM cannot mistake it for the complete dataset. + PREVIEW_CAP = 5000 + + # Extract structural metadata (array lengths, key names) + metadata_str = "" + smart_preview: str | None = None + if parsed_json is not None: + metadata_str = self._extract_json_metadata(parsed_json) + smart_preview = self._build_json_preview(parsed_json, max_chars=PREVIEW_CAP) + + if smart_preview is not None: + preview_block = smart_preview + else: + preview_block = result.content[:PREVIEW_CAP] + "…" + + # Assemble header with structural info + warning + header = ( + f"[Result from {tool_name}: {len(result.content):,} chars — " + f"too large for context, saved to '{filename}'.]" + ) + if metadata_str: + header += f"\n\nData structure:\n{metadata_str}" + header += ( + f"\n\nWARNING: The preview below is INCOMPLETE. " + f"Do NOT draw conclusions or counts from it. " + f"Use load_data(filename='{filename}') to read the " + f"full data from session storage before analysis. " + "Do NOT open this filename from the workspace or current directory." + ) + + content = f"{header}\n\nPreview (small sample only):\n{preview_block}" + logger.info( + "Tool result spilled to file: %s (%d chars → %s)", + tool_name, + len(result.content), + filename, + ) + else: + # Small result: keep full content + annotation + content = f"{result.content}\n\n[Saved to '{filename}']" + logger.info( + "Tool result saved to file: %s (%d chars → %s)", + tool_name, + len(result.content), + filename, + ) + + return ToolResult( + tool_use_id=result.tool_use_id, + content=content, + is_error=False, + ) + + # No spillover_dir — truncate in-place if needed + if limit > 0 and len(result.content) > limit: + PREVIEW_CAP = min(5000, max(limit - 500, limit // 2)) + + metadata_str = "" + smart_preview: str | None = None + try: + parsed_inline = json.loads(result.content) + metadata_str = self._extract_json_metadata(parsed_inline) + smart_preview = self._build_json_preview(parsed_inline, max_chars=PREVIEW_CAP) + except (json.JSONDecodeError, TypeError, ValueError): + pass + + if smart_preview is not None: + preview_block = smart_preview + else: + preview_block = result.content[:PREVIEW_CAP] + "…" + + header = ( + f"[Result from {tool_name}: {len(result.content):,} chars — " + f"truncated to fit context budget.]" + ) + if metadata_str: + header += f"\n\nData structure:\n{metadata_str}" + header += ( + "\n\nWARNING: This is an INCOMPLETE preview. " + "Do NOT draw conclusions or counts from the preview alone." + ) + + truncated = f"{header}\n\n{preview_block}" + logger.info( + "Tool result truncated in-place: %s (%d → %d chars)", + tool_name, + len(result.content), + len(truncated), + ) + return ToolResult( + tool_use_id=result.tool_use_id, + content=truncated, + is_error=False, + ) + + return result # --- Compaction ----------------------------------------------------------- @@ -3048,15 +4554,84 @@ async def _compact( does not fully resolve the budget. 4. Emergency deterministic summary only if LLM failed or unavailable. """ - return await compact( - ctx=ctx, - conversation=conversation, - accumulator=accumulator, - config=self._config, - event_bus=self._event_bus, - char_limit=self._LLM_COMPACT_CHAR_LIMIT, - max_depth=self._LLM_COMPACT_MAX_DEPTH, + ratio_before = conversation.usage_ratio() + phase_grad = getattr(ctx, "continuous_mode", False) + + # Capture pre-compaction message inventory when over budget, + # since compaction mutates the conversation in place. + pre_inventory: list[dict[str, Any]] | None = None + if ratio_before >= 1.0: + pre_inventory = self._build_message_inventory(conversation) + + # --- Step 1: Prune old tool results (free, no LLM) --- + protect = max(2000, self._config.max_context_tokens // 12) + pruned = await conversation.prune_old_tool_results( + protect_tokens=protect, + min_prune_tokens=max(1000, protect // 3), + ) + if pruned > 0: + logger.info( + "Pruned %d old tool results: %.0f%% -> %.0f%%", + pruned, + ratio_before * 100, + conversation.usage_ratio() * 100, + ) + if not conversation.needs_compaction(): + await self._log_compaction(ctx, conversation, ratio_before, pre_inventory) + return + + # --- Step 2: Standard structure-preserving compaction (free, no LLM) --- + # Removes freeform text to spillover files; keeps tool-call pairs in context. + spill_dir = self._config.spillover_dir + if spill_dir: + await conversation.compact_preserving_structure( + spillover_dir=spill_dir, + keep_recent=4, + phase_graduated=phase_grad, + ) + if not conversation.needs_compaction(): + await self._log_compaction(ctx, conversation, ratio_before, pre_inventory) + return + + # --- Step 3: LLM summary compaction --- + # Structural compaction alone did not hit target. Generate an LLM summary + # and place it as the first message — more reliable for token reduction + # than offloading more content to files. + if ctx.llm is not None: + logger.info( + "LLM summary compaction triggered (%.0f%% usage)", + conversation.usage_ratio() * 100, + ) + try: + summary = await self._llm_compact( + ctx, + list(conversation.messages), + accumulator, + ) + await conversation.compact( + summary, + keep_recent=2, + phase_graduated=phase_grad, + ) + except Exception as e: + logger.warning("LLM compaction failed: %s", e) + + if not conversation.needs_compaction(): + await self._log_compaction(ctx, conversation, ratio_before, pre_inventory) + return + + # --- Step 4: Emergency deterministic summary (LLM failed/unavailable) --- + logger.warning( + "Emergency compaction (%.0f%% usage)", + conversation.usage_ratio() * 100, + ) + summary = self._build_emergency_summary(ctx, accumulator, conversation) + await conversation.compact( + summary, + keep_recent=1, + phase_graduated=phase_grad, ) + await self._log_compaction(ctx, conversation, ratio_before, pre_inventory) # --- LLM compaction with binary-search splitting ---------------------- @@ -3074,22 +4649,101 @@ async def _llm_compact( in half and each half is summarised independently. Tool history is appended once at the top-level call (``_depth == 0``). """ - return await llm_compact( - ctx=ctx, - messages=messages, - accumulator=accumulator, - _depth=_depth, - char_limit=self._LLM_COMPACT_CHAR_LIMIT, - max_depth=self._LLM_COMPACT_MAX_DEPTH, - max_context_tokens=self._config.max_context_tokens, + from framework.graph.conversation import extract_tool_call_history + + if _depth > self._LLM_COMPACT_MAX_DEPTH: + raise RuntimeError(f"LLM compaction recursion limit ({self._LLM_COMPACT_MAX_DEPTH})") + + formatted = self._format_messages_for_summary(messages) + + # Proactive split: avoid wasting an API call on oversized input + if len(formatted) > self._LLM_COMPACT_CHAR_LIMIT and len(messages) > 1: + summary = await self._llm_compact_split( + ctx, + messages, + accumulator, + _depth, + ) + else: + prompt = self._build_llm_compaction_prompt( + ctx, + accumulator, + formatted, + ) + summary_budget = max(1024, self._config.max_context_tokens // 2) + try: + response = await ctx.llm.acomplete( + messages=[{"role": "user", "content": prompt}], + system=( + "You are a conversation compactor for an AI agent. " + "Write a detailed summary that allows the agent to " + "continue its work. Preserve user-stated rules, " + "constraints, and account/identity preferences verbatim." + ), + max_tokens=summary_budget, + ) + summary = response.content + except Exception as e: + if _is_context_too_large_error(e) and len(messages) > 1: + logger.info( + "LLM context too large (depth=%d, msgs=%d) — splitting", + _depth, + len(messages), + ) + summary = await self._llm_compact_split( + ctx, + messages, + accumulator, + _depth, + ) + else: + raise + + # Append tool history at top level only + if _depth == 0: + tool_history = extract_tool_call_history(messages) + if tool_history and "TOOLS ALREADY CALLED" not in summary: + summary += "\n\n" + tool_history + + return summary + + async def _llm_compact_split( + self, + ctx: NodeContext, + messages: list, + accumulator: OutputAccumulator | None, + _depth: int, + ) -> str: + """Split messages in half and summarise each half independently.""" + mid = max(1, len(messages) // 2) + s1 = await self._llm_compact(ctx, messages[:mid], None, _depth + 1) + s2 = await self._llm_compact( + ctx, + messages[mid:], + accumulator, + _depth + 1, ) + return s1 + "\n\n" + s2 # --- Compaction helpers ------------------------------------------------ @staticmethod def _format_messages_for_summary(messages: list) -> str: """Format messages as text for LLM summarisation.""" - return format_messages_for_summary(messages) + lines: list[str] = [] + for m in messages: + if m.role == "tool": + content = m.content[:500] + if len(m.content) > 500: + content += "..." + lines.append(f"[tool result]: {content}") + elif m.role == "assistant" and m.tool_calls: + names = [tc.get("function", {}).get("name", "?") for tc in m.tool_calls] + text = m.content[:200] if m.content else "" + lines.append(f"[assistant (calls: {', '.join(names)})]: {text}") + else: + lines.append(f"[{m.role}]: {m.content}") + return "\n\n".join(lines) def _build_llm_compaction_prompt( self, @@ -3098,13 +4752,229 @@ def _build_llm_compaction_prompt( formatted_messages: str, ) -> str: """Build prompt for LLM compaction targeting 50% of token budget.""" - return build_llm_compaction_prompt( - ctx, - accumulator, - formatted_messages, - max_context_tokens=self._config.max_context_tokens, + spec = ctx.node_spec + ctx_lines = [f"NODE: {spec.name} (id={spec.id})"] + if spec.description: + ctx_lines.append(f"PURPOSE: {spec.description}") + if spec.success_criteria: + ctx_lines.append(f"SUCCESS CRITERIA: {spec.success_criteria}") + + if accumulator: + acc = accumulator.to_dict() + done = {k: v for k, v in acc.items() if v is not None} + todo = [k for k, v in acc.items() if v is None] + if done: + ctx_lines.append( + "OUTPUTS ALREADY SET:\n" + + "\n".join(f" {k}: {str(v)[:150]}" for k, v in done.items()) + ) + if todo: + ctx_lines.append(f"OUTPUTS STILL NEEDED: {', '.join(todo)}") + elif spec.output_keys: + ctx_lines.append(f"OUTPUTS STILL NEEDED: {', '.join(spec.output_keys)}") + + target_tokens = self._config.max_context_tokens // 2 + target_chars = target_tokens * 4 + node_ctx = "\n".join(ctx_lines) + + return ( + "You are compacting an AI agent's conversation history. " + "The agent is still working and needs to continue.\n\n" + f"AGENT CONTEXT:\n{node_ctx}\n\n" + f"CONVERSATION MESSAGES:\n{formatted_messages}\n\n" + "INSTRUCTIONS:\n" + f"Write a summary of approximately {target_chars} characters " + f"(~{target_tokens} tokens).\n" + "1. Preserve ALL user-stated rules, constraints, and preferences " + "verbatim.\n" + "2. Preserve key decisions made and results obtained.\n" + "3. Preserve in-progress work state so the agent can continue.\n" + "4. Be detailed enough that the agent can resume without " + "re-doing work.\n" + ) + + @staticmethod + def _build_message_inventory( + conversation: NodeConversation, + ) -> list[dict[str, Any]]: + """Build a per-message size inventory for debug logging.""" + inventory: list[dict[str, Any]] = [] + for m in conversation.messages: + content_chars = len(m.content) + tc_chars = 0 + tool_name = None + if m.tool_calls: + for tc in m.tool_calls: + args = tc.get("function", {}).get("arguments", "") + tc_chars += len(args) if isinstance(args, str) else len(json.dumps(args)) + names = [tc.get("function", {}).get("name", "?") for tc in m.tool_calls] + tool_name = ", ".join(names) + elif m.role == "tool" and m.tool_use_id: + for prev in conversation.messages: + if prev.tool_calls: + for tc in prev.tool_calls: + if tc.get("id") == m.tool_use_id: + tool_name = tc.get("function", {}).get("name", "?") + break + if tool_name: + break + entry: dict[str, Any] = { + "seq": m.seq, + "role": m.role, + "content_chars": content_chars, + } + if tc_chars: + entry["tool_call_args_chars"] = tc_chars + if tool_name: + entry["tool"] = tool_name + if m.is_error: + entry["is_error"] = True + if m.phase_id: + entry["phase"] = m.phase_id + if content_chars > 2000: + entry["preview"] = m.content[:200] + "…" + inventory.append(entry) + return inventory + + async def _log_compaction( + self, + ctx: NodeContext, + conversation: NodeConversation, + ratio_before: float, + pre_inventory: list[dict[str, Any]] | None = None, + ) -> None: + """Log compaction result to runtime logger, event bus, and debug file.""" + import os as _os + + ratio_after = conversation.usage_ratio() + before_pct = round(ratio_before * 100) + after_pct = round(ratio_after * 100) + + # Determine label from what happened + if after_pct >= before_pct - 1: + level = "prune_only" + elif ratio_after <= 0.6: + level = "llm" + else: + level = "structural" + + logger.info( + "Compaction complete (%s): %d%% -> %d%%", + level, + before_pct, + after_pct, ) + if ctx.runtime_logger: + ctx.runtime_logger.log_step( + node_id=ctx.node_id, + node_type="event_loop", + step_index=-1, + llm_text=f"Context compacted ({level}): {before_pct}% \u2192 {after_pct}%", + verdict="COMPACTION", + verdict_feedback=f"level={level} before={before_pct}% after={after_pct}%", + ) + + if self._event_bus: + from framework.runtime.event_bus import AgentEvent, EventType + + event_data: dict[str, Any] = { + "level": level, + "usage_before": before_pct, + "usage_after": after_pct, + } + if pre_inventory is not None: + event_data["message_inventory"] = pre_inventory + await self._event_bus.publish( + AgentEvent( + type=EventType.CONTEXT_COMPACTED, + stream_id=ctx.stream_id or ctx.node_id, + node_id=ctx.node_id, + data=event_data, + ) + ) + + # Emit post-compaction usage update + await self._publish_context_usage(ctx, conversation, "post_compaction") + + # Write detailed debug log to ~/.hive/compaction_log/ when enabled + if _os.environ.get("HIVE_COMPACTION_DEBUG"): + self._write_compaction_debug_log(ctx, before_pct, after_pct, level, pre_inventory) + + @staticmethod + def _write_compaction_debug_log( + ctx: NodeContext, + before_pct: int, + after_pct: int, + level: str, + inventory: list[dict[str, Any]] | None, + ) -> None: + """Write detailed compaction analysis to ~/.hive/compaction_log/.""" + log_dir = Path.home() / ".hive" / "compaction_log" + log_dir.mkdir(parents=True, exist_ok=True) + + ts = datetime.now(UTC).strftime("%Y%m%dT%H%M%S_%f") + node_label = ctx.node_id.replace("/", "_") + log_path = log_dir / f"{ts}_{node_label}.md" + + lines: list[str] = [ + f"# Compaction Debug — {ctx.node_id}", + f"**Time:** {datetime.now(UTC).isoformat()}", + f"**Node:** {ctx.node_spec.name} (`{ctx.node_id}`)", + ] + if ctx.stream_id: + lines.append(f"**Stream:** {ctx.stream_id}") + lines.append(f"**Level:** {level}") + lines.append(f"**Usage:** {before_pct}% → {after_pct}%") + lines.append("") + + if inventory: + total_chars = sum( + e.get("content_chars", 0) + e.get("tool_call_args_chars", 0) for e in inventory + ) + lines.append( + f"## Pre-Compaction Message Inventory " + f"({len(inventory)} messages, {total_chars:,} total chars)" + ) + lines.append("") + ranked = sorted( + inventory, + key=lambda e: e.get("content_chars", 0) + e.get("tool_call_args_chars", 0), + reverse=True, + ) + lines.append("| # | seq | role | tool | chars | % of total | flags |") + lines.append("|---|-----|------|------|------:|------------|-------|") + for i, entry in enumerate(ranked, 1): + chars = entry.get("content_chars", 0) + entry.get("tool_call_args_chars", 0) + pct = (chars / total_chars * 100) if total_chars else 0 + tool = entry.get("tool", "") + flags = [] + if entry.get("is_error"): + flags.append("error") + if entry.get("phase"): + flags.append(f"phase={entry['phase']}") + lines.append( + f"| {i} | {entry['seq']} | {entry['role']} | {tool} " + f"| {chars:,} | {pct:.1f}% | {', '.join(flags)} |" + ) + + large = [e for e in ranked if e.get("preview")] + if large: + lines.append("") + lines.append("### Large message previews") + for entry in large: + lines.append( + f"\n**seq={entry['seq']}** ({entry['role']}, {entry.get('tool', '')}):" + ) + lines.append(f"```\n{entry['preview']}\n```") + lines.append("") + + try: + log_path.write_text("\n".join(lines), encoding="utf-8") + logger.debug("Compaction debug log written to %s", log_path) + except OSError: + logger.debug("Failed to write compaction debug log to %s", log_path) + def _build_emergency_summary( self, ctx: NodeContext, @@ -3119,26 +4989,189 @@ def _build_emergency_summary( node's known state so the LLM can continue working after compaction without losing track of its task and inputs. """ - return build_emergency_summary(ctx, accumulator, conversation, self._config) + parts = [ + "EMERGENCY COMPACTION — previous conversation was too large " + "and has been replaced with this summary.\n" + ] + + # 1. Node identity + spec = ctx.node_spec + parts.append(f"NODE: {spec.name} (id={spec.id})") + if spec.description: + parts.append(f"PURPOSE: {spec.description}") + + # 2. Inputs the node received + input_lines = [] + for key in spec.input_keys: + value = ctx.input_data.get(key) or ctx.memory.read(key) + if value is not None: + # Truncate long values but keep them recognisable + v_str = str(value) + if len(v_str) > 200: + v_str = v_str[:200] + "…" + input_lines.append(f" {key}: {v_str}") + if input_lines: + parts.append("INPUTS:\n" + "\n".join(input_lines)) + + # 3. Output accumulator state (what's been set so far) + if accumulator: + acc_state = accumulator.to_dict() + set_keys = {k: v for k, v in acc_state.items() if v is not None} + missing = [k for k, v in acc_state.items() if v is None] + if set_keys: + lines = [f" {k}: {str(v)[:150]}" for k, v in set_keys.items()] + parts.append("OUTPUTS ALREADY SET:\n" + "\n".join(lines)) + if missing: + parts.append(f"OUTPUTS STILL NEEDED: {', '.join(missing)}") + elif spec.output_keys: + parts.append(f"OUTPUTS STILL NEEDED: {', '.join(spec.output_keys)}") + + # 4. Available tools reminder + if spec.tools: + parts.append(f"AVAILABLE TOOLS: {', '.join(spec.tools)}") + + # 5. Spillover files — list actual files so the LLM can load + # them immediately instead of having to call list_data_files first. + # Inline adapt.md (agent memory) directly — it contains user rules + # and identity preferences that must survive emergency compaction. + if self._config.spillover_dir: + try: + from pathlib import Path + + data_dir = Path(self._config.spillover_dir) + if data_dir.is_dir(): + # Inline adapt.md content directly + adapt_path = data_dir / "adapt.md" + if adapt_path.is_file(): + adapt_text = adapt_path.read_text(encoding="utf-8").strip() + if adapt_text: + parts.append(f"AGENT MEMORY (adapt.md):\n{adapt_text}") + + all_files = sorted( + f.name for f in data_dir.iterdir() if f.is_file() and f.name != "adapt.md" + ) + # Separate conversation history files from regular data files + conv_files = [f for f in all_files if re.match(r"conversation_\d+\.md$", f)] + data_files = [f for f in all_files if f not in conv_files] + + if conv_files: + conv_list = "\n".join( + f" - {f} (full path: {data_dir / f})" for f in conv_files + ) + parts.append( + "CONVERSATION HISTORY (freeform messages saved during compaction — " + "use load_data('') to review earlier dialogue):\n" + conv_list + ) + if data_files: + file_list = "\n".join( + f" - {f} (full path: {data_dir / f})" for f in data_files[:30] + ) + parts.append( + "DATA FILES (use load_data('') to read):\n" + file_list + ) + if not all_files: + parts.append( + "NOTE: Large tool results may have been saved to files. " + "Use list_directory to check the data directory." + ) + except Exception: + parts.append( + "NOTE: Large tool results were saved to files. " + "Use read_file(path='') to read them." + ) + + # 6. Tool call history (prevent re-calling tools) + if conversation is not None: + tool_history = self._extract_tool_call_history(conversation) + if tool_history: + parts.append(tool_history) + + parts.append( + "\nContinue working towards setting the remaining outputs. " + "Use your tools and the inputs above." + ) + return "\n\n".join(parts) # ------------------------------------------------------------------- # Persistence: restore, cursor, injection, pause # ------------------------------------------------------------------- + @dataclass + class _RestoredState: + """State recovered from a previous checkpoint.""" + + conversation: NodeConversation + accumulator: OutputAccumulator + start_iteration: int + recent_responses: list[str] + recent_tool_fingerprints: list[list[tuple[str, str]]] + recent_output_fingerprints: list[list[tuple[str, str]]] + no_progress_turns: int + async def _restore( self, ctx: NodeContext, - ) -> RestoredState | None: + ) -> _RestoredState | None: """Attempt to restore from a previous checkpoint. - Returns a ``RestoredState`` with conversation, accumulator, iteration + Returns a ``_RestoredState`` with conversation, accumulator, iteration counter, and stall/doom-loop detection state — everything needed to resume exactly where execution stopped. """ - return await restore( - conversation_store=self._conversation_store, - ctx=ctx, - config=self._config, + if self._conversation_store is None: + return None + + # In isolated mode, filter parts by phase_id so the node only sees + # its own messages in the shared flat conversation store. In + # continuous mode (or when _restore is called for timer-resume) + # load all parts — the full conversation threads across nodes. + _is_continuous = getattr(ctx, "continuous_mode", False) + phase_filter = None if _is_continuous else ctx.node_id + conversation = await NodeConversation.restore( + self._conversation_store, + phase_id=phase_filter, + ) + if conversation is None: + return None + + accumulator = await OutputAccumulator.restore(self._conversation_store) + accumulator.spillover_dir = self._config.spillover_dir + accumulator.max_value_chars = self._config.max_output_value_chars + + cursor = await self._conversation_store.read_cursor() + start_iteration = cursor.get("iteration", 0) + 1 if cursor else 0 + + # Restore stall/doom-loop detection state + recent_responses: list[str] = cursor.get("recent_responses", []) if cursor else [] + raw_fps = cursor.get("recent_tool_fingerprints", []) if cursor else [] + recent_tool_fingerprints: list[list[tuple[str, str]]] = [ + [tuple(pair) for pair in fps] # type: ignore[misc] + for fps in raw_fps + ] + raw_output_fps = cursor.get("recent_output_fingerprints", []) if cursor else [] + recent_output_fingerprints: list[list[tuple[str, str]]] = [ + [tuple(pair) for pair in fps] # type: ignore[misc] + for fps in raw_output_fps + ] + no_progress_turns = int(cursor.get("no_progress_turns", 0)) if cursor else 0 + + logger.info( + f"Restored event loop: iteration={start_iteration}, " + f"messages={conversation.message_count}, " + f"outputs={list(accumulator.values.keys())}, " + f"stall_window={len(recent_responses)}, " + f"doom_window={len(recent_tool_fingerprints)}, " + f"output_doom_window={len(recent_output_fingerprints)}, " + f"no_progress_turns={no_progress_turns}" + ) + return EventLoopNode._RestoredState( + conversation=conversation, + accumulator=accumulator, + start_iteration=start_iteration, + recent_responses=recent_responses, + recent_tool_fingerprints=recent_tool_fingerprints, + recent_output_fingerprints=recent_output_fingerprints, + no_progress_turns=no_progress_turns, ) async def _write_cursor( @@ -3150,30 +5183,67 @@ async def _write_cursor( *, recent_responses: list[str] | None = None, recent_tool_fingerprints: list[list[tuple[str, str]]] | None = None, + recent_output_fingerprints: list[list[tuple[str, str]]] | None = None, + no_progress_turns: int | None = None, ) -> None: """Write checkpoint cursor for crash recovery. Persists iteration counter, accumulator outputs, and stall/doom-loop detection state so that resume picks up exactly where execution stopped. """ - return await write_cursor( - conversation_store=self._conversation_store, - ctx=ctx, - conversation=conversation, - accumulator=accumulator, - iteration=iteration, - recent_responses=recent_responses, - recent_tool_fingerprints=recent_tool_fingerprints, - ) + if self._conversation_store: + cursor = await self._conversation_store.read_cursor() or {} + cursor.update( + { + "iteration": iteration, + "node_id": ctx.node_id, + "next_seq": conversation.next_seq, + "outputs": accumulator.to_dict(), + } + ) + # Persist stall/doom-loop detection state for reliable resume + if recent_responses is not None: + cursor["recent_responses"] = recent_responses + if recent_tool_fingerprints is not None: + # Convert list[list[tuple]] → list[list[list]] for JSON + cursor["recent_tool_fingerprints"] = [ + [list(pair) for pair in fps] for fps in recent_tool_fingerprints + ] + if recent_output_fingerprints is not None: + cursor["recent_output_fingerprints"] = [ + [list(pair) for pair in fps] for fps in recent_output_fingerprints + ] + if no_progress_turns is not None: + cursor["no_progress_turns"] = no_progress_turns + await self._conversation_store.write_cursor(cursor) - async def _drain_injection_queue(self, conversation: NodeConversation, ctx: NodeContext) -> int: + async def _drain_injection_queue(self, conversation: NodeConversation) -> int: """Drain all pending injected events as user messages. Returns count.""" - return await drain_injection_queue( - queue=self._injection_queue, - conversation=conversation, - ctx=ctx, - describe_images_as_text_fn=_describe_images_as_text, - ) + count = 0 + while not self._injection_queue.empty(): + try: + content, is_client_input, image_content = self._injection_queue.get_nowait() + logger.info( + "[drain] injected message (client_input=%s): %s", + is_client_input, + content[:200] if content else "(empty)", + ) + # Real user input is stored as-is; external events get a prefix + if is_client_input: + await conversation.add_user_message( + content, + is_client_input=True, + image_content=image_content, + ) + else: + await conversation.add_user_message( + f"[External event]: {content}", + image_content=image_content, + ) + count += 1 + except asyncio.QueueEmpty: + break + return count async def _drain_trigger_queue(self, conversation: NodeConversation) -> int: """Drain all pending trigger events as a single batched user message. @@ -3181,10 +5251,27 @@ async def _drain_trigger_queue(self, conversation: NodeConversation) -> int: Multiple triggers are merged so the LLM sees them atomically and can reason about all pending triggers before acting. """ - return await drain_trigger_queue( - queue=self._trigger_queue, - conversation=conversation, - ) + triggers: list[TriggerEvent] = [] + while not self._trigger_queue.empty(): + try: + triggers.append(self._trigger_queue.get_nowait()) + except asyncio.QueueEmpty: + break + + if not triggers: + return 0 + + parts: list[str] = [] + for t in triggers: + task = t.payload.get("task", "") + task_line = f"\nTask: {task}" if task else "" + payload_str = json.dumps(t.payload, default=str) + parts.append(f"[TRIGGER: {t.trigger_type}/{t.source_id}]{task_line}\n{payload_str}") + + combined = "\n\n".join(parts) + logger.info("[drain] %d trigger(s): %s", len(triggers), combined[:200]) + await conversation.add_user_message(combined) + return len(triggers) async def _check_pause( self, @@ -3198,11 +5285,25 @@ async def _check_pause( Note: This check happens BEFORE starting iteration N, after completing N-1. If paused, the node exits having completed {iteration} iterations (0 to iteration-1). """ - return await check_pause( - ctx=ctx, - conversation=conversation, - iteration=iteration, - ) + # Check executor-level pause event (for /pause command, Ctrl+Z) + if ctx.pause_event and ctx.pause_event.is_set(): + completed = iteration # 0-indexed: iteration=3 means 3 iterations completed (0,1,2) + logger.info(f"⏸ Pausing after {completed} iteration(s) completed (executor-level)") + return True + + # Check context-level pause flags (legacy/alternative methods) + pause_requested = ctx.input_data.get("pause_requested", False) + if not pause_requested: + try: + pause_requested = ctx.memory.read("pause_requested") or False + except (PermissionError, KeyError): + pause_requested = False + if pause_requested: + completed = iteration + logger.info(f"⏸ Pausing after {completed} iteration(s) completed (context-level)") + return True + + return False # ------------------------------------------------------------------- # EventBus publishing helpers @@ -3211,13 +5312,13 @@ async def _check_pause( async def _publish_loop_started( self, stream_id: str, node_id: str, execution_id: str = "" ) -> None: - return await publish_loop_started( - event_bus=self._event_bus, - stream_id=stream_id, - node_id=node_id, - max_iterations=self._config.max_iterations, - execution_id=execution_id, - ) + if self._event_bus: + await self._event_bus.emit_node_loop_started( + stream_id=stream_id, + node_id=node_id, + max_iterations=self._config.max_iterations, + execution_id=execution_id, + ) async def _generate_action_plan( self, @@ -3230,13 +5331,41 @@ async def _generate_action_plan( Runs as a fire-and-forget task so it never blocks the main loop. """ - return await generate_action_plan( - event_bus=self._event_bus, - ctx=ctx, - stream_id=stream_id, - node_id=node_id, - execution_id=execution_id, - ) + try: + system_prompt = ctx.node_spec.system_prompt or "" + # Trim to keep the prompt small + prompt_summary = system_prompt[:500] + if len(system_prompt) > 500: + prompt_summary += "..." + + tool_names = [t.name for t in ctx.available_tools] + output_keys = ctx.node_spec.output_keys or [] + + prompt = ( + f'You are about to work on a task as node "{node_id}".\n\n' + f"System prompt:\n{prompt_summary}\n\n" + f"Tools available: {tool_names}\n" + f"Required outputs: {output_keys}\n\n" + f"Write a brief action plan (2-5 bullet points) describing " + f"what you will do to complete this task. Be specific and concise.\n" + f"Return ONLY the plan text, no preamble." + ) + + response = await ctx.llm.acomplete( + messages=[{"role": "user", "content": prompt}], + max_tokens=1024, + ) + + plan = response.content.strip() + if plan and self._event_bus: + await self._event_bus.emit_node_action_plan( + stream_id=stream_id, + node_id=node_id, + plan=plan, + execution_id=execution_id, + ) + except Exception as e: + logger.warning("Action plan generation failed for node '%s': %s", node_id, e) async def _run_hooks( self, @@ -3252,12 +5381,30 @@ async def _run_hooks( Hooks run in registration order; each sees the prompt as left by the previous hook. """ - return await run_hooks( - hooks_config=self._config.hooks, - event=event, - conversation=conversation, - trigger=trigger, - ) + hook_list = self._config.hooks.get(event, []) + if not hook_list: + return + for hook in hook_list: + ctx = HookContext( + event=event, + trigger=trigger, + system_prompt=conversation.system_prompt, + ) + try: + result = await hook(ctx) + except Exception: + import logging + + logging.getLogger(__name__).warning( + "Hook '%s' raised an exception", event, exc_info=True + ) + continue + if result is None: + continue + if result.system_prompt: + conversation.update_system_prompt(result.system_prompt) + if result.inject: + await conversation.add_user_message(result.inject) async def _publish_context_usage( self, @@ -3266,11 +5413,27 @@ async def _publish_context_usage( trigger: str, ) -> None: """Emit a CONTEXT_USAGE_UPDATED event with current context window state.""" - return await publish_context_usage( - event_bus=self._event_bus, - ctx=ctx, - conversation=conversation, - trigger=trigger, + if not self._event_bus: + return + from framework.runtime.event_bus import AgentEvent, EventType + + estimated = conversation.estimate_tokens() + max_tokens = conversation._max_context_tokens + ratio = estimated / max_tokens if max_tokens > 0 else 0.0 + await self._event_bus.publish( + AgentEvent( + type=EventType.CONTEXT_USAGE_UPDATED, + stream_id=ctx.stream_id or ctx.node_id, + node_id=ctx.node_id, + data={ + "usage_ratio": round(ratio, 4), + "usage_pct": round(ratio * 100), + "message_count": conversation.message_count, + "estimated_tokens": estimated, + "max_context_tokens": max_tokens, + "trigger": trigger, + }, + ) ) async def _publish_iteration( @@ -3281,14 +5444,14 @@ async def _publish_iteration( execution_id: str = "", extra_data: dict | None = None, ) -> None: - return await publish_iteration( - event_bus=self._event_bus, - stream_id=stream_id, - node_id=node_id, - iteration=iteration, - execution_id=execution_id, - extra_data=extra_data, - ) + if self._event_bus: + await self._event_bus.emit_node_loop_iteration( + stream_id=stream_id, + node_id=node_id, + iteration=iteration, + execution_id=execution_id, + extra_data=extra_data, + ) async def _publish_llm_turn_complete( self, @@ -3302,18 +5465,18 @@ async def _publish_llm_turn_complete( execution_id: str = "", iteration: int | None = None, ) -> None: - return await publish_llm_turn_complete( - event_bus=self._event_bus, - stream_id=stream_id, - node_id=node_id, - stop_reason=stop_reason, - model=model, - input_tokens=input_tokens, - output_tokens=output_tokens, - cached_tokens=cached_tokens, - execution_id=execution_id, - iteration=iteration, - ) + if self._event_bus: + await self._event_bus.emit_llm_turn_complete( + stream_id=stream_id, + node_id=node_id, + stop_reason=stop_reason, + model=model, + input_tokens=input_tokens, + output_tokens=output_tokens, + cached_tokens=cached_tokens, + execution_id=execution_id, + iteration=iteration, + ) def _log_skip_judge( self, @@ -3327,35 +5490,39 @@ def _log_skip_judge( iter_start: float, ) -> None: """Log a CONTINUE step that skips judge evaluation (e.g., waiting for input).""" - return log_skip_judge( - ctx=ctx, - node_id=node_id, - iteration=iteration, - feedback=feedback, - tool_calls=tool_calls, - llm_text=llm_text, - turn_tokens=turn_tokens, - iter_start=iter_start, - ) + if ctx.runtime_logger: + ctx.runtime_logger.log_step( + node_id=node_id, + node_type="event_loop", + step_index=iteration, + verdict="CONTINUE", + verdict_feedback=feedback, + tool_calls=tool_calls, + llm_text=llm_text, + input_tokens=turn_tokens.get("input", 0), + output_tokens=turn_tokens.get("output", 0), + latency_ms=int((time.time() - iter_start) * 1000), + ) async def _publish_loop_completed( self, stream_id: str, node_id: str, iterations: int, execution_id: str = "" ) -> None: - return await publish_loop_completed( - event_bus=self._event_bus, - stream_id=stream_id, - node_id=node_id, - iterations=iterations, - execution_id=execution_id, - ) + if self._event_bus: + await self._event_bus.emit_node_loop_completed( + stream_id=stream_id, + node_id=node_id, + iterations=iterations, + execution_id=execution_id, + ) async def _publish_stalled(self, stream_id: str, node_id: str, execution_id: str = "") -> None: - return await publish_stalled( - event_bus=self._event_bus, - stream_id=stream_id, - node_id=node_id, - execution_id=execution_id, - ) + if self._event_bus: + await self._event_bus.emit_node_stalled( + stream_id=stream_id, + node_id=node_id, + reason="Consecutive similar responses detected", + execution_id=execution_id, + ) async def _publish_text_delta( self, @@ -3368,17 +5535,26 @@ async def _publish_text_delta( iteration: int | None = None, inner_turn: int = 0, ) -> None: - return await publish_text_delta( - event_bus=self._event_bus, - stream_id=stream_id, - node_id=node_id, - content=content, - snapshot=snapshot, - ctx=ctx, - execution_id=execution_id, - iteration=iteration, - inner_turn=inner_turn, - ) + if self._event_bus: + if ctx.node_spec.client_facing: + await self._event_bus.emit_client_output_delta( + stream_id=stream_id, + node_id=node_id, + content=content, + snapshot=snapshot, + execution_id=execution_id, + iteration=iteration, + inner_turn=inner_turn, + ) + else: + await self._event_bus.emit_llm_text_delta( + stream_id=stream_id, + node_id=node_id, + content=content, + snapshot=snapshot, + execution_id=execution_id, + inner_turn=inner_turn, + ) async def _publish_tool_started( self, @@ -3389,15 +5565,15 @@ async def _publish_tool_started( tool_input: dict, execution_id: str = "", ) -> None: - return await publish_tool_started( - event_bus=self._event_bus, - stream_id=stream_id, - node_id=node_id, - tool_use_id=tool_use_id, - tool_name=tool_name, - tool_input=tool_input, - execution_id=execution_id, - ) + if self._event_bus: + await self._event_bus.emit_tool_call_started( + stream_id=stream_id, + node_id=node_id, + tool_use_id=tool_use_id, + tool_name=tool_name, + tool_input=tool_input, + execution_id=execution_id, + ) async def _publish_tool_completed( self, @@ -3409,16 +5585,16 @@ async def _publish_tool_completed( is_error: bool, execution_id: str = "", ) -> None: - return await publish_tool_completed( - event_bus=self._event_bus, - stream_id=stream_id, - node_id=node_id, - tool_use_id=tool_use_id, - tool_name=tool_name, - result=result, - is_error=is_error, - execution_id=execution_id, - ) + if self._event_bus: + await self._event_bus.emit_tool_call_completed( + stream_id=stream_id, + node_id=node_id, + tool_use_id=tool_use_id, + tool_name=tool_name, + result=result, + is_error=is_error, + execution_id=execution_id, + ) async def _publish_judge_verdict( self, @@ -3430,16 +5606,16 @@ async def _publish_judge_verdict( iteration: int = 0, execution_id: str = "", ) -> None: - return await publish_judge_verdict( - event_bus=self._event_bus, - stream_id=stream_id, - node_id=node_id, - action=action, - feedback=feedback, - judge_type=judge_type, - iteration=iteration, - execution_id=execution_id, - ) + if self._event_bus: + await self._event_bus.emit_judge_verdict( + stream_id=stream_id, + node_id=node_id, + action=action, + feedback=feedback, + judge_type=judge_type, + iteration=iteration, + execution_id=execution_id, + ) async def _publish_output_key_set( self, @@ -3448,13 +5624,10 @@ async def _publish_output_key_set( key: str, execution_id: str = "", ) -> None: - return await publish_output_key_set( - event_bus=self._event_bus, - stream_id=stream_id, - node_id=node_id, - key=key, - execution_id=execution_id, - ) + if self._event_bus: + await self._event_bus.emit_output_key_set( + stream_id=stream_id, node_id=node_id, key=key, execution_id=execution_id + ) # ------------------------------------------------------------------- # Subagent Execution @@ -3490,16 +5663,341 @@ async def _execute_subagent( - data: Subagent's output (free-form JSON) - metadata: Execution metadata (success, tokens, latency) """ - return await execute_subagent( - ctx=ctx, - agent_id=agent_id, - task=task, - accumulator=accumulator, - event_bus=self._event_bus, - config=self._config, + from framework.graph.node import NodeContext, SharedMemory + + # Log subagent invocation start + logger.info( + "\n" + "=" * 60 + "\n" + "🤖 SUBAGENT INVOCATION\n" + "=" * 60 + "\n" + "Parent Node: %s\n" + "Subagent ID: %s\n" + "Task: %s\n" + "=" * 60, + ctx.node_id, + agent_id, + task[:500] + "..." if len(task) > 500 else task, + ) + + # 1. Validate agent exists in registry + if agent_id not in ctx.node_registry: + return ToolResult( + tool_use_id="", + content=json.dumps( + { + "message": f"Sub-agent '{agent_id}' not found in registry", + "data": None, + "metadata": {"agent_id": agent_id, "success": False, "error": "not_found"}, + } + ), + is_error=True, + ) + + subagent_spec = ctx.node_registry[agent_id] + + # 2. Create read-only memory snapshot + # Start with everything the parent can read from shared memory. + parent_data = ctx.memory.read_all() + + # Merge in-flight outputs from the parent's accumulator. + # set_output() writes to the accumulator but shared memory is only + # updated after the parent node completes — so the subagent would + # otherwise miss any keys the parent set before delegating. + if accumulator: + for key, value in accumulator.to_dict().items(): + if key not in parent_data: + parent_data[key] = value + + subagent_memory = SharedMemory() + for key, value in parent_data.items(): + subagent_memory.write(key, value, validate=False) + + # Allow reads for parent data AND the subagent's declared input_keys + # (input_keys may reference keys that exist but weren't in read_all, + # or keys that were just written by the accumulator). + read_keys = set(parent_data.keys()) | set(subagent_spec.input_keys or []) + scoped_memory = subagent_memory.with_permissions( + read_keys=list(read_keys), + write_keys=[], # Read-only! + ) + + # 2b. Set up report callback (one-way channel to parent / event bus) + subagent_reports: list[dict] = [] + + async def _report_callback( + message: str, + data: dict | None = None, + *, + wait_for_response: bool = False, + ) -> str | None: + subagent_reports.append({"message": message, "data": data, "timestamp": time.time()}) + if self._event_bus: + await self._event_bus.emit_subagent_report( + stream_id=ctx.node_id, + node_id=f"{ctx.node_id}:subagent:{agent_id}", + subagent_id=agent_id, + message=message, + data=data, + execution_id=ctx.execution_id, + ) + + if not wait_for_response: + return None + + if not self._event_bus: + logger.warning( + "Subagent '%s' requested user response but no event_bus available", + agent_id, + ) + return None + + # Create isolated receiver and register for input routing + import uuid + + escalation_id = f"{ctx.node_id}:escalation:{uuid.uuid4().hex[:8]}" + receiver = _EscalationReceiver() + registry = ctx.shared_node_registry + + registry[escalation_id] = receiver + try: + # Escalate to the queen instead of asking the user directly. + # The queen handles the request and injects the response via + # inject_worker_message(), which finds this receiver through + # its _awaiting_input flag. + await self._event_bus.emit_escalation_requested( + stream_id=ctx.stream_id or ctx.node_id, + node_id=escalation_id, + reason=f"Subagent report (wait_for_response) from {agent_id}", + context=message, + execution_id=ctx.execution_id, + ) + # Block until queen responds + return await receiver.wait() + finally: + registry.pop(escalation_id, None) + + # 3. Filter tools for subagent + # Use the full tool catalog (ctx.all_tools) so subagents can access tools + # that aren't in the parent node's filtered set (e.g. browser tools for a + # GCU subagent when the parent only has web_scrape/save_data). + # Falls back to ctx.available_tools if all_tools is empty (e.g. in tests). + subagent_tool_names = set(subagent_spec.tools or []) + tool_source = ctx.all_tools if ctx.all_tools else ctx.available_tools + + # GCU auto-population: GCU nodes declare tools=[] because the runner + # auto-populates them at setup time. But that expansion doesn't reach + # subagents invoked via delegate_to_sub_agent — the subagent spec still + # has the original empty list. When a GCU subagent has no declared + # tools, include all catalog tools so browser tools are available. + if subagent_spec.node_type == "gcu" and not subagent_tool_names: + subagent_tools = [t for t in tool_source if t.name != "delegate_to_sub_agent"] + else: + subagent_tools = [ + t + for t in tool_source + if t.name in subagent_tool_names and t.name != "delegate_to_sub_agent" + ] + + missing = subagent_tool_names - {t.name for t in subagent_tools} + if missing: + logger.warning( + "Subagent '%s' requested tools not found in catalog: %s", + agent_id, + sorted(missing), + ) + + logger.info( + "📦 Subagent '%s' configuration:\n" + " - System prompt: %s\n" + " - Tools available (%d): %s\n" + " - Memory keys inherited: %s", + agent_id, + (subagent_spec.system_prompt[:200] + "...") + if subagent_spec.system_prompt and len(subagent_spec.system_prompt) > 200 + else subagent_spec.system_prompt, + len(subagent_tools), + [t.name for t in subagent_tools], + list(parent_data.keys()), + ) + + # 4. Build subagent context + max_iter = min(self._config.max_iterations, 10) + subagent_ctx = NodeContext( + runtime=ctx.runtime, + node_id=f"{ctx.node_id}:subagent:{agent_id}", + node_spec=subagent_spec, + memory=scoped_memory, + input_data={"task": task, **parent_data}, + llm=ctx.llm, + available_tools=subagent_tools, + goal_context=( + f"Your specific task: {task}\n\n" + f"COMPLETION REQUIREMENTS:\n" + f"When your task is done, you MUST call set_output() " + f"for each required key: {subagent_spec.output_keys}\n" + f"Alternatively, call report_to_parent(mark_complete=true) " + f"with your findings in message/data.\n" + f"You have a maximum of {max_iter} turns to complete this task." + ), + goal=ctx.goal, + max_tokens=ctx.max_tokens, + runtime_logger=ctx.runtime_logger, + is_subagent_mode=True, # Prevents nested delegation + report_callback=_report_callback, + node_registry={}, # Empty - no nested subagents + shared_node_registry=ctx.shared_node_registry, # For escalation routing + ) + + # 5. Create and execute subagent EventLoopNode + # Derive a conversation store for the subagent from the parent's store. + # Each invocation gets a unique path so that repeated delegate calls + # (e.g. one per profile) don't restore a stale completed conversation. + self._subagent_instance_counter.setdefault(agent_id, 0) + self._subagent_instance_counter[agent_id] += 1 + subagent_instance = str(self._subagent_instance_counter[agent_id]) + + subagent_conv_store = None + if self._conversation_store is not None: + from framework.storage.conversation_store import FileConversationStore + + parent_base = getattr(self._conversation_store, "_base", None) + if parent_base is not None: + # Store subagent conversations parallel to the parent node, + # not nested inside it. e.g. conversations/{node}:subagent:{agent_id}:{instance}/ + conversations_dir = parent_base.parent # e.g. conversations/ + subagent_dir_name = f"{agent_id}-{subagent_instance}" + subagent_store_path = conversations_dir / subagent_dir_name + subagent_conv_store = FileConversationStore(base_path=subagent_store_path) + + # Derive a subagent-scoped spillover dir so large tool results + # (e.g. browser_snapshot) get written to disk instead of being + # silently truncated. Each instance gets its own directory to + # avoid file collisions between concurrent subagents. + subagent_spillover = None + if self._config.spillover_dir: + subagent_spillover = str( + Path(self._config.spillover_dir) / agent_id / subagent_instance + ) + + subagent_node = EventLoopNode( + event_bus=self._event_bus, # Subagent events visible to Queen via shared bus + judge=SubagentJudge(task=task, max_iterations=max_iter), + config=LoopConfig( + max_iterations=max_iter, # Tighter budget + max_tool_calls_per_turn=self._config.max_tool_calls_per_turn, + tool_call_overflow_margin=self._config.tool_call_overflow_margin, + max_context_tokens=self._config.max_context_tokens, + stall_detection_threshold=self._config.stall_detection_threshold, + max_tool_result_chars=self._config.max_tool_result_chars, + spillover_dir=subagent_spillover, + ), tool_executor=self._tool_executor, - conversation_store=self._conversation_store, - subagent_instance_counter=self._subagent_instance_counter, - event_loop_node_cls=type(self), - escalation_receiver_cls=_EscalationReceiver, + conversation_store=subagent_conv_store, ) + + # Inject a unique GCU browser profile for this subagent so that + # concurrent GCU subagents (run via asyncio.gather) each get their own + # isolated BrowserContext. asyncio.gather copies the current context + # for each coroutine, so the reset token is safe to call in finally. + _profile_token = None + try: + from gcu.browser.session import set_active_profile as _set_gcu_profile + + _profile_token = _set_gcu_profile(f"{agent_id}-{subagent_instance}") + except ImportError: + pass # GCU tools not installed; no-op + + try: + logger.info("🚀 Starting subagent '%s' execution...", agent_id) + start_time = time.time() + result = await subagent_node.execute(subagent_ctx) + latency_ms = int((time.time() - start_time) * 1000) + + separator = "-" * 60 + logger.info( + "\n%s\n" + "✅ SUBAGENT '%s' COMPLETED\n" + "%s\n" + "Success: %s\n" + "Latency: %dms\n" + "Tokens used: %s\n" + "Output keys: %s\n" + "%s", + separator, + agent_id, + separator, + result.success, + latency_ms, + result.tokens_used, + list(result.output.keys()) if result.output else [], + separator, + ) + + result_json = { + "message": ( + f"Sub-agent '{agent_id}' completed successfully" + if result.success + else f"Sub-agent '{agent_id}' failed: {result.error}" + ), + "data": result.output, + "reports": subagent_reports if subagent_reports else None, + "metadata": { + "agent_id": agent_id, + "success": result.success, + "tokens_used": result.tokens_used, + "latency_ms": latency_ms, + "report_count": len(subagent_reports), + }, + } + + return ToolResult( + tool_use_id="", + content=json.dumps(result_json, indent=2, default=str), + is_error=not result.success, + ) + + except Exception as e: + logger.exception( + "\n" + "!" * 60 + "\n❌ SUBAGENT '%s' FAILED\nError: %s\n" + "!" * 60, + agent_id, + str(e), + ) + result_json = { + "message": f"Sub-agent '{agent_id}' raised exception: {e}", + "data": None, + "metadata": { + "agent_id": agent_id, + "success": False, + "error": str(e), + }, + } + return ToolResult( + tool_use_id="", + content=json.dumps(result_json, indent=2), + is_error=True, + ) + finally: + # Restore the GCU profile context that was set before this subagent ran. + if _profile_token is not None: + from gcu.browser.session import _active_profile as _gcu_profile_var + + _gcu_profile_var.reset(_profile_token) + + # Stop the browser session for this subagent's profile so tabs are + # closed immediately rather than accumulating until server shutdown. + if self._tool_executor is not None: + _subagent_profile = f"{agent_id}-{subagent_instance}" + try: + _stop_use = ToolUse( + id="gcu-cleanup", + name="browser_stop", + input={"profile": _subagent_profile}, + ) + _stop_result = self._tool_executor(_stop_use) + if asyncio.iscoroutine(_stop_result) or asyncio.isfuture(_stop_result): + await _stop_result + except Exception as _gcu_exc: + logger.warning( + "GCU browser_stop failed for profile %r: %s", + _subagent_profile, + _gcu_exc, + ) diff --git a/core/framework/graph/executor.py b/core/framework/graph/executor.py index 68f4e2b116..66b6f4ff59 100644 --- a/core/framework/graph/executor.py +++ b/core/framework/graph/executor.py @@ -1474,7 +1474,22 @@ async def execute( narrative=narrative, accounts_prompt=_node_accounts, ) - continuous_conversation.update_system_prompt(new_system) + continuous_conversation.update_system_prompt( + new_system, + output_keys=list(next_spec.output_keys or []), + ) + + # Stamp the next phase before inserting the transition + # marker so the marker itself is preserved with the + # phase it introduces during compaction/restore. + continuous_conversation.set_current_phase(next_spec.id) + + transition_tool_names = set(cumulative_tool_names) + transition_tool_names.update(next_spec.tools or []) + if next_spec.output_keys: + transition_tool_names.add("set_output") + if next_spec.client_facing: + transition_tool_names.update({"ask_user", "ask_user_multiple"}) # Insert transition marker into conversation data_dir = str(self._storage_path / "data") if self._storage_path else None @@ -1482,7 +1497,7 @@ async def execute( previous_node=node_spec, next_node=next_spec, memory=memory, - cumulative_tool_names=sorted(cumulative_tool_names), + cumulative_tool_names=sorted(transition_tool_names), data_dir=data_dir, adapt_content=_adapt_text, ) @@ -1491,9 +1506,6 @@ async def execute( is_transition_marker=True, ) - # Set current phase for phase-aware compaction - continuous_conversation.set_current_phase(next_spec.id) - # Phase-boundary compaction (same flow as EventLoopNode._compact) if continuous_conversation.usage_ratio() > 0.5: await continuous_conversation.prune_old_tool_results( diff --git a/core/framework/graph/prompt_composer.py b/core/framework/graph/prompt_composer.py index 29e26914b5..427aed398f 100644 --- a/core/framework/graph/prompt_composer.py +++ b/core/framework/graph/prompt_composer.py @@ -152,8 +152,6 @@ def compose_system_prompt( accounts_prompt: str | None = None, skills_catalog_prompt: str | None = None, protocols_prompt: str | None = None, - execution_preamble: str | None = None, - node_type_preamble: str | None = None, ) -> str: """Compose the multi-layer system prompt. @@ -164,10 +162,6 @@ def compose_system_prompt( accounts_prompt: Connected accounts block (sits between identity and narrative). skills_catalog_prompt: Available skills catalog XML (Agent Skills standard). protocols_prompt: Default skill operational protocols section. - execution_preamble: EXECUTION_SCOPE_PREAMBLE for worker nodes - (prepended before focus so the LLM knows its pipeline scope). - node_type_preamble: Node-type-specific preamble, e.g. GCU browser - best-practices prompt (prepended before focus). Returns: Composed system prompt with all layers present, plus current datetime. @@ -194,15 +188,6 @@ def compose_system_prompt( if narrative: parts.append(f"\n--- Context (what has happened so far) ---\n{narrative}") - # Execution scope preamble (worker nodes — tells the LLM it is one - # step in a multi-node pipeline and should not overreach) - if execution_preamble: - parts.append(f"\n{execution_preamble}") - - # Node-type preamble (e.g. GCU browser best-practices) - if node_type_preamble: - parts.append(f"\n{node_type_preamble}") - # Layer 3: Focus (current phase directive) if focus_prompt: parts.append(f"\n--- Current Focus ---\n{focus_prompt}") @@ -320,7 +305,8 @@ def build_transition_marker( file_size = (data_path / filename).stat().st_size val_str = ( f"[Saved to '{filename}' ({file_size:,} bytes). " - f"Use load_data(filename='{filename}') to access.]" + f"Use load_data(filename='{filename}') to access from session data. " + "Do NOT open it as a workspace file or expect it in the current directory.]" ) except Exception: val_str = val_str[:300] + "..." diff --git a/core/framework/llm/codex_adapter.py b/core/framework/llm/codex_adapter.py new file mode 100644 index 0000000000..0378c6727b --- /dev/null +++ b/core/framework/llm/codex_adapter.py @@ -0,0 +1,255 @@ +"""Codex adapter for Hive's LiteLLM provider. + +Codex CLI is tool-first and event-structured: tool invocations and tool results +are emitted as explicit response items, not as plain-text workflow narration. +This adapter keeps the ChatGPT Codex backend aligned with Hive's normal +provider contract by normalizing Codex request shaping and response recovery at +the provider boundary. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +from framework.llm.codex_backend import ( + build_codex_extra_headers, + is_codex_api_base, + merge_codex_allowed_openai_params, + normalize_codex_api_base, +) +from framework.llm.provider import Tool + +if TYPE_CHECKING: + from collections.abc import Callable + + from framework.llm.litellm import LiteLLMProvider + from framework.llm.stream_events import StreamEvent + +logger = logging.getLogger(__name__) + +_CODEX_CRITICAL_TOOL_NAMES = frozenset( + { + "ask_user", + "ask_user_multiple", + "set_output", + "escalate", + "save_agent_draft", + "confirm_and_build", + "initialize_and_build_agent", + } +) +_CODEX_SYSTEM_CHUNK_CHARS = 3500 +_CODEX_SYSTEM_PREAMBLE = """# Codex Execution Contract +Follow the system sections below in order. +- Obey every CRITICAL, MUST, NEVER, and ONLY instruction exactly. +- When tools are available, emit structured tool calls instead of replying with plain-text promises. +- Do not skip required workflow boundaries or approval gates. +""" + + +class CodexResponsesAdapter: + """Normalize the ChatGPT Codex backend to Hive's standard provider semantics.""" + + def __init__(self, provider: LiteLLMProvider): + self._provider = provider + + @property + def enabled(self) -> bool: + """Return True when the provider targets the ChatGPT Codex backend.""" + return is_codex_api_base(self._provider.api_base) + + def chunk_system_prompt(self, system: str) -> list[str]: + """Break large system prompts into smaller Codex-friendly chunks.""" + normalized = system.replace("\r\n", "\n").strip() + if not normalized: + return [] + + sections: list[str] = [] + current: list[str] = [] + for line in normalized.splitlines(): + if line.startswith("#") and current: + sections.append("\n".join(current).strip()) + current = [line] + else: + current.append(line) + if current: + sections.append("\n".join(current).strip()) + + chunks: list[str] = [] + for section in sections: + if len(section) <= _CODEX_SYSTEM_CHUNK_CHARS: + chunks.append(section) + continue + + paragraphs = [ + paragraph.strip() for paragraph in section.split("\n\n") if paragraph.strip() + ] + current_chunk = "" + for paragraph in paragraphs: + candidate = paragraph if not current_chunk else f"{current_chunk}\n\n{paragraph}" + if current_chunk and len(candidate) > _CODEX_SYSTEM_CHUNK_CHARS: + chunks.append(current_chunk) + current_chunk = paragraph + else: + current_chunk = candidate + if current_chunk: + chunks.append(current_chunk) + + return chunks or [normalized] + + def build_system_messages( + self, + system: str, + *, + json_mode: bool, + ) -> list[dict[str, Any]]: + """Build Codex system messages in the tool-first format Codex CLI expects.""" + system_messages: list[dict[str, Any]] = [] + if system: + chunks = self.chunk_system_prompt(system) + if len(chunks) > 1 or len(chunks[0]) > _CODEX_SYSTEM_CHUNK_CHARS: + system_messages.append({"role": "system", "content": _CODEX_SYSTEM_PREAMBLE}) + for chunk in chunks: + system_messages.append({"role": "system", "content": chunk}) + else: + system_messages.append({"role": "system", "content": "You are a helpful assistant."}) + + if json_mode: + system_messages.append( + {"role": "system", "content": "Please respond with a valid JSON object."} + ) + return system_messages + + def derive_tool_choice( + self, + messages: list[dict[str, Any]], + tools: list[Tool] | None, + ) -> str | dict[str, Any] | None: + """Force structured tool use when Codex sees critical framework tools.""" + if not tools: + return None + + tool_names = {tool.name for tool in tools} + if not (tool_names & _CODEX_CRITICAL_TOOL_NAMES): + return None + + last_role = next( + (m.get("role") for m in reversed(messages) if m.get("role") != "system"), + None, + ) + if last_role == "assistant": + return None + return "required" + + def harden_request_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]: + """Strip unsupported params and inject the Codex backend headers.""" + cleaned = dict(kwargs) + cleaned["api_base"] = normalize_codex_api_base( + cleaned.get("api_base") or self._provider.api_base + ) + cleaned["store"] = False + cleaned["allowed_openai_params"] = merge_codex_allowed_openai_params( + cleaned.get("allowed_openai_params") + ) + cleaned.pop("max_tokens", None) + cleaned.pop("stream_options", None) + + extra_headers = dict(cleaned.get("extra_headers") or {}) + if "ChatGPT-Account-Id" not in extra_headers: + try: + from framework.runner.runner import get_codex_account_id + + account_id = get_codex_account_id() + if account_id: + extra_headers["ChatGPT-Account-Id"] = account_id + except Exception: + logger.debug("Could not populate ChatGPT-Account-Id", exc_info=True) + + cleaned["extra_headers"] = build_codex_extra_headers( + self._provider.api_key, + account_id=extra_headers.get("ChatGPT-Account-Id"), + extra_headers=extra_headers, + ) + return cleaned + + async def recover_empty_stream( + self, + kwargs: dict[str, Any], + *, + last_role: str | None, + acompletion: Callable[..., Any], + ) -> list[StreamEvent] | None: + """Try a non-stream completion when Codex returns an empty stream.""" + fallback_kwargs = dict(kwargs) + fallback_kwargs.pop("stream", None) + fallback_kwargs.pop("stream_options", None) + fallback_kwargs = self._provider._sanitize_request_kwargs(fallback_kwargs, stream=False) + + try: + response = await acompletion(**fallback_kwargs) + except Exception as exc: + logger.debug( + "[stream-recover] %s non-stream fallback after empty %s stream failed: %s", + self._provider.model, + last_role, + exc, + ) + return None + + events = self._provider._build_stream_events_from_nonstream_response(response) + if events: + logger.info( + "[stream-recover] %s recovered empty %s stream via non-stream completion", + self._provider.model, + last_role, + ) + return events + return None + + def merge_tool_call_chunk( + self, + tool_calls_acc: dict[int, dict[str, str]], + tc: Any, + last_tool_idx: int, + ) -> int: + """Merge a streamed tool-call chunk, compensating for broken bridge indexes.""" + idx = tc.index if hasattr(tc, "index") and tc.index is not None else 0 + tc_id = getattr(tc, "id", None) or "" + func = getattr(tc, "function", None) + func_name = getattr(func, "name", "") if func is not None else "" + func_args = getattr(func, "arguments", "") if func is not None else "" + + if tc_id: + existing_idx = next( + (key for key, value in tool_calls_acc.items() if value["id"] == tc_id), + None, + ) + if existing_idx is not None: + idx = existing_idx + elif idx in tool_calls_acc and tool_calls_acc[idx]["id"] not in ("", tc_id): + idx = max(tool_calls_acc.keys(), default=-1) + 1 + last_tool_idx = idx + elif func_name: + if ( + last_tool_idx in tool_calls_acc + and tool_calls_acc[last_tool_idx]["name"] + and tool_calls_acc[last_tool_idx]["name"] != func_name + and tool_calls_acc[last_tool_idx]["arguments"] + ): + idx = max(tool_calls_acc.keys(), default=-1) + 1 + last_tool_idx = idx + else: + idx = last_tool_idx if tool_calls_acc else idx + else: + idx = last_tool_idx if tool_calls_acc else idx + + if idx not in tool_calls_acc: + tool_calls_acc[idx] = {"id": "", "name": "", "arguments": ""} + if tc_id: + tool_calls_acc[idx]["id"] = tc_id + if func_name: + tool_calls_acc[idx]["name"] = func_name + if func_args: + tool_calls_acc[idx]["arguments"] += func_args + return idx diff --git a/core/framework/llm/codex_backend.py b/core/framework/llm/codex_backend.py new file mode 100644 index 0000000000..076fa536b3 --- /dev/null +++ b/core/framework/llm/codex_backend.py @@ -0,0 +1,85 @@ +"""Shared helpers for Codex's ChatGPT-backed transport. + +Codex CLI talks to the ChatGPT Codex backend, which is not the standard +platform OpenAI API. Hive keeps its normal provider contract by centralizing +the transport-specific headers and request kwargs here. +""" + +from __future__ import annotations + +from typing import Any +from urllib.parse import urlparse, urlunparse + +CODEX_API_BASE = "https://chatgpt.com/backend-api/codex" +CODEX_USER_AGENT = "CodexBar" +CODEX_ALLOWED_OPENAI_PARAMS = ("store",) +_CODEX_HOST = "chatgpt.com" +_CODEX_PATH = "/backend-api/codex" + + +def is_codex_api_base(api_base: str | None) -> bool: + """Return True when *api_base* targets the ChatGPT Codex backend.""" + if not api_base: + return False + parsed = urlparse(api_base) + path = parsed.path.rstrip("/") + return ( + parsed.scheme in {"http", "https"} + and parsed.hostname == _CODEX_HOST + and (path == _CODEX_PATH or path == f"{_CODEX_PATH}/responses") + ) + + +def normalize_codex_api_base(api_base: str | None) -> str | None: + """Normalize ChatGPT Codex backend URLs to the stable base endpoint.""" + if not api_base: + return api_base + parsed = urlparse(api_base) + path = parsed.path.rstrip("/") + if not is_codex_api_base(api_base): + return api_base.rstrip("/") + if path.endswith("/responses"): + path = path[: -len("/responses")] + normalized = parsed._replace(path=path, params="", query="", fragment="") + return urlunparse(normalized).rstrip("/") + + +def merge_codex_allowed_openai_params(params: list[str] | tuple[str, ...] | None) -> list[str]: + """Ensure Codex-required pass-through params are always present.""" + allowed = set(params or []) + allowed.update(CODEX_ALLOWED_OPENAI_PARAMS) + return sorted(allowed) + + +def build_codex_extra_headers( + api_key: str | None, + *, + account_id: str | None = None, + extra_headers: dict[str, str] | None = None, +) -> dict[str, str]: + """Build headers for the ChatGPT Codex backend.""" + headers = dict(extra_headers or {}) + if api_key: + headers.setdefault("Authorization", f"Bearer {api_key}") + headers.setdefault("User-Agent", CODEX_USER_AGENT) + if account_id: + headers.setdefault("ChatGPT-Account-Id", account_id) + return headers + + +def build_codex_litellm_kwargs( + api_key: str | None, + *, + account_id: str | None = None, + extra_headers: dict[str, str] | None = None, +) -> dict[str, Any]: + """Return the LiteLLM kwargs required by the ChatGPT Codex backend.""" + return { + "extra_headers": build_codex_extra_headers( + api_key, + account_id=account_id, + extra_headers=extra_headers, + ), + "store": False, + "allowed_openai_params": list(CODEX_ALLOWED_OPENAI_PARAMS), + } diff --git a/core/framework/llm/litellm.py b/core/framework/llm/litellm.py index 7697cdd83f..d15dbc93cf 100644 --- a/core/framework/llm/litellm.py +++ b/core/framework/llm/litellm.py @@ -28,6 +28,8 @@ RateLimitError = Exception # type: ignore[assignment, misc] from framework.config import HIVE_LLM_ENDPOINT as HIVE_API_BASE +from framework.llm.codex_adapter import CodexResponsesAdapter +from framework.llm.codex_backend import normalize_codex_api_base from framework.llm.provider import LLMProvider, LLMResponse, Tool from framework.llm.stream_events import StreamEvent @@ -512,7 +514,9 @@ def __init__( api_base = api_base.rstrip("/")[:-3] self.model = model self.api_key = api_key - self.api_base = api_base or self._default_api_base_for_model(_original_model) + self.api_base = normalize_codex_api_base( + api_base or self._default_api_base_for_model(_original_model) + ) self.extra_kwargs = kwargs # Detect Claude Code OAuth subscription by checking the api_key prefix. self._claude_code_oauth = bool(api_key and api_key.startswith("sk-ant-oat")) @@ -520,13 +524,11 @@ def __init__( # Anthropic requires a specific User-Agent for OAuth requests. eh = self.extra_kwargs.setdefault("extra_headers", {}) eh.setdefault("user-agent", CLAUDE_CODE_USER_AGENT) - # The Codex ChatGPT backend (chatgpt.com/backend-api/codex) rejects - # several standard OpenAI params: max_output_tokens, stream_options. - self._codex_backend = bool( - self.api_base and "chatgpt.com/backend-api/codex" in self.api_base - ) # Antigravity routes through a local OpenAI-compatible proxy — no patches needed. self._antigravity = bool(self.api_base and "localhost:8069" in self.api_base) + self._codex_adapter = CodexResponsesAdapter(self) + # Backward-compatible alias for existing tests/callers. + self._codex_backend = self._codex_adapter.enabled if litellm is None: raise ImportError( @@ -553,6 +555,132 @@ def _default_api_base_for_model(model: str) -> str | None: return HIVE_API_BASE return None + @staticmethod + def _normalize_codex_api_base(api_base: str | None) -> str | None: + """Normalize ChatGPT Codex backend URLs to the stable base endpoint.""" + return normalize_codex_api_base(api_base) + + def _chunk_codex_system_prompt(self, system: str) -> list[str]: + """Break large system prompts into smaller Codex-friendly chunks.""" + return self._codex_adapter.chunk_system_prompt(system) + + def _build_request_messages( + self, + messages: list[dict[str, Any]], + system: str, + *, + json_mode: bool, + ) -> list[dict[str, Any]]: + """Build request messages, including Codex-specific prompt chunking.""" + full_messages: list[dict[str, Any]] = [] + if self._claude_code_oauth: + billing = _claude_code_billing_header(messages) + full_messages.append({"role": "system", "content": billing}) + + system_messages: list[dict[str, Any]] = [] + if system: + if self._codex_backend: + system_messages.extend( + self._codex_adapter.build_system_messages(system, json_mode=json_mode) + ) + else: + sys_msg: dict[str, Any] = {"role": "system", "content": system} + if _model_supports_cache_control(self.model): + sys_msg["cache_control"] = {"type": "ephemeral"} + system_messages.append(sys_msg) + elif self._codex_backend: + system_messages.extend( + self._codex_adapter.build_system_messages("", json_mode=json_mode) + ) + + if json_mode and not self._codex_backend: + json_instruction = "Please respond with a valid JSON object." + if system_messages: + system_messages[0] = { + **system_messages[0], + "content": f"{system_messages[0]['content']}\n\n{json_instruction}", + } + else: + system_messages.append({"role": "system", "content": json_instruction}) + + full_messages.extend(system_messages) + full_messages.extend(messages) + + return [ + m + for m in full_messages + if not ( + m.get("role") == "assistant" and not m.get("content") and not m.get("tool_calls") + ) + ] + + def _derive_codex_tool_choice( + self, + messages: list[dict[str, Any]], + tools: list[Tool] | None, + ) -> str | dict[str, Any] | None: + """Force tool use for Codex when critical framework tools are available.""" + if not self._codex_backend: + return None + return self._codex_adapter.derive_tool_choice(messages, tools) + + def _sanitize_request_kwargs( + self, + kwargs: dict[str, Any], + *, + stream: bool, + ) -> dict[str, Any]: + """Normalize provider kwargs, with extra hardening for Codex.""" + cleaned = dict(kwargs) + if cleaned.get("metadata") is None: + cleaned.pop("metadata", None) + + if self._codex_backend: + cleaned = self._codex_adapter.harden_request_kwargs(cleaned) + + if stream: + cleaned["stream"] = True + return cleaned + + def _build_completion_kwargs( + self, + messages: list[dict[str, Any]], + system: str, + *, + tools: list[Tool] | None, + max_tokens: int, + response_format: dict[str, Any] | None, + json_mode: bool, + stream: bool, + ) -> dict[str, Any]: + """Build request kwargs for completion/stream calls.""" + full_messages = self._build_request_messages(messages, system, json_mode=json_mode) + kwargs: dict[str, Any] = { + "model": self.model, + "messages": full_messages, + **self.extra_kwargs, + } + if not stream: + kwargs["max_tokens"] = max_tokens + else: + kwargs["max_tokens"] = max_tokens + if not self._is_anthropic_model(): + kwargs["stream_options"] = {"include_usage": True} + + if self.api_key: + kwargs["api_key"] = self.api_key + if self.api_base: + kwargs["api_base"] = self.api_base + if tools: + kwargs["tools"] = [self._tool_to_openai_format(t) for t in tools] + tool_choice = self._derive_codex_tool_choice(full_messages, tools) + if tool_choice is not None: + kwargs["tool_choice"] = tool_choice + if response_format: + kwargs["response_format"] = response_format + + return self._sanitize_request_kwargs(kwargs, stream=stream) + def _completion_with_rate_limit_retry( self, max_retries: int | None = None, **kwargs: Any ) -> Any: @@ -691,42 +819,15 @@ def complete( ) ) - # Prepare messages with system prompt - full_messages = [] - if system: - full_messages.append({"role": "system", "content": system}) - full_messages.extend(messages) - - # Add JSON mode via prompt engineering (works across all providers) - if json_mode: - json_instruction = "\n\nPlease respond with a valid JSON object." - # Append to system message if present, otherwise add as system message - if full_messages and full_messages[0]["role"] == "system": - full_messages[0]["content"] += json_instruction - else: - full_messages.insert(0, {"role": "system", "content": json_instruction.strip()}) - - # Build kwargs - kwargs: dict[str, Any] = { - "model": self.model, - "messages": full_messages, - "max_tokens": max_tokens, - **self.extra_kwargs, - } - - if self.api_key: - kwargs["api_key"] = self.api_key - if self.api_base: - kwargs["api_base"] = self.api_base - - # Add tools if provided - if tools: - kwargs["tools"] = [self._tool_to_openai_format(t) for t in tools] - - # Add response_format for structured output - # LiteLLM passes this through to the underlying provider - if response_format: - kwargs["response_format"] = response_format + kwargs = self._build_completion_kwargs( + messages, + system, + tools=tools, + max_tokens=max_tokens, + response_format=response_format, + json_mode=json_mode, + stream=False, + ) # Make the call response = self._completion_with_rate_limit_retry(max_retries=max_retries, **kwargs) @@ -887,40 +988,15 @@ async def acomplete( json_mode=json_mode, ) return await self._collect_stream_to_response(stream_iter) - - full_messages: list[dict[str, Any]] = [] - if self._claude_code_oauth: - billing = _claude_code_billing_header(messages) - full_messages.append({"role": "system", "content": billing}) - if system: - sys_msg: dict[str, Any] = {"role": "system", "content": system} - if _model_supports_cache_control(self.model): - sys_msg["cache_control"] = {"type": "ephemeral"} - full_messages.append(sys_msg) - full_messages.extend(messages) - - if json_mode: - json_instruction = "\n\nPlease respond with a valid JSON object." - if full_messages and full_messages[0]["role"] == "system": - full_messages[0]["content"] += json_instruction - else: - full_messages.insert(0, {"role": "system", "content": json_instruction.strip()}) - - kwargs: dict[str, Any] = { - "model": self.model, - "messages": full_messages, - "max_tokens": max_tokens, - **self.extra_kwargs, - } - - if self.api_key: - kwargs["api_key"] = self.api_key - if self.api_base: - kwargs["api_base"] = self.api_base - if tools: - kwargs["tools"] = [self._tool_to_openai_format(t) for t in tools] - if response_format: - kwargs["response_format"] = response_format + kwargs = self._build_completion_kwargs( + messages, + system, + tools=tools, + max_tokens=max_tokens, + response_format=response_format, + json_mode=json_mode, + stream=False, + ) response = await self._acompletion_with_rate_limit_retry(max_retries=max_retries, **kwargs) @@ -1170,17 +1246,92 @@ def _repair_truncated_tool_arguments(self, raw_arguments: str) -> dict[str, Any] return parsed return None + @staticmethod + def _normalize_pythonish_tool_arguments(raw_arguments: str) -> str: + """Convert common JSON-like literals into a form ast.literal_eval can parse.""" + replacements = { + "true": "True", + "false": "False", + "null": "None", + } + out: list[str] = [] + token: list[str] = [] + in_string = False + string_quote = "" + escaped = False + + def flush_token() -> None: + if not token: + return + word = "".join(token) + out.append(replacements.get(word, word)) + token.clear() + + for char in raw_arguments: + if in_string: + out.append(char) + if escaped: + escaped = False + elif char == "\\": + escaped = True + elif char == string_quote: + in_string = False + continue + + if char in {'"', "'"}: + flush_token() + in_string = True + string_quote = char + out.append(char) + continue + + if char.isalpha(): + token.append(char) + continue + + flush_token() + out.append(char) + + flush_token() + return "".join(out) + + @staticmethod + def _strip_tool_argument_fence(raw_arguments: str) -> str: + """Remove surrounding fenced-code markers from streamed tool arguments.""" + stripped = raw_arguments.strip() + if not stripped.startswith("```") or not stripped.endswith("```"): + return stripped + + lines = stripped.splitlines() + if len(lines) >= 2: + return "\n".join(lines[1:-1]).strip() + return stripped.strip("`").strip() + + def _parse_pythonish_tool_arguments(self, raw_arguments: str) -> dict[str, Any] | None: + """Parse single-quoted / trailing-comma argument payloads safely.""" + stripped = self._strip_tool_argument_fence(raw_arguments) + if not stripped or stripped[0] != "{": + return None + candidate = self._close_truncated_json_fragment(stripped) + candidate = self._normalize_pythonish_tool_arguments(candidate) + try: + parsed = ast.literal_eval(candidate) + except (SyntaxError, ValueError): + return None + return parsed if isinstance(parsed, dict) else None + def _parse_tool_call_arguments(self, raw_arguments: str, tool_name: str) -> dict[str, Any]: """Parse streamed tool arguments, repairing truncation when possible.""" + stripped = self._strip_tool_argument_fence(raw_arguments) try: - parsed = json.loads(raw_arguments) if raw_arguments else {} + parsed = json.loads(stripped) if stripped else {} except json.JSONDecodeError: parsed = None if isinstance(parsed, dict): return parsed - repaired = self._repair_truncated_tool_arguments(raw_arguments) + repaired = self._repair_truncated_tool_arguments(stripped) if repaired is not None: logger.warning( "[tool-args] Recovered truncated arguments for %s on %s", @@ -1189,6 +1340,15 @@ def _parse_tool_call_arguments(self, raw_arguments: str, tool_name: str) -> dict ) return repaired + pythonish = self._parse_pythonish_tool_arguments(stripped) + if pythonish is not None: + logger.warning( + "[tool-args] Recovered malformed arguments for %s on %s", + tool_name, + self.model, + ) + return pythonish + raise ValueError( f"Failed to parse tool call arguments for '{tool_name}' (likely truncated JSON)." ) @@ -1516,6 +1676,139 @@ async def _stream_via_nonstream_completion( model=response.model, ) + def _build_stream_events_from_nonstream_response( + self, + response: Any, + ) -> list[StreamEvent]: + """Convert a non-stream completion response into stream events.""" + from framework.llm.stream_events import ( + FinishEvent, + TextDeltaEvent, + TextEndEvent, + ToolCallEvent, + ) + + choices = getattr(response, "choices", None) or [] + if not choices: + output_text = getattr(response, "output_text", "") or "" + if not output_text: + return [] + from framework.llm.stream_events import FinishEvent, TextDeltaEvent, TextEndEvent + + usage = getattr(response, "usage", None) + return [ + TextDeltaEvent(content=output_text, snapshot=output_text), + TextEndEvent(full_text=output_text), + FinishEvent( + stop_reason="stop", + input_tokens=getattr(usage, "prompt_tokens", 0) or 0 if usage else 0, + output_tokens=getattr(usage, "completion_tokens", 0) or 0 if usage else 0, + model=getattr(response, "model", None) or self.model, + ), + ] + + choice = choices[0] + message = getattr(choice, "message", None) + content = self._extract_message_text(message) + tool_calls = getattr(message, "tool_calls", None) or [] + + events: list[StreamEvent] = [] + for tc in tool_calls: + parsed_args = self._coerce_tool_input( + tc.function.arguments if tc.function else {}, + tc.function.name if tc.function else "", + ) + events.append( + ToolCallEvent( + tool_use_id=getattr(tc, "id", ""), + tool_name=tc.function.name if tc.function else "", + tool_input=parsed_args, + ) + ) + + if content: + events.append(TextDeltaEvent(content=content, snapshot=content)) + events.append(TextEndEvent(full_text=content)) + + usage = getattr(response, "usage", None) + input_tokens = getattr(usage, "prompt_tokens", 0) or 0 if usage else 0 + output_tokens = getattr(usage, "completion_tokens", 0) or 0 if usage else 0 + cached_tokens = 0 + if usage: + details = getattr(usage, "prompt_tokens_details", None) + cached_tokens = ( + getattr(details, "cached_tokens", 0) or 0 + if details is not None + else getattr(usage, "cache_read_input_tokens", 0) or 0 + ) + + events.append( + FinishEvent( + stop_reason=getattr(choice, "finish_reason", None) + or ("tool_calls" if tool_calls else "stop"), + input_tokens=input_tokens, + output_tokens=output_tokens, + cached_tokens=cached_tokens, + model=getattr(response, "model", None) or self.model, + ) + ) + return events + + @staticmethod + def _extract_message_text(message: Any) -> str: + """Extract text from a provider message object across response shapes.""" + if message is None: + return "" + content = getattr(message, "content", "") + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for block in content: + if isinstance(block, str): + parts.append(block) + elif isinstance(block, dict): + text = block.get("text") or block.get("content") or "" + if isinstance(text, str): + parts.append(text) + else: + text = getattr(block, "text", "") or getattr(block, "content", "") + if isinstance(text, str): + parts.append(text) + return "".join(parts) + return str(content or "") + + def _coerce_tool_input(self, raw_arguments: Any, tool_name: str) -> dict[str, Any]: + """Normalize raw tool-call arguments from either string or object forms.""" + if isinstance(raw_arguments, dict): + return raw_arguments + if raw_arguments in (None, ""): + return {} + return self._parse_tool_call_arguments(str(raw_arguments), tool_name) + + async def _recover_empty_codex_stream( + self, + kwargs: dict[str, Any], + last_role: str | None, + ) -> list[StreamEvent] | None: + """Try a non-stream completion when Codex returns an empty stream.""" + if not self._codex_backend: + return None + return await self._codex_adapter.recover_empty_stream( + kwargs, + last_role=last_role, + acompletion=litellm.acompletion, # type: ignore[union-attr] + ) + + def _merge_tool_call_chunk( + self, + tool_calls_acc: dict[int, dict[str, str]], + tc: Any, + last_tool_idx: int, + ) -> int: + """Merge a streamed tool-call chunk, compensating for broken Codex indexes.""" + return self._codex_adapter.merge_tool_call_chunk(tool_calls_acc, tc, last_tool_idx) + async def stream( self, messages: list[dict[str, Any]], @@ -1567,65 +1860,16 @@ async def stream( yield event return - full_messages: list[dict[str, Any]] = [] - if self._claude_code_oauth: - billing = _claude_code_billing_header(messages) - full_messages.append({"role": "system", "content": billing}) - if system: - sys_msg: dict[str, Any] = {"role": "system", "content": system} - if _model_supports_cache_control(self.model): - sys_msg["cache_control"] = {"type": "ephemeral"} - full_messages.append(sys_msg) - full_messages.extend(messages) - - # Codex Responses API requires an `instructions` field (system prompt). - # Inject a minimal one when callers don't provide a system message. - if self._codex_backend and not any(m["role"] == "system" for m in full_messages): - full_messages.insert(0, {"role": "system", "content": "You are a helpful assistant."}) - - # Add JSON mode via prompt engineering (works across all providers) - if json_mode: - json_instruction = "\n\nPlease respond with a valid JSON object." - if full_messages and full_messages[0]["role"] == "system": - full_messages[0]["content"] += json_instruction - else: - full_messages.insert(0, {"role": "system", "content": json_instruction.strip()}) - - # Remove ghost empty assistant messages (content="" and no tool_calls). - # These arise when a model returns an empty stream after a tool result - # (an "expected" no-op turn). Keeping them in history confuses some - # models (notably Codex/gpt-5.3) and causes cascading empty streams. - full_messages = [ - m - for m in full_messages - if not ( - m.get("role") == "assistant" and not m.get("content") and not m.get("tool_calls") - ) - ] - - kwargs: dict[str, Any] = { - "model": self.model, - "messages": full_messages, - "max_tokens": max_tokens, - "stream": True, - **self.extra_kwargs, - } - # stream_options is OpenAI-specific; Anthropic rejects it with 400. - # Only include it for providers that support it. - if not self._is_anthropic_model(): - kwargs["stream_options"] = {"include_usage": True} - if self.api_key: - kwargs["api_key"] = self.api_key - if self.api_base: - kwargs["api_base"] = self.api_base - if tools: - kwargs["tools"] = [self._tool_to_openai_format(t) for t in tools] - if response_format: - kwargs["response_format"] = response_format - # The Codex ChatGPT backend (Responses API) rejects several params. - if self._codex_backend: - kwargs.pop("max_tokens", None) - kwargs.pop("stream_options", None) + kwargs = self._build_completion_kwargs( + messages, + system, + tools=tools, + max_tokens=max_tokens, + response_format=response_format, + json_mode=json_mode, + stream=True, + ) + full_messages = kwargs["messages"] for attempt in range(RATE_LIMIT_MAX_RETRIES + 1): # Post-stream events (ToolCall, TextEnd, Finish) are buffered @@ -1683,43 +1927,17 @@ async def stream( # argument deltas that arrive with id=None. if delta and delta.tool_calls: for tc in delta.tool_calls: - idx = tc.index if hasattr(tc, "index") and tc.index is not None else 0 - - if tc.id: - # New tool call announced (or done event re-sent). - # Check if this id already has a slot. - existing_idx = next( - (k for k, v in tool_calls_acc.items() if v["id"] == tc.id), - None, - ) - if existing_idx is not None: - idx = existing_idx - elif idx in tool_calls_acc and tool_calls_acc[idx]["id"] not in ( - "", - tc.id, - ): - # Slot taken by a different call — assign new index - idx = max(tool_calls_acc.keys()) + 1 - _last_tool_idx = idx - else: - # Argument delta with no id — route to last opened slot - idx = _last_tool_idx - - if idx not in tool_calls_acc: - tool_calls_acc[idx] = {"id": "", "name": "", "arguments": ""} - if tc.id: - tool_calls_acc[idx]["id"] = tc.id - if tc.function: - if tc.function.name: - tool_calls_acc[idx]["name"] = tc.function.name - if tc.function.arguments: - tool_calls_acc[idx]["arguments"] += tc.function.arguments + _last_tool_idx = self._merge_tool_call_chunk( + tool_calls_acc, + tc, + _last_tool_idx, + ) # --- Finish --- if choice.finish_reason: stream_finish_reason = choice.finish_reason for _idx, tc_data in sorted(tool_calls_acc.items()): - parsed_args = self._parse_tool_call_arguments( + parsed_args = self._coerce_tool_input( tc_data.get("arguments", ""), tc_data.get("name", ""), ) @@ -1852,6 +2070,11 @@ async def stream( (m["role"] for m in reversed(full_messages) if m.get("role") != "system"), None, ) + recovered_events = await self._recover_empty_codex_stream(kwargs, last_role) + if recovered_events: + for event in recovered_events: + yield event + return if attempt < EMPTY_STREAM_MAX_RETRIES: token_count, token_method = _estimate_tokens( self.model, diff --git a/core/framework/runner/runner.py b/core/framework/runner/runner.py index 901fd3605d..73a9fa8add 100644 --- a/core/framework/runner/runner.py +++ b/core/framework/runner/runner.py @@ -22,6 +22,7 @@ ) from framework.graph.executor import ExecutionResult from framework.graph.node import NodeSpec +from framework.llm.codex_backend import CODEX_API_BASE, build_codex_litellm_kwargs from framework.llm.provider import LLMProvider, Tool from framework.runner.preload_validation import run_preload_validation from framework.runner.tool_registry import ToolRegistry @@ -327,17 +328,68 @@ def _read_codex_auth_file() -> dict | None: return None +def _get_jwt_claims(token: str) -> dict | None: + """Decode JWT claims without verification for local expiry/account inspection.""" + import base64 + + try: + parts = token.split(".") + if len(parts) != 3: + return None + payload = parts[1] + padding = 4 - len(payload) % 4 + if padding != 4: + payload += "=" * padding + decoded = base64.urlsafe_b64decode(payload) + claims = json.loads(decoded) + return claims if isinstance(claims, dict) else None + except Exception: + return None + + +def _get_codex_token_expiry(auth_data: dict) -> float | None: + """Return the best-known expiry timestamp for a Codex access token.""" + from datetime import datetime + + tokens = auth_data.get("tokens", {}) + access_token = tokens.get("access_token") + explicit = ( + auth_data.get("expires_at") + or auth_data.get("expiresAt") + or tokens.get("expires_at") + or tokens.get("expiresAt") + ) + if isinstance(explicit, (int, float)): + return float(explicit) + if isinstance(explicit, str): + try: + return datetime.fromisoformat(explicit.replace("Z", "+00:00")).timestamp() + except (ValueError, TypeError): + pass + + if isinstance(access_token, str): + claims = _get_jwt_claims(access_token) or {} + exp = claims.get("exp") + if isinstance(exp, (int, float)): + return float(exp) + return None + + def _is_codex_token_expired(auth_data: dict) -> bool: """Check whether the Codex token is expired or close to expiry. The Codex auth.json has no explicit ``expiresAt`` field, so we infer expiry as ``last_refresh + _CODEX_TOKEN_LIFETIME_SECS``. Falls back - to the file mtime when ``last_refresh`` is absent. + to JWT ``exp`` or file age heuristics when no explicit timestamp exists. """ import time from datetime import datetime now = time.time() + explicit_expiry = _get_codex_token_expiry(auth_data) + if explicit_expiry is not None: + return now >= (explicit_expiry - _TOKEN_REFRESH_BUFFER_SECS) + last_refresh = auth_data.get("last_refresh") if last_refresh is None: @@ -431,6 +483,8 @@ def get_codex_token() -> str | None: Returns: The access token if available, None otherwise. """ + import time + # Try Keychain first, then file auth_data = _read_codex_keychain() or _read_codex_auth_file() if not auth_data: @@ -441,15 +495,20 @@ def get_codex_token() -> str | None: if not access_token: return None + explicit_expiry = _get_codex_token_expiry(auth_data) + is_expired = _is_codex_token_expired(auth_data) + # Check if token is still valid - if not _is_codex_token_expired(auth_data): + if not is_expired: return access_token # Token is expired or near expiry — attempt refresh refresh_token = tokens.get("refresh_token") if not refresh_token: logger.warning("Codex token expired and no refresh token available") - return access_token # Return expired token; it may still work briefly + if explicit_expiry is not None and time.time() >= explicit_expiry: + return None + return access_token logger.info("Codex token expired or near expiry, refreshing...") token_data = _refresh_codex_token(refresh_token) @@ -460,6 +519,8 @@ def get_codex_token() -> str | None: # Refresh failed — return the existing token and warn logger.warning("Codex token refresh failed. Run 'codex' to re-authenticate.") + if explicit_expiry is not None and time.time() >= explicit_expiry: + return None return access_token @@ -471,26 +532,12 @@ def _get_account_id_from_jwt(access_token: str) -> str | None: This is used as a fallback when the auth.json doesn't store the account_id explicitly. """ - import base64 - - try: - parts = access_token.split(".") - if len(parts) != 3: - return None - payload = parts[1] - # Add base64 padding - padding = 4 - len(payload) % 4 - if padding != 4: - payload += "=" * padding - decoded = base64.urlsafe_b64decode(payload) - claims = json.loads(decoded) - auth = claims.get("https://api.openai.com/auth") - if isinstance(auth, dict): - account_id = auth.get("chatgpt_account_id") - if isinstance(account_id, str) and account_id: - return account_id - except Exception: - pass + claims = _get_jwt_claims(access_token) or {} + auth = claims.get("https://api.openai.com/auth") + if isinstance(auth, dict): + account_id = auth.get("chatgpt_account_id") + if isinstance(account_id, str) and account_id: + return account_id return None @@ -1558,20 +1605,20 @@ def _setup(self, event_bus=None) -> None: # OpenAI Codex subscription routes through the ChatGPT backend # (chatgpt.com/backend-api/codex/responses), NOT the standard # OpenAI API. The consumer OAuth token lacks platform API scopes. - extra_headers: dict[str, str] = { - "Authorization": f"Bearer {api_key}", - "User-Agent": "CodexBar", - } account_id = get_codex_account_id() - if account_id: - extra_headers["ChatGPT-Account-Id"] = account_id self._llm = LiteLLMProvider( model=self.model, api_key=api_key, - api_base="https://chatgpt.com/backend-api/codex", - extra_headers=extra_headers, - store=False, - allowed_openai_params=["store"], + api_base=CODEX_API_BASE, + **build_codex_litellm_kwargs(api_key, account_id=account_id), + ) + elif api_key and use_kimi_code: + # Kimi Code subscription uses the Kimi coding API (OpenAI-compatible). + # The api_base is set automatically by LiteLLMProvider for kimi/ models. + self._llm = LiteLLMProvider( + model=self.model, + api_key=api_key, + api_base=api_base, ) elif api_key and use_kimi_code: # Kimi Code subscription uses the Kimi coding API (OpenAI-compatible). diff --git a/core/framework/runtime/event_bus.py b/core/framework/runtime/event_bus.py index 90aa186aad..d9e22f7847 100644 --- a/core/framework/runtime/event_bus.py +++ b/core/framework/runtime/event_bus.py @@ -535,8 +535,8 @@ async def run_handler(handler: EventHandler) -> None: async with self._semaphore: try: await handler(event) - except Exception: - logger.exception(f"Handler error for {event.type}") + except Exception as e: + logger.error(f"Handler error for {event.type}: {e}") # Run all handlers concurrently await asyncio.gather(*[run_handler(h) for h in handlers], return_exceptions=True) @@ -901,6 +901,9 @@ async def emit_client_input_requested( execution_id: str | None = None, options: list[str] | None = None, questions: list[dict] | None = None, + auto_blocked: bool = False, + assistant_text_present: bool = False, + assistant_text_requires_input: bool = False, ) -> None: """Emit client input requested event (client_facing=True nodes). @@ -917,6 +920,12 @@ async def emit_client_input_requested( data["options"] = options if questions: data["questions"] = questions + if auto_blocked: + data["auto_blocked"] = True + if assistant_text_present: + data["assistant_text_present"] = True + if assistant_text_requires_input: + data["assistant_text_requires_input"] = True await self.publish( AgentEvent( type=EventType.CLIENT_INPUT_REQUESTED, diff --git a/core/framework/runtime/execution_stream.py b/core/framework/runtime/execution_stream.py index cb6852df37..96d2467be9 100644 --- a/core/framework/runtime/execution_stream.py +++ b/core/framework/runtime/execution_stream.py @@ -446,7 +446,9 @@ async def inject_input( node = executor.node_registry.get(node_id) if node is not None and hasattr(node, "inject_event"): await node.inject_event( - content, is_client_input=is_client_input, image_content=image_content + content, + is_client_input=is_client_input, + image_content=image_content, ) return True return False @@ -1019,6 +1021,22 @@ async def _write_session_state( else: status = SessionStatus.ACTIVE + persisted_input_data = dict(ctx.input_data or {}) + entry_node_id = getattr(self.entry_spec, "entry_node", None) or getattr( + self.graph, "entry_node", None + ) + entry_input_keys: list[str] = [] + if entry_node_id and hasattr(self.graph, "get_node"): + entry_node = self.graph.get_node(entry_node_id) + entry_input_keys = list(getattr(entry_node, "input_keys", []) or []) + + if result and isinstance(result.output, dict): + for key in entry_input_keys: + if persisted_input_data.get(key) in (None, ""): + value = result.output.get(key) + if value not in (None, ""): + persisted_input_data[key] = value + # Create SessionState if result: # Create from execution result @@ -1029,7 +1047,7 @@ async def _write_session_state( stream_id=self.stream_id, correlation_id=ctx.correlation_id, started_at=ctx.started_at.isoformat(), - input_data=ctx.input_data, + input_data=persisted_input_data, agent_id=self.graph.id, entry_point=self.entry_spec.id, ) @@ -1064,7 +1082,7 @@ async def _write_session_state( ), progress=progress, memory=ss.get("memory", {}), - input_data=ctx.input_data, + input_data=persisted_input_data, ) # Handle error case diff --git a/core/framework/server/app.py b/core/framework/server/app.py index fe698aa37b..6205764888 100644 --- a/core/framework/server/app.py +++ b/core/framework/server/app.py @@ -48,7 +48,20 @@ def validate_agent_path(agent_path: str | Path) -> Path: Raises: ValueError: If the path is outside all allowed roots. """ - resolved = Path(agent_path).expanduser().resolve() + raw_path = str(agent_path).strip() + if not raw_path: + raise ValueError( + "agent_path must be inside an allowed directory " + "(exports/, examples/, or ~/.hive/agents/)" + ) + + candidate = Path(agent_path).expanduser() + if not candidate.is_absolute(): + # Resolve relative paths from the repository root so server-side + # validation is independent of the process working directory. + candidate = _REPO_ROOT / candidate + + resolved = candidate.resolve() for root in _get_allowed_agent_roots(): if resolved.is_relative_to(root) and resolved != root: return resolved @@ -281,6 +294,8 @@ def _setup_static_serving(app: web.Application) -> None: async def handle_spa(request: web.Request) -> web.FileResponse: """Serve static files with SPA fallback to index.html.""" rel_path = request.match_info.get("path", "") + if rel_path == "api" or rel_path.startswith("api/"): + raise web.HTTPNotFound() file_path = (dist_dir / rel_path).resolve() if file_path.is_file() and file_path.is_relative_to(dist_dir): diff --git a/core/framework/server/queen_orchestrator.py b/core/framework/server/queen_orchestrator.py index 4ea1a9db82..66da5e3a73 100644 --- a/core/framework/server/queen_orchestrator.py +++ b/core/framework/server/queen_orchestrator.py @@ -7,6 +7,7 @@ from __future__ import annotations import asyncio +import json import logging from pathlib import Path from typing import TYPE_CHECKING, Any @@ -16,6 +17,185 @@ logger = logging.getLogger(__name__) +_PRIMARY_RESULT_KEYS = ( + "result", + "answer", + "final_answer", + "final_result", + "final_output", + "response", + "summary", + "report", +) +_NON_PRIMARY_KEY_TOKENS = ( + "task", + "request", + "prompt", + "input", + "brief", + "context", + "plan", + "step", + "given", + "target", + "relationship", + "symbolic", + "metadata", + "artifact", + "file", + "path", + "url", + "link", +) +_ARTIFACT_SUFFIXES = { + ".csv", + ".doc", + ".docx", + ".gif", + ".html", + ".jpeg", + ".jpg", + ".json", + ".md", + ".pdf", + ".png", + ".ppt", + ".pptx", + ".svg", + ".txt", + ".xlsx", + ".xml", + ".yaml", + ".yml", +} +_PRIMARY_RESULT_CHAR_LIMIT = 8000 + + +def _stringify_worker_output_value(value: Any) -> str: + """Convert worker output values to user-presentable text.""" + if value is None: + return "" + if isinstance(value, str): + return value + if isinstance(value, (dict, list)): + try: + return json.dumps(value, ensure_ascii=False, indent=2) + except TypeError: + return str(value) + return str(value) + + +def _looks_like_artifact_reference(text: str) -> bool: + stripped = text.strip() + if not stripped: + return True + if stripped.startswith("[Saved to "): + return True + if "\n" in stripped: + return False + if stripped.startswith(("/", "./", "../")): + return True + return Path(stripped).suffix.lower() in _ARTIFACT_SUFFIXES + + +def _is_non_primary_key(key: str) -> bool: + lowered = key.lower() + return any(token in lowered for token in _NON_PRIMARY_KEY_TOKENS) + + +def _select_primary_worker_result(output: dict[str, Any]) -> tuple[str, str] | None: + """Pick the best worker output to relay verbatim, if one exists.""" + if not output: + return None + + normalized: dict[str, str] = {} + for key, value in output.items(): + text = _stringify_worker_output_value(value).strip() + if text: + normalized[key] = text + + for key in _PRIMARY_RESULT_KEYS: + text = normalized.get(key, "") + if text and not _looks_like_artifact_reference(text): + return key, text + + candidates: list[tuple[str, str]] = [] + for key, text in normalized.items(): + if _is_non_primary_key(key) or _looks_like_artifact_reference(text): + continue + candidates.append((key, text)) + + if len(candidates) == 1: + return candidates[0] + + for key, text in candidates: + lowered = key.lower() + if any(token in lowered for token in ("result", "answer", "summary", "report", "response")): + return key, text + return None + + +def _format_worker_output_summary(output: dict[str, Any]) -> str: + """Build a concise summary of worker output keys for queen handoff.""" + if not output: + return " (no output keys set)" + + lines: list[str] = [] + for key, value in output.items(): + preview = _stringify_worker_output_value(value).strip() or "(empty)" + if len(preview) > 200: + preview = preview[:200] + "..." + lines.append(f" {key}: {preview}") + return "\n".join(lines) + + +def _build_worker_terminal_notification(output: dict[str, Any]) -> str: + """Format a worker-completed notification for the queen.""" + summary = _format_worker_output_summary(output) + primary = _select_primary_worker_result(output) + if primary is None: + return ( + "[WORKER_TERMINAL] Worker finished successfully.\n" + f"Output summary:\n{summary}\n" + "Report this to the user. Ask if they want to continue with another run." + ) + + key, text = primary + if len(text) > _PRIMARY_RESULT_CHAR_LIMIT: + text = text[:_PRIMARY_RESULT_CHAR_LIMIT].rstrip() + "\n...[truncated]" + return ( + "[WORKER_TERMINAL] Worker finished successfully.\n" + f"Output summary:\n{summary}\n" + f"Primary result key: {key}\n" + "[PRIMARY_RESULT_BEGIN]\n" + f"{text}\n" + "[PRIMARY_RESULT_END]\n" + "Show the PRIMARY_RESULT to the user exactly as written between " + "[PRIMARY_RESULT_BEGIN] and [PRIMARY_RESULT_END] before any commentary. " + "Do not paraphrase, compress, or reformat it. After that, briefly mention " + "any important artifacts or other output keys if useful, then ask if they " + "want to continue with another run." + ) + + +def _client_input_counts_as_planning_ask(event: Any) -> bool: + """Return True when a queen input-request should satisfy planning ask rounds. + + Explicit ask_user / ask_user_multiple calls always count. We also count + queen auto-blocks that followed assistant text which clearly invited a + reply, which covers Codex-style plain-text planning questions that failed + to call ask_user. Empty/status-only auto-blocks do not count. + """ + data = getattr(event, "data", None) or {} + if data.get("prompt") or data.get("questions") or data.get("options"): + return True + if not data.get("auto_blocked"): + return False + requires_input = data.get("assistant_text_requires_input") + if requires_input is None: + requires_input = bool(data.get("assistant_text_present") and data.get("prompt")) + return bool(requires_input) + async def create_queen( session: Session, @@ -62,7 +242,6 @@ async def create_queen( from framework.agents.queen.nodes.thinking_hook import select_expert_persona from framework.graph.event_loop_node import HookContext, HookResult from framework.graph.executor import GraphExecutor - from framework.runner.mcp_registry import MCPRegistry from framework.runner.tool_registry import ToolRegistry from framework.runtime.core import Runtime from framework.runtime.event_bus import AgentEvent, EventType @@ -87,16 +266,6 @@ async def create_queen( except Exception: logger.warning("Queen: MCP config failed to load", exc_info=True) - try: - registry = MCPRegistry() - registry.initialize() - registry_configs = registry.load_agent_selection(queen_pkg_dir) - if registry_configs: - results = queen_registry.load_registry_servers(registry_configs) - logger.info("Queen: loaded MCP registry servers: %s", results) - except Exception: - logger.warning("Queen: MCP registry config failed to load", exc_info=True) - # ---- Phase state -------------------------------------------------- initial_phase = "staging" if worker_identity else "planning" phase_state = QueenPhaseState(phase=initial_phase, event_bus=session.event_bus) @@ -108,14 +277,7 @@ async def create_queen( async def _track_planning_asks(event: AgentEvent) -> None: if phase_state.phase != "planning": return - # Only count explicit ask_user / ask_user_multiple calls, not - # auto-block (text-only turns emit CLIENT_INPUT_REQUESTED with - # an empty prompt and no options/questions). - data = event.data or {} - has_prompt = bool(data.get("prompt")) - has_questions = bool(data.get("questions")) - has_options = bool(data.get("options")) - if has_prompt or has_questions or has_options: + if _client_input_counts_as_planning_ask(event): phase_state.planning_ask_rounds += 1 session.event_bus.subscribe( @@ -233,15 +395,11 @@ async def _track_planning_asks(event: AgentEvent) -> None: # ---- Default skill protocols ------------------------------------- try: - from framework.skills.manager import SkillsManager, SkillsManagerConfig + from framework.skills.manager import SkillsManager - # Pass project_root so user-scope skills (~/.hive/skills/, ~/.agents/skills/) - # are discovered. Queen has no agent-specific project root, so we use its - # own directory — the value just needs to be non-None to enable user-scope scanning. - _queen_skills_mgr = SkillsManager(SkillsManagerConfig(project_root=Path(__file__).parent)) + _queen_skills_mgr = SkillsManager() _queen_skills_mgr.load() phase_state.protocols_prompt = _queen_skills_mgr.protocols_prompt - phase_state.skills_catalog_prompt = _queen_skills_mgr.skills_catalog_prompt except Exception: logger.debug("Queen skill loading failed (non-fatal)", exc_info=True) @@ -326,20 +484,10 @@ async def _on_worker_done(event): # Mark worker as configured after first successful run session.worker_configured = True output = event.data.get("output", {}) - output_summary = "" - if output: - for key, value in output.items(): - val_str = str(value) - if len(val_str) > 200: - val_str = val_str[:200] + "..." - output_summary += f"\n {key}: {val_str}" - _out = output_summary or " (no output keys set)" - notification = ( - "[WORKER_TERMINAL] Worker finished successfully.\n" - f"Output:{_out}\n" - "Report this to the user. " - "Ask if they want to continue with another run." - ) + # Keep the worker's primary result intact during the + # queen handoff so the user sees the actual answer, + # not just a paraphrased digest of it. + notification = _build_worker_terminal_notification(output) else: # EXECUTION_FAILED error = event.data.get("error", "Unknown error") notification = ( diff --git a/core/framework/server/routes_execution.py b/core/framework/server/routes_execution.py index 8b0d7bb570..af23cf8ace 100644 --- a/core/framework/server/routes_execution.py +++ b/core/framework/server/routes_execution.py @@ -8,11 +8,113 @@ from aiohttp import web from framework.credentials.validation import validate_agent_credentials +from framework.runtime.event_bus import AgentEvent, EventType from framework.server.app import resolve_session, safe_path_segment, sessions_dir from framework.server.routes_sessions import _credential_error_response +from framework.server.session_manager import ( + _run_validation_report_sync, + _validation_blocks_stage_or_run, + _validation_failures, +) logger = logging.getLogger(__name__) +_TERMINAL_STOP_MARKERS = ( + "done for now", + "stop here", + "stop for now", + "end session", + "finish and close session", + "finish and close", +) + + +def _normalize_choice_text(text: str) -> str: + lowered = str(text or "").strip().lower() + return " ".join(lowered.replace("_", " ").split()) + + +def _looks_like_terminal_stop_reply(text: str) -> bool: + normalized = _normalize_choice_text(text) + return any(marker in normalized for marker in _TERMINAL_STOP_MARKERS) + + +def _queen_is_waiting_on_terminal_followup(session: Any) -> bool: + """Return True when the latest queen question offered a terminal stop option.""" + bus = getattr(session, "event_bus", None) + if bus is None or not hasattr(bus, "get_history"): + return False + + events = bus.get_history( + event_type=EventType.CLIENT_INPUT_REQUESTED, + stream_id="queen", + limit=5, + ) + for event in events: + data = getattr(event, "data", None) or {} + options = [str(opt) for opt in (data.get("options") or []) if opt] + for question in data.get("questions") or []: + options.extend(str(opt) for opt in (question.get("options") or []) if opt) + if options: + return any(_looks_like_terminal_stop_reply(opt) for opt in options) + return False + + +async def _acknowledge_terminal_queen_choice(session: Any, message: str) -> None: + """Emit a final acknowledgment when the user chooses to stop.""" + ack = "Okay, stopping here. I’ll wait for your next message." + bus = getattr(session, "event_bus", None) + if bus is None: + return + + if hasattr(bus, "publish"): + await bus.publish( + AgentEvent( + type=EventType.CLIENT_INPUT_RECEIVED, + stream_id="queen", + node_id="queen", + execution_id=session.id, + data={"content": message}, + ) + ) + if hasattr(bus, "emit_client_output_delta"): + await bus.emit_client_output_delta( + "queen", + "queen", + ack, + ack, + execution_id=session.id, + ) + + +async def _worker_validation_error(session) -> web.Response | None: + """Return a 409 response when the loaded worker is invalid.""" + report = getattr(session, "worker_validation_report", None) + if report is None and getattr(session, "worker_path", None): + loop = asyncio.get_running_loop() + report = await loop.run_in_executor( + None, lambda: _run_validation_report_sync(str(session.worker_path)) + ) + session.worker_validation_report = report + session.worker_validation_failures = _validation_failures(report) + + if _validation_blocks_stage_or_run(report): + failures = getattr(session, "worker_validation_failures", None) or _validation_failures( + report + ) + worker_name = getattr(getattr(session, "worker_path", None), "name", "") or "current worker" + return web.json_response( + { + "error": ( + f"Worker '{worker_name}' failed validation and cannot be executed. " + "Fix the package and reload it before running or resuming." + ), + "validation_failures": failures, + }, + status=409, + ) + return None + async def handle_trigger(request: web.Request) -> web.Response: """POST /api/sessions/{session_id}/trigger — start an execution. @@ -26,6 +128,10 @@ async def handle_trigger(request: web.Request) -> web.Response: if not session.worker_runtime: return web.json_response({"error": "No worker loaded in this session"}, status=503) + validation_err = await _worker_validation_error(session) + if validation_err is not None: + return validation_err + # Validate credentials before running — deferred from load time to avoid # showing the modal before the user clicks Run. Runs in executor because # validate_agent_credentials makes blocking HTTP health-check calls. @@ -53,11 +159,7 @@ async def handle_trigger(request: web.Request) -> web.Response: body = await request.json() entry_point_id = body.get("entry_point_id", "default") input_data = body.get("input_data", {}) - session_state = body.get("session_state") or {} - - # Scope the worker execution to the live session ID - if "resume_session_id" not in session_state: - session_state["resume_session_id"] = session.id + session_state = body.get("session_state") or None execution_id = await session.worker_runtime.trigger( entry_point_id, @@ -108,10 +210,7 @@ async def handle_chat(request: web.Request) -> web.Response: The input box is permanently connected to the queen agent. Worker input is handled separately via /worker-input. - Body: {"message": "hello", "images": [{"type": "image_url", "image_url": {"url": "data:..."}}]} - - The optional ``images`` field accepts a list of OpenAI-format image_url - content blocks. The frontend encodes images as base64 data URIs. + Body: {"message": "hello"} """ session, err = resolve_session(request) if err: @@ -119,29 +218,34 @@ async def handle_chat(request: web.Request) -> web.Response: body = await request.json() message = body.get("message", "") - image_content = body.get("images") or None # list[dict] | None - if not message and not image_content: + if not message: return web.json_response({"error": "message is required"}, status=400) + manager: Any = request.app["manager"] + + if _looks_like_terminal_stop_reply(message) and _queen_is_waiting_on_terminal_followup(session): + await _acknowledge_terminal_queen_choice(session, message) + await manager.suspend_queen(session) + return web.json_response( + { + "status": "queen", + "delivered": True, + } + ) + queen_executor = session.queen_executor if queen_executor is not None: node = queen_executor.node_registry.get("queen") if node is not None and hasattr(node, "inject_event"): - await node.inject_event(message, is_client_input=True, image_content=image_content) - # Publish to EventBus so the session event log captures user messages - from framework.runtime.event_bus import AgentEvent, EventType - + await node.inject_event(message, is_client_input=True) await session.event_bus.publish( AgentEvent( type=EventType.CLIENT_INPUT_RECEIVED, stream_id="queen", node_id="queen", execution_id=session.id, - data={ - "content": message, - "image_count": len(image_content) if image_content else 0, - }, + data={"content": message}, ) ) return web.json_response( @@ -152,7 +256,6 @@ async def handle_chat(request: web.Request) -> web.Response: ) # Queen is dead — try to revive her - manager: Any = request.app["manager"] try: await manager.revive_queen(session, initial_prompt=message) return web.json_response( @@ -274,6 +377,10 @@ async def handle_resume(request: web.Request) -> web.Response: if not session.worker_runtime: return web.json_response({"error": "No worker loaded in this session"}, status=503) + validation_err = await _worker_validation_error(session) + if validation_err is not None: + return validation_err + body = await request.json() worker_session_id = body.get("session_id") checkpoint_id = body.get("checkpoint_id") @@ -419,9 +526,14 @@ async def handle_stop(request: web.Request) -> web.Response: if hasattr(node, "cancel_current_turn"): node.cancel_current_turn() - cancelled = await stream.cancel_execution( - execution_id, reason="Execution stopped by user" - ) + try: + cancelled = await stream.cancel_execution( + execution_id, reason="Execution stopped by user" + ) + except TypeError: + # Backward compatibility for older stream/test doubles that + # still expose cancel_execution(execution_id) only. + cancelled = await stream.cancel_execution(execution_id) if cancelled: # Cancel queen's in-progress LLM turn if session.queen_executor: diff --git a/core/framework/server/routes_sessions.py b/core/framework/server/routes_sessions.py index e8f66ec7cf..89c74d9dc2 100644 --- a/core/framework/server/routes_sessions.py +++ b/core/framework/server/routes_sessions.py @@ -28,8 +28,6 @@ import json import logging import shutil -import subprocess -import sys import time from pathlib import Path @@ -42,22 +40,54 @@ sessions_dir, validate_agent_path, ) -from framework.server.session_manager import SessionManager +from framework.server.session_manager import ( + SessionManager, + WorkerValidationError, + _run_validation_report_sync, + _validation_blocks_stage_or_run, + _validation_failures, +) logger = logging.getLogger(__name__) +async def _worker_validation_error(session) -> web.Response | None: + """Return a 409 response when the loaded worker is invalid.""" + report = getattr(session, "worker_validation_report", None) + if report is None and getattr(session, "worker_path", None): + loop = asyncio.get_running_loop() + report = await loop.run_in_executor( + None, lambda: _run_validation_report_sync(str(session.worker_path)) + ) + session.worker_validation_report = report + session.worker_validation_failures = _validation_failures(report) + + if _validation_blocks_stage_or_run(report): + failures = getattr(session, "worker_validation_failures", None) or _validation_failures( + report + ) + worker_name = getattr(getattr(session, "worker_path", None), "name", "") or "current worker" + return web.json_response( + { + "error": ( + f"Worker '{worker_name}' failed validation and cannot be executed. " + "Fix the package and reload it before running or restoring." + ), + "validation_failures": failures, + }, + status=409, + ) + return None + + def _get_manager(request: web.Request) -> SessionManager: return request.app["manager"] def _session_to_live_dict(session) -> dict: """Serialize a live Session to the session-primary JSON shape.""" - from framework.llm.capabilities import supports_image_tool_results - info = session.worker_info phase_state = getattr(session, "phase_state", None) - queen_model: str = getattr(getattr(session, "runner", None), "model", "") or "" return { "session_id": session.id, "worker_id": session.worker_id, @@ -73,7 +103,6 @@ def _session_to_live_dict(session) -> dict: "queen_phase": phase_state.phase if phase_state else ("staging" if session.worker_runtime else "planning"), - "queen_supports_images": supports_image_tool_results(queen_model) if queen_model else True, } @@ -311,6 +340,11 @@ async def handle_load_worker(request: web.Request) -> web.Response: model=model, ) except ValueError as e: + if isinstance(e, WorkerValidationError): + return web.json_response( + {"error": str(e), "validation_failures": e.failures}, + status=409, + ) return web.json_response({"error": str(e)}, status=409) except FileNotFoundError: return web.json_response({"error": f"Agent not found: {agent_path}"}, status=404) @@ -729,6 +763,10 @@ async def handle_restore_checkpoint(request: web.Request) -> web.Response: if not session.worker_runtime: return web.json_response({"error": "No worker loaded in this session"}, status=503) + validation_err = await _worker_validation_error(session) + if validation_err is not None: + return validation_err + ws_id = request.match_info.get("ws_id") or request.match_info.get("session_id", "") ws_id = safe_path_segment(ws_id) checkpoint_id = safe_path_segment(request.match_info["checkpoint_id"]) @@ -984,29 +1022,6 @@ async def handle_discover(request: web.Request) -> web.Response: return web.json_response(result) -async def handle_reveal_session_folder(request: web.Request) -> web.Response: - """POST /api/sessions/{session_id}/reveal — open session data folder in the OS file manager.""" - manager: SessionManager = request.app["manager"] - session_id = request.match_info["session_id"] - - session = manager.get_session(session_id) - storage_session_id = (session.queen_resume_from or session.id) if session else session_id - folder = Path.home() / ".hive" / "queen" / "session" / storage_session_id - folder.mkdir(parents=True, exist_ok=True) - - try: - if sys.platform == "darwin": - subprocess.Popen(["open", str(folder)]) - elif sys.platform == "win32": - subprocess.Popen(["explorer", str(folder)]) - else: - subprocess.Popen(["xdg-open", str(folder)]) - except Exception as exc: - return web.json_response({"error": str(exc)}, status=500) - - return web.json_response({"path": str(folder)}) - - # ------------------------------------------------------------------ # Route registration # ------------------------------------------------------------------ @@ -1031,7 +1046,6 @@ def register_routes(app: web.Application) -> None: app.router.add_delete("/api/sessions/{session_id}/worker", handle_unload_worker) # Session info - app.router.add_post("/api/sessions/{session_id}/reveal", handle_reveal_session_folder) app.router.add_get("/api/sessions/{session_id}/stats", handle_session_stats) app.router.add_get("/api/sessions/{session_id}/entry-points", handle_session_entry_points) app.router.add_patch( diff --git a/core/framework/server/session_manager.py b/core/framework/server/session_manager.py index ded3a7d7dd..3b8d1655c5 100644 --- a/core/framework/server/session_manager.py +++ b/core/framework/server/session_manager.py @@ -12,6 +12,8 @@ import asyncio import json import logging +import subprocess +import textwrap import time import uuid from dataclasses import dataclass, field @@ -22,6 +24,134 @@ from framework.runtime.triggers import TriggerDefinition logger = logging.getLogger(__name__) +REPO_ROOT = Path(__file__).resolve().parents[3] +CODER_TOOLS_SERVER = REPO_ROOT / "tools" / "coder_tools_server.py" + + +def _parse_validation_report(raw: str | dict[str, Any] | None) -> dict[str, Any]: + """Best-effort parse of validate_agent_package output.""" + if isinstance(raw, dict): + return raw + if not isinstance(raw, str): + return {} + + cleaned = raw.strip() + if "\n\n[Saved to " in cleaned: + cleaned = cleaned.split("\n\n[Saved to ", 1)[0].strip() + if not cleaned: + return {} + + try: + return json.loads(cleaned) + except json.JSONDecodeError: + start = cleaned.find("{") + end = cleaned.rfind("}") + if start != -1 and end != -1 and end > start: + try: + return json.loads(cleaned[start : end + 1]) + except json.JSONDecodeError: + return {} + return {} + + +def _validation_failures(report: dict[str, Any] | None) -> list[str]: + """Extract readable failure summaries from a validation report.""" + if not isinstance(report, dict): + return [] + + steps = report.get("steps") or {} + failures: list[str] = [] + for step_name, step in steps.items(): + if not isinstance(step, dict) or step.get("passed", False): + continue + if step.get("errors"): + errors = step["errors"] + if isinstance(errors, list): + failures.extend(f"{step_name}: {err}" for err in errors) + continue + if step.get("missing_tools"): + missing = step["missing_tools"] + if isinstance(missing, list): + failures.extend(f"{step_name}: missing tool {tool}" for tool in missing) + continue + detail = step.get("error") or step.get("output") or "validation failed" + failures.append(f"{step_name}: {detail}") + if not failures and report.get("summary"): + failures.append(str(report["summary"])) + return failures + + +def _validation_blocks_stage_or_run(report: dict[str, Any] | None) -> bool: + """Return True when a validation report contains any failed step.""" + if not isinstance(report, dict): + return False + steps = report.get("steps") + if not isinstance(steps, dict): + return bool(report.get("valid") is False) + return any(isinstance(step, dict) and not step.get("passed", False) for step in steps.values()) + + +def _run_validation_report_sync(agent_ref: str | Path) -> dict[str, Any]: + """Run validate_agent_package in an isolated subprocess. + + Accepts either a built-agent package name (for exports/) or a full + allowed agent path such as examples/templates/. + """ + if not agent_ref: + return {} + agent_ref_str = str(agent_ref) + + script = textwrap.dedent( + """ + import importlib.util + import sys + + server_path = sys.argv[1] + agent_ref = sys.argv[2] + spec = importlib.util.spec_from_file_location("coder_tools_server", server_path) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + import json + print(json.dumps(module._validate_agent_package_impl(agent_ref), default=str)) + """ + ) + try: + proc = subprocess.run( + ["uv", "run", "python", "-c", script, str(CODER_TOOLS_SERVER), agent_ref_str], + capture_output=True, + text=True, + timeout=120, + cwd=REPO_ROOT, + stdin=subprocess.DEVNULL, + ) + except (OSError, subprocess.SubprocessError) as exc: + return { + "valid": False, + "summary": f"validate_agent_package failed for '{agent_ref_str}'", + "steps": {"validator_subprocess": {"passed": False, "error": str(exc)[:2000]}}, + } + if proc.returncode != 0: + detail = proc.stderr.strip() or proc.stdout.strip() or "validation subprocess failed" + return { + "valid": False, + "summary": f"validate_agent_package failed for '{agent_ref_str}'", + "steps": {"validator_subprocess": {"passed": False, "error": detail[:2000]}}, + } + return _parse_validation_report(proc.stdout) + + +class WorkerValidationError(ValueError): + """Raised when a worker package fails validation before load/run.""" + + def __init__(self, agent_name: str, report: dict[str, Any]): + self.agent_name = agent_name + self.report = report + self.failures = _validation_failures(report) + super().__init__( + f"Worker '{agent_name}' failed validation: " + + ("; ".join(self.failures) if self.failures else "validation failed") + ) @dataclass @@ -41,6 +171,8 @@ class Session: runner: Any | None = None # AgentRunner worker_runtime: Any | None = None # AgentRuntime worker_info: Any | None = None # AgentInfo + worker_validation_report: dict[str, Any] | None = None + worker_validation_failures: list[str] = field(default_factory=list) # Queen phase state (building/staging/running) phase_state: Any = None # QueenPhaseState # Worker handoff subscription @@ -83,6 +215,47 @@ def __init__(self, model: str | None = None, credential_store=None) -> None: self._credential_store = credential_store self._lock = asyncio.Lock() + async def suspend_queen(self, session: Session) -> None: + """Park the queen until the user sends a fresh message. + + This is lighter than stopping the full session: it tears down the + queen executor and its subscriptions, but preserves the live session, + loaded worker, and persisted history. The next `/chat` call will + revive the queen via the normal code path. + """ + if session.worker_handoff_sub is not None: + try: + session.event_bus.unsubscribe(session.worker_handoff_sub) + except Exception: + pass + session.worker_handoff_sub = None + + if session.memory_consolidation_sub is not None: + try: + session.event_bus.unsubscribe(session.memory_consolidation_sub) + except Exception: + pass + session.memory_consolidation_sub = None + + executor = session.queen_executor + if executor is not None: + node = executor.node_registry.get("queen") + if node is not None and hasattr(node, "signal_shutdown"): + node.signal_shutdown() + + if session.queen_task is not None: + task = session.queen_task + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + except Exception: + logger.debug("Queen task exited with error during suspend", exc_info=True) + session.queen_task = None + + session.queen_executor = None + # ------------------------------------------------------------------ # Session lifecycle # ------------------------------------------------------------------ @@ -96,7 +269,8 @@ async def _create_session_core( Internal helper — use create_session() or create_session_with_worker(). """ - from framework.config import RuntimeConfig, get_hive_config + from framework.config import RuntimeConfig + from framework.llm.litellm import LiteLLMProvider from framework.runtime.event_bus import EventBus ts = datetime.now().strftime("%Y%m%d_%H%M%S") @@ -110,20 +284,12 @@ async def _create_session_core( rc = RuntimeConfig(model=model or self._model or RuntimeConfig().model) # Session owns these — shared with queen and worker - llm_config = get_hive_config().get("llm", {}) - if llm_config.get("use_antigravity_subscription"): - from framework.llm.antigravity import AntigravityProvider - - llm = AntigravityProvider(model=rc.model) - else: - from framework.llm.litellm import LiteLLMProvider - - llm = LiteLLMProvider( - model=rc.model, - api_key=rc.api_key, - api_base=rc.api_base, - **rc.extra_kwargs, - ) + llm = LiteLLMProvider( + model=rc.model, + api_key=rc.api_key, + api_base=rc.api_base, + **rc.extra_kwargs, + ) event_bus = EventBus() session = Session( @@ -294,6 +460,11 @@ async def _load_worker_core( try: # Blocking I/O — load in executor loop = asyncio.get_running_loop() + validation_report = await loop.run_in_executor( + None, lambda: _run_validation_report_sync(str(agent_path)) + ) + if _validation_blocks_stage_or_run(validation_report): + raise WorkerValidationError(agent_path.name, validation_report) # Prioritize: explicit model arg > worker-specific model > session default from framework.config import ( @@ -320,25 +491,17 @@ async def _load_worker_core( # with the correct worker credentials so _setup() doesn't fall back # to the queen's llm config (which may be a different provider). if worker_model and not model: - from framework.config import get_hive_config - - worker_llm_cfg = get_hive_config().get("worker_llm", {}) - if worker_llm_cfg.get("use_antigravity_subscription"): - from framework.llm.antigravity import AntigravityProvider - - runner._llm = AntigravityProvider(model=resolved_model) - else: - from framework.llm.litellm import LiteLLMProvider - - worker_api_key = get_worker_api_key() - worker_api_base = get_worker_api_base() - worker_extra = get_worker_llm_extra_kwargs() - runner._llm = LiteLLMProvider( - model=resolved_model, - api_key=worker_api_key, - api_base=worker_api_base, - **worker_extra, - ) + from framework.llm.litellm import LiteLLMProvider + + worker_api_key = get_worker_api_key() + worker_api_base = get_worker_api_base() + worker_extra = get_worker_llm_extra_kwargs() + runner._llm = LiteLLMProvider( + model=resolved_model, + api_key=worker_api_key, + api_base=worker_api_base, + **worker_extra, + ) # Setup with session's event bus if runner._agent_runtime is None: @@ -383,6 +546,8 @@ async def _load_worker_core( session.runner = runner session.worker_runtime = runtime session.worker_info = info + session.worker_validation_report = validation_report + session.worker_validation_failures = _validation_failures(validation_report) # Subscribe to execution completion for per-run digest generation self._subscribe_worker_digest(session) @@ -637,6 +802,8 @@ async def unload_worker(self, session_id: str) -> bool: session.runner = None session.worker_runtime = None session.worker_info = None + session.worker_validation_report = None + session.worker_validation_failures = [] # Notify queen await self._notify_queen_worker_unloaded(session) @@ -820,12 +987,25 @@ async def _inject_digest_to_queen(run_id: str) -> None: return await node.inject_event(f"[WORKER_DIGEST]\n{content}") - async def _consolidate_and_notify(run_id: str, outcome_event: Any) -> None: - """Write the digest then push it to the queen.""" + async def _consolidate_and_notify( + run_id: str, + outcome_event: Any, + *, + inject_to_queen: bool, + ) -> None: + """Write the digest and optionally push it into the queen. + + Final worker completion/failure already emits a richer + [WORKER_TERMINAL] handoff with the real primary result. Injecting the + final digest as a second queen event causes Codex to replace that + result with a bland generic follow-up prompt. Keep writing digests to + disk for memory/history, but only inject mid-run snapshots. + """ from framework.agents.worker_memory import consolidate_worker_run await consolidate_worker_run(_agent_name, run_id, outcome_event, _bus, _llm) - await _inject_digest_to_queen(run_id) + if inject_to_queen: + await _inject_digest_to_queen(run_id) async def _on_worker_event(event: Any) -> None: if event.stream_id == "queen": @@ -851,7 +1031,7 @@ async def _on_worker_event(event: Any) -> None: run_id = getattr(event, "run_id", None) or _resolve_run_id(exec_id) if run_id: asyncio.create_task( - _consolidate_and_notify(run_id, event), + _consolidate_and_notify(run_id, event, inject_to_queen=False), name=f"worker-digest-final-{run_id}", ) @@ -872,7 +1052,7 @@ async def _on_worker_event(event: Any) -> None: if run_id: _last_digest[exec_id] = now asyncio.create_task( - _consolidate_and_notify(run_id, None), + _consolidate_and_notify(run_id, None, inject_to_queen=True), name=f"worker-digest-{run_id}", ) @@ -1047,17 +1227,10 @@ async def _start_queen( _consolidation_session_dir = queen_dir async def _on_compaction(_event) -> None: - # Only consolidate on queen compactions — worker and subagent - # compactions are frequent and don't warrant a memory update. - if getattr(_event, "stream_id", None) != "queen": - return from framework.agents.queen.queen_memory import consolidate_queen_memory - asyncio.create_task( - consolidate_queen_memory( - session.id, _consolidation_session_dir, _consolidation_llm - ), - name=f"queen-memory-consolidation-{session.id}", + await consolidate_queen_memory( + session.id, _consolidation_session_dir, _consolidation_llm ) from framework.runtime.event_bus import EventType as _ET diff --git a/core/framework/server/tests/test_api.py b/core/framework/server/tests/test_api.py index 4815192f0a..48c06d0083 100644 --- a/core/framework/server/tests/test_api.py +++ b/core/framework/server/tests/test_api.py @@ -14,6 +14,7 @@ import pytest from aiohttp.test_utils import TestClient, TestServer +from framework.runtime.event_bus import AgentEvent, EventType from framework.runtime.triggers import TriggerDefinition from framework.server.app import create_app from framework.server.session_manager import Session @@ -190,6 +191,8 @@ def _make_session( runner=runner, worker_runtime=rt, worker_info=MockAgentInfo(), + worker_validation_report={"valid": True, "steps": {}}, + worker_validation_failures=[], ) @@ -556,6 +559,7 @@ class TestExecution: @pytest.mark.asyncio async def test_trigger(self): session = _make_session() + session.worker_runtime.trigger = AsyncMock(return_value="exec_test_123") app = _make_app_with_session(session) async with TestClient(TestServer(app)) as client: resp = await client.post( @@ -565,6 +569,11 @@ async def test_trigger(self): assert resp.status == 200 data = await resp.json() assert data["execution_id"] == "exec_test_123" + session.worker_runtime.trigger.assert_awaited_once_with( + "default", + {"msg": "hi"}, + session_state=None, + ) @pytest.mark.asyncio async def test_trigger_not_found(self): @@ -576,6 +585,25 @@ async def test_trigger_not_found(self): ) assert resp.status == 404 + @pytest.mark.asyncio + async def test_trigger_blocks_invalid_loaded_worker(self): + session = _make_session() + session.worker_validation_report = { + "valid": False, + "steps": {"behavior_validation": {"passed": False}}, + } + session.worker_validation_failures = ["behavior_validation: placeholder prompt"] + app = _make_app_with_session(session) + async with TestClient(TestServer(app)) as client: + resp = await client.post( + "/api/sessions/test_agent/trigger", + json={"entry_point_id": "default", "input_data": {"msg": "hi"}}, + ) + assert resp.status == 409 + data = await resp.json() + assert "failed validation" in data["error"] + assert data["validation_failures"] == ["behavior_validation: placeholder prompt"] + @pytest.mark.asyncio async def test_inject(self): session = _make_session() @@ -616,8 +644,8 @@ async def test_chat_goes_to_queen_when_not_waiting(self): assert data["delivered"] is True @pytest.mark.asyncio - async def test_chat_injects_when_node_waiting(self): - """When a node is awaiting input, /chat should inject instead of trigger.""" + async def test_chat_still_goes_to_queen_when_node_waiting(self): + """The main chat channel stays wired to Queen even if a worker is waiting.""" session = _make_session() session.worker_runtime.find_awaiting_node = lambda: ("chat_node", "primary") app = _make_app_with_session(session) @@ -628,6 +656,83 @@ async def test_chat_injects_when_node_waiting(self): ) assert resp.status == 200 data = await resp.json() + assert data["status"] == "queen" + assert data["delivered"] is True + + @pytest.mark.asyncio + async def test_chat_done_for_now_parks_queen_without_new_followup(self): + """Terminal stop choices should acknowledge once and park the queen.""" + session = _make_session() + session.event_bus.get_history.return_value = [ + AgentEvent( + type=EventType.CLIENT_INPUT_REQUESTED, + stream_id="queen", + node_id="queen", + execution_id=session.id, + data={"options": ["Run again with same input", "Done for now"]}, + ) + ] + session.event_bus.emit_client_output_delta = AsyncMock() + app = _make_app_with_session(session) + + async with TestClient(TestServer(app)) as client: + resp = await client.post( + "/api/sessions/test_agent/chat", + json={"message": "No, stop here"}, + ) + assert resp.status == 200 + data = await resp.json() + assert data["status"] == "queen" + assert data["delivered"] is True + + queen_node = session.queen_executor + assert queen_node is None + session.event_bus.emit_client_output_delta.assert_awaited_once() + + @pytest.mark.asyncio + async def test_chat_non_terminal_choice_still_goes_to_queen(self): + """Non-terminal follow-up choices should still be injected into the queen.""" + session = _make_session() + session.event_bus.get_history.return_value = [ + AgentEvent( + type=EventType.CLIENT_INPUT_REQUESTED, + stream_id="queen", + node_id="queen", + execution_id=session.id, + data={"options": ["Run again with same input", "Done for now"]}, + ) + ] + session.event_bus.emit_client_output_delta = AsyncMock() + app = _make_app_with_session(session) + + async with TestClient(TestServer(app)) as client: + resp = await client.post( + "/api/sessions/test_agent/chat", + json={"message": "Run again with same input"}, + ) + assert resp.status == 200 + data = await resp.json() + assert data["status"] == "queen" + assert data["delivered"] is True + queen_node = session.queen_executor.node_registry["queen"] + queen_node.inject_event.assert_awaited_once_with( + "Run again with same input", + is_client_input=True, + ) + session.event_bus.emit_client_output_delta.assert_not_called() + + @pytest.mark.asyncio + async def test_worker_input_injects_when_node_waiting(self): + session = _make_session() + session.worker_runtime.find_awaiting_node = lambda: ("chat_node", "primary") + app = _make_app_with_session(session) + async with TestClient(TestServer(app)) as client: + resp = await client.post( + "/api/sessions/test_agent/worker-input", + json={"message": "user reply"}, + ) + assert resp.status == 200 + data = await resp.json() assert data["status"] == "injected" assert data["node_id"] == "chat_node" assert data["delivered"] is True @@ -715,8 +820,6 @@ async def test_resume_from_session_state(self, sample_session, tmp_agent_dir): assert resp.status == 200 data = await resp.json() assert data["execution_id"] == "exec_test_123" - assert data["resumed_from"] == session_id - assert data["checkpoint_id"] is None @pytest.mark.asyncio async def test_resume_with_checkpoint(self, sample_session, tmp_agent_dir): @@ -761,6 +864,31 @@ async def test_resume_session_not_found(self): ) assert resp.status == 404 + @pytest.mark.asyncio + async def test_resume_blocks_invalid_loaded_worker(self, sample_session, tmp_agent_dir): + session_id, session_dir, state = sample_session + tmp_path, agent_name, base = tmp_agent_dir + + session = _make_session(tmp_dir=tmp_path / ".hive" / "agents" / agent_name) + session.worker_validation_report = { + "valid": False, + "steps": {"tool_validation": {"passed": False}}, + } + session.worker_validation_failures = ["tool_validation: missing tool execute_command_tool"] + app = _make_app_with_session(session) + + async with TestClient(TestServer(app)) as client: + resp = await client.post( + "/api/sessions/test_agent/resume", + json={"session_id": session_id}, + ) + assert resp.status == 409 + data = await resp.json() + assert "failed validation" in data["error"] + assert data["validation_failures"] == [ + "tool_validation: missing tool execute_command_tool" + ] + class TestStop: @pytest.mark.asyncio @@ -800,6 +928,19 @@ async def test_stop_missing_execution_id(self): ) assert resp.status == 400 + @pytest.mark.asyncio + async def test_stop_ignores_worker_validation_failure(self): + session = _make_session() + session.worker_validation_failures = ["behavior_validation: broken"] + session.worker_runtime._mock_streams["default"]._execution_tasks["exec_abc"] = MagicMock() + app = _make_app_with_session(session) + async with TestClient(TestServer(app)) as client: + resp = await client.post( + "/api/sessions/test_agent/stop", + json={"execution_id": "exec_abc"}, + ) + assert resp.status == 200 + class TestReplay: @pytest.mark.asyncio diff --git a/core/framework/tools/queen_lifecycle_tools.py b/core/framework/tools/queen_lifecycle_tools.py index 0aa16a39f6..a681d6530c 100644 --- a/core/framework/tools/queen_lifecycle_tools.py +++ b/core/framework/tools/queen_lifecycle_tools.py @@ -36,6 +36,7 @@ import asyncio import json import logging +import re import time from dataclasses import dataclass, field from datetime import UTC, datetime @@ -53,6 +54,7 @@ save_flowchart_file, synthesize_draft_from_runtime, ) +from framework.tools.worker_monitoring_tools import read_worker_health_snapshot if TYPE_CHECKING: from framework.runner.tool_registry import ToolRegistry @@ -61,6 +63,16 @@ logger = logging.getLogger(__name__) +_NON_ACCEPT_JUDGE_ACTIONS = frozenset({"RETRY", "CONTINUE", "ESCALATE"}) +_HEALTH_SIGNAL_DESCRIPTIONS: dict[str, str] = { + "failed_session": "worker session is marked failed", + "stalled": "worker appears stalled with no meaningful progress for 5+ minutes", + "slow_progress": "worker progress has slowed for 2+ minutes without completing", + "long_non_accept_streak": "worker has a sustained non-ACCEPT judge streak", + "judge_pressure": "worker is under repeated non-ACCEPT judge pressure", + "recent_non_accept_churn": "recent judge verdicts are all non-ACCEPT, indicating churn", +} + @dataclass class WorkerSessionAdapter: @@ -118,8 +130,6 @@ class QueenPhaseState: # Default skill operational protocols — appended to every phase prompt protocols_prompt: str = "" - # Community skills catalog (XML) — appended after protocols - skills_catalog_prompt: str = "" def get_current_tools(self) -> list: """Return tools for the current phase.""" @@ -146,8 +156,6 @@ def get_current_prompt(self) -> str: memory = format_for_injection() parts = [base] - if self.skills_catalog_prompt: - parts.append(self.skills_catalog_prompt) if self.protocols_prompt: parts.append(self.protocols_prompt) if memory: @@ -750,6 +758,438 @@ def _update_meta_json(session_manager, manager_session_id, updates: dict) -> Non pass +def _parse_validation_report(raw: Any) -> dict | None: + """Best-effort parse of validate_agent_package output.""" + if isinstance(raw, dict): + return raw + if hasattr(raw, "content"): + raw = raw.content + if raw is None: + return None + text = str(raw).strip() + if not text: + return None + candidates = [text] + if "\n\n[Saved to" in text: + candidates.append(text.split("\n\n[Saved to", 1)[0].strip()) + start = text.find("{") + end = text.rfind("}") + if start != -1 and end != -1 and end > start: + candidates.append(text[start : end + 1]) + try: + for candidate in candidates: + try: + return json.loads(candidate) + except (TypeError, json.JSONDecodeError): + continue + except TypeError: + pass + return None + + +def _validation_failures(report: dict | None) -> list[str]: + """Flatten failed validation steps into readable messages.""" + if not report: + return [] + failures: list[str] = [] + for step_name, step in (report.get("steps") or {}).items(): + if step.get("passed"): + continue + detail = step.get("output") or step.get("error") or "failed" + failures.append(f"{step_name}: {detail}") + return failures + + +def _validation_blocks_stage_or_run(report: dict | None) -> bool: + """Return True when validation results should block staging or execution.""" + if not report: + return False + return any( + isinstance(step, dict) and not step.get("passed", False) + for step in (report.get("steps") or {}).values() + ) + + +def _invalid_validation_report(reason: str) -> dict: + """Build a structured validation failure when validator output is unusable.""" + return { + "valid": False, + "summary": reason, + "steps": { + "validator_subprocess": { + "passed": False, + "error": reason, + } + }, + } + + +_STRUCTURED_TASK_PAIR_RE = re.compile( + r"\[?(?P[A-Za-z_][A-Za-z0-9_]*)\]?\s*(?::|=)\s*(?P.*?)(?=(?:\s+\[?[A-Za-z_][A-Za-z0-9_]*\]?\s*(?::|=))|$)" +) +_STRUCTURED_TASK_LINE_RE = re.compile( + r"^\s*(?:[-*]\s*)?\[?(?P[A-Za-z_][A-Za-z0-9_]*)\]?\s*(?::|=)\s*(?P.*?)\s*$" +) +_NUMERIC_WITH_SUFFIX_RE = re.compile(r"^(?P-?\d+(?:\.\d+)?)\s*\([^)]*\)\s*$") +_LEADING_NUMERIC_RE = re.compile(r"^\s*(?P-?\d+(?:\.\d+)?)\b") +_RERUN_WITH_DEFAULTS_RE = re.compile( + r"\b(?:run\s+again|rerun|continue)\b.*\b(?:same\s+(?:default|defaults|settings|inputs)|" + r"defaults?)\b", + re.IGNORECASE, +) +_PATH_INPUT_KEY_HINTS = ( + "_dir", + "_path", + "_folder", + "_root", +) +_NUMERIC_INPUT_KEY_HINTS = ( + "_threshold", + "_count", + "_limit", + "_max", + "_min", + "_ratio", + "_size", +) + + +def _coerce_task_value(raw: str) -> Any: + """Best-effort coerce simple structured task values from text.""" + text = raw.strip().rstrip(",") + if not text: + return "" + numeric_match = _NUMERIC_WITH_SUFFIX_RE.match(text) + if numeric_match: + text = numeric_match.group("number") + try: + return json.loads(text) + except (TypeError, json.JSONDecodeError): + return text + + +def _parse_structured_task_payload(task: str) -> dict[str, Any]: + """Extract ``key: value`` pairs or JSON objects from a task string.""" + text = (task or "").strip() + if not text: + return {} + + try: + parsed = json.loads(text) + if isinstance(parsed, dict): + return parsed + except (TypeError, json.JSONDecodeError): + pass + + payload: dict[str, Any] = {} + current_key: str | None = None + for raw_line in text.splitlines(): + if not raw_line.strip(): + current_key = None + continue + inline_matches = list(_STRUCTURED_TASK_PAIR_RE.finditer(raw_line.strip())) + if len(inline_matches) > 1: + for match in inline_matches: + payload[match.group("key")] = _coerce_task_value(match.group("value")) + current_key = inline_matches[-1].group("key") + continue + line_match = _STRUCTURED_TASK_LINE_RE.match(raw_line) + if line_match: + key = line_match.group("key") + value_text = line_match.group("value") + payload[key] = _coerce_task_value(value_text) + current_key = key + continue + if current_key and (raw_line.startswith(" ") or raw_line.startswith("\t")): + existing = payload.get(current_key, "") + continuation = raw_line.strip() + if isinstance(existing, str): + payload[current_key] = f"{existing}\n{continuation}".strip() + continue + current_key = None + + if payload: + return payload + + matches = list(_STRUCTURED_TASK_PAIR_RE.finditer(text)) + for match in matches: + key = match.group("key") + value_text = match.group("value") + if not value_text.strip(): + continue + payload[key] = _coerce_task_value(value_text) + return payload + + +def _looks_like_path_input_key(key: str) -> bool: + lowered = key.lower() + return lowered.endswith(_PATH_INPUT_KEY_HINTS) or lowered in { + "path", + "dir", + "folder", + "root", + } + + +def _looks_like_numeric_input_key(key: str) -> bool: + lowered = key.lower() + return lowered.endswith(_NUMERIC_INPUT_KEY_HINTS) or lowered in { + "threshold", + "count", + "limit", + "max", + "min", + "ratio", + "size", + } + + +def _normalize_worker_input_value(key: str, value: Any) -> Any: + """Normalize structured worker inputs before handing them to the runtime.""" + if not isinstance(value, str): + return value + + text = value.strip() + if not text: + return "" + + if _looks_like_numeric_input_key(key): + numeric_match = _LEADING_NUMERIC_RE.match(text) + if numeric_match: + number = numeric_match.group("number") + return float(number) if "." in number else int(number) + + if _looks_like_path_input_key(key): + candidate = Path(text).expanduser() + if not candidate.is_absolute(): + candidate = (Path.cwd() / candidate).resolve() + return str(candidate) + + return text + + +def _should_backfill_from_recent_input(key: str, value: Any) -> bool: + """Return True when a recent session input should replace the current value.""" + if value is None: + return True + if isinstance(value, str): + text = value.strip() + if not text: + return True + if _looks_like_numeric_input_key(key): + return _LEADING_NUMERIC_RE.fullmatch(text) is None + return False + + +def _load_recent_worker_input_defaults( + runtime: Any, + input_keys: list[str], + session_id: str | None = None, +) -> dict[str, Any]: + """Load the best recent worker input payload from unified session state. + + When a current session is known, only that session is considered so reruns + cannot inherit structured defaults from an unrelated historical session. + Otherwise we fall back to the latest available session as a best-effort + compatibility path for legacy callers. + """ + store = getattr(runtime, "_session_store", None) + sessions_dir = Path(getattr(store, "sessions_dir", "")) if store is not None else None + if sessions_dir is None or not sessions_dir.exists(): + return {} + allowed_key_set = set(input_keys) + + candidate_state_paths: list[Path] + if session_id: + state_path = sessions_dir / session_id / "state.json" + candidate_state_paths = [state_path] if state_path.exists() else [] + else: + candidate_state_paths = [] + + if not candidate_state_paths: + candidate_state_paths = sorted( + sessions_dir.glob("session_*/state.json"), + key=lambda path: path.stat().st_mtime if path.exists() else 0, + reverse=True, + ) + + work_keys = [key for key in input_keys if key not in {"user_request", "task", "feedback"}] + if not work_keys: + best_payload: dict[str, Any] = {} + best_updated_at = "" + for state_path in candidate_state_paths: + try: + raw_state = json.loads(state_path.read_text(encoding="utf-8")) + except (OSError, json.JSONDecodeError): + continue + input_data = raw_state.get("input_data") or {} + if not isinstance(input_data, dict) or not input_data: + continue + input_data = {key: value for key, value in input_data.items() if key in allowed_key_set} + if not input_data: + continue + updated_at = str((raw_state.get("timestamps") or {}).get("updated_at") or "") + if updated_at >= best_updated_at: + best_payload = dict(input_data) + best_updated_at = updated_at + return best_payload + + best_payload: dict[str, Any] = {} + best_score = -1 + best_updated_at = "" + + for state_path in candidate_state_paths: + try: + raw_state = json.loads(state_path.read_text(encoding="utf-8")) + except (OSError, json.JSONDecodeError): + continue + + input_data = raw_state.get("input_data") or {} + if not isinstance(input_data, dict): + continue + + merged_input = {key: value for key, value in input_data.items() if key in allowed_key_set} + result_output = (raw_state.get("result") or {}).get("output") or {} + if isinstance(result_output, dict): + for key in work_keys: + if merged_input.get(key) in (None, "") and result_output.get(key) not in (None, ""): + merged_input[key] = result_output[key] + + score = sum(1 for key in work_keys if merged_input.get(key) not in (None, "")) + if score <= 0: + continue + + updated_at = str((raw_state.get("timestamps") or {}).get("updated_at") or "") + if score > best_score or (score == best_score and updated_at > best_updated_at): + best_payload = merged_input + best_score = score + best_updated_at = updated_at + + return best_payload + + +async def _preflight_worker_run(session: Any, runtime: Any, timeout_seconds: int) -> None: + """Validate credentials and refresh MCP servers before a worker run.""" + loop = asyncio.get_running_loop() + + async def _preflight(): + cred_error: CredentialError | None = None + try: + await loop.run_in_executor( + None, + lambda: validate_credentials( + runtime.graph.nodes, + interactive=False, + skip=False, + ), + ) + except CredentialError as e: + cred_error = e + + runner = getattr(session, "runner", None) + if runner: + try: + await loop.run_in_executor( + None, + lambda: runner._tool_registry.resync_mcp_servers_if_needed(), + ) + except Exception as e: + logger.warning("MCP resync failed: %s", e) + + if cred_error is not None: + raise cred_error + + try: + await asyncio.wait_for(_preflight(), timeout=timeout_seconds) + except TimeoutError: + logger.warning( + "worker run preflight timed out after %ds — proceeding", + timeout_seconds, + ) + + +def _get_default_entry_input_keys(runtime: Any) -> list[str]: + """Return the loaded worker's default entry node input keys, if available.""" + try: + entry_points = runtime.get_entry_points() + except Exception: + return [] + if not entry_points: + return [] + + graph = getattr(runtime, "graph", None) + if graph is None or not hasattr(graph, "get_node"): + return [] + + entry_spec = entry_points[0] + entry_node_id = getattr(entry_spec, "entry_node", None) or getattr(graph, "entry_node", None) + if not entry_node_id: + return [] + + node = graph.get_node(entry_node_id) + return list(getattr(node, "input_keys", []) or []) if node is not None else [] + + +def _build_worker_input_data( + runtime: Any, + task: str, + session_id: str | None = None, +) -> dict[str, Any]: + """Shape queen task text into the loaded worker's expected entry inputs.""" + structured = _parse_structured_task_payload(task) + allowed_keys = _get_default_entry_input_keys(runtime) + + # Backwards compatibility for older workers that still expect a single + # free-form task string, while allowing newer workers to receive + # structured fields directly. + if not allowed_keys: + payload = {"user_request": task, "task": task} + payload.update(structured) + return payload + + shaped: dict[str, Any] = {} + if "user_request" in allowed_keys: + shaped["user_request"] = task + if "task" in allowed_keys: + shaped["task"] = task + + work_keys = [key for key in allowed_keys if key not in {"user_request", "task", "feedback"}] + structured_work_keys = {key for key in work_keys if key in structured} + should_merge_recent_defaults = bool(structured_work_keys) or bool( + _RERUN_WITH_DEFAULTS_RE.search(task or "") + ) + recent_defaults = ( + _load_recent_worker_input_defaults(runtime, allowed_keys, session_id=session_id) + if should_merge_recent_defaults + else {} + ) + + for key in allowed_keys: + if key in {"user_request", "task", "feedback"}: + continue + if key in structured: + shaped[key] = _normalize_worker_input_value(key, structured[key]) + if ( + key in recent_defaults + and _should_backfill_from_recent_input(key, shaped[key]) + and recent_defaults.get(key) not in (None, "") + ): + shaped[key] = _normalize_worker_input_value(key, recent_defaults[key]) + elif recent_defaults.get(key) not in (None, ""): + shaped[key] = _normalize_worker_input_value(key, recent_defaults[key]) + + work_keys = [key for key in allowed_keys if key != "feedback"] + if not any(key in shaped for key in work_keys): + if len(work_keys) == 1: + shaped[work_keys[0]] = task + elif "user_request" in allowed_keys and "user_request" not in shaped: + shaped["user_request"] = task + elif "task" in allowed_keys and "task" not in shaped: + shaped["task"] = task + + return shaped + + def register_queen_lifecycle_tools( registry: ToolRegistry, session: Any = None, @@ -803,6 +1243,23 @@ def _get_runtime(): """Get current worker runtime from session (late-binding).""" return getattr(session, "worker_runtime", None) + async def _run_package_validation(agent_ref: str) -> dict | None: + """Run validate_agent_package if available in the registry.""" + validator = registry._tools.get("validate_agent_package") + if validator is None or not agent_ref: + return None + # The validator accepts either a built-agent package name or a + # fully resolved allowed agent path. + result = validator.executor({"agent_name": agent_ref}) + if asyncio.iscoroutine(result) or asyncio.isfuture(result): + result = await result + parsed = _parse_validation_report(result) + if parsed is None: + return _invalid_validation_report( + "validate_agent_package returned an invalid or undecodable report" + ) + return parsed + # --- start_worker --------------------------------------------------------- # How long to wait for credential validation + MCP resync before @@ -821,66 +1278,17 @@ async def start_worker(task: str) -> str: return json.dumps({"error": "No worker loaded in this session."}) try: - # Pre-flight: validate credentials and resync MCP servers. - # Both are blocking I/O (HTTP health-checks, subprocess spawns) - # so they run in a thread-pool executor. We cap the total - # preflight time so the queen never hangs waiting. - loop = asyncio.get_running_loop() - - async def _preflight(): - cred_error: CredentialError | None = None - try: - await loop.run_in_executor( - None, - lambda: validate_credentials( - runtime.graph.nodes, - interactive=False, - skip=False, - ), - ) - except CredentialError as e: - cred_error = e - - runner = getattr(session, "runner", None) - if runner: - try: - await loop.run_in_executor( - None, - lambda: runner._tool_registry.resync_mcp_servers_if_needed(), - ) - except Exception as e: - logger.warning("MCP resync failed: %s", e) - - # Re-raise CredentialError after MCP resync so both steps - # get a chance to run before we bail. - if cred_error is not None: - raise cred_error - - try: - await asyncio.wait_for(_preflight(), timeout=_START_PREFLIGHT_TIMEOUT) - except TimeoutError: - logger.warning( - "start_worker preflight timed out after %ds — proceeding with trigger", - _START_PREFLIGHT_TIMEOUT, - ) - except CredentialError: - raise # handled below + await _preflight_worker_run(session, runtime, _START_PREFLIGHT_TIMEOUT) # Resume timers in case they were paused by a previous stop_worker runtime.resume_timers() - # Get session state from any prior execution for memory continuity - session_state = runtime._get_primary_session_state("default") or {} - - # Use the shared session ID so queen, judge, and worker all - # scope their conversations to the same session. - if session_id: - session_state["resume_session_id"] = session_id - exec_id = await runtime.trigger( entry_point_id="default", - input_data={"user_request": task}, - session_state=session_state, + input_data=_build_worker_input_data(runtime, task, session_id=session_id), + # Worker runs should start from the explicit input payload for + # this run, not inherit another execution's shared session. + session_state=None, ) return json.dumps( { @@ -2547,19 +2955,76 @@ def _build_preamble( return preamble - def _detect_red_flags(bus: EventBus) -> int: + def _get_worker_health_snapshot() -> dict[str, Any] | None: + worker_path = getattr(session, "worker_path", None) + if not worker_path: + return None + try: + snapshot = read_worker_health_snapshot( + Path(worker_path), + session_id=session_id, + default_session_id=session_id, + ) + except Exception: + logger.exception("Failed to read worker health snapshot for queen status") + return None + if snapshot.get("error"): + return None + return snapshot + + def _detect_red_flags(bus: EventBus, health_snapshot: dict[str, Any] | None = None) -> int: """Count issue categories with cheap limit=1 queries.""" + if health_snapshot: + issue_signals = health_snapshot.get("issue_signals", []) + if isinstance(issue_signals, list) and issue_signals: + return len(issue_signals) + count = 0 for evt_type in ( + EventType.NODE_RETRY, EventType.NODE_STALLED, EventType.NODE_TOOL_DOOM_LOOP, EventType.CONSTRAINT_VIOLATION, ): if bus.get_history(event_type=evt_type, limit=1): count += 1 + if _get_recent_judge_pressure(bus)[0]: + count += 1 return count - def _format_summary(preamble: dict[str, Any], red_flags: int) -> str: + def _get_recent_judge_pressure(bus: EventBus, streak_threshold: int = 4) -> tuple[bool, str]: + """Detect sustained judge churn even when no hard stall event exists yet.""" + verdict_events = bus.get_history(event_type=EventType.JUDGE_VERDICT, limit=8) + if len(verdict_events) < streak_threshold: + return False, "" + + streak: list[str] = [] + for evt in verdict_events: + action = str(evt.data.get("action", "")).upper() + if action == "ACCEPT": + break + if action in _NON_ACCEPT_JUDGE_ACTIONS: + streak.append(action) + continue + break + + if len(streak) < streak_threshold: + return False, "" + + compressed: list[str] = [] + for action in streak: + if not compressed or compressed[-1] != action: + compressed.append(action) + return ( + True, + f"{len(streak)} consecutive non-ACCEPT judge verdict(s): {' -> '.join(compressed)}", + ) + + def _format_summary( + preamble: dict[str, Any], + red_flags: int, + health_snapshot: dict[str, Any] | None = None, + ) -> str: """Generate a 1-2 sentence prose summary from the preamble.""" status = preamble["status"] @@ -2586,10 +3051,17 @@ def _format_summary(preamble: dict[str, Any], red_flags: int) -> str: node_part += f", iteration {iteration}" parts.append(node_part) + health_signals = health_snapshot.get("issue_signals", []) if health_snapshot else [] if red_flags: - parts.append(f"{red_flags} issue type(s) detected — use focus='issues' for details") + if isinstance(health_signals, list) and health_signals: + parts.append( + f"{red_flags} issue signal(s) detected " + f"({', '.join(health_signals)}) — use focus='issues' for details" + ) + else: + parts.append(f"{red_flags} issue type(s) detected — use focus='issues' for details") else: - parts.append("No issues detected") + parts.append("No issue signals detected") # Latest subagent progress (if any delegation is in flight) bus = _get_event_bus() @@ -2737,7 +3209,7 @@ def _format_tools(bus: EventBus, last_n: int) -> str: return "\n".join(lines) - def _format_issues(bus: EventBus) -> str: + def _format_issues(bus: EventBus, health_snapshot: dict[str, Any] | None = None) -> str: """Format retries, stalls, doom loops, and constraint violations.""" lines = [] total = 0 @@ -2787,8 +3259,42 @@ def _format_issues(bus: EventBus) -> str: ago = _format_time_ago(evt.timestamp) lines.append(f" {cid} ({ago}): {desc}") + has_judge_pressure, judge_pressure_desc = _get_recent_judge_pressure(bus) + if has_judge_pressure: + total += 1 + lines.append("Judge pressure detected:") + lines.append(f" {judge_pressure_desc}") + + if health_snapshot: + issue_signals = health_snapshot.get("issue_signals", []) + if isinstance(issue_signals, list) and issue_signals: + total += len(issue_signals) + lines.append("Health signals:") + for signal in issue_signals: + desc = _HEALTH_SIGNAL_DESCRIPTIONS.get(signal, signal.replace("_", " ")) + if ( + signal in {"stalled", "slow_progress"} + and health_snapshot.get("stall_minutes") is not None + ): + desc += f" ({health_snapshot['stall_minutes']} min since last step)" + elif ( + signal in {"long_non_accept_streak", "judge_pressure"} + and health_snapshot.get("steps_since_last_accept") is not None + ): + desc += ( + " (" + f"{health_snapshot['steps_since_last_accept']} non-ACCEPT step(s)" + " since last ACCEPT)" + ) + elif signal == "recent_non_accept_churn" and health_snapshot.get( + "recent_verdicts" + ): + verdicts = ", ".join(health_snapshot["recent_verdicts"][-4:]) + desc += f" ({verdicts})" + lines.append(f" {signal}: {desc}") + if total == 0: - return "No issues detected. No retries, stalls, or constraint violations." + return "No issues detected. No runtime issue signals were found." header = f"{total} issue(s) detected." return header + "\n\n" + "\n".join(lines) @@ -3086,8 +3592,9 @@ async def get_worker_status(focus: str | None = None, last_n: int = 20) -> str: try: if focus is None: # Default: brief prose summary - red_flags = _detect_red_flags(bus) if bus else 0 - return _format_summary(preamble, red_flags) + health_snapshot = _get_worker_health_snapshot() + red_flags = _detect_red_flags(bus, health_snapshot) if bus else 0 + return _format_summary(preamble, red_flags, health_snapshot) if bus is None: return ( @@ -3102,7 +3609,7 @@ async def get_worker_status(focus: str | None = None, last_n: int = 20) -> str: elif focus == "tools": return _format_tools(bus, last_n) elif focus == "issues": - return _format_issues(bus) + return _format_issues(bus, _get_worker_health_snapshot()) elif focus == "progress": return await _format_progress(runtime, bus) elif focus == "full": @@ -3399,14 +3906,6 @@ async def load_built_agent(agent_path: str) -> str: available immediately. The user will see the agent's graph and can interact with it without opening a new tab. """ - runtime = _get_runtime() - if runtime is not None: - try: - await session_manager.unload_worker(manager_session_id) - except Exception as e: - logger.error("Failed to unload existing worker: %s", e, exc_info=True) - return json.dumps({"error": f"Failed to unload existing worker: {e}"}) - try: resolved_path = validate_agent_path(agent_path) except ValueError as e: @@ -3414,6 +3913,19 @@ async def load_built_agent(agent_path: str) -> str: if not resolved_path.exists(): return json.dumps({"error": f"Agent path does not exist: {agent_path}"}) + validation_report = await _run_package_validation(str(resolved_path)) + if _validation_blocks_stage_or_run(validation_report): + failures = _validation_failures(validation_report) + return json.dumps( + { + "error": ( + f"Cannot load agent '{resolved_path.name}' because validation failed. " + "Fix the package and re-run validate_agent_package() before loading." + ), + "validation_failures": failures, + } + ) + # Pre-check: verify the module exports goal/nodes/edges before # attempting the full load. This gives the queen an actionable # error message instead of a cryptic ImportError or TypeError. @@ -3459,6 +3971,14 @@ async def load_built_agent(agent_path: str) -> str: } ) + runtime = _get_runtime() + if runtime is not None: + try: + await session_manager.unload_worker(manager_session_id) + except Exception as e: + logger.error("Failed to unload existing worker: %s", e, exc_info=True) + return json.dumps({"error": f"Failed to unload existing worker: {e}"}) + try: updated_session = await session_manager.load_worker( manager_session_id, @@ -3607,63 +4127,123 @@ async def run_agent_with_input(task: str) -> str: if runtime is None: return json.dumps({"error": "No worker loaded in this session."}) + worker_path = getattr(session, "worker_path", None) + worker_name = Path(worker_path).name if worker_path else "" + validation_report = await _run_package_validation( + str(worker_path) if worker_path else worker_name + ) + if _validation_blocks_stage_or_run(validation_report): + failures = _validation_failures(validation_report) + return json.dumps( + { + "error": ( + f"Cannot run agent '{worker_name or 'current worker'}' because validation " + "is failing. Fix the package and reload it before running." + ), + "validation_failures": failures, + } + ) + try: - # Pre-flight: validate credentials and resync MCP servers. - loop = asyncio.get_running_loop() + await _preflight_worker_run(session, runtime, _START_PREFLIGHT_TIMEOUT) - async def _preflight(): - cred_error: CredentialError | None = None - try: - await loop.run_in_executor( - None, - lambda: validate_credentials( - runtime.graph.nodes, - interactive=False, - skip=False, - ), - ) - except CredentialError as e: - cred_error = e + # Resume timers in case they were paused by a previous stop + runtime.resume_timers() - runner = getattr(session, "runner", None) - if runner: - try: - await loop.run_in_executor( - None, - lambda: runner._tool_registry.resync_mcp_servers_if_needed(), - ) - except Exception as e: - logger.warning("MCP resync failed: %s", e) + exec_id = await runtime.trigger( + entry_point_id="default", + input_data=_build_worker_input_data(runtime, task, session_id=session_id), + # Fresh manual worker runs avoid stale state leaking from a + # previous execution into Codex's next tool/planning turn. + session_state=None, + ) - if cred_error is not None: - raise cred_error + # Switch to running phase + if phase_state is not None: + await phase_state.switch_to_running() + _update_meta_json(session_manager, manager_session_id, {"phase": "running"}) - try: - await asyncio.wait_for(_preflight(), timeout=_START_PREFLIGHT_TIMEOUT) - except TimeoutError: - logger.warning( - "run_agent_with_input preflight timed out after %ds — proceeding", - _START_PREFLIGHT_TIMEOUT, + return json.dumps( + { + "status": "started", + "phase": "running", + "execution_id": exec_id, + "task": task, + } + ) + except CredentialError as e: + error_payload = credential_errors_to_json(e) + error_payload["agent_path"] = str(getattr(session, "worker_path", "") or "") + + bus = getattr(session, "event_bus", None) + if bus is not None: + await bus.publish( + AgentEvent( + type=EventType.CREDENTIALS_REQUIRED, + stream_id="queen", + data=error_payload, + ) ) - except CredentialError: - raise # handled below + return json.dumps(error_payload) + except Exception as e: + return json.dumps({"error": f"Failed to start worker: {e}"}) - # Resume timers in case they were paused by a previous stop - runtime.resume_timers() + async def rerun_worker_with_last_input() -> str: + """Rerun the loaded worker using the last complete structured input payload.""" + runtime = _get_runtime() + if runtime is None: + return json.dumps({"error": "No worker loaded in this session."}) - # Get session state from any prior execution for memory continuity - session_state = runtime._get_primary_session_state("default") or {} + worker_path = getattr(session, "worker_path", None) + worker_name = Path(worker_path).name if worker_path else "" + validation_report = await _run_package_validation( + str(worker_path) if worker_path else worker_name + ) + if _validation_blocks_stage_or_run(validation_report): + failures = _validation_failures(validation_report) + return json.dumps( + { + "error": ( + f"Cannot rerun agent '{worker_name or 'current worker'}' " + "because validation " + "is failing. Fix the package and reload it before running." + ), + "validation_failures": failures, + } + ) - if session_id: - session_state["resume_session_id"] = session_id + allowed_keys = _get_default_entry_input_keys(runtime) + input_data = { + key: _normalize_worker_input_value(key, value) + for key, value in _load_recent_worker_input_defaults( + runtime, + allowed_keys, + session_id=session_id, + ).items() + } + work_keys = [key for key in allowed_keys if key not in {"user_request", "task", "feedback"}] + if work_keys: + missing = [key for key in work_keys if input_data.get(key) in (None, "")] + if missing: + return json.dumps( + { + "error": "No complete previous worker input is available for a " + "same-defaults rerun.", + "missing_inputs": missing, + } + ) + + try: + await _preflight_worker_run(session, runtime, _START_PREFLIGHT_TIMEOUT) + + runtime.resume_timers() exec_id = await runtime.trigger( entry_point_id="default", - input_data={"user_request": task}, - session_state=session_state, + input_data=input_data, + session_state=None, ) - # Switch to running phase if phase_state is not None: await phase_state.switch_to_running() _update_meta_json(session_manager, manager_session_id, {"phase": "running"}) @@ -3673,7 +4253,7 @@ async def _preflight(): "status": "started", "phase": "running", "execution_id": exec_id, - "task": task, + "input_data": input_data, } ) except CredentialError as e: @@ -3691,7 +4271,7 @@ async def _preflight(): ) return json.dumps(error_payload) except Exception as e: - return json.dumps({"error": f"Failed to start worker: {e}"}) + return json.dumps({"error": f"Failed to rerun worker: {e}"}) _run_input_tool = Tool( name="run_agent_with_input", @@ -3716,6 +4296,21 @@ async def _preflight(): ) tools_registered += 1 + _rerun_tool = Tool( + name="rerun_worker_with_last_input", + description=( + "Rerun the loaded worker using the most recent complete structured input payload. " + "Use this when the user asks to run again with the same defaults or same input." + ), + parameters={"type": "object", "properties": {}}, + ) + registry.register( + "rerun_worker_with_last_input", + _rerun_tool, + lambda _inputs: rerun_worker_with_last_input(), + ) + tools_registered += 1 + # --- set_trigger ----------------------------------------------------------- async def set_trigger( diff --git a/core/framework/tools/worker_monitoring_tools.py b/core/framework/tools/worker_monitoring_tools.py index 9882378fd6..5f71058ff3 100644 --- a/core/framework/tools/worker_monitoring_tools.py +++ b/core/framework/tools/worker_monitoring_tools.py @@ -36,6 +36,172 @@ # How many tool_log steps to include in the health summary _DEFAULT_LAST_N_STEPS = 40 +_NON_ACCEPT_VERDICTS = frozenset({"RETRY", "CONTINUE", "ESCALATE"}) + + +def classify_worker_health( + *, + session_status: str, + recent_verdicts: list[str], + steps_since_last_accept: int, + stall_minutes: float | None, +) -> tuple[str, list[str]]: + """Classify worker health from persisted run evidence. + + Keeping this logic at module scope lets Queen-facing status views reuse the + exact same health signals as the monitoring tool instead of drifting into a + separate, weaker interpretation of worker state. + """ + issue_signals: list[str] = [] + + if session_status == "failed": + issue_signals.append("failed_session") + + if stall_minutes is not None: + if stall_minutes >= 5: + issue_signals.append("stalled") + elif stall_minutes >= 2: + issue_signals.append("slow_progress") + + if steps_since_last_accept >= 6: + issue_signals.append("long_non_accept_streak") + elif steps_since_last_accept >= 4: + issue_signals.append("judge_pressure") + + if len(recent_verdicts) >= 4 and all(v in _NON_ACCEPT_VERDICTS for v in recent_verdicts[-4:]): + issue_signals.append("recent_non_accept_churn") + + issue_signals = list(dict.fromkeys(issue_signals)) + + if any(sig in issue_signals for sig in ("failed_session", "stalled", "long_non_accept_streak")): + return "critical", issue_signals + if issue_signals: + return "warning", issue_signals + return "healthy", issue_signals + + +def read_worker_health_snapshot( + storage_path: Path, + *, + session_id: str | None = None, + last_n_steps: int = _DEFAULT_LAST_N_STEPS, + default_session_id: str | None = None, + worker_agent_id: str | None = None, + worker_graph_id: str | None = None, +) -> dict[str, object]: + """Read persisted worker logs and return the structured health snapshot. + + This is the shared source of truth for worker-health reporting. The + monitoring tool returns it as JSON, while Queen's user-facing summaries can + consume the same dict directly to avoid underreporting issue signals. + """ + storage_path = Path(storage_path) + resolved_worker_agent_id = worker_agent_id or storage_path.name + resolved_worker_graph_id = worker_graph_id or storage_path.name + + # Auto-discover the most recent session if not specified. + if not session_id or session_id == "auto": + sessions_dir = storage_path / "sessions" + if not sessions_dir.exists(): + return {"error": "No sessions found — worker has not started yet"} + + if default_session_id and (sessions_dir / default_session_id).is_dir(): + session_id = default_session_id + else: + candidates = [ + d for d in sessions_dir.iterdir() if d.is_dir() and (d / "state.json").exists() + ] + if not candidates: + return {"error": "No sessions found — worker has not started yet"} + + def _sort_key(d: Path): + try: + state = json.loads((d / "state.json").read_text(encoding="utf-8")) + priority = 0 if state.get("status", "") in ("in_progress", "running") else 1 + return (priority, -d.stat().st_mtime) + except Exception: + return (2, 0) + + candidates.sort(key=_sort_key) + session_id = candidates[0].name + + session_dir = storage_path / "sessions" / str(session_id) + tool_logs_path = session_dir / "logs" / "tool_logs.jsonl" + state_path = session_dir / "state.json" + if not session_dir.exists() or not state_path.exists(): + return {"error": f"No persisted worker state found for session '{session_id}'"} + + session_status = "unknown" + if state_path.exists(): + try: + state = json.loads(state_path.read_text(encoding="utf-8")) + session_status = state.get("status", "unknown") + except Exception: + pass + + steps: list[dict] = [] + if tool_logs_path.exists(): + try: + with open(tool_logs_path, encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + try: + steps.append(json.loads(line)) + except json.JSONDecodeError: + continue + except OSError as e: + return {"error": f"Could not read tool logs: {e}"} + + total_steps = len(steps) + recent = steps[-last_n_steps:] if len(steps) > last_n_steps else steps + recent_verdicts = [s.get("verdict", "") for s in recent if s.get("verdict")] + + steps_since_last_accept = 0 + for verdict in reversed(recent_verdicts): + if verdict == "ACCEPT": + break + steps_since_last_accept += 1 + + last_step_time_iso: str | None = None + stall_minutes: float | None = None + if steps and tool_logs_path.exists(): + try: + mtime = tool_logs_path.stat().st_mtime + last_step_time_iso = datetime.fromtimestamp(mtime, UTC).isoformat() + elapsed = (datetime.now(UTC).timestamp() - mtime) / 60 + stall_minutes = round(elapsed, 1) if elapsed >= 1.0 else None + except OSError: + pass + + evidence_snippet = "" + for step in reversed(recent): + text = step.get("llm_text", "") + if text: + evidence_snippet = text[:500] + break + + health_status, issue_signals = classify_worker_health( + session_status=session_status, + recent_verdicts=recent_verdicts, + steps_since_last_accept=steps_since_last_accept, + stall_minutes=stall_minutes, + ) + + return { + "worker_agent_id": resolved_worker_agent_id, + "worker_graph_id": resolved_worker_graph_id, + "session_id": session_id, + "session_status": session_status, + "health_status": health_status, + "issue_signals": issue_signals, + "total_steps": total_steps, + "recent_verdicts": recent_verdicts, + "steps_since_last_accept": steps_since_last_accept, + "last_step_time_iso": last_step_time_iso, + "stall_minutes": stall_minutes, + "evidence_snippet": evidence_snippet, + } def register_worker_monitoring_tools( @@ -91,6 +257,8 @@ async def get_worker_health_summary( Returns a JSON object with: - session_id: the session inspected (useful when auto-discovered) - session_status: "running"|"completed"|"failed"|"in_progress"|"unknown" + - health_status: "healthy"|"warning"|"critical" + - issue_signals: list of detected warning/attention categories - total_steps: total number of log steps recorded so far - recent_verdicts: list of last N verdict strings (ACCEPT/RETRY/CONTINUE/ESCALATE) - steps_since_last_accept: consecutive non-ACCEPT steps from the end @@ -98,120 +266,22 @@ async def get_worker_health_summary( - stall_minutes: wall-clock minutes since last step (null if < 1 min) - evidence_snippet: last LLM text from the most recent step (truncated) """ - # Auto-discover the most recent session if not specified - if not session_id or session_id == "auto": - sessions_dir = storage_path / "sessions" - if not sessions_dir.exists(): - return json.dumps({"error": "No sessions found — worker has not started yet"}) - - # Prefer the queen's own session ID (set at registration time) over - # mtime-based discovery, which can pick a stale orphaned session after - # a cold-restore when a newer-but-empty session directory exists. - if default_session_id and (sessions_dir / default_session_id).is_dir(): - session_id = default_session_id - else: - candidates = [ - d for d in sessions_dir.iterdir() if d.is_dir() and (d / "state.json").exists() - ] - if not candidates: - return json.dumps({"error": "No sessions found — worker has not started yet"}) - - def _sort_key(d: Path): - try: - state = json.loads((d / "state.json").read_text(encoding="utf-8")) - # in_progress/running sorts before completed/failed - priority = 0 if state.get("status", "") in ("in_progress", "running") else 1 - return (priority, -d.stat().st_mtime) - except Exception: - return (2, 0) - - candidates.sort(key=_sort_key) - session_id = candidates[0].name - - # Resolve log paths - session_dir = storage_path / "sessions" / session_id - tool_logs_path = session_dir / "logs" / "tool_logs.jsonl" - state_path = session_dir / "state.json" - - # Read session status - session_status = "unknown" - if state_path.exists(): - try: - state = json.loads(state_path.read_text(encoding="utf-8")) - session_status = state.get("status", "unknown") - except Exception: - pass - - # Read tool logs - steps: list[dict] = [] - if tool_logs_path.exists(): - try: - with open(tool_logs_path, encoding="utf-8") as f: - for line in f: - line = line.strip() - if line: - try: - steps.append(json.loads(line)) - except json.JSONDecodeError: - continue - except OSError as e: - return json.dumps({"error": f"Could not read tool logs: {e}"}) - - total_steps = len(steps) - recent = steps[-last_n_steps:] if len(steps) > last_n_steps else steps - - # Extract verdict sequence - recent_verdicts = [s.get("verdict", "") for s in recent if s.get("verdict")] - - # Count consecutive non-ACCEPT from the end - steps_since_last_accept = 0 - for v in reversed(recent_verdicts): - if v == "ACCEPT": - break - steps_since_last_accept += 1 - - # Timing: use tool_logs file mtime as proxy for last step time - last_step_time_iso: str | None = None - stall_minutes: float | None = None - if steps and tool_logs_path.exists(): - try: - mtime = tool_logs_path.stat().st_mtime - last_step_time_iso = datetime.fromtimestamp(mtime, UTC).isoformat() - elapsed = (datetime.now(UTC).timestamp() - mtime) / 60 - stall_minutes = round(elapsed, 1) if elapsed >= 1.0 else None - except OSError: - pass - - # Evidence snippet: last LLM text - evidence_snippet = "" - for step in reversed(recent): - text = step.get("llm_text", "") - if text: - evidence_snippet = text[:500] - break - - return json.dumps( - { - "worker_agent_id": _worker_agent_id, - "worker_graph_id": _worker_graph_id, - "session_id": session_id, - "session_status": session_status, - "total_steps": total_steps, - "recent_verdicts": recent_verdicts, - "steps_since_last_accept": steps_since_last_accept, - "last_step_time_iso": last_step_time_iso, - "stall_minutes": stall_minutes, - "evidence_snippet": evidence_snippet, - }, - ensure_ascii=False, + snapshot = read_worker_health_snapshot( + storage_path, + session_id=session_id, + last_n_steps=last_n_steps, + default_session_id=default_session_id, + worker_agent_id=_worker_agent_id, + worker_graph_id=_worker_graph_id, ) + return json.dumps(snapshot, ensure_ascii=False) _health_summary_tool = Tool( name="get_worker_health_summary", description=( "Read the worker agent's execution logs and return a compact health snapshot. " "Returns worker_agent_id and worker_graph_id (use these for ticket identity fields), " - "recent verdicts, step count, time since last step, and " + "health_status, issue_signals, recent verdicts, step count, time since last step, and " "a snippet of the most recent LLM output. " "session_id is optional — omit it to auto-discover the most recent active session." ), diff --git a/core/frontend/src/lib/run-inputs.test.ts b/core/frontend/src/lib/run-inputs.test.ts new file mode 100644 index 0000000000..cf68698481 --- /dev/null +++ b/core/frontend/src/lib/run-inputs.test.ts @@ -0,0 +1,135 @@ +import { describe, expect, it } from "vitest"; + +import type { NodeSpec } from "@/api/types"; +import type { GraphNode } from "@/components/graph-types"; + +import { + buildStructuredRunQuestions, + canShowRunButton, + getStructuredRunInputKeys, + hasAllStructuredRunInputs, + trimStructuredRunInputs, +} from "./run-inputs"; + +function makeNodeSpec(overrides: Partial): NodeSpec { + return { + id: "node-1", + name: "Node 1", + description: "", + node_type: "event_loop", + input_keys: [], + output_keys: [], + nullable_output_keys: [], + tools: [], + routes: {}, + max_retries: 0, + max_node_visits: 0, + client_facing: false, + success_criteria: null, + system_prompt: "", + sub_agents: [], + ...overrides, + }; +} + +function makeGraphNode(overrides: Partial): GraphNode { + return { + id: "node-1", + label: "Node 1", + status: "pending", + ...overrides, + }; +} + +describe("getStructuredRunInputKeys", () => { + it("returns structured input keys from the first non-trigger graph node", () => { + const nodeSpecs = [ + makeNodeSpec({ + id: "receive-runtime-inputs", + input_keys: ["target_dir", "review_dir", "word_threshold"], + }), + ]; + const graphNodes = [ + makeGraphNode({ id: "__trigger_default", nodeType: "trigger" }), + makeGraphNode({ id: "receive-runtime-inputs", nodeType: "execution" }), + ]; + + expect(getStructuredRunInputKeys(nodeSpecs, graphNodes)).toEqual([ + "target_dir", + "review_dir", + "word_threshold", + ]); + }); + + it("filters out generic task-style entry keys", () => { + const nodeSpecs = [ + makeNodeSpec({ + id: "entry", + input_keys: ["user_request", "task", "feedback", "target_dir"], + }), + ]; + + expect(getStructuredRunInputKeys(nodeSpecs, [])).toEqual(["target_dir"]); + }); +}); + +describe("hasAllStructuredRunInputs", () => { + it("requires every structured key to be present and non-blank", () => { + expect( + hasAllStructuredRunInputs(["target_dir", "word_threshold"], { + target_dir: "/tmp/project", + word_threshold: "800", + }), + ).toBe(true); + + expect( + hasAllStructuredRunInputs(["target_dir", "word_threshold"], { + target_dir: " ", + word_threshold: "800", + }), + ).toBe(false); + + expect( + hasAllStructuredRunInputs(["target_dir", "word_threshold"], { + target_dir: "/tmp/project", + }), + ).toBe(false); + }); +}); + +describe("buildStructuredRunQuestions", () => { + it("creates free-text prompts for each required run input", () => { + expect(buildStructuredRunQuestions(["target_dir", "review_dir"])).toEqual([ + { id: "target_dir", prompt: "Provide target_dir for this run." }, + { id: "review_dir", prompt: "Provide review_dir for this run." }, + ]); + }); +}); + +describe("canShowRunButton", () => { + it("only exposes Run when a worker session is ready and staged/running", () => { + expect(canShowRunButton("sess-1", true, "staging", true)).toBe(true); + expect(canShowRunButton("sess-1", true, "running", true)).toBe(true); + + expect(canShowRunButton("sess-1", true, "planning", true)).toBe(false); + expect(canShowRunButton("sess-1", true, "building", true)).toBe(false); + expect(canShowRunButton("sess-1", false, "staging", true)).toBe(false); + expect(canShowRunButton("sess-1", true, "staging", false)).toBe(false); + expect(canShowRunButton(null, true, "staging", true)).toBe(false); + }); +}); + +describe("trimStructuredRunInputs", () => { + it("drops stale keys that are no longer part of the current schema", () => { + expect( + trimStructuredRunInputs(["target_dir", "word_threshold"], { + target_dir: "/tmp/project", + word_threshold: 800, + stale_key: "old", + }), + ).toEqual({ + target_dir: "/tmp/project", + word_threshold: 800, + }); + }); +}); diff --git a/core/frontend/src/lib/run-inputs.ts b/core/frontend/src/lib/run-inputs.ts new file mode 100644 index 0000000000..23dc1477ba --- /dev/null +++ b/core/frontend/src/lib/run-inputs.ts @@ -0,0 +1,56 @@ +import type { NodeSpec } from "@/api/types"; +import type { GraphNode } from "@/components/graph-types"; + +const GENERIC_ENTRY_KEYS = new Set(["task", "user_request", "feedback"]); +const RUNNABLE_PHASES = new Set(["staging", "running"]); + +type QueenPhase = "planning" | "building" | "staging" | "running"; + +function isMeaningfulValue(value: unknown): boolean { + if (typeof value === "string") return value.trim().length > 0; + return value !== undefined && value !== null; +} + +export function getStructuredRunInputKeys( + nodeSpecs: NodeSpec[], + graphNodes: GraphNode[], +): string[] { + const entryNodeId = + graphNodes.find((node) => node.nodeType !== "trigger")?.id ?? nodeSpecs[0]?.id; + if (!entryNodeId) return []; + + const entrySpec = nodeSpecs.find((node) => node.id === entryNodeId) ?? nodeSpecs[0]; + return (entrySpec?.input_keys ?? []).filter((key) => !GENERIC_ENTRY_KEYS.has(key)); +} + +export function hasAllStructuredRunInputs( + keys: string[], + inputData: Record | null | undefined, +): inputData is Record { + if (!inputData) return false; + return keys.every((key) => isMeaningfulValue(inputData[key])); +} + +export function buildStructuredRunQuestions(keys: string[]) { + return keys.map((key) => ({ + id: key, + prompt: `Provide ${key} for this run.`, + })); +} + +export function trimStructuredRunInputs( + keys: string[], + inputData: Record | null | undefined, +): Record { + if (!inputData) return {}; + return Object.fromEntries(keys.flatMap((key) => (key in inputData ? [[key, inputData[key]]] : []))); +} + +export function canShowRunButton( + sessionId: string | null | undefined, + ready: boolean | null | undefined, + queenPhase: QueenPhase | null | undefined, + topologyReady: boolean, +): boolean { + return Boolean(sessionId && ready && topologyReady && queenPhase && RUNNABLE_PHASES.has(queenPhase)); +} diff --git a/core/frontend/src/pages/workspace.tsx b/core/frontend/src/pages/workspace.tsx index 25d396228c..8f4297fbd2 100644 --- a/core/frontend/src/pages/workspace.tsx +++ b/core/frontend/src/pages/workspace.tsx @@ -18,6 +18,13 @@ import type { LiveSession, AgentEvent, DiscoverEntry, NodeSpec, DraftGraph as Dr import { sseEventToChatMessage, formatAgentDisplayName } from "@/lib/chat-helpers"; import { topologyToGraphNodes } from "@/lib/graph-converter"; import { cronToLabel } from "@/lib/graphUtils"; +import { + buildStructuredRunQuestions, + canShowRunButton, + getStructuredRunInputKeys, + hasAllStructuredRunInputs, + trimStructuredRunInputs, +} from "@/lib/run-inputs"; import { ApiError } from "@/api/client"; const makeId = () => Math.random().toString(36).slice(2, 9); @@ -351,7 +358,9 @@ interface AgentBackendState { /** Multiple questions from ask_user_multiple */ pendingQuestions: { id: string; prompt: string; options?: string[] }[] | null; /** Whether the pending question came from queen or worker */ - pendingQuestionSource: "queen" | "worker" | null; + pendingQuestionSource: "queen" | "worker" | "run" | null; + /** Last structured input payload successfully used to start the worker. */ + lastRunInputData: Record | null; /** Per-node context window usage (from context_usage_updated events) */ contextUsage: Record; /** Whether the queen's LLM supports image content — false disables the attach button */ @@ -393,6 +402,7 @@ function defaultAgentState(): AgentBackendState { pendingOptions: null, pendingQuestions: null, pendingQuestionSource: null, + lastRunInputData: null, contextUsage: {}, queenSupportsImages: true, }; @@ -693,15 +703,71 @@ export default function Workspace() { } }, [sessionsByAgent, activeSessionByAgent, activeWorker, agentStates]); + const appendSystemMessage = useCallback((agentType: string, content: string) => { + setSessionsByAgent((prev) => { + const sessions = prev[agentType] || []; + const activeId = activeSessionRef.current[agentType] || sessions[0]?.id; + return { + ...prev, + [agentType]: sessions.map((s) => { + if (s.id !== activeId) return s; + const errorMsg: ChatMessage = { + id: makeId(), + agent: "System", + agentColor: "", + content, + timestamp: "", + type: "system", + thread: agentType, + createdAt: Date.now(), + }; + return { ...s, messages: [...s.messages, errorMsg] }; + }), + }; + }); + }, []); + const handleRun = useCallback(async () => { const state = agentStates[activeWorker]; if (!state?.sessionId || !state?.ready) return; + + const sessions = sessionsRef.current[activeWorker] || []; + const activeId = activeSessionRef.current[activeWorker] || sessions[0]?.id; + const activeSession = sessions.find((s) => s.id === activeId) || sessions[0]; + const requiredRunKeys = getStructuredRunInputKeys( + state.nodeSpecs, + activeSession?.graphNodes || [], + ); + + if ( + requiredRunKeys.length > 0 && + !hasAllStructuredRunInputs(requiredRunKeys, state.lastRunInputData) + ) { + updateAgentState(activeWorker, { + awaitingInput: true, + pendingQuestion: null, + pendingOptions: null, + pendingQuestions: buildStructuredRunQuestions(requiredRunKeys), + pendingQuestionSource: "run", + workerRunState: "idle", + }); + return; + } + + const inputData = + requiredRunKeys.length > 0 + ? trimStructuredRunInputs(requiredRunKeys, state.lastRunInputData) + : {}; + // Reset dismissed banner so a repeated 424 re-shows it setDismissedBanner(null); try { updateAgentState(activeWorker, { workerRunState: "deploying" }); - const result = await executionApi.trigger(state.sessionId, "default", {}); - updateAgentState(activeWorker, { currentExecutionId: result.execution_id }); + const result = await executionApi.trigger(state.sessionId, "default", inputData); + updateAgentState(activeWorker, { + currentExecutionId: result.execution_id, + lastRunInputData: inputData, + }); } catch (err) { // 424 = credentials required — open the credentials modal if (err instanceof ApiError && err.status === 424) { @@ -714,25 +780,23 @@ export default function Workspace() { } const errMsg = err instanceof Error ? err.message : String(err); - setSessionsByAgent((prev) => { - const sessions = prev[activeWorker] || []; - const activeId = activeSessionRef.current[activeWorker] || sessions[0]?.id; - return { - ...prev, - [activeWorker]: sessions.map((s) => { - if (s.id !== activeId) return s; - const errorMsg: ChatMessage = { - id: makeId(), agent: "System", agentColor: "", - content: `Failed to trigger run: ${errMsg}`, - timestamp: "", type: "system", thread: activeWorker, createdAt: Date.now(), - }; - return { ...s, messages: [...s.messages, errorMsg] }; - }), - }; - }); + appendSystemMessage(activeWorker, `Failed to trigger run: ${errMsg}`); updateAgentState(activeWorker, { workerRunState: "idle" }); } - }, [agentStates, activeWorker, updateAgentState]); + }, [agentStates, activeWorker, appendSystemMessage, updateAgentState]); + + const canRunLoadedWorker = canShowRunButton( + activeAgentState?.sessionId, + activeAgentState?.ready, + activeAgentState?.queenPhase, + Boolean( + activeAgentState?.nodeSpecs?.length || + sessionsByAgent[activeWorker]?.some( + (session) => + session.id === activeAgentState?.sessionId && session.graphNodes.length > 0, + ), + ), + ); // --- Fetch discovered agents for NewTabPopover --- const [discoverAgents, setDiscoverAgents] = useState([]); @@ -2826,6 +2890,55 @@ export default function Workspace() { // --- handleMultiQuestionAnswer: submit answers to ask_user_multiple --- const handleMultiQuestionAnswer = useCallback((answers: Record) => { + const state = agentStates[activeWorker]; + if (state?.pendingQuestionSource === "run") { + if (!state.sessionId || !state.ready) return; + updateAgentState(activeWorker, { + pendingQuestion: null, + pendingOptions: null, + pendingQuestions: null, + pendingQuestionSource: null, + awaitingInput: false, + workerRunState: "deploying", + }); + const requiredRunKeys = getStructuredRunInputKeys( + state.nodeSpecs, + sessionsRef.current[activeWorker]?.find((s) => s.id === state.sessionId)?.graphNodes || [], + ); + const trimmedAnswers = trimStructuredRunInputs(requiredRunKeys, answers); + executionApi.trigger(state.sessionId, "default", trimmedAnswers).then((result) => { + updateAgentState(activeWorker, { + currentExecutionId: result.execution_id, + lastRunInputData: trimmedAnswers, + }); + }).catch((err: unknown) => { + if (err instanceof ApiError && err.status === 424) { + const errBody = err.body as Record; + const credPath = (errBody?.agent_path as string) || null; + if (credPath) setCredentialAgentPath(credPath); + updateAgentState(activeWorker, { + workerRunState: "idle", + error: "credentials_required", + lastRunInputData: trimmedAnswers, + }); + setCredentialsOpen(true); + return; + } + + const errMsg = err instanceof Error ? err.message : String(err); + appendSystemMessage(activeWorker, `Failed to trigger run: ${errMsg}`); + updateAgentState(activeWorker, { + workerRunState: "idle", + awaitingInput: true, + pendingQuestion: null, + pendingOptions: null, + pendingQuestions: buildStructuredRunQuestions(requiredRunKeys), + pendingQuestionSource: "run", + }); + }); + return; + } + updateAgentState(activeWorker, { pendingQuestion: null, pendingOptions: null, pendingQuestions: null, pendingQuestionSource: null, @@ -2835,7 +2948,7 @@ export default function Workspace() { ([id, answer]) => `[${id}]: ${answer}`, ); handleSend(lines.join("\n"), activeWorker); - }, [activeWorker, handleSend, updateAgentState]); + }, [activeWorker, agentStates, appendSystemMessage, handleSend, updateAgentState]); // --- handleQuestionDismiss: user closed the question widget without answering --- // Injects a dismiss signal so the blocked node can continue. @@ -2854,6 +2967,11 @@ export default function Workspace() { awaitingInput: false, }); + if (source === "run") { + updateAgentState(activeWorker, { workerRunState: "idle" }); + return; + } + // Unblock the waiting node with a dismiss signal const dismissMsg = `[User dismissed the question: "${question}"]`; if (source === "worker") { @@ -3145,7 +3263,7 @@ export default function Workspace() { : null } building={activeAgentState?.queenBuilding} - onRun={handleRun} + onRun={canRunLoadedWorker ? handleRun : undefined} onPause={handlePause} runState={activeAgentState?.workerRunState ?? "idle"} flowchartMap={activeAgentState?.flowchartMap ?? undefined} diff --git a/core/tests/test_codex_control_parity.py b/core/tests/test_codex_control_parity.py new file mode 100644 index 0000000000..ed238f0500 --- /dev/null +++ b/core/tests/test_codex_control_parity.py @@ -0,0 +1,306 @@ +from __future__ import annotations + +import asyncio +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from framework.graph.event_loop_node import EventLoopNode, LoopConfig +from framework.graph.node import NodeContext, NodeSpec, SharedMemory +from framework.llm.provider import LLMProvider +from framework.llm.stream_events import FinishEvent, TextDeltaEvent, ToolCallEvent +from framework.runtime.event_bus import AgentEvent, EventBus, EventType +from framework.server.queen_orchestrator import _client_input_counts_as_planning_ask +from framework.tools.queen_lifecycle_tools import QueenPhaseState + + +class MockStreamingLLM(LLMProvider): + """Minimal streaming LLM for Codex-vs-control parity checks.""" + + def __init__(self, scenarios: list[list[Any]] | None = None): + self.scenarios = scenarios or [] + self._call_index = 0 + + async def stream( + self, + messages: list[dict[str, Any]], + system: str = "", + tools=None, + max_tokens: int = 4096, + ): + if not self.scenarios: + return + events = self.scenarios[self._call_index % len(self.scenarios)] + self._call_index += 1 + for event in events: + yield event + + def complete(self, messages, system="", **kwargs): + raise NotImplementedError + + +def text_scenario(text: str) -> list[Any]: + return [ + TextDeltaEvent(content=text, snapshot=text), + FinishEvent(stop_reason="stop", input_tokens=10, output_tokens=5, model="mock"), + ] + + +def tool_call_scenario( + tool_name: str, + tool_input: dict[str, Any], + *, + tool_use_id: str = "call_1", + preamble_text: str = "", +) -> list[Any]: + events: list[Any] = [] + if preamble_text: + events.append(TextDeltaEvent(content=preamble_text, snapshot=preamble_text)) + events.append( + ToolCallEvent( + tool_use_id=tool_use_id, + tool_name=tool_name, + tool_input=tool_input, + ) + ) + events.append( + FinishEvent( + stop_reason="tool_calls", + input_tokens=10, + output_tokens=5, + model="mock", + ) + ) + return events + + +def build_ctx(spec: NodeSpec, llm: LLMProvider, *, stream_id: str) -> NodeContext: + runtime = MagicMock() + runtime.start_run = MagicMock(return_value=f"session_{stream_id}") + runtime.decide = MagicMock(return_value="dec_1") + runtime.record_outcome = MagicMock() + runtime.end_run = MagicMock() + runtime.report_problem = MagicMock() + runtime.set_node = MagicMock() + return NodeContext( + runtime=runtime, + node_id=spec.id, + node_spec=spec, + memory=SharedMemory(), + input_data={}, + llm=llm, + available_tools=[], + stream_id=stream_id, + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("style", "first_turn"), + [ + ( + "control", + tool_call_scenario( + "ask_user", + { + "question": "What kind of agent should I design for you?", + "options": ["Summarizer"], + }, + tool_use_id="ask_1", + ), + ), + ( + "codex", + text_scenario("What kind of agent should I design for you?"), + ), + ], +) +async def test_codex_and_control_styles_both_count_toward_planning_gate( + style: str, + first_turn: list[Any], +) -> None: + bus = EventBus() + phase_state = QueenPhaseState(phase="planning", event_bus=bus) + received: list[AgentEvent] = [] + + async def capture(event: AgentEvent) -> None: + received.append(event) + if _client_input_counts_as_planning_ask(event): + phase_state.planning_ask_rounds += 1 + + bus.subscribe([EventType.CLIENT_INPUT_REQUESTED], capture, filter_stream="queen") + + spec = NodeSpec( + id="queen", + name="Queen", + description="planning orchestrator", + node_type="event_loop", + client_facing=True, + output_keys=[], + skip_judge=True, + ) + llm = MockStreamingLLM(scenarios=[first_turn]) + node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5)) + ctx = build_ctx(spec, llm, stream_id="queen") + + async def shutdown_after_first_block() -> None: + await asyncio.sleep(0.05) + node.signal_shutdown() + + task = asyncio.create_task(shutdown_after_first_block()) + result = await node.execute(ctx) + await task + + assert result.success is True + assert phase_state.planning_ask_rounds == 1 + assert received + if style == "control": + assert received[0].data["prompt"] == "What kind of agent should I design for you?" + assert received[0].data.get("auto_blocked") is not True + else: + assert received[0].data["prompt"] == "" + assert received[0].data["auto_blocked"] is True + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("style", "scenarios"), + [ + ( + "control", + [ + tool_call_scenario( + "ask_user", + { + "question": "Paste old and new policy text.", + "options": ["I'll paste both now"], + }, + tool_use_id="ask_1", + ), + tool_call_scenario( + "set_output", + { + "key": "important_changes", + "value": "- Remote days increased from 2 to 4", + }, + tool_use_id="set_1", + ), + ], + ), + ( + "codex", + [ + text_scenario("Paste old and new policy text."), + tool_call_scenario( + "set_output", + { + "key": "important_changes", + "value": "- Remote days increased from 2 to 4", + }, + tool_use_id="set_1", + ), + ], + ), + ], +) +async def test_codex_and_control_styles_complete_same_human_in_loop_run( + style: str, + scenarios: list[list[Any]], +) -> None: + spec = NodeSpec( + id=f"policy_diff_{style}", + name="Policy Diff Worker", + description="Compare two policy versions", + node_type="event_loop", + output_keys=["important_changes"], + client_facing=True, + ) + llm = MockStreamingLLM(scenarios=scenarios) + node = EventLoopNode(config=LoopConfig(max_iterations=6)) + ctx = build_ctx(spec, llm, stream_id=f"worker_{style}") + + async def user_responds() -> None: + await asyncio.sleep(0.05) + await node.inject_event("Old policy ... New policy ...") + + task = asyncio.create_task(user_responds()) + result = await node.execute(ctx) + await task + + assert result.success is True + assert result.output["important_changes"] == "- Remote days increased from 2 to 4" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("style", "scenario"), + [ + ( + "control", + tool_call_scenario( + "ask_user", + {"question": "What would you like to do next?", "options": ["Rerun", "Stop"]}, + tool_use_id="ask_1", + preamble_text="Root cause: checkout is failing because the DB pool is exhausted.", + ), + ), + ( + "codex", + tool_call_scenario( + "ask_user", + { + "question": ( + "Root cause: checkout is failing because the DB pool is exhausted.\n\n" + "What would you like to do next?" + ), + "options": ["Rerun", "Stop"], + }, + tool_use_id="ask_1", + ), + ), + ], +) +async def test_codex_and_control_styles_surface_result_before_followup_widget( + style: str, + scenario: list[Any], +) -> None: + spec = NodeSpec( + id=f"queen_{style}", + name="Queen", + description="orchestrator", + node_type="event_loop", + client_facing=True, + output_keys=[], + skip_judge=True, + ) + llm = MockStreamingLLM(scenarios=[scenario]) + bus = EventBus() + received: list[AgentEvent] = [] + + async def capture(event: AgentEvent) -> None: + received.append(event) + + bus.subscribe( + event_types=[EventType.CLIENT_OUTPUT_DELTA, EventType.CLIENT_INPUT_REQUESTED], + handler=capture, + ) + + node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5)) + ctx = build_ctx(spec, llm, stream_id="queen") + + async def shutdown() -> None: + await asyncio.sleep(0.05) + node.signal_shutdown() + + task = asyncio.create_task(shutdown()) + await node.execute(ctx) + await task + + output_events = [e for e in received if e.type == EventType.CLIENT_OUTPUT_DELTA] + input_events = [e for e in received if e.type == EventType.CLIENT_INPUT_REQUESTED] + + assert output_events + assert input_events + assert "DB pool is exhausted" in output_events[0].data["snapshot"] + assert input_events[0].data["prompt"] == "What would you like to do next?" diff --git a/core/tests/test_codex_parity_gate.py b/core/tests/test_codex_parity_gate.py new file mode 100644 index 0000000000..cc21ecd7eb --- /dev/null +++ b/core/tests/test_codex_parity_gate.py @@ -0,0 +1,448 @@ +from __future__ import annotations + +import asyncio +import json +from dataclasses import dataclass, field +from pathlib import Path +from types import SimpleNamespace +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest +from aiohttp.test_utils import TestClient, TestServer + +import framework.tools.queen_lifecycle_tools as qlt +from framework.graph.event_loop_node import EventLoopNode, LoopConfig +from framework.graph.node import NodeContext, NodeSpec, SharedMemory +from framework.llm.provider import LLMProvider, Tool +from framework.llm.stream_events import FinishEvent, TextDeltaEvent, ToolCallEvent +from framework.runner.tool_registry import ToolRegistry +from framework.runtime.event_bus import AgentEvent, EventBus, EventType +from framework.server.app import create_app, validate_agent_path +from framework.server.session_manager import ( + Session, + _run_validation_report_sync, + _validation_blocks_stage_or_run, +) +from framework.tools.queen_lifecycle_tools import QueenPhaseState, register_queen_lifecycle_tools + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +class MockStreamingLLM(LLMProvider): + """Minimal streaming LLM for parity-gate regressions.""" + + def __init__(self, scenarios: list[list[Any]] | None = None): + self.scenarios = scenarios or [] + self._call_index = 0 + + async def stream( + self, + messages: list[dict[str, Any]], + system: str = "", + tools=None, + max_tokens: int = 4096, + ): + if not self.scenarios: + return + events = self.scenarios[self._call_index % len(self.scenarios)] + self._call_index += 1 + for event in events: + yield event + + def complete(self, messages, system="", **kwargs): + raise NotImplementedError + + +def text_scenario(text: str) -> list[Any]: + return [ + TextDeltaEvent(content=text, snapshot=text), + FinishEvent(stop_reason="stop", input_tokens=10, output_tokens=5, model="mock"), + ] + + +def tool_call_scenario( + tool_name: str, + tool_input: dict[str, Any], + *, + tool_use_id: str = "call_1", +) -> list[Any]: + return [ + ToolCallEvent(tool_use_id=tool_use_id, tool_name=tool_name, tool_input=tool_input), + FinishEvent(stop_reason="tool_calls", input_tokens=10, output_tokens=5, model="mock"), + ] + + +def build_ctx( + spec: NodeSpec, + llm: LLMProvider, + *, + stream_id: str = "worker", + input_data: dict[str, Any] | None = None, +) -> NodeContext: + runtime = MagicMock() + runtime.start_run = MagicMock(return_value="session_codex_parity") + runtime.decide = MagicMock(return_value="dec_1") + runtime.record_outcome = MagicMock() + runtime.end_run = MagicMock() + runtime.report_problem = MagicMock() + runtime.set_node = MagicMock() + return NodeContext( + runtime=runtime, + node_id=spec.id, + node_spec=spec, + memory=SharedMemory(), + input_data=input_data or {}, + llm=llm, + available_tools=[], + stream_id=stream_id, + ) + + +@pytest.mark.parametrize( + "agent_ref", + [ + "examples/templates/tech_news_reporter", + "examples/templates/vulnerability_assessment", + ], +) +def test_codex_parity_existing_templates_validate_for_stage_run(agent_ref: str) -> None: + """Existing checked-in agents should pass the shared stage/run gate.""" + resolved = validate_agent_path(agent_ref) + report = _run_validation_report_sync(agent_ref) + + assert resolved.is_dir() + assert report.get("valid") is True + assert _validation_blocks_stage_or_run(report) is False + + +@pytest.mark.asyncio +async def test_codex_parity_local_only_human_in_loop_run_completes() -> None: + """A local-only client-facing worker flow should complete end to end.""" + spec = NodeSpec( + id="policy_diff_worker", + name="Policy Diff Worker", + description="Compare two policy versions", + node_type="event_loop", + output_keys=["important_changes"], + client_facing=True, + ) + llm = MockStreamingLLM( + scenarios=[ + tool_call_scenario( + "ask_user", + {"question": "Paste old and new policy text.", "options": ["I'll paste both now"]}, + tool_use_id="ask_1", + ), + tool_call_scenario( + "set_output", + { + "key": "important_changes", + "value": ( + "- Remote days increased from 2 to 4\n" + "- Security training increased from annual to twice yearly" + ), + }, + tool_use_id="set_1", + ), + ] + ) + + node = EventLoopNode(config=LoopConfig(max_iterations=6)) + ctx = build_ctx(spec, llm, stream_id="worker") + + async def user_responds() -> None: + await asyncio.sleep(0.05) + await node.inject_event("Old policy ... New policy ...") + + task = asyncio.create_task(user_responds()) + result = await node.execute(ctx) + await task + + assert result.success is True + assert "Remote days increased" in result.output["important_changes"] + + +@pytest.mark.asyncio +async def test_codex_parity_result_is_visible_before_followup_widget() -> None: + """Long result-bearing queen prompts should stream the result before the widget.""" + spec = NodeSpec( + id="queen", + name="Queen", + description="orchestrator", + node_type="event_loop", + client_facing=True, + output_keys=[], + skip_judge=True, + ) + llm = MockStreamingLLM( + scenarios=[ + tool_call_scenario( + "ask_user", + { + "question": ( + "Root cause: checkout is failing because the DB pool is exhausted.\n\n" + "What would you like to do next?" + ), + "options": ["Rerun", "Stop"], + }, + tool_use_id="ask_1", + ) + ] + ) + bus = EventBus() + received: list[AgentEvent] = [] + + async def capture(event: AgentEvent) -> None: + received.append(event) + + bus.subscribe( + event_types=[EventType.CLIENT_OUTPUT_DELTA, EventType.CLIENT_INPUT_REQUESTED], + handler=capture, + ) + + node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5)) + ctx = build_ctx(spec, llm, stream_id="queen") + + async def shutdown() -> None: + await asyncio.sleep(0.05) + node.signal_shutdown() + + task = asyncio.create_task(shutdown()) + await node.execute(ctx) + await task + + output_events = [e for e in received if e.type == EventType.CLIENT_OUTPUT_DELTA] + input_events = [e for e in received if e.type == EventType.CLIENT_INPUT_REQUESTED] + + assert output_events + assert input_events + assert "DB pool is exhausted" in output_events[0].data["snapshot"] + assert input_events[0].data["prompt"] == "What would you like to do next?" + + +@pytest.mark.asyncio +async def test_codex_parity_rerun_reuses_complete_recent_defaults( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path +) -> None: + """Rerun should keep structured inputs stable instead of relying on text reconstruction.""" + registry = ToolRegistry() + registry.register( + "validate_agent_package", + Tool( + name="validate_agent_package", + description="fake validator", + parameters={"type": "object", "properties": {"agent_name": {"type": "string"}}}, + ), + lambda _inputs: json.dumps({"valid": True, "steps": {}}), + ) + + monkeypatch.setattr(qlt, "validate_credentials", lambda *args, **kwargs: None) + monkeypatch.chdir(tmp_path) + + sessions_dir = tmp_path / "agent_store" / "sessions" + sessions_dir.mkdir(parents=True) + + valid_prior_state = { + "timestamps": {"updated_at": "2026-03-24T20:44:00"}, + "input_data": { + "target_dir": "docs", + "review_dir": "docs_reviews", + "word_threshold": 800, + }, + } + malformed_recent_state = { + "timestamps": {"updated_at": "2026-03-24T21:20:23"}, + "input_data": { + "review_dir": "docs_reviews", + "word_threshold": "800. Validate inputs and continue.", + }, + } + + for session_name, state in { + "session_20260324_204400_good": valid_prior_state, + "session_20260324_212023_bad": malformed_recent_state, + }.items(): + session_dir = sessions_dir / session_name + session_dir.mkdir() + (session_dir / "state.json").write_text(json.dumps(state), encoding="utf-8") + + runtime = SimpleNamespace( + _session_store=SimpleNamespace(sessions_dir=sessions_dir), + resume_timers=MagicMock(), + trigger=AsyncMock(return_value="exec-rerun"), + graph=SimpleNamespace( + nodes=[], + entry_node="process", + get_node=lambda node_id: ( + SimpleNamespace(input_keys=["target_dir", "review_dir", "word_threshold"]) + if node_id == "process" + else None + ), + ), + get_entry_points=lambda: [SimpleNamespace(id="default", entry_node="process")], + ) + session = SimpleNamespace( + worker_runtime=runtime, + event_bus=None, + worker_path=Path("exports/local_markdown_review_probe_2"), + runner=None, + ) + register_queen_lifecycle_tools( + registry, + session=session, + session_id="sess-rerun", + phase_state=QueenPhaseState(phase="staging"), + ) + + result_raw = await registry._tools["rerun_worker_with_last_input"].executor({}) + result = json.loads(result_raw) + + assert result["status"] == "started" + runtime.trigger.assert_awaited_once() + assert runtime.trigger.await_args.kwargs["input_data"] == { + "target_dir": str((tmp_path / "docs").resolve()), + "review_dir": str((tmp_path / "docs_reviews").resolve()), + "word_threshold": 800, + } + assert runtime.trigger.await_args.kwargs["session_state"] is None + + +@dataclass +class _MockEntryPoint: + id: str = "default" + name: str = "Default" + entry_node: str = "start" + trigger_type: str = "manual" + trigger_config: dict = field(default_factory=dict) + + +@dataclass +class _MockStream: + is_awaiting_input: bool = False + _execution_tasks: dict = field(default_factory=dict) + _active_executors: dict = field(default_factory=dict) + active_execution_ids: set = field(default_factory=set) + + async def cancel_execution(self, execution_id: str) -> bool: + return execution_id in self._execution_tasks + + +@dataclass +class _MockGraphRegistration: + graph: Any = field(default_factory=lambda: SimpleNamespace(nodes=[], edges=[], entry_node="")) + streams: dict = field(default_factory=dict) + entry_points: dict = field(default_factory=dict) + + +class _MockRuntime: + def __init__(self): + self._entry_points = [_MockEntryPoint()] + self._mock_streams = {"default": _MockStream()} + self._registration = _MockGraphRegistration( + streams=self._mock_streams, + entry_points={"default": self._entry_points[0]}, + ) + + def list_graphs(self): + return ["primary"] + + def get_graph_registration(self, graph_id): + if graph_id == "primary": + return self._registration + return None + + def get_entry_points(self): + return self._entry_points + + async def trigger(self, ep_id, input_data=None, session_state=None): + return "exec_test_123" + + async def inject_input(self, node_id, content, graph_id=None, *, is_client_input=False): + return True + + def pause_timers(self): + pass + + async def get_goal_progress(self): + return {"progress": 0.5, "criteria": []} + + def find_awaiting_node(self): + return None, None + + def get_stats(self): + return {"running": True, "executions": 1} + + def get_timer_next_fire_in(self, ep_id): + return None + + +def _make_queen_executor(): + mock_node = MagicMock() + mock_node.inject_event = AsyncMock() + executor = MagicMock() + executor.node_registry = {"queen": mock_node} + return executor + + +def _make_session(agent_id="test_agent") -> Session: + runner = MagicMock() + runner.intro_message = "Test intro" + return Session( + id=agent_id, + event_bus=EventBus(), + llm=MagicMock(), + loaded_at=1000000.0, + queen_executor=_make_queen_executor(), + worker_id=agent_id, + worker_path=Path("/tmp/test_agent"), + runner=runner, + worker_runtime=_MockRuntime(), + worker_info=SimpleNamespace( + name="test_agent", + description="A test agent", + goal_name="test_goal", + node_count=2, + ), + worker_validation_report={"valid": True, "steps": {}}, + worker_validation_failures=[], + ) + + +def _make_app_with_session(session: Session): + app = create_app() + mgr = app["manager"] + mgr._sessions[session.id] = session + return app + + +@pytest.mark.asyncio +async def test_codex_parity_done_for_now_parks_queen_without_new_followup() -> None: + """Terminal stop choices should acknowledge once and park the queen.""" + session = _make_session() + session.event_bus.get_history = MagicMock( + return_value=[ + AgentEvent( + type=EventType.CLIENT_INPUT_REQUESTED, + stream_id="queen", + node_id="queen", + execution_id=session.id, + data={"options": ["Run again with same input", "Done for now"]}, + ) + ] + ) + session.event_bus.emit_client_output_delta = AsyncMock() + app = _make_app_with_session(session) + + async with TestClient(TestServer(app)) as client: + resp = await client.post( + "/api/sessions/test_agent/chat", + json={"message": "No, stop here"}, + ) + assert resp.status == 200 + data = await resp.json() + assert data["status"] == "queen" + assert data["delivered"] is True + + assert session.queen_executor is None + session.event_bus.emit_client_output_delta.assert_awaited_once() diff --git a/core/tests/test_codex_planning_phase.py b/core/tests/test_codex_planning_phase.py new file mode 100644 index 0000000000..e2c863415d --- /dev/null +++ b/core/tests/test_codex_planning_phase.py @@ -0,0 +1,194 @@ +from __future__ import annotations + +import asyncio +import json +from types import SimpleNamespace +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from framework.graph.event_loop_node import EventLoopNode, LoopConfig +from framework.graph.node import NodeContext, NodeSpec, SharedMemory +from framework.llm.provider import LLMProvider +from framework.llm.stream_events import FinishEvent, TextDeltaEvent +from framework.runner.tool_registry import ToolRegistry +from framework.runtime.event_bus import AgentEvent, EventBus, EventType +from framework.server.queen_orchestrator import _client_input_counts_as_planning_ask +from framework.tools.queen_lifecycle_tools import QueenPhaseState, register_queen_lifecycle_tools + + +class MockStreamingLLM(LLMProvider): + """Minimal streaming LLM for planning-phase regression tests.""" + + def __init__(self, scenarios: list[list[Any]] | None = None): + self.scenarios = scenarios or [] + self._call_index = 0 + + async def stream( + self, + messages: list[dict[str, Any]], + system: str = "", + tools=None, + max_tokens: int = 4096, + ): + if not self.scenarios: + return + events = self.scenarios[self._call_index % len(self.scenarios)] + self._call_index += 1 + for event in events: + yield event + + def complete(self, messages, system="", **kwargs): + raise NotImplementedError + + +def text_scenario(text: str) -> list[Any]: + return [ + TextDeltaEvent(content=text, snapshot=text), + FinishEvent(stop_reason="stop", input_tokens=10, output_tokens=5, model="mock"), + ] + + +def build_ctx(spec: NodeSpec, llm: LLMProvider) -> NodeContext: + runtime = MagicMock() + runtime.start_run = MagicMock(return_value="session_codex_planning") + runtime.decide = MagicMock(return_value="dec_1") + runtime.record_outcome = MagicMock() + runtime.end_run = MagicMock() + runtime.report_problem = MagicMock() + runtime.set_node = MagicMock() + return NodeContext( + runtime=runtime, + node_id=spec.id, + node_spec=spec, + memory=SharedMemory(), + input_data={"greeting": "Session started."}, + llm=llm, + available_tools=[], + stream_id="queen", + ) + + +@pytest.mark.asyncio +async def test_codex_style_text_only_planning_turn_counts_toward_ask_rounds() -> None: + """Plain-text planning questions should satisfy the ask_rounds gate. + + This reproduces the Codex failure mode: the queen asks a planning question + in plain text instead of calling ask_user(), which triggers an auto-blocked + CLIENT_INPUT_REQUESTED event with an empty prompt. + """ + bus = EventBus() + phase_state = QueenPhaseState(phase="planning", event_bus=bus) + received: list[AgentEvent] = [] + + async def capture(event: AgentEvent) -> None: + received.append(event) + if _client_input_counts_as_planning_ask(event): + phase_state.planning_ask_rounds += 1 + + bus.subscribe([EventType.CLIENT_INPUT_REQUESTED], capture, filter_stream="queen") + + spec = NodeSpec( + id="queen", + name="Queen", + description="planning orchestrator", + node_type="event_loop", + client_facing=True, + output_keys=[], + skip_judge=True, + ) + llm = MockStreamingLLM(scenarios=[text_scenario("What kind of agent should I design for you?")]) + node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5)) + ctx = build_ctx(spec, llm) + + async def shutdown_after_first_block() -> None: + await asyncio.sleep(0.05) + node.signal_shutdown() + + task = asyncio.create_task(shutdown_after_first_block()) + result = await node.execute(ctx) + await task + + assert result.success is True + assert len(received) >= 1 + assert received[0].data["prompt"] == "" + assert received[0].data["auto_blocked"] is True + assert received[0].data["assistant_text_present"] is True + assert received[0].data["assistant_text_requires_input"] is True + assert phase_state.planning_ask_rounds == 1 + + +@pytest.mark.asyncio +async def test_save_agent_draft_accepts_two_codex_style_planning_rounds() -> None: + """Two counted auto-blocked planning turns should unlock save_agent_draft().""" + phase_state = QueenPhaseState(phase="planning") + + codex_style_event = AgentEvent( + type=EventType.CLIENT_INPUT_REQUESTED, + stream_id="queen", + data={ + "prompt": "", + "auto_blocked": True, + "assistant_text_present": True, + "assistant_text_requires_input": True, + }, + ) + for _ in range(2): + if _client_input_counts_as_planning_ask(codex_style_event): + phase_state.planning_ask_rounds += 1 + + registry = ToolRegistry() + session = SimpleNamespace( + worker_runtime=None, + event_bus=None, + worker_path=None, + runner=None, + ) + register_queen_lifecycle_tools( + registry, + session=session, + session_id="session_codex_planning", + phase_state=phase_state, + ) + + save_draft = registry._tools["save_agent_draft"].executor + result_raw = await save_draft( + { + "agent_name": "codex_planning_repro", + "goal": "Reproduce the planning gate.", + "nodes": [ + {"id": "start"}, + {"id": "discover"}, + {"id": "plan"}, + {"id": "review"}, + {"id": "finish"}, + ], + "edges": [ + {"source": "start", "target": "discover"}, + {"source": "discover", "target": "plan"}, + {"source": "plan", "target": "review"}, + {"source": "review", "target": "finish"}, + ], + } + ) + result = json.loads(result_raw) + + assert phase_state.planning_ask_rounds == 2 + assert result["status"] == "draft_saved" + + +def test_status_only_auto_block_does_not_count_toward_planning_ask_rounds() -> None: + """Auto-blocked acknowledgements should not satisfy the planning ask gate.""" + event = AgentEvent( + type=EventType.CLIENT_INPUT_REQUESTED, + stream_id="queen", + data={ + "prompt": "", + "auto_blocked": True, + "assistant_text_present": True, + "assistant_text_requires_input": False, + }, + ) + + assert _client_input_counts_as_planning_ask(event) is False diff --git a/core/tests/test_config.py b/core/tests/test_config.py index 272b8403e7..f3e1971063 100644 --- a/core/tests/test_config.py +++ b/core/tests/test_config.py @@ -1,8 +1,15 @@ """Tests for framework/config.py - Hive configuration loading.""" import logging +from unittest.mock import patch -from framework.config import get_api_base, get_hive_config, get_preferred_model +from framework.config import ( + get_api_base, + get_hive_config, + get_llm_extra_kwargs, + get_preferred_model, +) +from framework.llm.codex_backend import CODEX_API_BASE, is_codex_api_base, normalize_codex_api_base class TestGetHiveConfig: @@ -59,9 +66,65 @@ def test_get_api_base_falls_back_to_openrouter_default(self, tmp_path, monkeypat def test_get_api_base_keeps_explicit_openrouter_api_base(self, tmp_path, monkeypatch): config_file = tmp_path / "configuration.json" config_file.write_text( - '{"llm":{"provider":"openrouter","model":"x-ai/grok-4.20-beta","api_base":"https://proxy.example/v1"}}', + ( + '{"llm":{"provider":"openrouter","model":"x-ai/grok-4.20-beta",' + '"api_base":"https://proxy.example/v1"}}' + ), encoding="utf-8", ) monkeypatch.setattr("framework.config.HIVE_CONFIG_FILE", config_file) assert get_api_base() == "https://proxy.example/v1" + + +class TestCodexConfig: + """Codex config helpers should share the same transport defaults.""" + + def test_get_api_base_uses_shared_codex_backend(self, tmp_path, monkeypatch): + config_file = tmp_path / "configuration.json" + config_file.write_text( + '{"llm":{"provider":"openai","model":"gpt-5.3-codex","use_codex_subscription":true}}', + encoding="utf-8", + ) + monkeypatch.setattr("framework.config.HIVE_CONFIG_FILE", config_file) + + assert get_api_base() == CODEX_API_BASE + + def test_get_llm_extra_kwargs_uses_shared_codex_transport(self, tmp_path, monkeypatch): + config_file = tmp_path / "configuration.json" + config_file.write_text( + '{"llm":{"provider":"openai","model":"gpt-5.3-codex","use_codex_subscription":true}}', + encoding="utf-8", + ) + monkeypatch.setattr("framework.config.HIVE_CONFIG_FILE", config_file) + + with ( + patch("framework.runner.runner.get_codex_token", return_value="tok_test"), + patch("framework.runner.runner.get_codex_account_id", return_value="acct_123"), + ): + kwargs = get_llm_extra_kwargs() + + assert kwargs["store"] is False + assert kwargs["allowed_openai_params"] == ["store"] + assert kwargs["extra_headers"] == { + "Authorization": "Bearer tok_test", + "User-Agent": "CodexBar", + "ChatGPT-Account-Id": "acct_123", + } + + def test_codex_api_base_detection_requires_real_chatgpt_origin(self): + assert is_codex_api_base("https://chatgpt.com/backend-api/codex") + assert is_codex_api_base("https://chatgpt.com/backend-api/codex/responses") + assert not is_codex_api_base( + "https://proxy.example/v1?target=https://chatgpt.com/backend-api/codex" + ) + + def test_normalize_codex_api_base_strips_only_real_responses_suffix(self): + assert ( + normalize_codex_api_base("https://chatgpt.com/backend-api/codex/responses") + == CODEX_API_BASE + ) + assert ( + normalize_codex_api_base("https://proxy.example/v1/responses") + == "https://proxy.example/v1/responses" + ) diff --git a/core/tests/test_continuous_conversation.py b/core/tests/test_continuous_conversation.py index 750e7e835a..60df69ba64 100644 --- a/core/tests/test_continuous_conversation.py +++ b/core/tests/test_continuous_conversation.py @@ -224,6 +224,12 @@ def test_update(self): conv.update_system_prompt("updated") assert conv.system_prompt == "updated" + def test_update_replaces_output_keys(self): + conv = NodeConversation(system_prompt="original", output_keys=["brief"]) + conv.update_system_prompt("updated", output_keys=["articles_data"]) + assert conv.system_prompt == "updated" + assert conv._output_keys == ["articles_data"] + # =========================================================================== # Conversation threading through executor @@ -372,6 +378,61 @@ async def test_continuous_transition_marker_present(self): ) assert "PHASE TRANSITION" in all_content + @pytest.mark.asyncio + async def test_transition_marker_uses_next_node_tools_not_stale_previous_tools(self): + runtime = _make_runtime() + web_scrape = _make_tool("web_scrape") + file_tool = _make_tool("read_file") + + llm = MockStreamingLLM( + scenarios=[ + _text_then_set_output("Intake done.", "brief", "enterprise ai"), + _text_finish(""), + _text_then_set_output("Research done.", "articles_data", '{"articles": []}'), + _text_finish(""), + ] + ) + + node_a = NodeSpec( + id="a", + name="Intake", + description="Collect user preference", + node_type="event_loop", + client_facing=True, + output_keys=["brief"], + ) + node_b = NodeSpec( + id="b", + name="Research", + description="Scrape recent articles", + node_type="event_loop", + input_keys=["brief"], + output_keys=["articles_data"], + tools=["web_scrape"], + ) + + graph = GraphSpec( + id="g1", + goal_id="g1", + entry_node="a", + nodes=[node_a, node_b], + edges=[EdgeSpec(id="e1", source="a", target="b", condition=EdgeCondition.ON_SUCCESS)], + terminal_nodes=["b"], + conversation_mode="continuous", + ) + + executor = GraphExecutor(runtime=runtime, llm=llm, tools=[file_tool, web_scrape]) + result = await executor.execute(graph=graph, goal=_make_goal()) + assert result.success + + node_b_messages = llm.stream_calls[2]["messages"] + all_content = " ".join( + m.get("content", "") for m in node_b_messages if isinstance(m.get("content"), str) + ) + assert "Available tools:" in all_content + assert "web_scrape" in all_content + assert "set_output" in all_content + # =========================================================================== # Cumulative tools diff --git a/core/tests/test_event_loop_node.py b/core/tests/test_event_loop_node.py index 5b23c8fc44..51ac13194a 100644 --- a/core/tests/test_event_loop_node.py +++ b/core/tests/test_event_loop_node.py @@ -7,6 +7,7 @@ from __future__ import annotations import asyncio +import contextvars from collections.abc import AsyncIterator from typing import Any from unittest.mock import AsyncMock, MagicMock @@ -348,6 +349,143 @@ async def test_set_output_rejects_invalid_key(self, runtime, node_spec, memory): assert result.output["result"] == "ok" assert "bad_key" not in result.output + def test_set_output_rejects_identical_duplicate_value(self): + """Identical repeated set_output calls should be treated as an error, not progress.""" + node = EventLoopNode() + + result = node._handle_set_output( + {"key": "result", "value": "42"}, + ["result"], + missing_keys=["result", "summary"], + current_value=42, + normalized_value=42, + ) + + assert result.is_error is True + assert "already set to the same value" in result.content + assert "summary" in result.content + + @pytest.mark.asyncio + async def test_set_output_auto_completes_non_client_facing_node( + self, + runtime, + node_spec, + memory, + ): + """A worker node should finish immediately once required outputs are set.""" + llm = MockStreamingLLM( + scenarios=[ + tool_call_scenario("set_output", {"key": "result", "value": "done"}), + ] + ) + + ctx = build_ctx(runtime, node_spec, memory, llm) + node = EventLoopNode(config=LoopConfig(max_iterations=5)) + result = await node.execute(ctx) + + assert result.success is True + assert result.output["result"] == "done" + assert len(llm.stream_calls) == 1 + + @pytest.mark.asyncio + async def test_set_output_auto_completes_client_facing_node( + self, + runtime, + memory, + ): + """Client-facing nodes should also finish once required outputs are set.""" + spec = NodeSpec( + id="review", + name="Review", + description="client-facing review node", + node_type="event_loop", + output_keys=["decision"], + client_facing=True, + ) + llm = MockStreamingLLM( + scenarios=[ + tool_call_scenario("set_output", {"key": "decision", "value": "approve"}), + ] + ) + + ctx = build_ctx(runtime, spec, memory, llm) + node = EventLoopNode(config=LoopConfig(max_iterations=5)) + result = await node.execute(ctx) + + assert result.success is True + assert result.output["decision"] == "approve" + assert len(llm.stream_calls) == 1 + + @pytest.mark.asyncio + async def test_client_facing_completes_immediately_after_user_reply_sets_all_outputs( + self, + runtime, + memory, + ): + """Client-facing nodes should finish once a post-user-reply turn sets all outputs.""" + spec = NodeSpec( + id="findings-review", + name="Findings Review", + description="review findings", + node_type="event_loop", + output_keys=["continue_scanning", "feedback", "all_findings"], + client_facing=True, + ) + llm = MockStreamingLLM( + scenarios=[ + tool_call_scenario( + "ask_user", + { + "question": "Continue scanning or generate report?", + "options": ["Continue", "Report"], + }, + tool_use_id="ask_1", + ), + [ + ToolCallEvent( + tool_use_id="set_continue", + tool_name="set_output", + tool_input={"key": "continue_scanning", "value": "false"}, + ), + ToolCallEvent( + tool_use_id="set_feedback", + tool_name="set_output", + tool_input={"key": "feedback", "value": "generate final report"}, + ), + ToolCallEvent( + tool_use_id="set_all", + tool_name="set_output", + tool_input={"key": "all_findings", "value": '{"ok": true}'}, + ), + FinishEvent( + stop_reason="tool_calls", + input_tokens=10, + output_tokens=5, + model="mock", + ), + ], + ] + ) + + node = EventLoopNode(config=LoopConfig(max_iterations=10)) + ctx = build_ctx(runtime, spec, memory, llm) + + async def user_responds(): + await asyncio.sleep(0.05) + await node.inject_event("Generate the report") + + task = asyncio.create_task(user_responds()) + result = await node.execute(ctx) + await task + + assert result.success is True + assert result.output == { + "continue_scanning": False, + "feedback": "generate final report", + "all_findings": {"ok": True}, + } + assert len(llm.stream_calls) == 2 + @pytest.mark.asyncio async def test_missing_keys_triggers_retry(self, runtime, node_spec, memory): """Judge accepts but output keys are missing -> retry with hint.""" @@ -399,6 +537,47 @@ async def test_stall_detection(self, runtime, node_spec, memory): assert result.success is False assert "stalled" in result.error.lower() + @pytest.mark.asyncio + async def test_no_progress_churn_detection(self, runtime, node_spec, memory): + """Different text with missing outputs should still fail if nothing progresses.""" + llm = MockStreamingLLM( + scenarios=[ + text_scenario("Reviewing the logs and thinking through the issue."), + text_scenario("I am narrowing down the likely cause."), + text_scenario("I have more context and am still analyzing."), + ] + ) + + ctx = build_ctx(runtime, node_spec, memory, llm) + node = EventLoopNode(config=LoopConfig(max_iterations=10, stall_detection_threshold=3)) + result = await node.execute(ctx) + + assert result.success is False + assert "no-progress loop detected" in (result.error or "").lower() + + @pytest.mark.asyncio + async def test_no_progress_counter_resets_after_output_progress( + self, + runtime, + node_spec, + memory, + ): + """A real output-setting turn should reset the no-progress churn counter.""" + llm = MockStreamingLLM( + scenarios=[ + text_scenario("Parsing the problem statement."), + text_scenario("Extracting the key facts now."), + tool_call_scenario("set_output", {"key": "result", "value": "triaged"}), + ] + ) + + ctx = build_ctx(runtime, node_spec, memory, llm) + node = EventLoopNode(config=LoopConfig(max_iterations=10, stall_detection_threshold=3)) + result = await node.execute(ctx) + + assert result.success is True + assert result.output["result"] == "triaged" + # =========================================================================== # EventBus lifecycle events @@ -647,6 +826,57 @@ async def shutdown(): assert len(received) >= 1 assert received[0].type == EventType.CLIENT_INPUT_REQUESTED + @pytest.mark.asyncio + async def test_queen_long_ask_user_prompt_surfaces_result_before_widget( + self, runtime, memory, client_spec + ): + """Long queen prompts should stream visible result text before options.""" + llm = MockStreamingLLM( + scenarios=[ + tool_call_scenario( + "ask_user", + { + "question": ( + "Root cause: checkout requests are failing because the DB pool is " + "exhausted and cart reads are timing out.\n\n" + "What would you like to do next?" + ), + "options": ["Rerun", "Stop"], + }, + tool_use_id="ask_1", + ), + ] + ) + bus = EventBus() + received = [] + + async def capture(e): + received.append(e) + + bus.subscribe( + event_types=[EventType.CLIENT_OUTPUT_DELTA, EventType.CLIENT_INPUT_REQUESTED], + handler=capture, + ) + + node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5)) + ctx = build_ctx(runtime, client_spec, memory, llm, stream_id="queen") + + async def shutdown(): + await asyncio.sleep(0.05) + node.signal_shutdown() + + task = asyncio.create_task(shutdown()) + await node.execute(ctx) + await task + + output_events = [e for e in received if e.type == EventType.CLIENT_OUTPUT_DELTA] + input_events = [e for e in received if e.type == EventType.CLIENT_INPUT_REQUESTED] + + assert output_events + assert input_events + assert "Root cause: checkout requests are failing" in output_events[0].data["snapshot"] + assert input_events[0].data["prompt"] == "What would you like to do next?" + @pytest.mark.asyncio @pytest.mark.skip(reason="Hangs in non-interactive shells (client-facing blocks on stdin)") async def test_ask_user_with_real_tools(self, runtime, memory): @@ -905,9 +1135,79 @@ async def queen_reply(): assert result.success is True assert result.output["result"] == "resolved after queen guidance" - assert judge.evaluate.await_count >= 1 + assert judge.evaluate.await_count == 0 assert len(client_input_events) == 0 + @pytest.mark.asyncio + async def test_escalate_then_complete_outputs_autocompletes_without_extra_turn( + self, runtime, memory + ): + """Worker nodes should still auto-complete after queen guidance once outputs are set.""" + spec = NodeSpec( + id="csv-intake", + name="CSV Intake", + description="parse csv", + node_type="event_loop", + output_keys=["original_headers", "parsed_rows", "raw_csv_text"], + ) + llm = MockStreamingLLM( + scenarios=[ + tool_call_scenario( + "escalate", + {"reason": "need csv", "context": "input malformed"}, + tool_use_id="esc_1", + ), + [ + ToolCallEvent( + tool_use_id="set_headers", + tool_name="set_output", + tool_input={"key": "original_headers", "value": '["name","email"]'}, + ), + ToolCallEvent( + tool_use_id="set_rows", + tool_name="set_output", + tool_input={ + "key": "parsed_rows", + "value": '[{"name":"Alice","email":"alice@example.com"}]', + }, + ), + ToolCallEvent( + tool_use_id="set_raw", + tool_name="set_output", + tool_input={ + "key": "raw_csv_text", + "value": "name,email\\nAlice,alice@example.com", + }, + ), + FinishEvent( + stop_reason="tool_calls", + input_tokens=10, + output_tokens=5, + model="mock", + ), + ], + ] + ) + + ctx = build_ctx(runtime, spec, memory, llm, stream_id="worker") + node = EventLoopNode(config=LoopConfig(max_iterations=5)) + + async def queen_reply(): + await asyncio.sleep(0.05) + await node.inject_event("Use the sample CSV and continue.") + + task = asyncio.create_task(queen_reply()) + result = await node.execute(ctx) + await task + + assert result.success is True + assert result.output == { + "original_headers": ["name", "email"], + "parsed_rows": [{"name": "Alice", "email": "alice@example.com"}], + "raw_csv_text": "name,email\\nAlice,alice@example.com", + } + assert len(llm.stream_calls) == 2 + # =========================================================================== # Client-facing: _cf_expecting_work state machine @@ -1382,6 +1682,39 @@ async def test_pause_returns_early(self, runtime, node_spec, memory): assert llm._call_index == 0 +class TestToolExecutionContext: + @pytest.mark.asyncio + async def test_execute_tool_preserves_contextvars_in_threadpool(self): + marker = contextvars.ContextVar("marker", default="missing") + + def tool_exec(tool_use: ToolUse) -> ToolResult: + return ToolResult( + tool_use_id=tool_use.id, + content=marker.get(), + is_error=False, + ) + + node = EventLoopNode( + tool_executor=tool_exec, + config=LoopConfig(tool_call_timeout_seconds=5), + ) + + token = marker.set("present") + try: + result = await node._execute_tool( + ToolCallEvent( + tool_use_id="tool-1", + tool_name="echo_marker", + tool_input={}, + ) + ) + finally: + marker.reset(token) + + assert result.is_error is False + assert result.content == "present" + + # =========================================================================== # Stream errors # =========================================================================== @@ -1783,6 +2116,17 @@ def test_sort_keys_deterministic(self): ) +class TestFingerprintSetOutputCalls: + """Unit tests for _fingerprint_set_output_calls().""" + + def test_basic_fingerprint(self): + results = [ + {"tool_name": "set_output", "tool_input": {"key": "result", "value": {"a": 1}}}, + ] + fps = EventLoopNode._fingerprint_set_output_calls(results) + assert fps == [("result", '{"a": 1}')] + + class TestIsToolDoomLoop: """Unit tests for _is_tool_doom_loop().""" @@ -1821,6 +2165,25 @@ def test_empty_fingerprints_no_doom(self): assert is_doom is False +class TestIsOutputDoomLoop: + """Unit tests for _is_output_doom_loop().""" + + def test_at_threshold_identical(self): + node = EventLoopNode(config=LoopConfig(tool_doom_loop_threshold=3)) + fp = [("result", '"done"')] + is_doom, desc = node._is_output_doom_loop([fp, fp, fp]) + assert is_doom is True + assert "set_output" in desc + + def test_different_values_no_doom(self): + node = EventLoopNode(config=LoopConfig(tool_doom_loop_threshold=3)) + fp1 = [("result", '"a"')] + fp2 = [("result", '"b"')] + fp3 = [("result", '"c"')] + is_doom, _ = node._is_output_doom_loop([fp1, fp2, fp3]) + assert is_doom is False + + class ToolRepeatLLM(LLMProvider): """LLM that produces identical tool calls across outer iterations. @@ -1879,6 +2242,95 @@ def complete(self, messages, system="", **kwargs) -> LLMResponse: ) +class SetOutputRepeatLLM(LLMProvider): + """LLM that repeats the same set_output-only turn across iterations.""" + + def __init__(self, key: str, value: str, tool_turns: int, final_text: str = "done"): + self.key = key + self.value = value + self.tool_turns = tool_turns + self.final_text = final_text + self._call_index = 0 + + async def stream(self, messages, system="", tools=None, max_tokens=4096): + idx = self._call_index + self._call_index += 1 + outer_iter = idx // 2 + is_tool_call = (idx % 2 == 0) and outer_iter < self.tool_turns + if is_tool_call: + yield ToolCallEvent( + tool_use_id=f"set_{outer_iter}", + tool_name="set_output", + tool_input={"key": self.key, "value": self.value}, + ) + yield FinishEvent( + stop_reason="tool_calls", + input_tokens=10, + output_tokens=5, + model="mock", + ) + else: + text = f"{self.final_text} (call {idx})" + yield TextDeltaEvent(content=text, snapshot=text) + yield FinishEvent( + stop_reason="stop", + input_tokens=10, + output_tokens=5, + model="mock", + ) + + def complete(self, messages, system="", **kwargs) -> LLMResponse: + return LLMResponse( + content="ok", + model="mock", + stop_reason="stop", + ) + + +class VaryingSetOutputRepeatLLM(LLMProvider): + """LLM that repeats set_output turns with different values across iterations.""" + + def __init__(self, key: str, values: list[str], final_text: str = "done"): + self.key = key + self.values = values + self.final_text = final_text + self._call_index = 0 + + async def stream(self, messages, system="", tools=None, max_tokens=4096): + idx = self._call_index + self._call_index += 1 + outer_iter = idx // 2 + is_tool_call = (idx % 2 == 0) and outer_iter < len(self.values) + if is_tool_call: + yield ToolCallEvent( + tool_use_id=f"set_{outer_iter}", + tool_name="set_output", + tool_input={"key": self.key, "value": self.values[outer_iter]}, + ) + yield FinishEvent( + stop_reason="tool_calls", + input_tokens=10, + output_tokens=5, + model="mock", + ) + else: + text = f"{self.final_text} (call {idx})" + yield TextDeltaEvent(content=text, snapshot=text) + yield FinishEvent( + stop_reason="stop", + input_tokens=10, + output_tokens=5, + model="mock", + ) + + def complete(self, messages, system="", **kwargs) -> LLMResponse: + return LLMResponse( + content="ok", + model="mock", + stop_reason="stop", + ) + + class TestToolDoomLoopIntegration: """Integration tests for doom loop detection in execute(). @@ -2263,7 +2715,97 @@ def tool_exec(tool_use: ToolUse) -> ToolResult: assert result.success is True # Doom loop MUST fire for repeatedly-failing tool calls assert len(doom_events) >= 1 - assert "failing_tool" in doom_events[0].data["description"] + + @pytest.mark.asyncio + async def test_repeated_identical_set_output_turns_fail_fast( + self, + runtime, + node_spec, + memory, + ): + """Repeated identical set_output-only turns should fail instead of spinning forever.""" + node_spec.output_keys = ["result", "review_manifest"] + judge = AsyncMock(spec=JudgeProtocol) + judge.evaluate = AsyncMock(return_value=JudgeVerdict(action="RETRY")) + + llm = SetOutputRepeatLLM("result", "same summary", tool_turns=4) + bus = EventBus() + doom_events: list = [] + bus.subscribe( + event_types=[EventType.NODE_TOOL_DOOM_LOOP], + handler=lambda e: doom_events.append(e), + ) + + ctx = build_ctx(runtime, node_spec, memory, llm, tools=[]) + node = EventLoopNode( + judge=judge, + event_bus=bus, + config=LoopConfig( + max_iterations=10, + tool_doom_loop_threshold=3, + stall_similarity_threshold=1.0, + ), + ) + result = await node.execute(ctx) + + assert result.success is False + assert "Output doom loop detected" in (result.error or "") + assert doom_events + assert "set_output" in doom_events[0].data["description"] + assert "result" in doom_events[0].data["description"] + + @pytest.mark.asyncio + async def test_meta_reset_set_output_turns_fail_fast( + self, + runtime, + node_spec, + memory, + ): + """Fresh-payload reset chatter should trip the output doom loop guard.""" + node_spec.output_keys = ["rules", "candidates", "scan_stats"] + judge = AsyncMock(spec=JudgeProtocol) + judge.evaluate = AsyncMock(return_value=JudgeVerdict(action="RETRY")) + + llm = VaryingSetOutputRepeatLLM( + "rules", + [ + ( + "New event acknowledged. Awaiting fresh request payload " + "(phase transition details + structured inputs) to proceed." + ), + ( + "Context reset complete. Awaiting fresh phase transition payload " + "and structured inputs to proceed." + ), + ( + "Ready for fresh request payload with phase transition " + "instructions and structured inputs." + ), + ], + ) + bus = EventBus() + doom_events: list = [] + bus.subscribe( + event_types=[EventType.NODE_TOOL_DOOM_LOOP], + handler=lambda e: doom_events.append(e), + ) + + ctx = build_ctx(runtime, node_spec, memory, llm, tools=[]) + node = EventLoopNode( + judge=judge, + event_bus=bus, + config=LoopConfig( + max_iterations=10, + tool_doom_loop_threshold=3, + stall_similarity_threshold=1.0, + ), + ) + result = await node.execute(ctx) + + assert result.success is False + assert "fresh payload" in (result.error or "").lower() + assert doom_events + assert "fresh payload" in doom_events[0].data["description"].lower() # =========================================================================== diff --git a/core/tests/test_litellm_provider.py b/core/tests/test_litellm_provider.py index 6024f355c3..7ee2665250 100644 --- a/core/tests/test_litellm_provider.py +++ b/core/tests/test_litellm_provider.py @@ -238,6 +238,48 @@ def test_parse_tool_call_arguments_raises_when_unrepairable(self): with pytest.raises(ValueError, match="Failed to parse tool call arguments"): provider._parse_tool_call_arguments('{"question": foo', "ask_user") + def test_parse_tool_call_arguments_recovers_pythonish_payloads(self): + """Single-quoted and trailing-comma argument payloads should be recovered.""" + provider = LiteLLMProvider(model="openai/gpt-5.3-codex", api_key="test-key") + + parsed = provider._parse_tool_call_arguments( + "{'question': 'Continue?', 'options': ['Yes', 'No'],}", + "ask_user", + ) + + assert parsed == { + "question": "Continue?", + "options": ["Yes", "No"], + } + + def test_parse_tool_call_arguments_keeps_null_inside_strings(self): + """Literal normalization should not mutate quoted text values.""" + provider = LiteLLMProvider(model="openai/gpt-5.3-codex", api_key="test-key") + + parsed = provider._parse_tool_call_arguments( + "{'hypothesis': 'null hypothesis', 'approved': false}", + "summarize", + ) + + assert parsed == { + "hypothesis": "null hypothesis", + "approved": False, + } + + def test_parse_tool_call_arguments_strips_json_code_fences(self): + """Fence stripping should remove the language tag before JSON parsing.""" + provider = LiteLLMProvider(model="openai/gpt-5.3-codex", api_key="test-key") + + parsed = provider._parse_tool_call_arguments( + '```json\n{"question":"Continue?","options":["Yes","No"]}\n```', + "ask_user", + ) + + assert parsed == { + "question": "Continue?", + "options": ["Yes", "No"], + } + class TestAnthropicProviderBackwardCompatibility: """Test AnthropicProvider backward compatibility with LiteLLM backend.""" @@ -728,6 +770,221 @@ def test_is_minimax_model_variants(self): assert not LiteLLMProvider(model="gpt-4o-mini", api_key="x")._is_minimax_model() +class TestCodexEmptyStreamRecovery: + """Codex empty streams should fall back before surfacing ghost-stream retries.""" + + @pytest.mark.asyncio + @patch("litellm.acompletion") + async def test_stream_recovers_empty_codex_stream_via_nonstream_completion( + self, + mock_acompletion, + ): + """An empty Codex stream should be salvaged with a non-stream completion.""" + from framework.llm.stream_events import FinishEvent, TextDeltaEvent + + provider = LiteLLMProvider( + model="openai/gpt-5.3-codex", + api_key="test-key", + api_base="https://chatgpt.com/backend-api/codex", + ) + + class EmptyStreamResponse: + chunks: list = [] + + def __aiter__(self): + return self + + async def __anext__(self): + raise StopAsyncIteration + + recovered = MagicMock() + recovered.choices = [MagicMock()] + recovered.choices[0].message.content = "Recovered via fallback" + recovered.choices[0].message.tool_calls = [] + recovered.choices[0].finish_reason = "stop" + recovered.model = provider.model + recovered.usage.prompt_tokens = 12 + recovered.usage.completion_tokens = 4 + + async def side_effect(*args, **kwargs): + if kwargs.get("stream"): + return EmptyStreamResponse() + return recovered + + mock_acompletion.side_effect = side_effect + + events = [] + async for event in provider.stream(messages=[{"role": "user", "content": "hi"}]): + events.append(event) + + text_events = [event for event in events if isinstance(event, TextDeltaEvent)] + assert len(text_events) == 1 + assert text_events[0].snapshot == "Recovered via fallback" + + finish_events = [event for event in events if isinstance(event, FinishEvent)] + assert len(finish_events) == 1 + assert finish_events[0].stop_reason == "stop" + assert finish_events[0].input_tokens == 12 + assert finish_events[0].output_tokens == 4 + + assert mock_acompletion.call_count == 2 + assert mock_acompletion.call_args_list[0].kwargs["stream"] is True + assert "stream" not in mock_acompletion.call_args_list[1].kwargs + + @pytest.mark.asyncio + @patch("litellm.acompletion") + async def test_stream_recovers_empty_codex_stream_with_tool_calls( + self, + mock_acompletion, + ): + """Non-stream fallback should preserve tool calls, not just text.""" + from framework.llm.stream_events import FinishEvent, ToolCallEvent + + provider = LiteLLMProvider( + model="openai/gpt-5.3-codex", + api_key="test-key", + api_base="https://chatgpt.com/backend-api/codex/responses", + ) + + class EmptyStreamResponse: + chunks: list = [] + + def __aiter__(self): + return self + + async def __anext__(self): + raise StopAsyncIteration + + tc = MagicMock() + tc.id = "tool_1" + tc.function.name = "ask_user" + tc.function.arguments = '{"question":"Continue?","options":["Yes","No"]}' + + recovered = MagicMock() + recovered.choices = [MagicMock()] + recovered.choices[0].message.content = "" + recovered.choices[0].message.tool_calls = [tc] + recovered.choices[0].finish_reason = "tool_calls" + recovered.model = provider.model + recovered.usage.prompt_tokens = 14 + recovered.usage.completion_tokens = 5 + + async def side_effect(*args, **kwargs): + if kwargs.get("stream"): + return EmptyStreamResponse() + return recovered + + mock_acompletion.side_effect = side_effect + + events = [] + async for event in provider.stream( + messages=[{"role": "user", "content": "Should we continue?"}], + tools=[ + Tool( + name="ask_user", + description="Ask the user", + parameters={"properties": {"question": {"type": "string"}}}, + ) + ], + ): + events.append(event) + + tool_events = [event for event in events if isinstance(event, ToolCallEvent)] + assert len(tool_events) == 1 + assert tool_events[0].tool_name == "ask_user" + assert tool_events[0].tool_input == { + "question": "Continue?", + "options": ["Yes", "No"], + } + + finish_events = [event for event in events if isinstance(event, FinishEvent)] + assert len(finish_events) == 1 + assert finish_events[0].stop_reason == "tool_calls" + + +class TestCodexRequestHardening: + def test_codex_build_completion_kwargs_splits_prompt_and_forces_tool_choice(self): + """Codex requests should chunk large system prompts and require tools when needed.""" + provider = LiteLLMProvider( + model="openai/gpt-5.3-codex", + api_key="test-key", + api_base="https://chatgpt.com/backend-api/codex/responses", + ) + kwargs = provider._build_completion_kwargs( + messages=[{"role": "user", "content": "hi"}], + system="# Identity\n" + ("rule\n" * 2000), + tools=[ + Tool( + name="ask_user", + description="Ask the user", + parameters={"properties": {"question": {"type": "string"}}}, + ) + ], + max_tokens=256, + response_format=None, + json_mode=False, + stream=True, + ) + + system_messages = [m for m in kwargs["messages"] if m["role"] == "system"] + assert len(system_messages) >= 2 + assert system_messages[0]["content"].startswith("# Codex Execution Contract") + assert kwargs["tool_choice"] == "required" + assert kwargs["store"] is False + assert "max_tokens" not in kwargs + assert "stream_options" not in kwargs + assert kwargs["api_base"] == "https://chatgpt.com/backend-api/codex" + assert "store" in kwargs["allowed_openai_params"] + + def test_codex_merge_tool_call_chunk_handles_parallel_calls_with_broken_indexes(self): + """Codex chunk merging should survive index=0 for multiple parallel tool calls.""" + from types import SimpleNamespace + + provider = LiteLLMProvider( + model="openai/gpt-5.3-codex", + api_key="test-key", + api_base="https://chatgpt.com/backend-api/codex", + ) + acc: dict[int, dict[str, str]] = {} + last_idx = 0 + + chunks = [ + SimpleNamespace( + id="tool_1", + index=0, + function=SimpleNamespace(name="web_search", arguments='{"query":"alpha'), + ), + SimpleNamespace( + id="tool_2", + index=0, + function=SimpleNamespace(name="read_file", arguments='{"path":"beta'), + ), + SimpleNamespace( + id=None, + index=0, + function=SimpleNamespace(name=None, arguments='"}'), + ), + SimpleNamespace( + id=None, + index=0, + function=SimpleNamespace(name=None, arguments='"}'), + ), + ] + + for chunk in chunks: + last_idx = provider._merge_tool_call_chunk(acc, chunk, last_idx) + + assert len(acc) == 2 + parsed = [ + provider._parse_tool_call_arguments(slot["arguments"], slot["name"]) + for _, slot in sorted(acc.items()) + ] + assert parsed == [ + {"query": "alpha"}, + {"path": "beta"}, + ] + + class TestOpenRouterToolCompatFallback: """OpenRouter models should fall back when native tool use is unavailable.""" diff --git a/core/tests/test_queen_lifecycle_validation.py b/core/tests/test_queen_lifecycle_validation.py new file mode 100644 index 0000000000..ba8e723aff --- /dev/null +++ b/core/tests/test_queen_lifecycle_validation.py @@ -0,0 +1,985 @@ +from __future__ import annotations + +import json +import os +import time +from datetime import datetime +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +import pytest + +import framework.tools.queen_lifecycle_tools as qlt +from framework.llm.provider import Tool +from framework.runner.tool_registry import ToolRegistry +from framework.runtime.event_bus import EventBus +from framework.tools.queen_lifecycle_tools import QueenPhaseState, register_queen_lifecycle_tools + + +def _write_worker_logs( + storage_path: Path, + session_id: str, + *, + session_status: str, + steps: list[dict[str, object]], +) -> Path: + session_dir = storage_path / "sessions" / session_id + logs_dir = session_dir / "logs" + logs_dir.mkdir(parents=True, exist_ok=True) + (session_dir / "state.json").write_text( + json.dumps({"status": session_status}), + encoding="utf-8", + ) + log_path = logs_dir / "tool_logs.jsonl" + log_path.write_text( + "".join(json.dumps(step) + "\n" for step in steps), + encoding="utf-8", + ) + return log_path + + +def _register_fake_validator(registry: ToolRegistry, report: dict) -> None: + registry.register( + "validate_agent_package", + Tool( + name="validate_agent_package", + description="fake validator", + parameters={"type": "object", "properties": {"agent_name": {"type": "string"}}}, + ), + lambda _inputs: json.dumps(report), + ) + + +def test_parse_validation_report_handles_saved_footer() -> None: + raw = ( + '{\n "valid": false,\n "steps": {"tool_validation": {"passed": false}}\n}\n\n' + "[Saved to 'validate.txt']" + ) + + parsed = qlt._parse_validation_report(raw) + + assert parsed == {"valid": False, "steps": {"tool_validation": {"passed": False}}} + + +def test_validation_blocks_stage_or_run_ignores_non_blocking_warnings() -> None: + report = { + "steps": { + "behavior_validation": { + "passed": True, + "warnings": ["placeholder prompt"], + "output": "placeholder prompt", + }, + "tests": { + "passed": True, + "warnings": ["1 failed"], + "summary": "1 failed", + }, + }, + } + + assert qlt._validation_blocks_stage_or_run(report) is False + + +def test_invalid_validation_report_blocks_stage_or_run() -> None: + report = qlt._invalid_validation_report("validator returned garbage") + + assert report["valid"] is False + assert qlt._validation_blocks_stage_or_run(report) is True + + +@pytest.mark.asyncio +async def test_get_worker_status_summary_flags_retry_and_judge_pressure() -> None: + registry = ToolRegistry() + bus = EventBus() + + await bus.emit_node_retry( + stream_id="worker", + node_id="scan", + retry_count=1, + max_retries=3, + error="still missing required result", + ) + for _ in range(4): + await bus.emit_judge_verdict( + stream_id="worker", + node_id="scan", + action="RETRY", + feedback="missing structured output", + ) + + runtime = SimpleNamespace( + graph_id="worker-graph", + get_graph_registration=lambda _gid: SimpleNamespace( + streams={ + "default": SimpleNamespace( + active_execution_ids=["exec-1"], + get_context=lambda _exec_id: SimpleNamespace(started_at=datetime.now()), + get_waiting_nodes=lambda: [], + ) + } + ), + ) + session = SimpleNamespace(worker_runtime=runtime, event_bus=bus, worker_path=None, runner=None) + + register_queen_lifecycle_tools(registry, session=session, session_id="sess-status") + + summary = await registry._tools["get_worker_status"].executor({}) + + assert "issue type(s) detected" in summary + + +@pytest.mark.asyncio +async def test_get_worker_status_issues_reports_judge_pressure() -> None: + registry = ToolRegistry() + bus = EventBus() + + for action in ("CONTINUE", "RETRY", "RETRY", "ESCALATE"): + await bus.emit_judge_verdict( + stream_id="worker", + node_id="review", + action=action, + feedback="still not converging", + ) + + runtime = SimpleNamespace( + graph_id="worker-graph", + get_graph_registration=lambda _gid: SimpleNamespace(streams={}), + ) + session = SimpleNamespace(worker_runtime=runtime, event_bus=bus, worker_path=None, runner=None) + + register_queen_lifecycle_tools(registry, session=session, session_id="sess-issues") + + issues = await registry._tools["get_worker_status"].executor({"focus": "issues"}) + + assert "Judge pressure detected" in issues + assert "consecutive non-ACCEPT judge verdict" in issues + + +@pytest.mark.asyncio +async def test_get_worker_status_summary_uses_health_snapshot_signals(tmp_path: Path) -> None: + storage_path = tmp_path / "agent_store" + storage_path.mkdir(parents=True, exist_ok=True) + log_path = _write_worker_logs( + storage_path, + "sess-health", + session_status="running", + steps=[ + {"verdict": "CONTINUE", "llm_text": "thinking"}, + {"verdict": "RETRY", "llm_text": "retrying"}, + {"verdict": "RETRY", "llm_text": "still retrying"}, + {"verdict": "ESCALATE", "llm_text": "need help"}, + ], + ) + three_minutes_ago = time.time() - 180 + os.utime(log_path, (three_minutes_ago, three_minutes_ago)) + + registry = ToolRegistry() + bus = EventBus() + runtime = SimpleNamespace( + graph_id="worker-graph", + get_graph_registration=lambda _gid: SimpleNamespace( + streams={ + "default": SimpleNamespace( + active_execution_ids=["exec-1"], + get_context=lambda _exec_id: SimpleNamespace(started_at=datetime.now()), + get_waiting_nodes=lambda: [], + ) + } + ), + ) + session = SimpleNamespace( + worker_runtime=runtime, + event_bus=bus, + worker_path=storage_path, + runner=None, + ) + + register_queen_lifecycle_tools(registry, session=session, session_id="sess-health") + + summary = await registry._tools["get_worker_status"].executor({}) + + assert "issue signal(s) detected" in summary + assert "judge_pressure" in summary + assert "recent_non_accept_churn" in summary + + +@pytest.mark.asyncio +async def test_get_worker_status_issues_includes_health_snapshot_signals(tmp_path: Path) -> None: + storage_path = tmp_path / "agent_store" + storage_path.mkdir(parents=True, exist_ok=True) + log_path = _write_worker_logs( + storage_path, + "sess-health", + session_status="running", + steps=[ + {"verdict": "CONTINUE", "llm_text": "thinking"}, + {"verdict": "RETRY", "llm_text": "retrying"}, + {"verdict": "RETRY", "llm_text": "still retrying"}, + {"verdict": "ESCALATE", "llm_text": "need help"}, + ], + ) + three_minutes_ago = time.time() - 180 + os.utime(log_path, (three_minutes_ago, three_minutes_ago)) + + registry = ToolRegistry() + bus = EventBus() + runtime = SimpleNamespace( + graph_id="worker-graph", + get_graph_registration=lambda _gid: SimpleNamespace(streams={}), + ) + session = SimpleNamespace( + worker_runtime=runtime, + event_bus=bus, + worker_path=storage_path, + runner=None, + ) + + register_queen_lifecycle_tools(registry, session=session, session_id="sess-health") + + issues = await registry._tools["get_worker_status"].executor({"focus": "issues"}) + + assert "Health signals:" in issues + assert "slow_progress" in issues + assert "recent_non_accept_churn" in issues + + +def test_build_worker_input_data_maps_bullet_task_fields_to_entry_inputs( + monkeypatch, tmp_path: Path +) -> None: + monkeypatch.chdir(tmp_path) + (tmp_path / "docs").mkdir() + + runtime = SimpleNamespace( + get_entry_points=lambda: [SimpleNamespace(id="default", entry_node="process")], + graph=SimpleNamespace( + entry_node="process", + get_node=lambda node_id: ( + SimpleNamespace( + input_keys=[ + "docs_dir", + "review_dir", + "word_threshold", + "style_rules", + "target_ratio", + ] + ) + if node_id == "process" + else None + ), + ), + ) + + payload = qlt._build_worker_input_data( + runtime, + ( + "Run md_condense_reviewer with the following runtime config:\n" + "- docs_dir: docs/\n" + "- review_dir: docs_reviews/\n" + "- word_threshold: 800\n" + "- target_ratio: 0.6 (default)\n" + "- style_rules: Preserve headings and links.\n\n" + "Execution requirements:\n" + "1) Scan the docs directory.\n" + "2) Write review copies." + ), + ) + + assert payload == { + "docs_dir": str((tmp_path / "docs").resolve()), + "review_dir": str((tmp_path / "docs_reviews").resolve()), + "word_threshold": 800, + "style_rules": "Preserve headings and links.", + "target_ratio": 0.6, + } + + +def test_build_worker_input_data_maps_equals_style_runtime_fields( + monkeypatch, tmp_path: Path +) -> None: + monkeypatch.chdir(tmp_path) + + runtime = SimpleNamespace( + get_entry_points=lambda: [SimpleNamespace(id="default", entry_node="process")], + graph=SimpleNamespace( + entry_node="process", + get_node=lambda node_id: ( + SimpleNamespace(input_keys=["target_dir", "review_dir_mode", "word_threshold"]) + if node_id == "process" + else None + ), + ), + ) + + payload = qlt._build_worker_input_data( + runtime, + ("Yes, rerun with target_dir=docs review_dir_mode=next_to_source word_threshold=800"), + ) + + assert payload == { + "target_dir": str((tmp_path / "docs").resolve()), + "review_dir_mode": "next_to_source", + "word_threshold": 800, + } + + +def test_build_worker_input_data_backfills_missing_fields_from_recent_session( + monkeypatch, tmp_path: Path +) -> None: + monkeypatch.chdir(tmp_path) + sessions_dir = tmp_path / "agent_store" / "sessions" + sessions_dir.mkdir(parents=True) + + valid_prior_state = { + "timestamps": {"updated_at": "2026-03-24T20:44:00"}, + "input_data": { + "target_dir": "docs", + "review_dir": "docs_reviews", + "word_threshold": 800, + }, + } + malformed_recent_state = { + "timestamps": {"updated_at": "2026-03-24T21:20:23"}, + "input_data": { + "review_dir": "docs_reviews", + "word_threshold": "800. Validate inputs and continue.", + }, + } + + for session_name, state in { + "session_20260324_204400_good": valid_prior_state, + "session_20260324_212023_bad": malformed_recent_state, + }.items(): + session_dir = sessions_dir / session_name + session_dir.mkdir() + (session_dir / "state.json").write_text(json.dumps(state), encoding="utf-8") + + runtime = SimpleNamespace( + _session_store=SimpleNamespace(sessions_dir=sessions_dir), + get_entry_points=lambda: [SimpleNamespace(id="default", entry_node="process")], + graph=SimpleNamespace( + entry_node="process", + get_node=lambda node_id: ( + SimpleNamespace(input_keys=["target_dir", "review_dir", "word_threshold"]) + if node_id == "process" + else None + ), + ), + ) + + payload = qlt._build_worker_input_data( + runtime, + ("review_dir: docs_reviews\nword_threshold: 800. Validate inputs and continue."), + ) + + assert payload == { + "target_dir": str((tmp_path / "docs").resolve()), + "review_dir": str((tmp_path / "docs_reviews").resolve()), + "word_threshold": 800, + } + + +def test_build_worker_input_data_reuses_recent_defaults_for_rerun_phrase( + monkeypatch, tmp_path: Path +) -> None: + monkeypatch.chdir(tmp_path) + sessions_dir = tmp_path / "agent_store" / "sessions" + sessions_dir.mkdir(parents=True) + + state = { + "timestamps": {"updated_at": "2026-03-24T21:17:00"}, + "input_data": { + "target_dir": "docs", + "review_dir": "docs_reviews", + "word_threshold": 800, + }, + } + session_dir = sessions_dir / "session_20260324_211700_prev" + session_dir.mkdir() + (session_dir / "state.json").write_text(json.dumps(state), encoding="utf-8") + + runtime = SimpleNamespace( + _session_store=SimpleNamespace(sessions_dir=sessions_dir), + get_entry_points=lambda: [SimpleNamespace(id="default", entry_node="process")], + graph=SimpleNamespace( + entry_node="process", + get_node=lambda node_id: ( + SimpleNamespace(input_keys=["target_dir", "review_dir", "word_threshold"]) + if node_id == "process" + else None + ), + ), + ) + + payload = qlt._build_worker_input_data(runtime, "Run again with same defaults") + + assert payload == { + "target_dir": str((tmp_path / "docs").resolve()), + "review_dir": str((tmp_path / "docs_reviews").resolve()), + "word_threshold": 800, + } + + +def test_build_worker_input_data_backfills_from_recent_result_output( + monkeypatch, tmp_path: Path +) -> None: + monkeypatch.chdir(tmp_path) + sessions_dir = tmp_path / "agent_store" / "sessions" + sessions_dir.mkdir(parents=True) + + state = { + "timestamps": {"updated_at": "2026-03-24T23:35:19"}, + "input_data": { + "review_dir": "docs_reviews", + "word_threshold": 800, + }, + "result": { + "output": { + "target_dir": "docs", + "review_dir": "docs_reviews", + "word_threshold": 800, + } + }, + } + session_dir = sessions_dir / "session_20260324_233519_prev" + session_dir.mkdir() + (session_dir / "state.json").write_text(json.dumps(state), encoding="utf-8") + + runtime = SimpleNamespace( + _session_store=SimpleNamespace(sessions_dir=sessions_dir), + get_entry_points=lambda: [SimpleNamespace(id="default", entry_node="process")], + graph=SimpleNamespace( + entry_node="process", + get_node=lambda node_id: ( + SimpleNamespace(input_keys=["target_dir", "review_dir", "word_threshold"]) + if node_id == "process" + else None + ), + ), + ) + + payload = qlt._build_worker_input_data( + runtime, + ("review_dir: docs_reviews\nword_threshold: 600"), + ) + + assert payload == { + "target_dir": str((tmp_path / "docs").resolve()), + "review_dir": str((tmp_path / "docs_reviews").resolve()), + "word_threshold": 600, + } + + +@pytest.mark.asyncio +async def test_load_built_agent_blocks_invalid_package(monkeypatch, tmp_path: Path) -> None: + registry = ToolRegistry() + captured: dict[str, str] = {} + registry.register( + "validate_agent_package", + Tool( + name="validate_agent_package", + description="fake validator", + parameters={"type": "object", "properties": {"agent_name": {"type": "string"}}}, + ), + lambda inputs: ( + captured.setdefault("agent_name", inputs["agent_name"]), + json.dumps( + { + "valid": False, + "steps": { + "behavior_validation": { + "passed": False, + "output": ( + "Node 'scan-markdown' has a blank or placeholder system_prompt" + ), + } + }, + } + ), + )[1], + ) + + session = SimpleNamespace(worker_runtime=None, event_bus=None, worker_path=None, runner=None) + fake_manager = SimpleNamespace( + get_session=lambda _sid: None, + unload_worker=AsyncMock(), + load_worker=AsyncMock(), + ) + phase_state = QueenPhaseState(phase="building") + register_queen_lifecycle_tools( + registry, + session=session, + session_manager=fake_manager, + manager_session_id="sess-1", + phase_state=phase_state, + ) + + agent_dir = tmp_path / "broken_agent" + agent_dir.mkdir() + monkeypatch.setattr(qlt, "validate_agent_path", lambda _path: agent_dir) + + result_raw = await registry._tools["load_built_agent"].executor({"agent_path": str(agent_dir)}) + result = json.loads(result_raw) + + assert "Cannot load agent" in result["error"] + assert "behavior_validation" in result["validation_failures"][0] + assert captured["agent_name"] == str(agent_dir) + fake_manager.load_worker.assert_not_called() + fake_manager.unload_worker.assert_not_called() + + +@pytest.mark.asyncio +async def test_load_built_agent_keeps_current_worker_when_replacement_fails_validation( + monkeypatch, tmp_path: Path +) -> None: + registry = ToolRegistry() + _register_fake_validator( + registry, + { + "valid": False, + "steps": { + "behavior_validation": { + "passed": False, + "output": "Node 'scan' has a blank or placeholder system_prompt", + } + }, + }, + ) + session = SimpleNamespace( + worker_runtime=SimpleNamespace(), + event_bus=None, + worker_path=Path("exports/existing_agent"), + runner=None, + ) + fake_manager = SimpleNamespace( + get_session=lambda _sid: None, + unload_worker=AsyncMock(), + load_worker=AsyncMock(), + ) + phase_state = QueenPhaseState(phase="building") + register_queen_lifecycle_tools( + registry, + session=session, + session_manager=fake_manager, + manager_session_id="sess-1", + phase_state=phase_state, + ) + + agent_dir = tmp_path / "broken_agent" + agent_dir.mkdir() + monkeypatch.setattr(qlt, "validate_agent_path", lambda _path: agent_dir) + + result_raw = await registry._tools["load_built_agent"].executor({"agent_path": str(agent_dir)}) + result = json.loads(result_raw) + + assert "Cannot load agent" in result["error"] + fake_manager.unload_worker.assert_not_called() + fake_manager.load_worker.assert_not_called() + + +@pytest.mark.asyncio +async def test_run_agent_with_input_blocks_loaded_invalid_worker() -> None: + registry = ToolRegistry() + _register_fake_validator( + registry, + { + "valid": False, + "steps": { + "tool_validation": { + "passed": False, + "output": "Scan Markdown Files missing run_command", + } + }, + }, + ) + + runtime = SimpleNamespace( + resume_timers=MagicMock(), + trigger=AsyncMock(), + _get_primary_session_state=MagicMock(return_value={}), + graph=SimpleNamespace(nodes=[]), + ) + session = SimpleNamespace( + worker_runtime=runtime, + event_bus=None, + worker_path=Path("exports/broken_agent"), + runner=None, + ) + phase_state = QueenPhaseState(phase="staging") + register_queen_lifecycle_tools( + registry, + session=session, + session_id="sess-2", + phase_state=phase_state, + ) + + result_raw = await registry._tools["run_agent_with_input"].executor({"task": "run it"}) + result = json.loads(result_raw) + + assert "Cannot run agent" in result["error"] + assert "tool_validation" in result["validation_failures"][0] + runtime.trigger.assert_not_called() + + +@pytest.mark.asyncio +async def test_run_agent_with_input_uses_structured_entry_inputs( + monkeypatch, tmp_path: Path +) -> None: + registry = ToolRegistry() + _register_fake_validator(registry, {"valid": True, "steps": {}}) + + monkeypatch.setattr(qlt, "validate_credentials", lambda *args, **kwargs: None) + monkeypatch.chdir(tmp_path) + + runtime = SimpleNamespace( + resume_timers=MagicMock(), + trigger=AsyncMock(return_value="exec-1"), + _get_primary_session_state=MagicMock(return_value={}), + get_entry_points=lambda: [SimpleNamespace(id="default", entry_node="process")], + graph=SimpleNamespace( + nodes=[], + entry_node="process", + get_node=lambda node_id: ( + SimpleNamespace( + input_keys=["docs_path", "review_path", "word_threshold", "style_rules"] + ) + if node_id == "process" + else None + ), + ), + ) + session = SimpleNamespace( + worker_runtime=runtime, + event_bus=None, + worker_path=Path("exports/markdown_condense_approver"), + runner=None, + ) + phase_state = QueenPhaseState(phase="staging") + register_queen_lifecycle_tools( + registry, + session=session, + session_id="sess-3", + phase_state=phase_state, + ) + + result_raw = await registry._tools["run_agent_with_input"].executor( + { + "task": ( + "docs_path: docs/ review_path: docs_reviews/ word_threshold: 800 " + "style_rules: Preserve headings, keep links intact." + ) + } + ) + result = json.loads(result_raw) + + assert result["status"] == "started" + runtime.trigger.assert_awaited_once() + trigger_kwargs = runtime.trigger.await_args.kwargs + assert trigger_kwargs["input_data"] == { + "docs_path": str((tmp_path / "docs").resolve()), + "review_path": str((tmp_path / "docs_reviews").resolve()), + "word_threshold": 800, + "style_rules": "Preserve headings, keep links intact.", + } + assert trigger_kwargs["session_state"] is None + runtime._get_primary_session_state.assert_not_called() + + +@pytest.mark.asyncio +async def test_rerun_worker_with_last_input_reuses_complete_recent_defaults( + monkeypatch, tmp_path: Path +) -> None: + registry = ToolRegistry() + _register_fake_validator(registry, {"valid": True, "steps": {}}) + + monkeypatch.setattr(qlt, "validate_credentials", lambda *args, **kwargs: None) + monkeypatch.chdir(tmp_path) + + sessions_dir = tmp_path / "agent_store" / "sessions" + sessions_dir.mkdir(parents=True) + + valid_prior_state = { + "timestamps": {"updated_at": "2026-03-24T20:44:00"}, + "input_data": { + "target_dir": "docs", + "review_dir": "docs_reviews", + "word_threshold": 800, + "feedback": "stale", + }, + } + malformed_recent_state = { + "timestamps": {"updated_at": "2026-03-24T21:20:23"}, + "input_data": { + "review_dir": "docs_reviews", + "word_threshold": "800. Validate inputs and continue.", + }, + } + + for session_name, state in { + "session_20260324_204400_good": valid_prior_state, + "session_20260324_212023_bad": malformed_recent_state, + }.items(): + session_dir = sessions_dir / session_name + session_dir.mkdir() + (session_dir / "state.json").write_text(json.dumps(state), encoding="utf-8") + + runtime = SimpleNamespace( + _session_store=SimpleNamespace(sessions_dir=sessions_dir), + resume_timers=MagicMock(), + trigger=AsyncMock(return_value="exec-rerun"), + graph=SimpleNamespace( + nodes=[], + entry_node="process", + get_node=lambda node_id: ( + SimpleNamespace(input_keys=["target_dir", "review_dir", "word_threshold"]) + if node_id == "process" + else None + ), + ), + get_entry_points=lambda: [SimpleNamespace(id="default", entry_node="process")], + ) + session = SimpleNamespace( + worker_runtime=runtime, + event_bus=None, + worker_path=Path("exports/local_markdown_review_probe_2"), + runner=None, + ) + phase_state = QueenPhaseState(phase="staging") + register_queen_lifecycle_tools( + registry, + session=session, + session_id="sess-rerun", + phase_state=phase_state, + ) + + result_raw = await registry._tools["rerun_worker_with_last_input"].executor({}) + result = json.loads(result_raw) + + assert result["status"] == "started" + runtime.trigger.assert_awaited_once() + trigger_kwargs = runtime.trigger.await_args.kwargs + assert trigger_kwargs["input_data"] == { + "target_dir": str((tmp_path / "docs").resolve()), + "review_dir": str((tmp_path / "docs_reviews").resolve()), + "word_threshold": 800, + } + assert trigger_kwargs["session_state"] is None + + +@pytest.mark.asyncio +async def test_rerun_worker_with_last_input_preserves_legacy_task_payload( + monkeypatch, tmp_path: Path +) -> None: + registry = ToolRegistry() + _register_fake_validator(registry, {"valid": True, "steps": {}}) + + monkeypatch.setattr(qlt, "validate_credentials", lambda *args, **kwargs: None) + monkeypatch.chdir(tmp_path) + + sessions_dir = tmp_path / "agent_store" / "sessions" + sessions_dir.mkdir(parents=True) + session_dir = sessions_dir / "session_20260324_204400_task" + session_dir.mkdir() + (session_dir / "state.json").write_text( + json.dumps( + { + "timestamps": {"updated_at": "2026-03-24T20:44:00"}, + "input_data": {"task": "re-run the markdown review"}, + } + ), + encoding="utf-8", + ) + + runtime = SimpleNamespace( + _session_store=SimpleNamespace(sessions_dir=sessions_dir), + resume_timers=MagicMock(), + trigger=AsyncMock(return_value="exec-rerun"), + graph=SimpleNamespace( + nodes=[], + entry_node="process", + get_node=lambda node_id: ( + SimpleNamespace(input_keys=["task"]) if node_id == "process" else None + ), + ), + get_entry_points=lambda: [SimpleNamespace(id="default", entry_node="process")], + ) + session = SimpleNamespace( + worker_runtime=runtime, + event_bus=None, + worker_path=Path("exports/legacy_worker"), + runner=None, + ) + phase_state = QueenPhaseState(phase="staging") + register_queen_lifecycle_tools( + registry, + session=session, + session_id="sess-rerun", + phase_state=phase_state, + ) + + result_raw = await registry._tools["rerun_worker_with_last_input"].executor({}) + result = json.loads(result_raw) + + assert result["status"] == "started" + trigger_kwargs = runtime.trigger.await_args.kwargs + assert trigger_kwargs["input_data"] == {"task": "re-run the markdown review"} + + +@pytest.mark.asyncio +async def test_rerun_worker_with_last_input_uses_current_session_defaults_only( + monkeypatch, tmp_path: Path +) -> None: + registry = ToolRegistry() + _register_fake_validator(registry, {"valid": True, "steps": {}}) + + monkeypatch.setattr(qlt, "validate_credentials", lambda *args, **kwargs: None) + monkeypatch.chdir(tmp_path) + + sessions_dir = tmp_path / "agent_store" / "sessions" + current_session_dir = sessions_dir / "sess-rerun-current" + current_session_dir.mkdir(parents=True) + (current_session_dir / "state.json").write_text( + json.dumps( + { + "timestamps": {"updated_at": "2026-03-24T20:44:00"}, + "input_data": {"target_dir": str((tmp_path / "current").resolve())}, + } + ), + encoding="utf-8", + ) + + other_session_dir = sessions_dir / "session_20260325_204400_other" + other_session_dir.mkdir(parents=True) + (other_session_dir / "state.json").write_text( + json.dumps( + { + "timestamps": {"updated_at": "2026-03-25T20:44:00"}, + "input_data": { + "target_dir": str((tmp_path / "other").resolve()), + "review_dir": str((tmp_path / "other_reviews").resolve()), + }, + } + ), + encoding="utf-8", + ) + + runtime = SimpleNamespace( + _session_store=SimpleNamespace(sessions_dir=sessions_dir), + resume_timers=MagicMock(), + trigger=AsyncMock(return_value="exec-rerun"), + graph=SimpleNamespace( + nodes=[], + entry_node="process", + get_node=lambda node_id: ( + SimpleNamespace(input_keys=["target_dir", "review_dir"]) + if node_id == "process" + else None + ), + ), + get_entry_points=lambda: [SimpleNamespace(id="default", entry_node="process")], + ) + session = SimpleNamespace( + worker_runtime=runtime, + event_bus=None, + worker_path=Path("exports/docs_reviewer"), + runner=None, + ) + phase_state = QueenPhaseState(phase="staging") + register_queen_lifecycle_tools( + registry, + session=session, + session_id="sess-rerun-current", + phase_state=phase_state, + ) + + result_raw = await registry._tools["rerun_worker_with_last_input"].executor({}) + result = json.loads(result_raw) + + assert ( + result["error"] + == "No complete previous worker input is available for a same-defaults rerun." + ) + assert result["missing_inputs"] == ["review_dir"] + + +@pytest.mark.asyncio +async def test_run_agent_with_input_blocks_when_validator_output_is_undecodable( + monkeypatch, tmp_path: Path +) -> None: + registry = ToolRegistry() + registry.register( + "validate_agent_package", + Tool( + name="validate_agent_package", + description="fake validator", + parameters={"type": "object", "properties": {"agent_name": {"type": "string"}}}, + ), + lambda _inputs: "not-json", + ) + + monkeypatch.setattr(qlt, "validate_credentials", lambda *args, **kwargs: None) + monkeypatch.chdir(tmp_path) + + runtime = SimpleNamespace( + resume_timers=MagicMock(), + trigger=AsyncMock(return_value="exec-2"), + graph=SimpleNamespace(nodes=[]), + ) + session = SimpleNamespace( + worker_runtime=runtime, + event_bus=None, + worker_path=Path("exports/docs_sanitizer_agent"), + runner=None, + ) + register_queen_lifecycle_tools( + registry, + session=session, + session_id="sess-invalid-validator", + phase_state=QueenPhaseState(phase="staging"), + ) + + result_raw = await registry._tools["run_agent_with_input"].executor({"task": "run it"}) + result = json.loads(result_raw) + + assert "validation is failing" in result["error"] + assert result["validation_failures"] == [ + "validator_subprocess: validate_agent_package returned an invalid or undecodable report" + ] + runtime.trigger.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_start_worker_starts_fresh_worker_session(monkeypatch, tmp_path: Path) -> None: + registry = ToolRegistry() + monkeypatch.setattr(qlt, "validate_credentials", lambda *args, **kwargs: None) + monkeypatch.chdir(tmp_path) + + runtime = SimpleNamespace( + resume_timers=MagicMock(), + trigger=AsyncMock(return_value="exec-2"), + _get_primary_session_state=MagicMock(return_value={"resume_session_id": "old"}), + graph=SimpleNamespace(nodes=[]), + ) + session = SimpleNamespace( + worker_runtime=runtime, + event_bus=None, + worker_path=Path("exports/docs_sanitizer_agent"), + runner=None, + ) + register_queen_lifecycle_tools( + registry, + session=session, + session_id="sess-4", + phase_state=QueenPhaseState(phase="staging"), + ) + + result_raw = await registry._tools["start_worker"].executor( + {"task": "run with docs_path: docs/"} + ) + result = json.loads(result_raw) + + assert result["status"] == "started" + runtime.trigger.assert_awaited_once() + trigger_kwargs = runtime.trigger.await_args.kwargs + assert trigger_kwargs["session_state"] is None + runtime._get_primary_session_state.assert_not_called() diff --git a/core/tests/test_queen_orchestrator.py b/core/tests/test_queen_orchestrator.py new file mode 100644 index 0000000000..abaa754c14 --- /dev/null +++ b/core/tests/test_queen_orchestrator.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from framework.server.queen_orchestrator import ( + _build_worker_terminal_notification, + _select_primary_worker_result, +) + + +def test_select_primary_worker_result_prefers_result_over_input_like_keys() -> None: + output = { + "task": "Solve this word problem using strict Given/Steps/Answer format.", + "givens": "Item prices and discount rule.", + "final_value": "₹2,295", + "result": ( + "Given:\n- Item prices: ₹850, ₹1,200, ₹650\n\n" + "Steps:\n1) Add prices.\n2) Apply discount.\n\n" + "Answer:\n₹2,295" + ), + } + + assert _select_primary_worker_result(output) == ( + "result", + output["result"], + ) + + +def test_worker_terminal_notification_embeds_primary_result_verbatim() -> None: + result = ( + "Given:\n- Item prices: ₹850, ₹1,200, ₹650\n\n" + "Steps:\n1) Add prices.\n2) Apply discount.\n\n" + "Answer:\n₹2,295" + ) + notification = _build_worker_terminal_notification( + { + "task": "Solve the problem.", + "final_value": "₹2,295", + "result": result, + } + ) + + assert "Primary result key: result" in notification + assert "[PRIMARY_RESULT_BEGIN]" in notification + assert result in notification + assert "Do not paraphrase, compress, or reformat it." in notification + + +def test_worker_terminal_notification_does_not_treat_artifact_filename_as_primary_result() -> None: + notification = _build_worker_terminal_notification( + { + "research_brief": "AI news roundup from the past week.", + "articles_data": "[Saved to 'output_articles_data.json' (3957 bytes).]", + "report_file": "tech_news_report.html", + } + ) + + assert "[PRIMARY_RESULT_BEGIN]" not in notification + assert "report_file: tech_news_report.html" in notification diff --git a/core/tests/test_session_manager_worker_handoff.py b/core/tests/test_session_manager_worker_handoff.py index f496227d4f..172db01d02 100644 --- a/core/tests/test_session_manager_worker_handoff.py +++ b/core/tests/test_session_manager_worker_handoff.py @@ -1,11 +1,14 @@ from __future__ import annotations +import asyncio +from pathlib import Path from types import SimpleNamespace from unittest.mock import AsyncMock import pytest -from framework.runtime.event_bus import EventBus +import framework.agents.worker_memory as worker_memory +from framework.runtime.event_bus import AgentEvent, EventBus, EventType from framework.server.session_manager import Session, SessionManager @@ -123,3 +126,52 @@ async def test_stop_session_unsubscribes_worker_handoff() -> None: reason="after stop", ) assert queen_node.inject_event.await_count == 1 + + +@pytest.mark.asyncio +async def test_worker_digest_final_completion_does_not_overwrite_terminal_result( + monkeypatch: pytest.MonkeyPatch, +) -> None: + bus = EventBus() + manager = SessionManager() + session = _make_session(bus, session_id="session_digest_final") + session.worker_path = Path("/tmp/log_triage_agent") + + queen_node = SimpleNamespace(inject_event=AsyncMock()) + session.queen_executor = _make_executor(queen_node) + + consolidate = AsyncMock() + monkeypatch.setattr(worker_memory, "consolidate_worker_run", consolidate) + + manager._subscribe_worker_digest(session) + + bus.get_history = lambda event_type=None, limit=None: [ + AgentEvent( + type=EventType.EXECUTION_STARTED, + stream_id="default", + execution_id="exec_digest", + run_id="run_digest", + ) + ] + + await bus.publish( + AgentEvent( + type=EventType.EXECUTION_STARTED, + stream_id="default", + execution_id="exec_digest", + run_id="run_digest", + ) + ) + await bus.publish( + AgentEvent( + type=EventType.EXECUTION_COMPLETED, + stream_id="default", + execution_id="exec_digest", + run_id="run_digest", + data={"output": {"result": "final answer"}}, + ) + ) + await asyncio.sleep(0.05) + + consolidate.assert_awaited_once() + assert queen_node.inject_event.await_count == 0 diff --git a/core/tests/test_session_manager_worker_validation.py b/core/tests/test_session_manager_worker_validation.py new file mode 100644 index 0000000000..71f3cf1926 --- /dev/null +++ b/core/tests/test_session_manager_worker_validation.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +import framework.runner as runner_mod +from framework.server import session_manager as sm + + +@pytest.mark.asyncio +async def test_load_worker_blocks_invalid_package_before_runner_load(monkeypatch) -> None: + manager = sm.SessionManager() + session = sm.Session( + id="sess-1", + event_bus=MagicMock(), + llm=MagicMock(), + loaded_at=0.0, + ) + manager._sessions[session.id] = session + + captured: dict[str, str] = {} + + def _fake_validation(agent_ref): + captured["agent_ref"] = str(agent_ref) + return { + "valid": False, + "steps": { + "behavior_validation": { + "passed": False, + "errors": ["Node 'scan' has a blank or placeholder system_prompt"], + } + }, + } + + monkeypatch.setattr( + sm, + "_run_validation_report_sync", + _fake_validation, + ) + + called = {"runner_load": False} + + class _FakeAgentRunner: + @staticmethod + def load(*args, **kwargs): + called["runner_load"] = True + raise AssertionError("AgentRunner.load should not run for invalid workers") + + monkeypatch.setattr(runner_mod, "AgentRunner", _FakeAgentRunner) + + with pytest.raises(sm.WorkerValidationError) as exc: + await manager.load_worker(session.id, Path("/tmp/bad_worker")) + + assert "blank or placeholder system_prompt" in str(exc.value) + assert Path(captured["agent_ref"]).as_posix() == "/tmp/bad_worker" + assert called["runner_load"] is False + + +def test_run_validation_report_sync_uses_internal_validator_impl(monkeypatch) -> None: + captured: dict[str, object] = {} + + class _Proc: + returncode = 0 + stdout = '{"valid": true, "steps": {}}' + stderr = "" + + def _fake_run(cmd, **kwargs): + captured["cmd"] = cmd + return _Proc() + + monkeypatch.setattr(sm.subprocess, "run", _fake_run) + + report = sm._run_validation_report_sync("/tmp/demo_agent") + + assert report["valid"] is True + script = captured["cmd"][4] + assert "_validate_agent_package_impl" in script + assert "validate_agent_package(agent_name)" not in script + assert captured["cmd"][6] == "/tmp/demo_agent" + + +def test_validation_blocks_stage_or_run_ignores_non_blocking_warnings() -> None: + report = { + "steps": { + "behavior_validation": { + "passed": True, + "warnings": ["placeholder prompt"], + "output": "placeholder prompt", + }, + "tests": { + "passed": True, + "warnings": ["1 failed"], + "summary": "1 failed", + }, + }, + } + + assert sm._validation_blocks_stage_or_run(report) is False + + +def test_run_validation_report_sync_handles_subprocess_launcher_errors(monkeypatch) -> None: + def _boom(*args, **kwargs): + raise FileNotFoundError("uv not found") + + monkeypatch.setattr(sm.subprocess, "run", _boom) + + report = sm._run_validation_report_sync("/tmp/demo_agent") + + assert report["valid"] is False + assert report["steps"]["validator_subprocess"]["passed"] is False + assert "uv not found" in report["steps"]["validator_subprocess"]["error"] diff --git a/core/tests/test_validate_agent_path.py b/core/tests/test_validate_agent_path.py index b92fd948ec..fa68af61ba 100644 --- a/core/tests/test_validate_agent_path.py +++ b/core/tests/test_validate_agent_path.py @@ -123,6 +123,28 @@ def test_returns_path_object(self, tmp_path): result = validate_agent_path(str(agent_dir)) assert isinstance(result, Path) + def test_repo_relative_path_resolves_from_repo_root_not_cwd(self, tmp_path, monkeypatch): + import framework.server.app as app_module + + repo_root = tmp_path / "repo" + examples_root = repo_root / "examples" + agent_dir = examples_root / "some_agent" + agent_dir.mkdir(parents=True) + other_cwd = tmp_path / "elsewhere" + other_cwd.mkdir() + + monkeypatch.setattr(app_module, "_REPO_ROOT", repo_root) + app_module._ALLOWED_AGENT_ROOTS = ( + repo_root / "exports", + examples_root, + tmp_path / ".hive" / "agents", + ) + monkeypatch.chdir(other_cwd) + + result = validate_agent_path("examples/some_agent") + + assert result == agent_dir.resolve() + # --------------------------------------------------------------------------- # validate_agent_path: negative cases (should raise ValueError) diff --git a/core/tests/test_worker_monitoring_tools.py b/core/tests/test_worker_monitoring_tools.py new file mode 100644 index 0000000000..8a29b2d3e5 --- /dev/null +++ b/core/tests/test_worker_monitoring_tools.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +import json +import os +import time +from pathlib import Path + +import pytest + +from framework.runner.tool_registry import ToolRegistry +from framework.runtime.event_bus import EventBus +from framework.tools.worker_monitoring_tools import register_worker_monitoring_tools + + +def _write_session_logs( + storage_path: Path, + session_id: str, + *, + session_status: str, + steps: list[dict], +) -> Path: + session_dir = storage_path / "sessions" / session_id + logs_dir = session_dir / "logs" + logs_dir.mkdir(parents=True, exist_ok=True) + (session_dir / "state.json").write_text( + json.dumps({"status": session_status}), + encoding="utf-8", + ) + log_path = logs_dir / "tool_logs.jsonl" + log_path.write_text( + "".join(json.dumps(step) + "\n" for step in steps), + encoding="utf-8", + ) + return log_path + + +@pytest.mark.asyncio +async def test_worker_health_summary_marks_healthy_runs(tmp_path: Path) -> None: + registry = ToolRegistry() + event_bus = EventBus() + storage_path = tmp_path / "agent_store" + storage_path.mkdir(parents=True, exist_ok=True) + + _write_session_logs( + storage_path, + "session-healthy", + session_status="running", + steps=[ + {"verdict": "RETRY", "llm_text": "first pass"}, + {"verdict": "ACCEPT", "llm_text": "done"}, + ], + ) + + register_worker_monitoring_tools( + registry, + event_bus, + storage_path, + default_session_id="session-healthy", + ) + + raw = await registry._tools["get_worker_health_summary"].executor({}) + data = json.loads(raw) + + assert data["health_status"] == "healthy" + assert data["issue_signals"] == [] + assert data["recent_verdicts"] == ["RETRY", "ACCEPT"] + assert data["steps_since_last_accept"] == 0 + + +@pytest.mark.asyncio +async def test_worker_health_summary_flags_stall_and_non_accept_churn(tmp_path: Path) -> None: + registry = ToolRegistry() + event_bus = EventBus() + storage_path = tmp_path / "agent_store" + storage_path.mkdir(parents=True, exist_ok=True) + + log_path = _write_session_logs( + storage_path, + "session-stalled", + session_status="running", + steps=[ + {"verdict": "CONTINUE", "llm_text": "thinking"}, + {"verdict": "RETRY", "llm_text": "still working"}, + {"verdict": "RETRY", "llm_text": "trying again"}, + {"verdict": "ESCALATE", "llm_text": "blocked"}, + ], + ) + ten_minutes_ago = time.time() - 600 + os.utime(log_path, (ten_minutes_ago, ten_minutes_ago)) + + register_worker_monitoring_tools( + registry, + event_bus, + storage_path, + default_session_id="session-stalled", + ) + + raw = await registry._tools["get_worker_health_summary"].executor({}) + data = json.loads(raw) + + assert data["health_status"] == "critical" + assert "stalled" in data["issue_signals"] + assert "judge_pressure" in data["issue_signals"] + assert "recent_non_accept_churn" in data["issue_signals"] + assert data["steps_since_last_accept"] == 4 + assert data["stall_minutes"] is not None + + +@pytest.mark.asyncio +async def test_worker_health_summary_errors_when_default_session_has_no_state( + tmp_path: Path, +) -> None: + registry = ToolRegistry() + event_bus = EventBus() + storage_path = tmp_path / "agent_store" + (storage_path / "sessions" / "session-stale").mkdir(parents=True, exist_ok=True) + + register_worker_monitoring_tools( + registry, + event_bus, + storage_path, + default_session_id="session-stale", + ) + + raw = await registry._tools["get_worker_health_summary"].executor({}) + data = json.loads(raw) + + assert "error" in data + assert "session-stale" in data["error"] diff --git a/tools/coder_tools_server.py b/tools/coder_tools_server.py index 2b124916a1..30144829ce 100644 --- a/tools/coder_tools_server.py +++ b/tools/coder_tools_server.py @@ -68,6 +68,15 @@ def _patched_console_init(self, *args, **kwargs): PROJECT_ROOT: str = "" SNAPSHOT_DIR: str = "" +_PLACEHOLDER_MARKERS = ( + "TODO", + "TODO:", + "TODO ", + "Add system prompt for this node", + "Add identity prompt", + "Define success criteria", + "Describe what this node does", +) # ── Path resolution ─────────────────────────────────────────────────────── @@ -138,6 +147,388 @@ def _resolve_path(path: str) -> str: return resolved +def _is_placeholder_text(value: str | None) -> bool: + text = (value or "").strip() + if not text: + return True + return any(marker in text for marker in _PLACEHOLDER_MARKERS) + + +_ENTRY_INTAKE_HINTS = ( + "parse the incoming task", + "parse the task text", + "structured runtime task", + "accept structured runtime task", + "read runtime task input", + "validate runtime", + "validate runtime path", + "validate runtime paths", + "intake & validate", + "configuration values", + "intake config", + "infer values conservatively", +) +_ENTRY_DIRECT_WORK_HINTS = ( + "scan", + "scanning", + "discover", + "discovery", + "search", + "fetch", + "analyze", + "analyse", + "transform", + "sanitize", + "summarize", + "summarise", + "generate", + "write", + "apply", + "candidate", +) +_TOOL_ALIAS_HINTS = { + "run_command": "execute_command_tool", +} +_OUTPUT_DIRECTORY_INPUT_HINTS = { + "review_dir", + "output_dir", + "destination_dir", + "dest_dir", + "target_dir", + "review_path", + "output_path", + "destination_path", + "dest_path", + "target_path", +} +_SESSION_DATA_TOOLS = frozenset( + { + "save_data", + "load_data", + "list_data_files", + "append_data", + "edit_data", + "serve_file_to_user", + } +) +_WORKSPACE_PATH_HINTS = frozenset( + { + "review_dir", + "review_root", + "output_dir", + "output_root", + "output_path", + "target_dir", + "target_root", + "target_path", + "workspace", + "project folder", + "project folders", + } +) +_SESSION_DATA_TOOL_PATH_OP_RE = re.compile( + r"(save_data|load_data|list_data_files|append_data|edit_data|serve_file_to_user)" + r"[^.\n]{0,200}\b(?:to|into|in|inside|under|within|from|at|on)\s+(?:the\s+)?" + r"(review_dir|review_root|output_dir|output_root|output_path|target_dir|target_root|" + r"target_path|workspace|project folder|project folders)\b|" + r"\b(?:to|into|in|inside|under|within|from|at|on)\s+(?:the\s+)?" + r"(review_dir|review_root|output_dir|output_root|output_path|target_dir|target_root|" + r"target_path|workspace|project folder|project folders)\b[^.\n]{0,200}" + r"(save_data|load_data|list_data_files|append_data|edit_data|serve_file_to_user)", + re.IGNORECASE, +) + + +def _contains_hint_word(text: str, hint: str) -> bool: + """Return True when *hint* appears as a word/phrase, not just a substring.""" + if " " in hint: + return hint in text + return re.search(rf"\b{re.escape(hint)}\b", text) is not None + + +def _default_intro_message(human_name: str, description: str) -> str: + """Return a non-placeholder intro line for generated agents.""" + desc = (description or "").strip().rstrip(".") + if desc: + return f"{desc}." + return f"Ready to run {human_name}." + + +def _default_success_metric(index: int) -> str: + """Return a generic but non-placeholder success metric name.""" + return f"criterion_{index}_satisfied" + + +def _looks_like_agent_path(agent_ref: str) -> bool: + """Return True when *agent_ref* should be treated as a filesystem path.""" + candidate = Path(agent_ref).expanduser() + return candidate.is_absolute() or len(candidate.parts) > 1 or agent_ref.startswith((".", "~")) + + +def _resolve_agent_package_target(agent_ref: str) -> tuple[Path, str, str]: + """Resolve a validator target to (agent_dir, package_name, display_ref). + + Bare names still target exports/ for build-time validation. + Paths are resolved relative to the repository root and must pass the + server allowlist so existing example agents can be staged safely. + """ + ref = (agent_ref or "").strip() + if not ref: + raise ValueError("Agent reference is required") + + if not PROJECT_ROOT: + raise ValueError("PROJECT_ROOT is not configured") + + if _looks_like_agent_path(ref): + resolved = Path(_resolve_path(ref)) + try: + from framework.server.app import validate_agent_path + except ImportError as exc: + raise ValueError("Cannot validate agent path: framework package not available") from exc + + resolved = validate_agent_path(str(resolved)) + return resolved, resolved.name, str(resolved) + + resolved = (Path(PROJECT_ROOT) / "exports" / ref).resolve() + return resolved, ref, f"exports/{ref}" + + +def _default_success_target() -> str: + """Return a generic non-placeholder success target.""" + return "true" + + +def _node_can_progress_without_declared_tools(node) -> bool: + """Return True when a node can legitimately work without MCP/local tools. + + Runtime supports two common cases that should not be blocked by static + validation: + - ``gcu`` nodes, whose browser tools are injected by the framework. + - pure LLM work nodes that consume inputs and explicitly write outputs via + ``set_output`` without needing external tools. + """ + if getattr(node, "node_type", "") == "gcu": + return True + + output_keys = list(getattr(node, "output_keys", []) or []) + if not output_keys: + return False + + prompt = (getattr(node, "system_prompt", "") or "").lower() + text = " ".join( + filter( + None, + [ + getattr(node, "name", "") or "", + getattr(node, "description", "") or "", + getattr(node, "system_prompt", "") or "", + ], + ) + ).lower() + mentions_set_output = ( + "set_output(" in prompt + or "call set_output" in prompt + or "use set_output" in prompt + or "set_output " in prompt + ) + looks_like_real_work = any(_contains_hint_word(text, hint) for hint in _ENTRY_DIRECT_WORK_HINTS) + return mentions_set_output and looks_like_real_work + + +def _behavior_validation_errors(agent_module) -> list[str]: + """Return behavior-level validation errors for a generated agent package.""" + errors: list[str] = [] + nodes = list(getattr(agent_module, "nodes", []) or []) + terminal_ids = set(getattr(agent_module, "terminal_nodes", []) or []) + entry_node_id = getattr(agent_module, "entry_node", None) or "" + identity_prompt = getattr(agent_module, "identity_prompt", "") or "" + metadata = getattr(agent_module, "metadata", None) + goal = getattr(agent_module, "goal", None) + + identity_prompt_text = identity_prompt.strip() + if not identity_prompt_text: + errors.append("identity_prompt is blank") + elif any(marker in identity_prompt_text for marker in _PLACEHOLDER_MARKERS): + errors.append("identity_prompt still contains TODO placeholders") + + if metadata is not None: + if _is_placeholder_text(getattr(metadata, "description", "") or ""): + errors.append("metadata.description is blank or still contains TODO placeholders") + if _is_placeholder_text(getattr(metadata, "intro_message", "") or ""): + errors.append("metadata.intro_message is blank or still contains TODO placeholders") + + if goal is not None: + if _is_placeholder_text(getattr(goal, "description", "") or ""): + errors.append("goal.description is blank or still contains TODO placeholders") + for criterion in list(getattr(goal, "success_criteria", []) or []): + cid = getattr(criterion, "id", "") + for attr in ("description", "metric", "target"): + if _is_placeholder_text(getattr(criterion, attr, "") or ""): + errors.append(f"Success criterion '{cid}' has blank or placeholder {attr}") + for constraint in list(getattr(goal, "constraints", []) or []): + cid = getattr(constraint, "id", "") + if _is_placeholder_text(getattr(constraint, "description", "") or ""): + errors.append(f"Constraint '{cid}' has blank or placeholder description") + + for node in nodes: + node_id = getattr(node, "id", "") + node_desc = getattr(node, "description", "") or "" + if _is_placeholder_text(node_desc): + errors.append(f"Node '{node_id}' has a blank or placeholder description") + + prompt = getattr(node, "system_prompt", "") or "" + prompt_lower = prompt.lower() + if _is_placeholder_text(prompt): + errors.append(f"Node '{node_id}' has a blank or placeholder system_prompt") + else: + tools = list(getattr(node, "tools", []) or []) + for tool_name in tools: + if isinstance(tool_name, str) and f"{tool_name}(" in prompt: + errors.append( + f"Node '{node_id}' system_prompt uses callable-style tool syntax for " + f"'{tool_name}'. Describe tool usage in prose instead of " + "Python-style calls." + ) + for alias, actual in _TOOL_ALIAS_HINTS.items(): + if alias in prompt and actual in tools and alias not in tools: + errors.append( + f"Node '{node_id}' system_prompt references unsupported tool alias " + f"'{alias}'. Use the actual registered tool name '{actual}'." + ) + data_tools_used = [tool for tool in tools if tool in _SESSION_DATA_TOOLS] + if data_tools_used: + workspace_path_hints = [] + if _SESSION_DATA_TOOL_PATH_OP_RE.search(prompt): + workspace_path_hints = [ + hint + for hint in _WORKSPACE_PATH_HINTS + if _contains_hint_word(prompt_lower, hint) + ] + if workspace_path_hints: + joined_tools = ", ".join(sorted(data_tools_used)) + joined_hints = ", ".join(sorted(workspace_path_hints)) + errors.append( + f"Node '{node_id}' uses session data tools ({joined_tools}) as if they " + f"can operate on workspace paths ({joined_hints}). Data tools use the " + "framework-managed session data directory; use execute_command_tool for " + "workspace/review/output directories." + ) + + success_criteria = getattr(node, "success_criteria", "") or "" + if _is_placeholder_text(success_criteria): + errors.append(f"Node '{node_id}' has blank or placeholder success_criteria") + + tools = list(getattr(node, "tools", []) or []) + sub_agents = list(getattr(node, "sub_agents", []) or []) + client_facing = bool(getattr(node, "client_facing", False)) + if ( + node_id not in terminal_ids + and not client_facing + and not tools + and not sub_agents + and not _node_can_progress_without_declared_tools(node) + ): + errors.append(f"Autonomous node '{node_id}' has no tools or sub_agents") + + if node_id == entry_node_id: + input_keys = list(getattr(node, "input_keys", []) or []) + output_keys = list(getattr(node, "output_keys", []) or []) + text = " ".join( + filter( + None, + [ + getattr(node, "name", "") or "", + node_desc, + prompt, + ], + ) + ).lower() + lowered_input_keys = {str(key).lower() for key in input_keys} + lowered_output_keys = {str(key).lower() for key in output_keys} + generic_task_only = len(input_keys) == 1 and input_keys[0] in { + "task", + "user_request", + "raw", + "input", + "request", + "message", + } + intake_like = any(hint in text for hint in _ENTRY_INTAKE_HINTS) + direct_work_like = any( + _contains_hint_word(text, hint) for hint in _ENTRY_DIRECT_WORK_HINTS + ) + pass_through_inputs = bool(lowered_input_keys) and ( + lowered_input_keys <= lowered_output_keys + ) + runtime_normalization_only = any( + hint in text + for hint in ( + "validate", + "validation", + "normalize", + "normalise", + "config", + "configuration", + "runtime", + "path", + "paths", + ) + ) + if ( + intake_like + and not direct_work_like + and (generic_task_only or pass_through_inputs or runtime_normalization_only) + ): + errors.append( + f"Entry node '{node_id}' appears to be an intake/config parser. " + "The queen handles intake. Make the first real work node consume " + "structured input_keys directly instead of reparsing a generic task string." + ) + for input_key in input_keys: + lowered_key = str(input_key).lower() + if lowered_key not in _OUTPUT_DIRECTORY_INPUT_HINTS: + continue + if ( + lowered_key in text + and "exist" in text + and ("directory" in text or "directories" in text) + ): + errors.append( + f"Entry node '{node_id}' requires output path '{input_key}' to pre-exist. " + "Output/review directories should be created if missing instead of " + "blocking the run during intake validation." + ) + break + + return errors + + +def _classify_behavior_validation_errors(errors: list[str]) -> tuple[list[str], list[str]]: + """Split behavior validation findings into blocking errors and warnings. + + Hard failures are reserved for issues that are likely to break runtime + execution or violate framework contracts. Quality/style issues remain + warnings so they can be surfaced without preventing staging/runs. + """ + blocking_markers = ( + "identity_prompt still contains TODO placeholders", + "blank or placeholder system_prompt", + "uses session data tools", + "Autonomous node ", + "appears to be an intake/config parser", + "requires output path", + ) + + blocking: list[str] = [] + warnings: list[str] = [] + for error in errors: + if any(marker in error for marker in blocking_markers): + blocking.append(error) + else: + warnings.append(error) + return blocking, warnings + + # ── Git snapshot system (ported from opencode's shadow git) ─────────────── @@ -752,9 +1143,14 @@ def _infer_service(tool_name: str) -> str: def _validate_agent_tools_impl(agent_path: str) -> dict: - """Validate that all tools declared in an agent's nodes exist in its MCP servers. + """Validate that all tools declared in an agent's nodes exist at runtime. - Returns a dict with validation result: pass/fail, missing tools per node, available tools. + Mirrors runtime tool discovery: + 1. MCP tools from ``mcp_servers.json`` (when present) + 2. Agent-local ``tools.py`` custom tools (when present) + + Returns a dict with validation result: pass/fail, missing tools per node, + available tools, and discovery warnings. """ try: resolved = _resolve_path(agent_path) @@ -781,11 +1177,8 @@ def _validate_agent_tools_impl(agent_path: str) -> dict: agent_dir = resolved # Keep path; 'resolved' is reused for MCP config in loop - # --- Discover available tools from agent's MCP servers --- + # --- Discover available tools from MCP + local tools.py --- mcp_config_path = os.path.join(agent_dir, "mcp_servers.json") - if not os.path.isfile(mcp_config_path): - return {"error": f"No mcp_servers.json found in {agent_path}"} - try: from pathlib import Path @@ -796,36 +1189,45 @@ def _validate_agent_tools_impl(agent_path: str) -> dict: available_tools: set[str] = set() discovery_errors = [] - config_dir = Path(mcp_config_path).parent - - try: - with open(mcp_config_path, encoding="utf-8") as f: - servers_config = json.load(f) - except (json.JSONDecodeError, OSError) as e: - return {"error": f"Failed to read mcp_servers.json: {e}"} - - for server_name, server_conf in servers_config.items(): - resolved = ToolRegistry.resolve_mcp_stdio_config( - {"name": server_name, **server_conf}, config_dir - ) + if os.path.isfile(mcp_config_path): + config_dir = Path(mcp_config_path).parent try: - config = MCPServerConfig( - name=server_name, - transport=resolved.get("transport", "stdio"), - command=resolved.get("command"), - args=resolved.get("args", []), - env=resolved.get("env", {}), - cwd=resolved.get("cwd"), - url=resolved.get("url"), - headers=resolved.get("headers", {}), + with open(mcp_config_path, encoding="utf-8") as f: + servers_config = json.load(f) + except (json.JSONDecodeError, OSError) as e: + return {"error": f"Failed to read mcp_servers.json: {e}"} + + for server_name, server_conf in servers_config.items(): + resolved = ToolRegistry.resolve_mcp_stdio_config( + {"name": server_name, **server_conf}, config_dir ) - client = MCPClient(config) - client.connect() - for tool in client.list_tools(): - available_tools.add(tool.name) - client.disconnect() + try: + config = MCPServerConfig( + name=server_name, + transport=resolved.get("transport", "stdio"), + command=resolved.get("command"), + args=resolved.get("args", []), + env=resolved.get("env", {}), + cwd=resolved.get("cwd"), + url=resolved.get("url"), + headers=resolved.get("headers", {}), + ) + client = MCPClient(config) + client.connect() + for tool in client.list_tools(): + available_tools.add(tool.name) + client.disconnect() + except Exception as e: + discovery_errors.append({"server": server_name, "error": str(e)}) + + local_tools_path = Path(agent_dir) / "tools.py" + if local_tools_path.is_file(): + try: + registry = ToolRegistry() + registry.discover_from_module(local_tools_path) + available_tools.update(registry.get_tools().keys()) except Exception as e: - discovery_errors.append({"server": server_name, "error": str(e)}) + discovery_errors.append({"server": "tools.py", "error": str(e)}) # --- Load agent nodes and extract declared tools --- agent_py = os.path.join(agent_dir, "agent.py") @@ -1227,7 +1629,7 @@ def get_agent_checkpoint( def _run_agent_tests_impl( - agent_name: str, + agent_ref: str, test_types: str = "all", fail_fast: bool = False, ) -> dict: @@ -1235,22 +1637,35 @@ def _run_agent_tests_impl( Returns a dict with summary counts, per-test results, and failure details. """ - agent_path = Path(PROJECT_ROOT) / "exports" / agent_name + try: + agent_path, agent_name, display_ref = _resolve_agent_package_target(agent_ref) + except ValueError as e: + return {"error": str(e)} + if not agent_path.is_dir(): - # Fall back to framework agents + # Fall back to framework agents for bare framework package names. agent_path = Path(PROJECT_ROOT) / "core" / "framework" / "agents" / agent_name tests_dir = agent_path / "tests" if not agent_path.is_dir(): return { - "error": f"Agent not found: {agent_name}", + "error": f"Agent not found: {agent_ref}", "hint": "Use list_agents() to see available agents.", } if not tests_dir.exists(): return { - "error": f"No tests directory: exports/{agent_name}/tests/", - "hint": "Create test files in the tests/ directory first.", + "agent_name": agent_name, + "agent_path": str(agent_path), + "summary": f"No tests directory: {tests_dir}", + "passed": 0, + "failed": 0, + "skipped": 1, + "errors": 0, + "total": 0, + "test_results": [], + "failures": [], + "skipped_all": True, } # Parse test types @@ -1297,7 +1712,10 @@ def _run_agent_tests_impl( core_path = os.path.join(PROJECT_ROOT, "core") exports_path = os.path.join(PROJECT_ROOT, "exports") fw_agents_path = os.path.join(PROJECT_ROOT, "core", "framework", "agents") + package_parent = str(agent_path.parent) path_parts = [core_path, exports_path, fw_agents_path, PROJECT_ROOT] + if package_parent not in path_parts: + path_parts.insert(1, package_parent) if pythonpath: path_parts.append(pythonpath) env["PYTHONPATH"] = os.pathsep.join(path_parts) @@ -1387,6 +1805,7 @@ def _run_agent_tests_impl( return { "agent_name": agent_name, + "agent_path": str(agent_path), "summary": summary_text, "passed": passed, "failed": failed, @@ -1425,28 +1844,50 @@ def run_agent_tests( # ── Meta-agent: Unified agent validation ─────────────────────────────────── -@mcp.tool() -def validate_agent_package(agent_name: str) -> str: +def _validate_agent_package_impl(agent_name: str) -> dict[str, object]: """Run structural validation checks on a built agent package in one call. - Executes 5 steps and reports all results (does not stop on first failure): + Executes multiple checks and reports all results (does not stop on first failure): 1. Class validation — checks graph structure and entry_points contract 2. Node completeness — every NodeSpec in nodes/ must be in the nodes list, and GCU nodes must be referenced in a parent's sub_agents 3. Graph validation — loads the agent graph without credential checks - 4. Tool validation — checks declared tools exist in MCP servers - 5. Tests — runs the agent's pytest suite + 4. Behavior validation — rejects placeholder prompts and empty autonomous nodes + 5. Tool validation — checks declared tools exist in MCP servers + 6. Tests — runs the agent's pytest suite Note: Credential validation is intentionally skipped here (building phase). Credentials are validated at run time by run_agent_with_input() preflight. Args: - agent_name: Agent package name (e.g. 'my_agent'). Must exist in exports/. + agent_name: Agent package name (e.g. 'my_agent') or an allowed + agent path such as examples/templates/my_agent. Returns: - JSON with per-step results and overall pass/fail summary + Dict with per-step results and overall pass/fail summary """ - agent_path = f"exports/{agent_name}" + global PROJECT_ROOT, SNAPSHOT_DIR + + if not PROJECT_ROOT: + PROJECT_ROOT = _find_project_root() + if not SNAPSHOT_DIR and PROJECT_ROOT: + SNAPSHOT_DIR = os.path.join( + os.path.expanduser("~"), + ".hive", + "snapshots", + os.path.basename(PROJECT_ROOT), + ) + + try: + agent_dir, package_name, display_ref = _resolve_agent_package_target(agent_name) + except ValueError as e: + return { + "valid": False, + "agent_name": agent_name, + "steps": {"target_resolution": {"passed": False, "error": str(e)}}, + "summary": "FAIL: 1 of 1 steps failed (target_resolution)", + } + steps: dict[str, dict] = {} # Set up env for subprocess calls @@ -1454,8 +1895,11 @@ def validate_agent_package(agent_name: str) -> str: core_path = os.path.join(PROJECT_ROOT, "core") exports_path = os.path.join(PROJECT_ROOT, "exports") fw_agents_path = os.path.join(PROJECT_ROOT, "core", "framework", "agents") + package_parent = str(agent_dir.parent) pythonpath = env.get("PYTHONPATH", "") path_parts = [core_path, exports_path, fw_agents_path, PROJECT_ROOT] + if package_parent not in path_parts: + path_parts.insert(1, package_parent) if pythonpath: path_parts.append(pythonpath) env["PYTHONPATH"] = os.pathsep.join(path_parts) @@ -1464,22 +1908,22 @@ def validate_agent_package(agent_name: str) -> str: try: _contract_script = textwrap.dedent("""\ import importlib, json - mod = importlib.import_module('{agent_name}') + mod = importlib.import_module('{package_name}') missing = [a for a in ('goal', 'nodes', 'edges') if getattr(mod, a, None) is None] if missing: print(json.dumps({{ 'valid': False, 'error': ( - "Module '{agent_name}' is missing module-level attributes: " + "Module '{package_name}' is missing module-level attributes: " + ", ".join(missing) + ". " - "Fix: in {agent_name}/__init__.py, add " + "Fix: in {package_name}/__init__.py, add " "'from .agent import " + ", ".join(missing) + "' " - "so that 'import {agent_name}' exposes them at package level." + "so that 'import {package_name}' exposes them at package level." ) }})) else: print(json.dumps({{'valid': True}})) - """).format(agent_name=agent_name) + """).format(package_name=package_name) proc = subprocess.run( ["uv", "run", "python", "-c", _contract_script], capture_output=True, @@ -1499,8 +1943,8 @@ def validate_agent_package(agent_name: str) -> str: steps["module_contract"] = { "passed": False, "error": ( - f"Failed to import '{agent_name}': {proc.stderr.strip()[:1000]}. " - f"Fix: ensure {agent_name}/__init__.py exists and can be imported " + f"Failed to import '{package_name}': {proc.stderr.strip()[:1000]}. " + f"Fix: ensure {package_name}/__init__.py exists and can be imported " f"without errors (check syntax, missing dependencies, relative imports)." ), } @@ -1515,7 +1959,7 @@ def validate_agent_package(agent_name: str) -> str: "run", "python", "-c", - f"from {agent_name} import default_agent; print(default_agent.validate())", + f"from {package_name} import default_agent; print(default_agent.validate())", ], capture_output=True, text=True, @@ -1562,7 +2006,7 @@ def validate_agent_package(agent_name: str) -> str: ) print(json.dumps({{'valid': len(errors) == 0, 'errors': errors}})) """) - check_script = _check_template.format(agent_name=agent_name) + check_script = _check_template.format(agent_name=package_name) proc = subprocess.run( ["uv", "run", "python", "-c", check_script], capture_output=True, @@ -1603,7 +2047,7 @@ def validate_agent_package(agent_name: str) -> str: "python", "-c", f"from framework.runner.runner import AgentRunner; " - f'r = AgentRunner.load("exports/{agent_name}", ' + f"r = AgentRunner.load({str(display_ref)!r}, " f"skip_credential_validation=True); " f'print("AgentRunner.load (graph-only): OK")', ], @@ -1624,9 +2068,42 @@ def validate_agent_package(agent_name: str) -> str: except Exception as e: steps["graph_validation"] = {"passed": False, "error": str(e)} + # Step B2: Behavior validation — reject placeholder prompts and empty work nodes + try: + import importlib + + if package_parent not in sys.path: + sys.path.insert(0, package_parent) + + stale = [ + name + for name in sys.modules + if name == package_name or name.startswith(f"{package_name}.") + ] + for name in stale: + del sys.modules[name] + + agent_mod = importlib.import_module(package_name) + behavior_errors = _behavior_validation_errors(agent_mod) + behavior_blockers, behavior_warnings = _classify_behavior_validation_errors(behavior_errors) + steps["behavior_validation"] = { + "passed": len(behavior_blockers) == 0, + "output": ( + "No placeholder prompts or empty autonomous nodes detected" + if not behavior_errors + else "; ".join(behavior_blockers or behavior_warnings) + ), + } + if behavior_blockers: + steps["behavior_validation"]["errors"] = behavior_blockers + if behavior_warnings: + steps["behavior_validation"]["warnings"] = behavior_warnings + except Exception as e: + steps["behavior_validation"] = {"passed": False, "error": str(e)} + # Step C: Tool validation (direct call) try: - tool_result = _validate_agent_tools_impl(agent_path) + tool_result = _validate_agent_tools_impl(str(agent_dir)) if "error" in tool_result: steps["tool_validation"] = {"passed": False, "error": tool_result["error"]} else: @@ -1641,17 +2118,33 @@ def validate_agent_package(agent_name: str) -> str: # Step D: Tests (direct call) try: - test_result = _run_agent_tests_impl(agent_name) - if "error" in test_result: - steps["tests"] = {"passed": False, "error": test_result["error"]} + test_result = _run_agent_tests_impl(str(agent_dir)) + if test_result.get("skipped_all"): + steps["tests"] = { + "passed": True, + "skipped": True, + "summary": test_result.get("summary", "No tests directory found; skipped"), + } + elif "error" in test_result: + steps["tests"] = { + "passed": False, + "warning": test_result["error"], + "warnings": [test_result["error"]], + } else: all_passed = test_result.get("failed", 0) == 0 and test_result.get("errors", 0) == 0 steps["tests"] = { "passed": all_passed, "summary": test_result.get("summary", "unknown"), } - if not all_passed and test_result.get("failures"): - steps["tests"]["failures"] = test_result["failures"] + if not all_passed: + warning_summary = ( + f"Test suite not fully passing: {test_result.get('summary', 'unknown')}" + ) + steps["tests"]["warning"] = warning_summary + steps["tests"]["warnings"] = [warning_summary] + if test_result.get("failures"): + steps["tests"]["failures"] = test_result["failures"] except Exception as e: steps["tests"] = {"passed": False, "error": str(e)} @@ -1665,13 +2158,20 @@ def validate_agent_package(agent_name: str) -> str: else: summary = f"FAIL: {len(failed_steps)} of {total} steps failed ({', '.join(failed_steps)})" + return { + "valid": valid, + "agent_name": package_name, + "agent_path": str(agent_dir), + "steps": steps, + "summary": summary, + } + + +@mcp.tool() +def validate_agent_package(agent_name: str) -> str: + """Run structural validation checks on a built agent package in one call.""" return json.dumps( - { - "valid": valid, - "agent_name": agent_name, - "steps": steps, - "summary": summary, - }, + _validate_agent_package_impl(agent_name), indent=2, default=str, ) @@ -1804,10 +2304,10 @@ class RuntimeConfig: @dataclass class AgentMetadata: - name: str = "{human_name}" + name: str = {human_name!r} version: str = "1.0.0" - description: str = "{_draft_desc or "TODO: Add agent description."}" - intro_message: str = "TODO: Add intro message." + description: str = {(_draft_desc or "TODO: Add agent description.")!r} + intro_message: str = {_default_intro_message(human_name, _draft_desc)!r} metadata = AgentMetadata() @@ -1913,8 +2413,8 @@ class AgentMetadata: SuccessCriterion( id="sc-{i + 1}", description="{sc}", - metric="TODO", - target="TODO", + metric="{_default_success_metric(i + 1)}", + target="{_default_success_target()}", weight=1.0, ),""" for i, sc in enumerate(_draft_sc) @@ -1924,8 +2424,8 @@ class AgentMetadata: SuccessCriterion( id="sc-1", description="TODO: Define success criterion.", - metric="TODO", - target="TODO", + metric="criterion_1_satisfied", + target="true", weight=1.0, ),""" @@ -1997,7 +2497,10 @@ class AgentMetadata: terminal_nodes = [] conversation_mode = "continuous" -identity_prompt = "TODO: Add identity prompt." +identity_prompt = ( + "You are {human_name}, a focused Hive worker that follows the goal, " + "constraints, and node instructions precisely." +) loop_config = {{ "max_iterations": 100, "max_tool_calls_per_turn": 30, @@ -2063,7 +2566,7 @@ def _setup(self): name="Default", entry_node=self.entry_node, trigger_type="manual", - isolation_level="shared", + isolation_level="isolated", ), ], llm=llm, @@ -2347,11 +2850,13 @@ def runner_loaded(): "files": all_file_paths, "next_steps": [ ( - "IMPORTANT: All generated files are structurally complete " - "with correct imports, class definition, validate() method, " - "and __init__.py exports. Use edit_file to customize TODO " - "placeholders — do NOT use write_file to rewrite entire files, " - "as this will break imports and structure." + "IMPORTANT: The generated scaffold has correct imports, class " + "definition, validate() method, and __init__.py exports, but " + "it is NOT ready to load or run yet. Replace every TODO / " + "placeholder prompt and make validation pass before staging. " + "Use edit_file to customize placeholders — do NOT use " + "write_file to rewrite entire files, as this will break " + "imports and structure." ), ( f"Use edit_file to customize system prompts, tools, " diff --git a/tools/tests/test_coder_tools_server.py b/tools/tests/test_coder_tools_server.py index a4ddb39d14..0b798c727b 100644 --- a/tools/tests/test_coder_tools_server.py +++ b/tools/tests/test_coder_tools_server.py @@ -5,6 +5,7 @@ import sys import types from pathlib import Path +from types import SimpleNamespace def _load_coder_tools_server(): @@ -54,10 +55,24 @@ def disconnect(self): return None class FakeToolRegistry: + def __init__(self): + self._tools = {} + @staticmethod def resolve_mcp_stdio_config(config: dict, _config_dir: Path) -> dict: return config + def discover_from_module(self, module_path: Path) -> int: + spec = importlib.util.spec_from_file_location("fake_agent_tools", module_path) + assert spec is not None and spec.loader is not None + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + self._tools.update(getattr(module, "TOOLS", {})) + return len(getattr(module, "TOOLS", {})) + + def get_tools(self) -> dict: + return dict(self._tools) + mcp_client_mod.MCPClient = FakeMCPClient mcp_client_mod.MCPServerConfig = FakeMCPServerConfig tool_registry_mod.ToolRegistry = FakeToolRegistry @@ -164,3 +179,1047 @@ def test_list_agent_tools_provider_filter_and_legacy_prefix_filter(monkeypatch, legacy_data = json.loads(legacy_raw) assert list(legacy_data["tools_by_provider"].keys()) == ["google"] assert legacy_data["all_tool_names"] == ["gmail_list_messages"] + + +def test_behavior_validation_errors_rejects_placeholder_prompts_and_empty_work_nodes(): + mod = _load_coder_tools_server() + + agent_module = SimpleNamespace( + identity_prompt="TODO: Add identity prompt.", + metadata=SimpleNamespace( + description="TODO: Add agent description.", + intro_message="TODO: Add intro message.", + ), + goal=SimpleNamespace( + description="TODO: Describe the goal.", + success_criteria=[ + SimpleNamespace( + id="sc-1", + description="TODO: Define success criterion.", + metric="TODO", + target="TODO", + ) + ], + constraints=[ + SimpleNamespace(id="c-1", description="TODO: Define constraint."), + ], + ), + terminal_nodes=["done"], + nodes=[ + SimpleNamespace( + id="scan", + client_facing=False, + tools=["execute_command_tool"], + sub_agents=[], + description="TODO: Describe what this node does.", + system_prompt="TODO: Add system prompt for this node.", + success_criteria="Find files.", + ), + SimpleNamespace( + id="summarize", + client_facing=False, + tools=[], + sub_agents=[], + description="Summarize the docs.", + system_prompt="Write condensed markdown summaries.", + success_criteria="Produce summaries.", + ), + SimpleNamespace( + id="done", + client_facing=False, + tools=[], + sub_agents=[], + description="Done.", + system_prompt="Return completion message.", + success_criteria="Done.", + ), + ], + ) + + errors = mod._behavior_validation_errors(agent_module) + + assert "identity_prompt still contains TODO placeholders" in errors + assert "metadata.description is blank or still contains TODO placeholders" in errors + assert "metadata.intro_message is blank or still contains TODO placeholders" in errors + assert "goal.description is blank or still contains TODO placeholders" in errors + assert "Success criterion 'sc-1' has blank or placeholder metric" in errors + assert "Constraint 'c-1' has blank or placeholder description" in errors + assert "Node 'scan' has a blank or placeholder description" in errors + assert "Node 'scan' has a blank or placeholder system_prompt" in errors + assert "Autonomous node 'summarize' has no tools or sub_agents" in errors + assert "Autonomous node 'done' has no tools or sub_agents" not in errors + + +def test_validate_agent_package_accepts_absolute_path_and_skips_missing_tests( + tmp_path, + monkeypatch, +): + mod = _load_coder_tools_server() + mod.PROJECT_ROOT = str(tmp_path) + + agent_dir = tmp_path / "examples" / "templates" / "demo_agent" + agent_dir.mkdir(parents=True) + (agent_dir / "__init__.py").write_text( + "from .agent import default_agent, edges, goal, nodes\n", + encoding="utf-8", + ) + (agent_dir / "agent.py").write_text( + ( + "goal = object()\n" + "nodes = []\n" + "edges = []\n" + "class _Agent:\n" + " def validate(self):\n" + " return True\n" + "default_agent = _Agent()\n" + ), + encoding="utf-8", + ) + (agent_dir / "mcp_servers.json").write_text("{}", encoding="utf-8") + + tool_calls: dict[str, str] = {} + + def _fake_validate_agent_tools(agent_path: str) -> dict: + tool_calls["agent_path"] = agent_path + return {"valid": True, "message": "PASS: tools ok"} + + framework_mod = types.ModuleType("framework") + server_mod = types.ModuleType("framework.server") + app_mod = types.ModuleType("framework.server.app") + app_mod.validate_agent_path = lambda path: Path(path).resolve() + server_mod.app = app_mod + framework_mod.server = server_mod + monkeypatch.setitem(sys.modules, "framework", framework_mod) + monkeypatch.setitem(sys.modules, "framework.server", server_mod) + monkeypatch.setitem(sys.modules, "framework.server.app", app_mod) + + class _Proc: + def __init__(self, stdout: str): + self.returncode = 0 + self.stdout = stdout + self.stderr = "" + + def _fake_run(cmd, **kwargs): + command_text = " ".join(str(part) for part in cmd) + if "missing = [a for a in ('goal', 'nodes', 'edges')" in command_text: + return _Proc('{"valid": true}') + if "from demo_agent import default_agent" in command_text: + return _Proc("True") + if "graph_ids = {n.id for n in agent.nodes}" in command_text: + return _Proc('{"valid": true, "errors": []}') + if "AgentRunner.load" in command_text: + return _Proc("AgentRunner.load (graph-only): OK") + raise AssertionError(f"Unexpected subprocess command: {cmd}") + + monkeypatch.setattr(mod, "_validate_agent_tools_impl", _fake_validate_agent_tools) + monkeypatch.setattr(mod, "_behavior_validation_errors", lambda _module: []) + monkeypatch.setattr(mod.subprocess, "run", _fake_run) + + report = mod._validate_agent_package_impl(str(agent_dir)) + + assert report["valid"] is True + assert report["agent_name"] == "demo_agent" + assert report["agent_path"] == str(agent_dir) + assert report["steps"]["tests"]["passed"] is True + assert report["steps"]["tests"]["skipped"] is True + assert str(agent_dir) in report["steps"]["tests"]["summary"] + assert tool_calls["agent_path"] == str(agent_dir) + + +def test_behavior_validation_identity_prompt_placeholder_blocks_stage() -> None: + mod = _load_coder_tools_server() + + blocking, warnings = mod._classify_behavior_validation_errors( + ["identity_prompt still contains TODO placeholders"] + ) + + assert blocking == ["identity_prompt still contains TODO placeholders"] + assert warnings == [] + + +def test_behavior_validation_errors_accepts_complete_worker_nodes(): + mod = _load_coder_tools_server() + + agent_module = SimpleNamespace( + identity_prompt="You summarize markdown files conservatively.", + metadata=SimpleNamespace( + description="Summarize long markdown files for manual review.", + intro_message="Ready to review docs.", + ), + goal=SimpleNamespace( + description="Summarize docs that exceed the word threshold.", + success_criteria=[ + SimpleNamespace( + id="sc-1", + description="Create concise draft summaries.", + metric="summaries_created", + target=">=1 when files exceed threshold", + ) + ], + constraints=[SimpleNamespace(id="c-1", description="Do not overwrite files.")], + ), + terminal_nodes=["done"], + nodes=[ + SimpleNamespace( + id="scan", + client_facing=False, + tools=["execute_command_tool"], + sub_agents=[], + description="Scan docs and find long files.", + system_prompt="Scan docs/ and compute word counts.", + success_criteria="File inventory and over-limit files are set.", + ), + SimpleNamespace( + id="review", + client_facing=True, + tools=[], + sub_agents=[], + description="Collect operator approval.", + system_prompt="Ask the user whether to continue.", + success_criteria="User provides a decision.", + ), + SimpleNamespace( + id="done", + client_facing=False, + tools=[], + sub_agents=[], + description="Finish and report.", + system_prompt="Return final result.", + success_criteria="Finished cleanly.", + ), + ], + ) + + assert mod._behavior_validation_errors(agent_module) == [] + + +def test_behavior_validation_errors_allows_pure_llm_set_output_work_nodes(): + mod = _load_coder_tools_server() + + agent_module = SimpleNamespace( + identity_prompt="You analyze resumes conservatively.", + metadata=SimpleNamespace( + description="Analyze resumes to identify strong target roles.", + intro_message="Ready to analyze the resume.", + ), + goal=SimpleNamespace( + description="Identify the strongest role targets from the resume.", + success_criteria=[ + SimpleNamespace( + id="sc-1", + description="Role analysis is stored for later steps.", + metric="role_analysis_ready", + target="1.0", + ) + ], + constraints=[SimpleNamespace(id="c-1", description="Stay faithful to the resume.")], + ), + entry_node="intake", + terminal_nodes=["done"], + nodes=[ + SimpleNamespace( + id="intake", + name="Intake", + client_facing=False, + node_type="event_loop", + tools=[], + sub_agents=[], + input_keys=["resume_text"], + output_keys=["resume_text", "role_analysis"], + description="Analyze the resume and identify the strongest role fits.", + system_prompt=( + "Analyze the user's resume, identify 3-5 strong role fits, " + "and call set_output for resume_text and role_analysis." + ), + success_criteria="Role analysis saved for downstream nodes.", + ), + SimpleNamespace( + id="done", + name="Done", + client_facing=False, + node_type="event_loop", + tools=[], + sub_agents=[], + input_keys=[], + output_keys=[], + description="Done.", + system_prompt="Return completion message.", + success_criteria="Done.", + ), + ], + ) + + errors = mod._behavior_validation_errors(agent_module) + + assert not any("Autonomous node 'intake'" in error for error in errors) + + +def test_behavior_validation_errors_allows_gcu_nodes_without_explicit_tools(): + mod = _load_coder_tools_server() + + agent_module = SimpleNamespace( + identity_prompt="You coordinate browser work carefully.", + metadata=SimpleNamespace( + description="Use a browser worker to collect business URLs.", + intro_message="Ready to collect business URLs.", + ), + goal=SimpleNamespace( + description="Collect candidate businesses before enrichment.", + success_criteria=[ + SimpleNamespace( + id="sc-1", + description="Business list returned from browser worker.", + metric="business_list_ready", + target=">=5", + ) + ], + constraints=[ + SimpleNamespace( + id="c-1", + description="Use browser tools only inside the GCU.", + ) + ], + ), + terminal_nodes=["done"], + nodes=[ + SimpleNamespace( + id="map-search-worker", + name="Maps Browser Worker", + client_facing=False, + node_type="gcu", + tools=[], + sub_agents=[], + input_keys=["query"], + output_keys=["business_list"], + description="Browser worker that searches Google Maps.", + system_prompt=( + "Search Google Maps for the query, collect relevant businesses, " + 'and call set_output("business_list", ...).' + ), + success_criteria="Business list extracted.", + ), + SimpleNamespace( + id="done", + name="Done", + client_facing=False, + node_type="event_loop", + tools=[], + sub_agents=[], + input_keys=[], + output_keys=[], + description="Done.", + system_prompt="Return completion message.", + success_criteria="Done.", + ), + ], + ) + + errors = mod._behavior_validation_errors(agent_module) + + assert not any("Autonomous node 'map-search-worker'" in error for error in errors) + + +def test_validate_agent_tools_discovers_agent_local_tools_py(tmp_path, monkeypatch): + _install_fake_framework(monkeypatch, tools_by_server={}) + mod = _load_coder_tools_server() + mod.PROJECT_ROOT = str(tmp_path) + + agent_dir = tmp_path / "examples" / "templates" / "local_tools_agent" + agent_dir.mkdir(parents=True) + (agent_dir / "__init__.py").write_text("from .agent import nodes\n", encoding="utf-8") + (agent_dir / "agent.py").write_text( + ( + "from types import SimpleNamespace\n" + "nodes = [SimpleNamespace(id='fetch', name='Fetch', tools=['bulk_fetch_emails'])]\n" + ), + encoding="utf-8", + ) + (agent_dir / "mcp_servers.json").write_text("{}", encoding="utf-8") + (agent_dir / "tools.py").write_text( + ( + "from types import SimpleNamespace\n" + "TOOLS = {'bulk_fetch_emails': SimpleNamespace(name='bulk_fetch_emails')}\n" + ), + encoding="utf-8", + ) + + framework_mod = sys.modules["framework"] + server_mod = types.ModuleType("framework.server") + app_mod = types.ModuleType("framework.server.app") + app_mod.validate_agent_path = lambda path: Path(path).resolve() + server_mod.app = app_mod + framework_mod.server = server_mod + monkeypatch.setitem(sys.modules, "framework.server", server_mod) + monkeypatch.setitem(sys.modules, "framework.server.app", app_mod) + + result = mod._validate_agent_tools_impl(str(agent_dir)) + + assert result["valid"] is True + assert result["available_tool_count"] == 1 + assert "missing_tools" not in result + + +def test_validate_agent_tools_allows_local_only_agents_without_mcp_config(tmp_path, monkeypatch): + _install_fake_framework(monkeypatch, tools_by_server={}) + mod = _load_coder_tools_server() + mod.PROJECT_ROOT = str(tmp_path) + + agent_dir = tmp_path / "exports" / "local_only_agent" + agent_dir.mkdir(parents=True) + (agent_dir / "__init__.py").write_text("from .agent import nodes\n", encoding="utf-8") + (agent_dir / "agent.py").write_text( + ( + "from types import SimpleNamespace\n" + "nodes = [SimpleNamespace(id='clock', name='Clock', tools=['get_current_timestamp'])]\n" + ), + encoding="utf-8", + ) + (agent_dir / "tools.py").write_text( + ( + "from types import SimpleNamespace\n" + "TOOLS = {'get_current_timestamp': SimpleNamespace(name='get_current_timestamp')}\n" + ), + encoding="utf-8", + ) + + framework_mod = sys.modules["framework"] + server_mod = types.ModuleType("framework.server") + app_mod = types.ModuleType("framework.server.app") + app_mod.validate_agent_path = lambda path: Path(path).resolve() + server_mod.app = app_mod + framework_mod.server = server_mod + monkeypatch.setitem(sys.modules, "framework.server", server_mod) + monkeypatch.setitem(sys.modules, "framework.server.app", app_mod) + + result = mod._validate_agent_tools_impl(str(agent_dir)) + + assert result["valid"] is True + assert result["available_tool_count"] == 1 + assert "error" not in result + + +def test_behavior_validation_errors_rejects_callable_style_tool_prompt_usage(): + mod = _load_coder_tools_server() + + agent_module = SimpleNamespace( + identity_prompt="You are a careful file reviewer.", + metadata=SimpleNamespace( + description="Review markdown files.", + intro_message="Ready to review markdown files.", + ), + goal=SimpleNamespace( + description="Summarize large markdown files safely.", + success_criteria=[ + SimpleNamespace( + id="sc-1", + description="Scan markdown files.", + metric="scan_complete", + target="1.0", + ) + ], + constraints=[ + SimpleNamespace( + id="c-1", + description="Do not overwrite without approval.", + ) + ], + ), + terminal_nodes=["done"], + nodes=[ + SimpleNamespace( + id="scan", + client_facing=False, + tools=["list_dir", "load_data"], + sub_agents=[], + description="Scan the docs folder.", + system_prompt=( + "Use list_dir(path='docs') first, then load_data(filename='x.md') " + "to inspect markdown files." + ), + success_criteria="Collect files to review.", + ), + SimpleNamespace( + id="done", + client_facing=False, + tools=[], + sub_agents=[], + description="Done.", + system_prompt="Return completion message.", + success_criteria="Done.", + ), + ], + ) + + errors = mod._behavior_validation_errors(agent_module) + + assert ( + "Node 'scan' system_prompt uses callable-style tool syntax for 'list_dir'. " + "Describe tool usage in prose instead of Python-style calls." + ) in errors + assert ( + "Node 'scan' system_prompt uses callable-style tool syntax for 'load_data'. " + "Describe tool usage in prose instead of Python-style calls." + ) in errors + + +def test_behavior_validation_errors_rejects_entry_intake_parsers_and_tool_aliases(): + mod = _load_coder_tools_server() + + agent_module = SimpleNamespace( + identity_prompt="You process markdown review jobs.", + metadata=SimpleNamespace( + description="Condense markdown files for review.", + intro_message="Ready to process markdown files.", + ), + goal=SimpleNamespace( + description="Condense markdown files with review safeguards.", + success_criteria=[ + SimpleNamespace( + id="sc-1", + description="Collect runtime configuration once.", + metric="config_ready", + target="1.0", + ) + ], + constraints=[SimpleNamespace(id="c-1", description="Do not overwrite without review.")], + ), + entry_node="start-intake", + terminal_nodes=["done"], + nodes=[ + SimpleNamespace( + id="start-intake", + name="Intake Config", + client_facing=False, + tools=["execute_command_tool"], + sub_agents=[], + input_keys=["task"], + output_keys=["docs_path", "review_path", "word_threshold", "style_rules"], + description="Accept structured runtime task from Queen with docs path and rules.", + system_prompt=( + "Parse the incoming task text into configuration values. " + "Use run_command if you need to inspect anything." + ), + success_criteria="All required config values parsed and validated.", + ), + SimpleNamespace( + id="done", + name="Done", + client_facing=False, + tools=[], + sub_agents=[], + input_keys=[], + output_keys=[], + description="Done.", + system_prompt="Return completion message.", + success_criteria="Done.", + ), + ], + ) + + errors = mod._behavior_validation_errors(agent_module) + + assert ( + "Entry node 'start-intake' appears to be an intake/config parser. " + "The queen handles intake. Make the first real work node consume " + "structured input_keys directly instead of reparsing a generic task string." + ) in errors + assert ( + "Node 'start-intake' system_prompt references unsupported tool alias " + "'run_command'. Use the actual registered tool name 'execute_command_tool'." + ) in errors + + +def test_behavior_validation_errors_rejects_structured_entry_intake_validators(): + mod = _load_coder_tools_server() + + agent_module = SimpleNamespace( + identity_prompt="You sanitize markdown docs safely.", + metadata=SimpleNamespace( + description="Prepare cleaned markdown review copies.", + intro_message="Ready to sanitize docs.", + ), + goal=SimpleNamespace( + description="Sanitize markdown files with explicit approval before overwrite.", + success_criteria=[ + SimpleNamespace( + id="sc-1", + description="Runtime paths validated.", + metric="paths_ready", + target="1.0", + ) + ], + constraints=[SimpleNamespace(id="c-1", description="Never overwrite without review.")], + ), + entry_node="intake", + terminal_nodes=["done"], + nodes=[ + SimpleNamespace( + id="intake", + name="Intake & Validate Runtime Paths", + client_facing=False, + tools=["execute_command_tool"], + sub_agents=[], + input_keys=["docs_root", "review_root"], + output_keys=["docs_root", "review_root", "run_id"], + description="Read runtime task input and validate the filesystem paths.", + system_prompt=( + "Accept structured runtime task input, validate runtime paths, " + "create review_root if missing, then pass the normalized values onward." + ), + success_criteria="Validated paths and emitted normalized runtime values.", + ), + SimpleNamespace( + id="done", + name="Done", + client_facing=False, + tools=[], + sub_agents=[], + input_keys=[], + output_keys=[], + description="Done.", + system_prompt="Return completion message.", + success_criteria="Done.", + ), + ], + ) + + errors = mod._behavior_validation_errors(agent_module) + + assert ( + "Entry node 'intake' appears to be an intake/config parser. " + "The queen handles intake. Make the first real work node consume " + "structured input_keys directly instead of reparsing a generic task string." + ) in errors + + +def test_behavior_validation_errors_rejects_entry_intake_parser_with_scan_exclusions_text(): + mod = _load_coder_tools_server() + + agent_module = SimpleNamespace( + identity_prompt="You generate local markdown reviews.", + metadata=SimpleNamespace( + description="Prepare review drafts for oversized markdown files.", + intro_message="Ready to review markdown files.", + ), + goal=SimpleNamespace( + description="Generate markdown review drafts from structured runtime inputs.", + success_criteria=[ + SimpleNamespace( + id="sc-1", + description="Runtime config normalized.", + metric="config_ready", + target="1.0", + ) + ], + constraints=[SimpleNamespace(id="c-1", description="Stay local-only.")], + ), + entry_node="intake-config", + terminal_nodes=["done"], + nodes=[ + SimpleNamespace( + id="intake-config", + name="Intake Config", + client_facing=False, + tools=["execute_command_tool"], + sub_agents=[], + input_keys=["target_dir", "word_threshold", "review_dir_mode"], + output_keys=["target_dir", "word_threshold", "review_root", "scan_exclusions"], + description=( + "Validate provided directory configuration and emit " + "normalized runtime settings." + ), + system_prompt=( + "Validate target_dir, normalize word_threshold, set scan_exclusions, " + "and resolve review_root before real work begins." + ), + success_criteria="Configuration normalized.", + ), + SimpleNamespace( + id="done", + name="Done", + client_facing=False, + tools=[], + sub_agents=[], + input_keys=[], + output_keys=[], + description="Done.", + system_prompt="Return completion message.", + success_criteria="Done.", + ), + ], + ) + + errors = mod._behavior_validation_errors(agent_module) + + assert ( + "Entry node 'intake-config' appears to be an intake/config parser. " + "The queen handles intake. Make the first real work node consume " + "structured input_keys directly instead of reparsing a generic task string." + ) in errors + + +def test_behavior_validation_errors_rejects_output_dirs_that_must_preexist(): + mod = _load_coder_tools_server() + + agent_module = SimpleNamespace( + identity_prompt="You review markdown rewrites safely.", + metadata=SimpleNamespace( + description="Prepare markdown review copies.", + intro_message="Ready to review markdown files.", + ), + goal=SimpleNamespace( + description="Write review copies for markdown documents.", + success_criteria=[ + SimpleNamespace( + id="sc-1", + description="Review copies are written for each eligible file.", + metric="review_copy_write_success_rate", + target=">=0.99", + ) + ], + constraints=[ + SimpleNamespace( + id="c-1", + description="Do not overwrite originals early.", + ) + ], + ), + entry_node="start", + terminal_nodes=["done"], + nodes=[ + SimpleNamespace( + id="start", + name="Initialize Inputs", + client_facing=False, + tools=["execute_command_tool"], + sub_agents=[], + input_keys=["docs_dir", "review_dir", "word_threshold"], + output_keys=["docs_dir", "review_dir", "word_threshold"], + description="Validate directories before scanning markdown files.", + system_prompt=( + "Validate docs_dir and review_dir exist and are directories before continuing." + ), + success_criteria="Paths validated.", + ), + SimpleNamespace( + id="done", + name="Done", + client_facing=False, + tools=[], + sub_agents=[], + input_keys=[], + output_keys=[], + description="Done.", + system_prompt="Return completion message.", + success_criteria="Done.", + ), + ], + ) + + errors = mod._behavior_validation_errors(agent_module) + + assert ( + "Entry node 'start' requires output path 'review_dir' to pre-exist. " + "Output/review directories should be created if missing instead of " + "blocking the run during intake validation." + ) in errors + + +def test_behavior_validation_errors_allows_direct_scan_entry_nodes(): + mod = _load_coder_tools_server() + + agent_module = SimpleNamespace( + identity_prompt="You sanitize markdown docs safely.", + metadata=SimpleNamespace( + description="Prepare cleaned markdown review copies.", + intro_message="Ready to sanitize docs.", + ), + goal=SimpleNamespace( + description="Sanitize markdown files with explicit approval before overwrite.", + success_criteria=[ + SimpleNamespace( + id="sc-1", + description="Markdown candidates discovered.", + metric="candidate_discovery_success_rate", + target=">=0.99", + ) + ], + constraints=[SimpleNamespace(id="c-1", description="Never overwrite without review.")], + ), + entry_node="scan-candidates", + terminal_nodes=["done"], + nodes=[ + SimpleNamespace( + id="scan-candidates", + name="Scan Markdown Candidates", + client_facing=False, + tools=["execute_command_tool"], + sub_agents=[], + input_keys=["source_dir", "review_dir"], + output_keys=["source_dir", "review_dir", "candidates", "scan_stats", "rules"], + description=( + "Consume structured source_dir/review_dir inputs directly, ensure review_dir " + "exists, recursively scan .md files, and detect candidate files." + ), + system_prompt=( + "Start markdown candidate discovery from structured inputs source_dir and " + "review_dir. Use execute_command_tool to create review_dir if missing, " + "recursively scan .md files, and emit candidates, scan_stats, and rules." + ), + success_criteria="Scanning completes and emits candidates, stats, and rules.", + ), + SimpleNamespace( + id="done", + name="Done", + client_facing=False, + tools=[], + sub_agents=[], + input_keys=[], + output_keys=[], + description="Done.", + system_prompt="Return completion message.", + success_criteria="Done.", + ), + ], + ) + + errors = mod._behavior_validation_errors(agent_module) + + assert not any("appears to be an intake/config parser" in error for error in errors) + + +def test_behavior_validation_errors_rejects_data_tools_used_for_review_root_workspace_paths(): + mod = _load_coder_tools_server() + + agent_module = SimpleNamespace( + identity_prompt="You generate local markdown reviews.", + metadata=SimpleNamespace( + description="Generate review drafts and deliver a manifest.", + intro_message="Ready to write local review outputs.", + ), + goal=SimpleNamespace( + description="Write markdown review drafts and a manifest for the user.", + success_criteria=[ + SimpleNamespace( + id="sc-1", + description="Review outputs are saved.", + metric="review_outputs_written", + target=">=0.99", + ) + ], + constraints=[SimpleNamespace(id="c-1", description="Stay local-only.")], + ), + entry_node="scan", + terminal_nodes=["done"], + nodes=[ + SimpleNamespace( + id="scan", + name="Scan Markdown Files", + client_facing=False, + tools=["execute_command_tool"], + sub_agents=[], + input_keys=["target_dir"], + output_keys=["review_root", "review_files"], + description="Scan markdown files and prepare review targets.", + system_prompt="Scan the target directory and emit review_root plus review_files.", + success_criteria="Review targets prepared.", + ), + SimpleNamespace( + id="write-manifest", + name="Write Manifest", + client_facing=False, + tools=["save_data", "list_data_files", "serve_file_to_user"], + sub_agents=[], + input_keys=["review_root", "review_files"], + output_keys=["manifest_file"], + description="Write review artifacts for the user.", + system_prompt=( + "Use save_data to persist each draft into review_root, then list files in " + "review_root with list_data_files and serve them to the user from review_root." + ), + success_criteria="Manifest and links delivered.", + ), + SimpleNamespace( + id="done", + name="Done", + client_facing=False, + tools=[], + sub_agents=[], + input_keys=[], + output_keys=[], + description="Done.", + system_prompt="Return completion message.", + success_criteria="Done.", + ), + ], + ) + + errors = mod._behavior_validation_errors(agent_module) + + assert any("uses session data tools" in error and "review_root" in error for error in errors) + + +def test_behavior_validation_errors_allows_session_data_tools_for_delivery_payloads(): + mod = _load_coder_tools_server() + + agent_module = SimpleNamespace( + identity_prompt="You deliver markdown review artifacts safely.", + metadata=SimpleNamespace( + description="Deliver a session artifact for generated markdown reviews.", + intro_message="Ready to deliver review artifacts.", + ), + goal=SimpleNamespace( + description=( + "Expose a manifest artifact and clickable link after local review generation." + ), + success_criteria=[ + SimpleNamespace( + id="sc-1", + description="Delivery payload is saved and link is returned.", + metric="artifact_delivery_success", + target=">=1 link", + ) + ], + constraints=[ + SimpleNamespace( + id="c-1", + description="Do not write workspace files here.", + ) + ], + ), + entry_node="publish-links", + terminal_nodes=["done"], + nodes=[ + SimpleNamespace( + id="publish-links", + name="Publish Artifact Links", + client_facing=False, + tools=["save_data", "serve_file_to_user"], + sub_agents=[], + input_keys=[ + "manifest_path", + "manifest_summary", + "scan_summary", + "draft_paths", + "review_dir", + "target_dir", + "word_threshold", + ], + output_keys=["artifact_links", "result"], + description="Deliver session-scoped artifact links to the user.", + system_prompt="""\ +This is the session artifact delivery node. +Do NOT write workspace files here. + +Allowed tools: +- save_data: store session-scoped delivery artifacts only. +- serve_file_to_user: return clickable links for saved session artifacts. + +Tasks: +1) Create a compact delivery payload containing manifest_path, summary counts, and draft_paths. +2) Save that payload to session data via save_data (e.g., review_delivery.json). +3) Call serve_file_to_user for the saved session artifact and capture clickable URI(s). +4) Build final result object with: + - status + - target_dir + - review_dir + - word_threshold + - total_markdown_files + - flagged_files_count + - draft_paths + - manifest_path + - artifact_links +""", + success_criteria=( + "Clickable session artifact links are returned and final result is complete." + ), + ), + SimpleNamespace( + id="done", + name="Done", + client_facing=False, + tools=[], + sub_agents=[], + input_keys=[], + output_keys=[], + description="Done.", + system_prompt="Return completion message.", + success_criteria="Done.", + ), + ], + ) + + errors = mod._behavior_validation_errors(agent_module) + + assert not any("uses session data tools" in error for error in errors) + + +def test_generated_agent_template_uses_isolated_manual_entry_point(): + mod = _load_coder_tools_server() + source = Path(mod.__file__).read_text(encoding="utf-8") + + assert 'trigger_type="manual",' in source + assert 'isolation_level="isolated"' in source + + +def test_generated_agent_template_avoids_intro_and_success_metric_todos(): + mod = _load_coder_tools_server() + source = Path(mod.__file__).read_text(encoding="utf-8") + + assert "intro_message: str = {_default_intro_message(human_name, _draft_desc)!r}" in source + assert 'metric="{_default_success_metric(i + 1)}"' in source + assert 'target="{_default_success_target()}"' in source + assert 'identity_prompt = "TODO: Add identity prompt."' not in source + + +def test_validation_reports_failed_tests_as_failed(monkeypatch) -> None: + mod = _load_coder_tools_server() + mod.PROJECT_ROOT = "/tmp/demo_project" + + monkeypatch.setattr( + mod, + "_resolve_agent_package_target", + lambda _name: (Path("/tmp/demo_agent"), "demo_agent", "demo_agent"), + ) + monkeypatch.setattr( + mod, + "_validate_agent_tools_impl", + lambda _path: {"valid": True, "message": "ok"}, + ) + monkeypatch.setattr(mod, "_behavior_validation_errors", lambda _module: []) + monkeypatch.setattr( + mod, + "_run_agent_tests_impl", + lambda _path: {"failed": 1, "errors": 0, "summary": "1 failed", "failures": ["boom"]}, + ) + + class _Proc: + def __init__(self, stdout: str): + self.returncode = 0 + self.stdout = stdout + self.stderr = "" + + def _fake_run(cmd, **kwargs): + command_text = " ".join(str(part) for part in cmd) + if "missing = [a for a in ('goal', 'nodes', 'edges')" in command_text: + return _Proc('{"valid": true}') + if "from demo_agent import default_agent" in command_text: + return _Proc("True") + if "graph_ids = {n.id for n in agent.nodes}" in command_text: + return _Proc('{"valid": true, "errors": []}') + if "AgentRunner.load" in command_text: + return _Proc("AgentRunner.load (graph-only): OK") + raise AssertionError(f"Unexpected subprocess command: {cmd}") + + monkeypatch.setitem( + sys.modules, + "demo_agent", + types.SimpleNamespace(goal=object(), nodes=[], edges=[]), + ) + monkeypatch.setattr(mod.subprocess, "run", _fake_run) + + report = mod._validate_agent_package_impl("demo_agent") + + assert report["valid"] is False + assert report["steps"]["tests"]["passed"] is False + assert report["steps"]["tests"]["failures"] == ["boom"]