diff --git a/.gitignore b/.gitignore index 90ef702..f2b852e 100644 --- a/.gitignore +++ b/.gitignore @@ -44,3 +44,6 @@ platform/htmlcov/ platform/coverage.xml platform/**/*.cover platform/**/*.py,cover + +# Local git worktrees +.worktrees/ diff --git a/platform/.env.example b/platform/.env.example index cf58eb0..ced5996 100644 --- a/platform/.env.example +++ b/platform/.env.example @@ -10,6 +10,7 @@ LLM_API_KEY=your-api-key-here LLM_BASE_URL=https://api.openai.com LLM_MODEL=gpt-4o-mini LLM_TIMEOUT=120.0 +LLM_ROLE_MODEL_MAP={"orchestrator":"gpt-4o-mini","test":"gpt-4o-mini"} EXTRA_SKILL_DIR_WHITELIST= # SkillKit compatibility env vars. diff --git a/platform/app/api/v1/tasks.py b/platform/app/api/v1/tasks.py index 227149b..99ec6d0 100644 --- a/platform/app/api/v1/tasks.py +++ b/platform/app/api/v1/tasks.py @@ -95,6 +95,14 @@ async def cancel_task(task_id: str, service: TaskService = Depends(get_task_serv return task +@router.post("/{task_id}/clone", response_model=TaskDetailResponse, status_code=201) +async def clone_task(task_id: str, service: TaskService = Depends(get_task_service)): + task = await service.clone_task(task_id) + if task is None: + raise HTTPException(status_code=404, detail="Task not found") + return task + + @router.post("/{task_id}/retry", response_model=TaskDetailResponse) async def retry_task(task_id: str, service: TaskService = Depends(get_task_service)): task = await service.retry_task(task_id) diff --git a/platform/app/config.py b/platform/app/config.py index 8e0a8f1..33636eb 100644 --- a/platform/app/config.py +++ b/platform/app/config.py @@ -16,8 +16,9 @@ class Settings(BaseSettings): LLM_TIMEOUT: float = 120.0 # Per-role model routing (JSON string: {"coding": "gpt-4o", "review": "claude-sonnet-4-20250514"}) - # Unspecified roles fall back to LLM_MODEL - LLM_ROLE_MODEL_MAP: str = "{}" + # Unspecified roles fall back to LLM_MODEL. Keep lightweight defaults on + # orchestrator/test so parse + signoff stay cheaper unless env overrides them. + LLM_ROLE_MODEL_MAP: str = '{"orchestrator":"gpt-4o-mini","test":"gpt-4o-mini"}' # Comma-separated absolute path prefixes allowed in agent config `extra_skill_dirs`. # Empty means only built-in platform/skills directory is allowed. EXTRA_SKILL_DIR_WHITELIST: str = "" @@ -38,9 +39,12 @@ class Settings(BaseSettings): DB_POOL_TIMEOUT: int = 30 # Circuit breaker configuration - CB_MAX_TOKENS_PER_TASK: int = 200000 - CB_MAX_COST_PER_TASK_RMB: float = 50.0 + CB_MAX_TOKENS_PER_TASK: int = 400000 + CB_MAX_COST_PER_TASK_RMB: float = 100.0 CB_TOKEN_PRICE_PER_1K: float = 0.01 + # Per-stage token budgets (JSON string). Empty means disabled. + # Example: {"parse": 60000, "code": 250000} + CB_STAGE_TOKEN_BUDGETS: str = '{"parse": 60000, "code": 250000}' # Webhook secrets (empty = skip verification) JIRA_WEBHOOK_SECRET: str = "" @@ -96,6 +100,12 @@ class Settings(BaseSettings): SANDBOX_ROLES: str = '["coding", "test"]' SANDBOX_DUMP_MODEL_API_RESPONSE: bool = True SANDBOX_MODEL_API_RAW_LOG_HOST_DIR: str = "/tmp/silicon_agent/model_api_logs" + SANDBOX_GRADLE_CMD_TIMEOUT_SECONDS: int = 480 + SANDBOX_GRADLE_CACHE_HOST_DIR: str = "/var/lib/silicon_agent/gradle-cache" + SANDBOX_GRADLE_USER_HOME: str = "/var/lib/silicon_agent/gradle-cache" + SANDBOX_DEFAULT_JAVA_VERSION: int = 8 + SANDBOX_GRADLE_WRAPPER_PREWARM: bool = True + SANDBOX_GRADLE_WRAPPER_PREWARM_TIMEOUT_SECONDS: int = 180 # Memory & compression configuration MEMORY_ENABLED: bool = True diff --git a/platform/app/services/agent_service.py b/platform/app/services/agent_service.py index 07f8352..1241b03 100644 --- a/platform/app/services/agent_service.py +++ b/platform/app/services/agent_service.py @@ -154,8 +154,8 @@ def _build_role_defaults(self, available_models: list[str]) -> dict[str, str]: for role, _ in AGENT_ROLES: model = ( role_model_map.get(role) - or FALLBACK_ROLE_DEFAULT_MODELS.get(role) or settings.LLM_MODEL + or FALLBACK_ROLE_DEFAULT_MODELS.get(role) ) if model not in available_models and available_models: if settings.LLM_MODEL in available_models: diff --git a/platform/app/services/project_service.py b/platform/app/services/project_service.py index d74b6b8..03b146e 100644 --- a/platform/app/services/project_service.py +++ b/platform/app/services/project_service.py @@ -7,6 +7,7 @@ from sqlalchemy import func, or_, select from sqlalchemy.ext.asyncio import AsyncSession +from app.config import settings from app.models.project import ProjectModel from app.schemas.project import ( ProjectCreateRequest, @@ -75,7 +76,7 @@ async def create_project(self, request: ProjectCreateRequest) -> ProjectResponse repo_local_path=request.repo_local_path, branch=request.branch, description=request.description, - sandbox_image=request.sandbox_image, + sandbox_image=request.sandbox_image or settings.SANDBOX_IMAGE, ) self.session.add(project) await self.session.commit() diff --git a/platform/app/services/task_service.py b/platform/app/services/task_service.py index bbcaa75..05a3e24 100644 --- a/platform/app/services/task_service.py +++ b/platform/app/services/task_service.py @@ -129,6 +129,24 @@ async def create_task(self, request: TaskCreateRequest) -> TaskDetailResponse: task = result.scalar_one() return self._task_to_response(task) + async def clone_task(self, task_id: str) -> Optional[TaskDetailResponse]: + """Create a new task by copying only safe creation fields from a source task.""" + source_task = await self._load_task_with_relations_optional(task_id) + if source_task is None: + return None + + return await self.create_task( + TaskCreateRequest( + jira_id=source_task.jira_id, + title=source_task.title, + description=source_task.description, + template_id=source_task.template_id, + project_id=source_task.project_id, + yunxiao_task_id=source_task.yunxiao_task_id, + github_issue_number=getattr(source_task, "github_issue_number", None), + ) + ) + async def get_task(self, task_id: str) -> Optional[TaskDetailResponse]: result = await self.session.execute( select(TaskModel) diff --git a/platform/app/worker/agents.py b/platform/app/worker/agents.py index 8ec75e7..83baf5e 100644 --- a/platform/app/worker/agents.py +++ b/platform/app/worker/agents.py @@ -40,8 +40,8 @@ ROLE_TOOLS: dict[str, set[str]] = { "orchestrator": {"read", "execute", "skill"}, "spec": {"read", "write", "edit", "skill"}, - "coding": {"read", "write", "edit", "execute", "execute_script", "skill"}, - "test": {"read", "write", "edit", "execute", "execute_script", "skill"}, + "coding": {"read", "write", "edit", "execute", "execute_script"}, + "test": {"read", "write", "edit", "execute", "execute_script"}, "review": {"read", "execute", "skill"}, "smoke": {"read", "execute", "skill"}, "doc": {"read", "write", "edit", "skill"}, @@ -57,8 +57,8 @@ _ROLE_SKILL_DIRS: dict[str, list[str]] = { "orchestrator": ["shared", "orchestrator"], "spec": ["shared", "spec"], - "coding": ["shared", "coding"], - "test": ["shared", "test"], + "coding": [], + "test": [], "review": ["shared", "review"], "smoke": ["shared", "smoke"], "doc": ["shared", "doc"], diff --git a/platform/app/worker/compressor.py b/platform/app/worker/compressor.py index c641049..31601a7 100644 --- a/platform/app/worker/compressor.py +++ b/platform/app/worker/compressor.py @@ -17,7 +17,7 @@ # Fallback truncation limits when LLM is unavailable _L0_FALLBACK_CHARS = 200 _L1_FALLBACK_CHARS = 1500 -_L2_MAX_CHARS = 20_000 # Hard cap on full-text prior output to prevent token explosion +_L2_MAX_CHARS = 4_000 # Hard cap on full-text prior output to prevent token explosion @dataclass diff --git a/platform/app/worker/engine.py b/platform/app/worker/engine.py index b322854..515ae0f 100644 --- a/platform/app/worker/engine.py +++ b/platform/app/worker/engine.py @@ -4,6 +4,7 @@ import asyncio import json import logging +import os import shutil import tempfile import time @@ -47,6 +48,24 @@ _running = False _task: Optional[asyncio.Task] = None +_PREFLIGHT_SKIP_DIRS = { + ".git", + ".hg", + ".svn", + "node_modules", + ".venv", + "venv", + "build", + "dist", + "target", + ".gradle", + ".idea", + "__pycache__", +} +_PREFLIGHT_MAX_FILES = 2000 +_PREFLIGHT_MAX_DEPTH = 6 +_PREFLIGHT_MAX_CHARS = 600 + async def _safe_broadcast(event: str, data: dict) -> None: """Broadcast a WebSocket event, swallowing any errors.""" @@ -1702,6 +1721,8 @@ async def _execute_single_stage( except Exception: logger.warning("Failed to load memory for role %s", stage.agent_role, exc_info=True) + preflight_summary = _build_stage_preflight_summary(stage.stage_name, workspace_path) + # Build compressed prior context via sliding window # Phase 1.5: Cross-stage context recall — override compression for specified stages context_from = sdef.get("context_from") @@ -1816,6 +1837,7 @@ async def _execute_single_stage( compressed_outputs=compressed_prior if compressed_prior else None, project_memory=project_memory, repo_context=repo_context, + preflight_summary=preflight_summary, retry_context=retry_context, stage_model=stage_model, custom_instruction=custom_instruction, @@ -1827,6 +1849,7 @@ async def _execute_single_stage( compressed_outputs=compressed_prior if compressed_prior else None, project_memory=project_memory, repo_context=repo_context, + preflight_summary=preflight_summary, retry_context=retry_context, stage_model=stage_model, workdir_override=effective_workdir, @@ -2802,6 +2825,191 @@ def _build_repo_context(project) -> str: return "\n\n".join(parts) +def _iter_preflight_files(workspace_path: str): + root = Path(workspace_path) + scanned = 0 + for dirpath, dirnames, filenames in os.walk(root): + current = Path(dirpath) + try: + rel = current.relative_to(root) + depth = len(rel.parts) + except ValueError: + depth = 0 + dirnames[:] = [ + name for name in dirnames + if name not in _PREFLIGHT_SKIP_DIRS and depth < _PREFLIGHT_MAX_DEPTH + ] + for filename in filenames: + scanned += 1 + yield current / filename + if scanned >= _PREFLIGHT_MAX_FILES: + return + + +def _format_preflight_section(title: str, items: list[str], *, limit: int = 4) -> str: + if not items: + return "" + unique: list[str] = [] + seen: set[str] = set() + for item in items: + if item in seen: + continue + seen.add(item) + unique.append(item) + if len(unique) >= limit: + break + if not unique: + return "" + return f"- {title}: {', '.join(unique)}" + + +def _rank_preflight_path(rel_path: str, *, kind: str) -> tuple[int, int, int, str]: + lowered = rel_path.lower() + score = 0 + + if kind == "impl": + if "src/main/" in lowered: + score += 4 + if any(token in lowered for token in ("/controller/", "/handler/", "/service/", "/api/", "response")): + score += 5 + if "/src/test/" in lowered or lowered.startswith("src/test/"): + score -= 6 + elif kind == "test": + if any(token in lowered for token in ("/controller/", "/api/", "controller", "api")): + score += 5 + if lowered.endswith("test.java") or lowered.endswith("tests.java") or lowered.endswith("_test.go"): + score += 3 + if any(token in lowered for token in ("basetest", "sdk/", "mybatisgenerator", "mapper")): + score -= 4 + + return (-score, len(rel_path.split("/")), len(rel_path), rel_path) + + +def _pick_preflight_paths(items: list[str], *, kind: str, limit: int) -> list[str]: + unique = list(dict.fromkeys(items)) + return sorted(unique, key=lambda value: _rank_preflight_path(value, kind=kind))[:limit] + + +def _infer_validation_command(build_files: list[str]) -> str: + lowered = {item.lower() for item in build_files} + if "build.gradle" in lowered or "build.gradle.kts" in lowered: + return "./gradlew test" + if "pom.xml" in lowered: + return "./mvnw test" + if "package.json" in lowered: + return "npm test" + if "pyproject.toml" in lowered: + return "pytest" + if "go.mod" in lowered: + return "go test ./..." + if "cargo.toml" in lowered: + return "cargo test" + return "优先执行最小相关验证命令" + + +def _infer_coding_edit_target(source_roots: list[str], impl_examples: list[str]) -> str: + if impl_examples: + first = impl_examples[0] + parent = str(Path(first).parent).replace("\\", "/") + return parent if parent and parent != "." else first + if source_roots: + return source_roots[0] + return "优先在现有 controller/service 相邻目录做最小修改" + + +def _infer_test_target(test_examples: list[str], impl_examples: list[str]) -> str: + if test_examples: + return test_examples[0] + if impl_examples: + return impl_examples[0] + return "优先补充与当前改动直接相关的最小测试" + + +def _build_stage_preflight_summary(stage_name: str, workspace_path: Optional[str]) -> Optional[str]: + normalized = (stage_name or "").strip().lower() + if normalized not in {"code", "coding", "test"}: + return None + if not workspace_path: + return None + + root = Path(workspace_path) + if not root.exists() or not root.is_dir(): + return None + + build_files: list[str] = [] + source_roots: list[str] = [] + impl_examples: list[str] = [] + test_examples: list[str] = [] + + build_file_names = { + "build.gradle", + "build.gradle.kts", + "settings.gradle", + "settings.gradle.kts", + "pom.xml", + "package.json", + "pyproject.toml", + "go.mod", + "cargo.toml", + } + impl_keywords = ("controller", "handler", "service", "api", "route", "response") + test_keywords = ("test", "spec") + + for path in _iter_preflight_files(str(root)): + try: + rel = path.relative_to(root).as_posix() + except ValueError: + rel = path.as_posix() + lower_rel = rel.lower() + name = path.name.lower() + + if name in build_file_names: + build_files.append(rel) + if any(token in lower_rel for token in ("src/main", "app/", "cmd/", "internal/", "lib/")): + parent = str(Path(rel).parent).replace("\\", "/") + if parent and parent != ".": + source_roots.append(parent) + if any(keyword in name for keyword in impl_keywords) or any( + segment in lower_rel for segment in ("/controller/", "/handler/", "/service/", "/api/") + ): + impl_examples.append(rel) + if any(keyword in name for keyword in test_keywords) or any( + segment in lower_rel for segment in ("/test/", "/tests/", "/__tests__/") + ): + test_examples.append(rel) + + build_files = list(dict.fromkeys(build_files)) + source_roots = list(dict.fromkeys(source_roots)) + impl_examples = _pick_preflight_paths(impl_examples, kind="impl", limit=3) + test_examples = _pick_preflight_paths(test_examples, kind="test", limit=3) + validation_command = _infer_validation_command(build_files) + + lines = [] + if normalized in {"code", "coding"}: + lines.append(_format_preflight_section("构建入口", build_files, limit=2)) + lines.append(f"- 推荐修改落点: {_infer_coding_edit_target(source_roots, impl_examples)}") + lines.append(_format_preflight_section("最相关实现参考", impl_examples, limit=2)) + lines.append(_format_preflight_section("最相关测试参考", test_examples, limit=2)) + lines.append(f"- 推荐最小验证命令: {validation_command}") + if not any(lines): + lines.append("- 未发现明显的实现参考,请直接聚焦最小修改并谨慎验证。") + else: + lines.append(_format_preflight_section("构建入口", build_files, limit=2)) + lines.append(f"- 推荐验证落点: {_infer_test_target(test_examples, impl_examples)}") + lines.append(_format_preflight_section("最相关测试参考", test_examples, limit=2)) + lines.append(_format_preflight_section("对应实现参考", impl_examples, limit=2)) + lines.append(f"- 推荐最小验证命令: {validation_command}") + if not any(lines): + lines.append("- 未发现明显测试样例,请优先选择最小、最快的验证路径。") + + summary = "\n".join(line for line in lines if line).strip() + if not summary: + return None + if len(summary) > _PREFLIGHT_MAX_CHARS: + summary = summary[:_PREFLIGHT_MAX_CHARS] + "\n...(预扫摘要已截断)" + return summary + + async def _fail_task(session: AsyncSession, task: TaskModel, reason: str) -> None: """Mark task as failed, broadcast, and send external notification.""" failed_at = datetime.now(timezone.utc) diff --git a/platform/app/worker/executor.py b/platform/app/worker/executor.py index 60ba1ce..671c8e5 100644 --- a/platform/app/worker/executor.py +++ b/platform/app/worker/executor.py @@ -77,6 +77,26 @@ def _build_runtime_overrides( } +_DEFAULT_STAGE_MAX_TURNS: dict[str, int] = { + "spec": 10, + "coding": 6, + "doc": 10, + "test": 6, +} + +_STAGE_MAX_TURN_CAPS: dict[str, int] = { + "coding": 6, + "test": 6, +} + + +def _resolve_stage_max_turns(agent_role: str, override: Optional[int]) -> int: + default_value = _DEFAULT_STAGE_MAX_TURNS.get(agent_role, 10) + requested = override if isinstance(override, int) and override > 0 else default_value + cap = _STAGE_MAX_TURN_CAPS.get(agent_role) + return min(requested, cap) if cap else requested + + def _chat_kwargs_for_runner(runner: Any, runtime_overrides: dict[str, Any]) -> dict[str, Any]: kwargs: dict[str, Any] = {} try: @@ -172,7 +192,16 @@ def _is_signoff_stage(stage_name: str) -> bool: def _output_summary_limit(stage_name: str) -> int: # Cap stage output stored in DB to limit downstream prior-context injection. - return 50_000 + normalized = (stage_name or "").strip().lower() + if normalized == "parse": + return 600 + if normalized in {"code", "coding", "test"}: + return 1200 + if _is_signoff_stage(normalized): + return 1500 + if normalized in {"spec", "approve", "review", "doc"}: + return 1800 + return 1500 def _format_tool_digest(tool_items: list[dict[str, str]], limit: int = 6) -> str: @@ -210,6 +239,20 @@ def _resolve_stage_output_summary( return _clip_text(resolved, _output_summary_limit(stage_name)) +def _stage_goal_summary(stage_name: str | None) -> str: + normalized = (stage_name or "").strip().lower() + if normalized in {"code", "coding"}: + return "直接完成最小必要代码修改,并提供最小验证结果。" + if normalized == "test": + return "直接完成最小、最相关的验证,并明确成功或阻塞结论。" + return "完成当前阶段的最终结果。" + + +def _prefer_restart_continuations(stage_name: str | None) -> bool: + normalized = (stage_name or "").strip().lower() + return normalized in {"code", "coding", "test"} + + # --------------------------------------------------------------------------- # Module-level helpers extracted from execute_stage # --------------------------------------------------------------------------- @@ -237,6 +280,49 @@ def _clear_current_task_cancellation_state() -> None: current.uncancel() +_EXPLORATION_EXECUTE_PREFIXES = ( + "ls ", + "find ", + "pwd", + "cat ", + "head ", + "tail ", + "tree", + "rg ", +) +_VERIFICATION_EXECUTE_MARKERS = ( + "pytest", + "unittest", + "gradlew test", + "gradlew build", + "gradlew testclasses", + "mvn test", + "npm test", + "pnpm test", + "yarn test", + "go test", + "cargo test", +) +_RESTART_OUTPUT_CHARS = 700 + + +def _classify_tool_activity(tool_name: str, args: dict[str, Any]) -> str: + normalized_tool = (tool_name or "").strip().lower() + if normalized_tool in {"write", "edit"}: + return "implementation" + if normalized_tool == "read": + return "exploration" + + if normalized_tool in {"execute", "execute_script"}: + command = str(args.get("command") or "").strip().lower() + if any(marker in command for marker in _VERIFICATION_EXECUTE_MARKERS): + return "verification" + if any(command.startswith(prefix) for prefix in _EXPLORATION_EXECUTE_PREFIXES): + return "exploration" + + return "other" + + # --------------------------------------------------------------------------- # StageEventTracker – encapsulates mutable tracking state and event helpers # --------------------------------------------------------------------------- @@ -266,6 +352,12 @@ def __init__( self._instrumented_runners: list[Any] = [] self._instrumented_runner_ids: set[int] = set() self._completed_tool_runs: list[dict[str, str]] = [] + self._exploration_actions = 0 + self._implementation_actions = 0 + self._verification_attempts = 0 + self._verification_failures = 0 + self._successful_verifications = 0 + self._forced_convergence_used = False # -- public emit helpers -------------------------------------------------- @@ -352,6 +444,37 @@ async def emit_chat_received( def get_completed_tool_runs(self) -> list[dict[str, str]]: return list(self._completed_tool_runs) + def should_force_convergence(self) -> bool: + if self._forced_convergence_used: + return False + + normalized = self.stage_name.strip().lower() + if normalized == "code": + return self._implementation_actions == 0 and self._exploration_actions >= 4 + if normalized == "test": + if self._verification_failures > 0 and self._successful_verifications == 0: + return True + return self._verification_attempts == 0 and self._exploration_actions >= 3 + return False + + def mark_forced_convergence_used(self) -> None: + self._forced_convergence_used = True + + def record_tool_activity(self, tool_name: str, args: dict[str, Any], status: str) -> None: + activity = _classify_tool_activity(tool_name, args) + if activity == "exploration": + self._exploration_actions += 1 + return + if activity == "implementation": + self._implementation_actions += 1 + return + if activity == "verification": + self._verification_attempts += 1 + if status == "success": + self._successful_verifications += 1 + else: + self._verification_failures += 1 + # -- runner event registration ------------------------------------------- def register_runner_events(self, current_runner: Any) -> None: @@ -495,6 +618,7 @@ async def _on_after_tool_result(event: Any) -> None: args = {} output = str(getattr(event, "result", "")) status = infer_tool_status(output) + tracker.record_tool_activity(tool_name, args, status) run_info = tracker._tool_runs.get(tool_call_id) duration_ms: Optional[float] = None @@ -655,68 +779,310 @@ def detach_all_handlers(self) -> None: # Extracted helpers: continuations and stage success # --------------------------------------------------------------------------- +def _build_continuation_prompt(stage_name: str | None) -> str: + normalized = (stage_name or "").strip().lower() + if normalized == "code": + return ( + "请停止继续广泛探索。基于已知信息直接修改代码;" + "如果仍缺信息,只允许再查看 1 个最关键文件,然后必须完成修改并给出最小验证结果。" + ) + if normalized == "test": + return ( + "请停止扩展测试范围。只做最小、最相关的验证;" + "如果验证命令失败,必须直接给出失败命令、关键报错和唯一阻塞点," + "不要再用代码阅读代替测试结论。" + ) + return "请继续完成上面的输出,从你停下的地方继续。" + + +def _build_forced_convergence_prompt(stage_name: str | None) -> str: + normalized = (stage_name or "").strip().lower() + if normalized == "code": + return ( + "你已经在当前阶段花了过多轮次进行探索。现在禁止继续浏览仓库。" + "请直接做最小代码修改,并只执行最小必要验证。" + "如果仍然无法完成,请只输出唯一阻塞点和证据。" + ) + if normalized == "test": + return ( + "你已经在当前阶段花了过多轮次进行探索。现在禁止继续扩展测试范围。" + "请直接执行最小、最相关的验证。" + "如果验证命令失败,必须明确给出失败命令、关键报错和唯一阻塞点;" + "不要仅凭代码阅读判断测试通过。" + ) + return "请立即收敛到当前阶段的最终结果,不要继续扩展。" + + +def _build_stage_restart_prompt( + restart_context: dict[str, Any] | None, + tracker: StageEventTracker, + output: str, + *, + reason: str, +) -> str: + context = restart_context or {} + title = str(context.get("task_title") or "").strip() + description = str(context.get("task_description") or "").strip() + stage_name = str(context.get("stage_name") or tracker.stage_name).strip() + preflight_summary = str(context.get("preflight_summary") or "").strip() + partial_output = _clip_text((output or "").replace("[Max turns reached. Please continue the conversation.]", "").strip(), _RESTART_OUTPUT_CHARS) + tool_digest = _format_tool_digest(tracker.get_completed_tool_runs(), limit=2) + action_prompt = ( + _build_forced_convergence_prompt(stage_name) + if reason == "forced_convergence" + else _build_continuation_prompt(stage_name) + ) + + parts: list[str] = [] + if title: + parts.append(f"## 任务\n**{title}**") + if description: + parts.append(description) + parts.append(f"\n## 当前阶段\n{stage_name}") + parts.append(_stage_goal_summary(stage_name)) + if preflight_summary: + parts.append(f"\n## 阶段预扫摘要\n{preflight_summary}") + if partial_output: + parts.append(f"\n## 当前阶段已有部分输出\n{partial_output}") + if tool_digest: + parts.append(f"\n## 最近关键工具结果\n{tool_digest}") + parts.append("\n## 下一步要求") + parts.append(action_prompt) + parts.append("不要重新展开整段历史;只基于上面的当前状态继续完成必要工作。") + return "\n".join(parts).strip() + + +async def _run_stage_restart( + runner: Any, + output: str, + runtime_overrides: dict[str, Any], + tracker: StageEventTracker, + *, + reason: str, + restart_index: int, + restart_context: dict[str, Any] | None = None, +) -> str: + prompt = _build_stage_restart_prompt(restart_context, tracker, output, reason=reason) + chat_started = time.monotonic() + chat_correlation = await tracker.emit_chat_sent( + request_body={ + "prompt": prompt, + "model": getattr(getattr(runner, "config", None), "model", None), + "stage": tracker.stage_name, + "agent_role": tracker.agent_role, + "temperature": runtime_overrides.get("temperature"), + "max_tokens": runtime_overrides.get("max_tokens"), + "restart": restart_index, + "restart_reason": reason, + "forced_convergence": reason == "forced_convergence", + "reset": True, + "timeout_seconds": settings.WORKER_STAGE_TIMEOUT, + }, + ) + try: + restart_kwargs = _chat_kwargs_for_runner(runner, runtime_overrides) + response = await asyncio.wait_for( + runner.chat(prompt, reset=True, **restart_kwargs), + timeout=settings.WORKER_STAGE_TIMEOUT, + ) + restart_text = response.text_content or "" + await tracker.emit_chat_received( + chat_correlation, + status="success", + response_body={ + "restart": restart_index, + "restart_reason": reason, + "forced_convergence": reason == "forced_convergence", + "content": restart_text, + }, + duration_ms=round((time.monotonic() - chat_started) * 1000, 2), + ) + cleaned = (output or "").replace("[Max turns reached. Please continue the conversation.]", "").strip() + return f"{cleaned}\n\n{restart_text}".strip() if cleaned else restart_text + except asyncio.CancelledError: + _clear_current_task_cancellation_state() + await tracker.emit_chat_received( + chat_correlation, + status="cancelled", + response_body={"restart": restart_index, "restart_reason": reason, "error": "cancelled"}, + duration_ms=round((time.monotonic() - chat_started) * 1000, 2), + ) + raise + except Exception as exc: + await tracker.emit_chat_received( + chat_correlation, + status="failed", + response_body={ + "restart": restart_index, + "restart_reason": reason, + "forced_convergence": reason == "forced_convergence", + "error": str(exc), + }, + duration_ms=round((time.monotonic() - chat_started) * 1000, 2), + ) + return output + + +async def _run_forced_convergence( + runner: Any, + output: str, + runtime_overrides: dict[str, Any], + tracker: StageEventTracker, + stage_name: str | None = None, +) -> str: + if not tracker.should_force_convergence(): + return output + + tracker.mark_forced_convergence_used() + prompt = _build_forced_convergence_prompt(stage_name or tracker.stage_name) + chat_started = time.monotonic() + chat_correlation = await tracker.emit_chat_sent( + request_body={ + "prompt": prompt, + "model": getattr(getattr(runner, "config", None), "model", None), + "stage": tracker.stage_name, + "agent_role": tracker.agent_role, + "temperature": runtime_overrides.get("temperature"), + "max_tokens": runtime_overrides.get("max_tokens"), + "forced_convergence": True, + "timeout_seconds": settings.WORKER_STAGE_TIMEOUT, + }, + ) + + try: + followup_kwargs = _chat_kwargs_for_runner(runner, runtime_overrides) + response = await asyncio.wait_for( + runner.chat(prompt, reset=False, **followup_kwargs), + timeout=settings.WORKER_STAGE_TIMEOUT, + ) + forced_text = response.text_content or "" + await tracker.emit_chat_received( + chat_correlation, + status="success", + response_body={"forced_convergence": True, "content": forced_text}, + duration_ms=round((time.monotonic() - chat_started) * 1000, 2), + ) + cleaned = output.replace( + "[Max turns reached. Please continue the conversation.]", + "", + ).strip() + return f"{cleaned}\n\n{forced_text}".strip() if cleaned else forced_text + except asyncio.CancelledError: + _clear_current_task_cancellation_state() + await tracker.emit_chat_received( + chat_correlation, + status="cancelled", + response_body={"forced_convergence": True, "error": "cancelled"}, + duration_ms=round((time.monotonic() - chat_started) * 1000, 2), + ) + raise + except Exception as exc: + await tracker.emit_chat_received( + chat_correlation, + status="failed", + response_body={"forced_convergence": True, "error": str(exc)}, + duration_ms=round((time.monotonic() - chat_started) * 1000, 2), + ) + return output + + async def _handle_continuations( runner: Any, output: str, runtime_overrides: dict[str, Any], tracker: StageEventTracker, + stage_name: str | None = None, + restart_context: dict[str, Any] | None = None, ) -> tuple[str, int]: """Follow up with continuation prompts when the LLM output was truncated.""" _MAX_CONTINUATIONS = 3 _TRUNCATION_SENTINEL = "Max turns reached" - continuations = 0 - - while _TRUNCATION_SENTINEL in (output or "") and continuations < _MAX_CONTINUATIONS: - continuations += 1 - continuation_started = time.monotonic() - prompt = "请继续完成上面的输出,从你停下的地方继续。" - chat_correlation = await tracker.emit_chat_sent( - request_body={ - "prompt": prompt, - "model": getattr(getattr(runner, "config", None), "model", None), - "stage": tracker.stage_name, - "agent_role": tracker.agent_role, - "temperature": runtime_overrides.get("temperature"), - "max_tokens": runtime_overrides.get("max_tokens"), - "continuation": continuations, - "timeout_seconds": settings.WORKER_STAGE_TIMEOUT, - }, - ) - try: - continuation_kwargs = _chat_kwargs_for_runner(runner, runtime_overrides) - cont_response = await asyncio.wait_for( - runner.chat(prompt, reset=False, **continuation_kwargs), - timeout=settings.WORKER_STAGE_TIMEOUT, - ) - cont_text = cont_response.text_content or "" - await tracker.emit_chat_received( - chat_correlation, - status="success", - response_body={"continuation": continuations, "content": cont_text}, - duration_ms=round((time.monotonic() - continuation_started) * 1000, 2), + restarts = 0 + effective_stage_name = stage_name or tracker.stage_name + + if restart_context is None and _prefer_restart_continuations(effective_stage_name): + restart_context = {"stage_name": effective_stage_name} + + if restart_context is None: + continuations = 0 + output = await _run_forced_convergence(runner, output, runtime_overrides, tracker, effective_stage_name) + + while _TRUNCATION_SENTINEL in (output or "") and continuations < _MAX_CONTINUATIONS: + continuations += 1 + continuation_started = time.monotonic() + prompt = _build_continuation_prompt(effective_stage_name) + chat_correlation = await tracker.emit_chat_sent( + request_body={ + "prompt": prompt, + "model": getattr(getattr(runner, "config", None), "model", None), + "stage": tracker.stage_name, + "agent_role": tracker.agent_role, + "temperature": runtime_overrides.get("temperature"), + "max_tokens": runtime_overrides.get("max_tokens"), + "continuation": continuations, + "timeout_seconds": settings.WORKER_STAGE_TIMEOUT, + }, ) - output = output.replace( - f"[{_TRUNCATION_SENTINEL}. Please continue the conversation.]", - "", - ).strip() - output = f"{output}\n\n{cont_text}".strip() - except asyncio.CancelledError: - _clear_current_task_cancellation_state() - await tracker.emit_chat_received( - chat_correlation, - status="cancelled", - response_body={"continuation": continuations, "error": "cancelled"}, - duration_ms=round((time.monotonic() - continuation_started) * 1000, 2), + try: + continuation_kwargs = _chat_kwargs_for_runner(runner, runtime_overrides) + cont_response = await asyncio.wait_for( + runner.chat(prompt, reset=False, **continuation_kwargs), + timeout=settings.WORKER_STAGE_TIMEOUT, + ) + cont_text = cont_response.text_content or "" + await tracker.emit_chat_received( + chat_correlation, + status="success", + response_body={"continuation": continuations, "content": cont_text}, + duration_ms=round((time.monotonic() - continuation_started) * 1000, 2), + ) + output = output.replace( + f"[{_TRUNCATION_SENTINEL}. Please continue the conversation.]", + "", + ).strip() + output = f"{output}\n\n{cont_text}".strip() + except asyncio.CancelledError: + _clear_current_task_cancellation_state() + await tracker.emit_chat_received( + chat_correlation, + status="cancelled", + response_body={"continuation": continuations, "error": "cancelled"}, + duration_ms=round((time.monotonic() - continuation_started) * 1000, 2), + ) + raise + except Exception as e: + await tracker.emit_chat_received( + chat_correlation, + status="failed", + response_body={"continuation": continuations, "error": str(e)}, + duration_ms=round((time.monotonic() - continuation_started) * 1000, 2), + ) + break + else: + if tracker.should_force_convergence(): + restarts += 1 + tracker.mark_forced_convergence_used() + output = await _run_stage_restart( + runner, + output, + runtime_overrides, + tracker, + reason="forced_convergence", + restart_index=restarts, + restart_context=restart_context, ) - raise - except Exception as e: - await tracker.emit_chat_received( - chat_correlation, - status="failed", - response_body={"continuation": continuations, "error": str(e)}, - duration_ms=round((time.monotonic() - continuation_started) * 1000, 2), + + while _TRUNCATION_SENTINEL in (output or "") and restarts < _MAX_CONTINUATIONS: + restarts += 1 + output = await _run_stage_restart( + runner, + output, + runtime_overrides, + tracker, + reason="truncation", + restart_index=restarts, + restart_context=restart_context, ) - break total_tokens = runner.cumulative_usage.total_tokens return output, total_tokens @@ -796,6 +1162,7 @@ async def execute_stage( compressed_outputs: Optional[List[Dict[str, str]]] = None, project_memory: Optional[str] = None, repo_context: Optional[str] = None, + preflight_summary: Optional[str] = None, retry_context: Optional[Dict[str, str]] = None, stage_model: Optional[str] = None, workdir_override: Optional[str] = None, @@ -850,6 +1217,7 @@ async def execute_stage( compressed_outputs=compressed_outputs, project_memory=project_memory, repo_context=repo_context, + preflight_summary=preflight_summary, retry_context=retry_context, custom_instruction=custom_instruction, gate_rejection_context=gate_rejection_context, @@ -857,13 +1225,15 @@ async def execute_stage( user_prompt = build_user_prompt(ctx) runtime_overrides = _build_runtime_overrides(agent, stage_model) - runner = get_agent( + stage_max_turns = _resolve_stage_max_turns(stage.agent_role, runtime_overrides["max_turns"]) + runner_factory = get_agent_text_only if _is_signoff_stage(stage.stage_name) else get_agent + runner = runner_factory( stage.agent_role, task_id, model=runtime_overrides["model"], temperature=runtime_overrides["temperature"], max_tokens=runtime_overrides["max_tokens"], - max_turns=runtime_overrides["max_turns"], + max_turns=stage_max_turns, extra_skill_dirs=runtime_overrides["extra_skill_dirs"], system_prompt_append=runtime_overrides["system_prompt_append"], ) @@ -895,6 +1265,7 @@ async def execute_stage( "agent_role": stage.agent_role, "temperature": runtime_overrides.get("temperature"), "max_tokens": runtime_overrides.get("max_tokens"), + "max_turns": stage_max_turns, "attempt": attempt + 1, "timeout_seconds": settings.WORKER_STAGE_TIMEOUT, }, @@ -968,7 +1339,7 @@ async def execute_stage( model=runtime_overrides["model"], temperature=runtime_overrides["temperature"], max_tokens=runtime_overrides["max_tokens"], - max_turns=runtime_overrides["max_turns"], + max_turns=stage_max_turns, extra_skill_dirs=runtime_overrides["extra_skill_dirs"], system_prompt_append=runtime_overrides["system_prompt_append"], ) @@ -1015,9 +1386,15 @@ async def execute_stage( elapsed = time.monotonic() - start_time output = response.text_content total_tokens = runner.cumulative_usage.total_tokens + restart_context = { + "task_title": task.title, + "task_description": task.description, + "stage_name": stage.stage_name, + "preflight_summary": preflight_summary, + } output, total_tokens = await _handle_continuations( - runner, output, runtime_overrides, tracker + runner, output, runtime_overrides, tracker, stage.stage_name, restart_context ) # Phase 2.2: Evaluator-optimizer loop (if configured for this stage) @@ -1194,6 +1571,7 @@ async def execute_stage_sandboxed( compressed_outputs: Optional[List[Dict[str, str]]] = None, project_memory: Optional[str] = None, repo_context: Optional[str] = None, + preflight_summary: Optional[str] = None, retry_context: Optional[Dict[str, str]] = None, stage_model: Optional[str] = None, custom_instruction: Optional[str] = None, @@ -1250,6 +1628,7 @@ async def execute_stage_sandboxed( compressed_outputs=compressed_outputs, project_memory=project_memory, repo_context=repo_context, + preflight_summary=preflight_summary, retry_context=retry_context, custom_instruction=custom_instruction, gate_rejection_context=gate_rejection_context, @@ -1269,8 +1648,7 @@ async def execute_stage_sandboxed( from app.worker.agents import _get_skill_dirs skill_dirs = [f"/skills/{d.name}" for d in _get_skill_dirs(stage.agent_role)] - max_turns_map = {"spec": 20, "coding": 20, "doc": 20, "test": 20} - max_turns = max_turns_map.get(stage.agent_role, 10) + max_turns = _resolve_stage_max_turns(stage.agent_role, runtime_overrides["max_turns"]) # 5. Log the request via shared pipeline contract pipeline = get_task_log_pipeline() @@ -1291,6 +1669,7 @@ async def execute_stage_sandboxed( "model": resolved_model, "temperature": runtime_overrides.get("temperature"), "max_tokens": runtime_overrides.get("max_tokens"), + "max_turns": max_turns, "stage": stage.stage_name, "agent_role": stage.agent_role, "prompt": user_prompt, diff --git a/platform/app/worker/prompts.py b/platform/app/worker/prompts.py index 597b8b4..9162177 100644 --- a/platform/app/worker/prompts.py +++ b/platform/app/worker/prompts.py @@ -1,9 +1,27 @@ """Role-based system prompts and stage instruction templates for Agent Worker.""" from __future__ import annotations +import re from dataclasses import dataclass from typing import Dict, List, Optional +_EXECUTION_STAGE_NAMES = {"code", "coding", "test"} +_EXECUTION_MEMORY_LIMIT = 320 +_EXECUTION_REPO_HINT_LIMIT = 720 +_EXECUTION_PRIOR_LIMITS = { + "parse": 520, + "approve": 520, + "spec": 720, + "review": 720, + "doc": 720, + "code": 960, + "coding": 960, + "test": 960, + "signoff": 960, +} +_EXECUTION_PRIOR_MARKER = "\n...(前序阶段产出已截断)" +_REPO_SECTION_PATTERN = re.compile(r"^###\s+(?P[^\n]+)\n", re.MULTILINE) + # --------------------------------------------------------------------------- # System prompts per agent role @@ -130,13 +148,19 @@ STAGE_GUARDRAILS: Dict[str, str] = { "code": ( "只完成当前阶段,不要提前执行后续阶段任务。\n" - "你可以为了验证实现而运行必要命令,但不要提前生成最终签收/验收报告," - "也不要调用 signoff、review、smoke、e2e-test 等后续阶段能力。\n" + "不要为了理解整个仓库而广泛探索,优先基于已知信息直接实现。\n" + "只有在缺少关键实现信息时才少量补读文件;最多再检查 3 个关键文件或执行 1 次探索性目录命令," + "之后必须开始修改代码。\n" + "你可以为了验证实现而运行必要命令,但目标必须是最小必要验证。\n" + "不要提前生成最终签收/验收报告,也不要调用 signoff、review、smoke、e2e-test 等后续阶段能力。\n" "完成实现并简要总结本阶段改动后结束。" ), "test": ( "只完成当前阶段,不要提前执行后续阶段任务。\n" - "请聚焦当前任务直接相关的自动化测试与验证;如果相关测试已经通过,且已覆盖验收标准,请立即停止。\n" + "请聚焦当前任务直接相关的自动化测试与验证,优先最小、最相关、最快的验证路径。\n" + "最多再补读 2 个关键文件、执行 2 条验证命令;超过后必须停止扩展并给出结论。\n" + "如果验证命令失败,必须明确给出失败命令、关键报错和阻塞点;不要只根据代码阅读就判定测试通过。\n" + "如果相关测试已经通过,且已满足验收标准,请立即停止。\n" "不要继续扩展额外类型的测试,例如 E2E、冒烟、性能或签收报告,除非任务明确要求。" ), "signoff": ( @@ -163,6 +187,7 @@ class StageContext: compressed_outputs: Optional[List[Dict[str, str]]] = None # sliding-window compressed project_memory: Optional[str] = None # injected project memory text repo_context: Optional[str] = None # injected repo context (tech stack + dir tree) + preflight_summary: Optional[str] = None # deterministic stage-local workspace scan summary # Smart retry: failure context from previous attempt (Ralph Loop V2 pattern) retry_context: Optional[Dict[str, str]] = None # {"error": msg, "prior_output": text} # Phase 1.4: Custom instruction from template stage definition @@ -171,6 +196,142 @@ class StageContext: gate_rejection_context: Optional[Dict[str, str]] = None # {"comment": ..., "retry": "2/3"} +def _clip_stage_context(value: Optional[str], *, limit: int, marker: str) -> Optional[str]: + text = (value or "").strip() + if not text: + return None + if len(text) <= limit: + return text + keep_len = max(0, limit - len(marker)) + return text[:keep_len].rstrip() + marker + + +def _is_execution_stage(stage_name: str) -> bool: + return (stage_name or "").strip().lower() in _EXECUTION_STAGE_NAMES + + +def _extract_repo_section(repo_context: str, title: str) -> str: + text = (repo_context or "").strip() + if not text: + return "" + matches = list(_REPO_SECTION_PATTERN.finditer(text)) + for index, match in enumerate(matches): + if match.group("title").strip() != title: + continue + start = match.end() + end = matches[index + 1].start() if index + 1 < len(matches) else len(text) + return text[start:end].strip() + return "" + + +def _collect_tree_matches( + tree_lines: list[str], + *, + predicates: tuple[str, ...], + limit: int, + require_file: bool = False, +) -> list[str]: + matches: list[str] = [] + seen: set[str] = set() + for raw in tree_lines: + line = raw.strip() + lowered = line.lower() + if not line or line.startswith("...(目录树已截断)"): + continue + if require_file and "." not in line.rsplit("/", 1)[-1]: + continue + if not any(token in lowered for token in predicates): + continue + if line in seen: + continue + seen.add(line) + matches.append(line) + if len(matches) >= limit: + break + return matches + + +def _build_execution_repo_hint(repo_context: Optional[str]) -> Optional[str]: + text = (repo_context or "").strip() + if not text: + return None + + tech_stack = _extract_repo_section(text, "技术栈") + repo_tree = _extract_repo_section(text, "目录结构") + repo_tree_lines = [line for line in repo_tree.splitlines() if line.strip()] + + build_files = _collect_tree_matches( + repo_tree_lines, + predicates=( + "build.gradle", + "build.gradle.kts", + "pom.xml", + "package.json", + "pyproject.toml", + "go.mod", + "cargo.toml", + ), + limit=3, + require_file=True, + ) + source_roots = _collect_tree_matches( + repo_tree_lines, + predicates=("src/main", "app/", "app\\", "server/", "lib/", "internal/"), + limit=3, + ) + test_roots = _collect_tree_matches( + repo_tree_lines, + predicates=("src/test", "tests/", "__tests__", "spec/"), + limit=3, + ) + impl_refs = _collect_tree_matches( + repo_tree_lines, + predicates=("controller", "handler", "service", "api", "route", "response"), + limit=2, + require_file=True, + ) + + parts: list[str] = [] + if tech_stack: + parts.append(f"- 技术栈: {tech_stack[:180].strip()}") + if build_files: + parts.append(f"- 构建入口: {', '.join(build_files)}") + if source_roots: + parts.append(f"- 源码目录: {', '.join(source_roots)}") + if test_roots: + parts.append(f"- 测试目录: {', '.join(test_roots)}") + if impl_refs: + parts.append(f"- 参考实现: {', '.join(impl_refs)}") + + if not parts: + return _clip_stage_context( + text, + limit=_EXECUTION_REPO_HINT_LIMIT, + marker="...(执行阶段仓库信息已截断)", + ) + + return _clip_stage_context( + "\n".join(parts), + limit=_EXECUTION_REPO_HINT_LIMIT, + marker="...(执行阶段仓库信息已截断)", + ) + + +def _clip_execution_prior_outputs(prior: List[Dict[str, str]]) -> List[Dict[str, str]]: + clipped: List[Dict[str, str]] = [] + for item in prior: + stage = str(item.get("stage") or "").strip() + output = str(item.get("output") or "") + limit = _EXECUTION_PRIOR_LIMITS.get(stage.lower(), 720) + clipped_output = _clip_stage_context( + output, + limit=limit, + marker=_EXECUTION_PRIOR_MARKER, + ) or "" + clipped.append({"stage": stage, "output": clipped_output}) + return clipped + + def build_user_prompt(ctx: StageContext) -> str: """Build the user prompt text for an AgentRunner chat call. @@ -184,16 +345,34 @@ def build_user_prompt(ctx: StageContext) -> str: if ctx.task_description: parts.append(f"\n{ctx.task_description}") + repo_context = ctx.repo_context + project_memory = ctx.project_memory + if _is_execution_stage(ctx.stage_name): + if ctx.preflight_summary: + repo_context = None + else: + repo_context = _build_execution_repo_hint(repo_context) + project_memory = _clip_stage_context( + project_memory, + limit=_EXECUTION_MEMORY_LIMIT, + marker="...(执行阶段记忆已截断)", + ) + # Inject repo context (tech stack + directory structure) - if ctx.repo_context: - parts.append(f"\n## 项目代码库信息\n{ctx.repo_context}") + if repo_context: + parts.append(f"\n## 项目代码库信息\n{repo_context}") # Inject project memory from historical tasks - if ctx.project_memory: - parts.append(f"\n## 项目上下文(来自历史任务)\n{ctx.project_memory}") + if project_memory: + parts.append(f"\n## 项目上下文(来自历史任务)\n{project_memory}") + + if ctx.preflight_summary: + parts.append(f"\n## 阶段预扫摘要\n{ctx.preflight_summary}") # Use compressed outputs (sliding-window) when available, otherwise raw prior = ctx.compressed_outputs if ctx.compressed_outputs is not None else ctx.prior_outputs + if prior and _is_execution_stage(ctx.stage_name): + prior = _clip_execution_prior_outputs(prior) if prior: parts.append("\n## 前序阶段产出") for po in prior: diff --git a/platform/app/worker/sandbox.py b/platform/app/worker/sandbox.py index f84a703..83b1be6 100644 --- a/platform/app/worker/sandbox.py +++ b/platform/app/worker/sandbox.py @@ -457,6 +457,24 @@ def _build_docker_run_cmd( ) capture_model_api_raw = False + gradle_cache_host_dir = Path(settings.SANDBOX_GRADLE_CACHE_HOST_DIR).expanduser() + gradle_cache_container_dir = str(settings.SANDBOX_GRADLE_USER_HOME).strip() or "/var/lib/silicon_agent/gradle-cache" + try: + gradle_cache_host_dir.mkdir(parents=True, exist_ok=True) + parts.extend( + [ + "--mount", + f"type=bind,src={gradle_cache_host_dir},dst={gradle_cache_container_dir}", + ] + ) + except Exception: + logger.warning( + "Failed to prepare gradle cache mount directory: %s", + gradle_cache_host_dir, + exc_info=True, + ) + gradle_cache_container_dir = "/tmp/.gradle" + if settings.SANDBOX_READONLY_ROOT: parts.append("--read-only") parts.extend(["--tmpfs", "/tmp:size=512m"]) @@ -478,10 +496,24 @@ def _build_docker_run_cmd( [ "-e", f"SANDBOX_DUMP_MODEL_API_RESPONSE={'true' if capture_model_api_raw else 'false'}", + "-e", + f"GRADLE_USER_HOME={gradle_cache_container_dir}", + "-e", + f"SANDBOX_DEFAULT_JAVA_VERSION={int(settings.SANDBOX_DEFAULT_JAVA_VERSION)}", + "-e", + f"SANDBOX_GRADLE_WRAPPER_PREWARM={'true' if settings.SANDBOX_GRADLE_WRAPPER_PREWARM else 'false'}", + "-e", + f"SANDBOX_GRADLE_WRAPPER_PREWARM_TIMEOUT_SECONDS={int(settings.SANDBOX_GRADLE_WRAPPER_PREWARM_TIMEOUT_SECONDS)}", ] ) if capture_model_api_raw and container_raw_log_path: parts.extend(["-e", f"SANDBOX_MODEL_API_RAW_LOG_PATH={container_raw_log_path}"]) + parts.extend( + [ + "-e", + f"SANDBOX_GRADLE_CMD_TIMEOUT_SECONDS={int(settings.SANDBOX_GRADLE_CMD_TIMEOUT_SECONDS)}", + ] + ) parts.append(image) diff --git a/platform/sandbox/Dockerfile.base b/platform/sandbox/Dockerfile.base index d26014d..7c7f017 100644 --- a/platform/sandbox/Dockerfile.base +++ b/platform/sandbox/Dockerfile.base @@ -17,11 +17,13 @@ COPY sandbox/requirements.txt /app/requirements.txt RUN pip install --no-cache-dir -r /app/requirements.txt # Copy agent server -COPY sandbox/agent_server.py /app/agent_server.py -COPY sandbox/tool_policy.py /app/tool_policy.py +COPY --chown=agent:agent sandbox/agent_server.py /app/agent_server.py +COPY --chown=agent:agent sandbox/tool_policy.py /app/tool_policy.py -# Copy skills -COPY skills/ /skills/ +# Copy skills and keep them readable for sandbox processes that run as +# the workspace owner instead of the image's default agent user. +COPY --chown=agent:agent skills/ /skills/ +RUN chmod -R a+rX /app /skills EXPOSE 9090 diff --git a/platform/sandbox/Dockerfile.coding b/platform/sandbox/Dockerfile.coding index 9299e6d..4cae582 100644 --- a/platform/sandbox/Dockerfile.coding +++ b/platform/sandbox/Dockerfile.coding @@ -2,22 +2,40 @@ FROM eclipse-temurin:8-jdk AS jdk8 FROM eclipse-temurin:17-jdk AS jdk17 FROM silicon-agent-sandbox:base +ARG GRADLE_VERSION=8.5 +ENV GRADLE_PREWARM_USER_HOME=/opt/gradle-prewarm + USER root # Development toolchain RUN apt-get update && apt-get install -y --no-install-recommends \ build-essential gcc g++ make \ nodejs npm \ + unzip \ && rm -rf /var/lib/apt/lists/* +# Install Gradle from the official distribution so we can reuse the bundled +# JDK 17 and avoid pulling Debian's extra OpenJDK runtime chain. +RUN wget -q "https://services.gradle.org/distributions/gradle-${GRADLE_VERSION}-bin.zip" -O /tmp/gradle.zip && \ + unzip -q /tmp/gradle.zip -d /opt && \ + ln -s "/opt/gradle-${GRADLE_VERSION}/bin/gradle" /usr/local/bin/gradle && \ + rm -f /tmp/gradle.zip + COPY --from=jdk8 /opt/java/openjdk /opt/jdk8 COPY --from=jdk17 /opt/java/openjdk /opt/jdk17 +COPY sandbox/scripts/prewarm_gradle_cache.sh /usr/local/bin/prewarm_gradle_cache.sh ENV JAVA8_HOME=/opt/jdk8 ENV JAVA17_HOME=/opt/jdk17 ENV JAVA_HOME=/opt/jdk17 ENV PATH="${JAVA_HOME}/bin:${PATH}" +RUN ln -sf /opt/jdk17/bin/java /usr/local/bin/java && \ + ln -sf /opt/jdk17/bin/javac /usr/local/bin/javac + +RUN chmod +x /usr/local/bin/prewarm_gradle_cache.sh && \ + /usr/local/bin/prewarm_gradle_cache.sh + # Common Python dev tools RUN pip install --no-cache-dir \ pytest ruff black mypy \ diff --git a/platform/sandbox/agent_server.py b/platform/sandbox/agent_server.py index 1edf197..fc983c3 100644 --- a/platform/sandbox/agent_server.py +++ b/platform/sandbox/agent_server.py @@ -14,6 +14,7 @@ import logging import os import re +import shlex import sys import time from datetime import datetime, timezone @@ -63,19 +64,40 @@ "build.gradle", "build.gradle.kts", "gradle.properties", + "settings.gradle", + "settings.gradle.kts", + ".java-version", + ".tool-versions", ) _JAVA8_PATTERNS = ( r"<java\.version>\s*(?:1\.8|8)\s*</java\.version>", r"<maven\.compiler\.(?:source|target|release)>\s*(?:1\.8|8)\s*</maven\.compiler\.(?:source|target|release)>", r"sourceCompatibility\s*=\s*(?:['\"]?1\.8['\"]?|JavaVersion\.VERSION_1_8)", r"targetCompatibility\s*=\s*(?:['\"]?1\.8['\"]?|JavaVersion\.VERSION_1_8)", + r"JavaLanguageVersion\.of\(\s*(?:1\.8|8)\s*\)", + r"^\s*8(?:\.\d+)?\s*$", + r"(?m)^\s*java\s+(?:temurin-)?(?:1\.8|8)\s*$", ) _JAVA17_PATTERNS = ( r"<java\.version>\s*17\s*</java\.version>", r"<maven\.compiler\.(?:source|target|release)>\s*17\s*</maven\.compiler\.(?:source|target|release)>", r"sourceCompatibility\s*=\s*(?:['\"]?17['\"]?|JavaVersion\.VERSION_17)", r"targetCompatibility\s*=\s*(?:['\"]?17['\"]?|JavaVersion\.VERSION_17)", + r"JavaLanguageVersion\.of\(\s*17\s*\)", + r"^\s*17(?:\.\d+)?\s*$", + r"(?m)^\s*java\s+(?:temurin-)?17\s*$", ) +_JAVA_VERSION_MISMATCH_PATTERNS = ( + r"Unsupported class file major version", + r"invalid source release", + r"release version \d+ not supported", + r"Could not target platform", +) +_GRADLE_ANY_CMD_RE = re.compile(r"(?<![\w./-])(?:gradle|(?:sh\s+)?(?:\./)?gradlew)(?![\w.-])") +_RUNTIME_PREFLIGHT_DONE = False +_RUNTIME_PREFLIGHT_LOCK = asyncio.Lock() +_WRAPPER_PREWARM_DONE = False +_WRAPPER_PREWARM_LOCK = asyncio.Lock() def _normalize_openai_base_url(base_url: str | None) -> str: @@ -114,9 +136,14 @@ def _detect_java_major_version(workdir: str) -> int | None: def _configure_java_runtime_for_workspace(workdir: str) -> int | None: - major = _detect_java_major_version(workdir) - if major is None: - return None + override_raw = (os.environ.get("SANDBOX_JAVA_VERSION") or "").strip() + if override_raw in {"8", "17"}: + major = int(override_raw) + else: + default_major = _env_int("SANDBOX_DEFAULT_JAVA_VERSION", 8) + if default_major not in {8, 17}: + default_major = 8 + major = _detect_java_major_version(workdir) or default_major java_home_key = "JAVA8_HOME" if major == 8 else "JAVA17_HOME" target_java_home = (os.environ.get(java_home_key) or "").strip() @@ -142,6 +169,14 @@ def _configure_java_runtime_for_workspace(workdir: str) -> int | None: return major +def _should_retry_with_other_java(output: str) -> bool: + text = str(output or "") + return any( + re.search(pattern, text, re.IGNORECASE) + for pattern in _JAVA_VERSION_MISMATCH_PATTERNS + ) + + def _is_gemini_model(model: str | None) -> bool: return "gemini" in ((model or "").lower()) @@ -168,6 +203,128 @@ def _sanitize_reasoning_kwargs_for_model( return sanitized_kwargs +def _env_flag(name: str, default: bool) -> bool: + raw = (os.environ.get(name) or "").strip().lower() + if not raw: + return default + return raw in {"1", "true", "yes", "on"} + + +def _env_int(name: str, default: int) -> int: + raw = (os.environ.get(name) or "").strip() + if not raw: + return default + try: + parsed = int(raw) + except ValueError: + return default + return parsed if parsed > 0 else default + + +def _wrap_with_timeout(command: str, timeout_seconds: int) -> str: + if timeout_seconds <= 0: + return command + payload = ( + "if command -v timeout >/dev/null 2>&1; " + f"then timeout {timeout_seconds}s bash -lc {shlex.quote(command)}; " + f"else bash -lc {shlex.quote(command)}; fi" + ) + return payload + + +async def _run_runtime_preflight_once() -> None: + global _RUNTIME_PREFLIGHT_DONE + if _RUNTIME_PREFLIGHT_DONE: + return + + async with _RUNTIME_PREFLIGHT_LOCK: + if _RUNTIME_PREFLIGHT_DONE: + return + + async def _run_line(cmd: str, *, timeout: float = 5.0) -> str: + proc = await asyncio.create_subprocess_exec( + "sh", "-lc", cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + ) + try: + out, _ = await asyncio.wait_for(proc.communicate(), timeout=timeout) + except Exception: + with contextlib.suppress(Exception): + proc.kill() + await proc.communicate() + return "" + return (out or b"").decode("utf-8", errors="ignore").strip() + + gradle_line = await _run_line("gradle -v 2>&1 | head -n 5") + java_line = await _run_line("java -version 2>&1 | head -n 3") + if gradle_line: + logger.info("sandbox_runtime_preflight gradle=%s", gradle_line.replace("\n", " | ")) + else: + logger.info("sandbox_runtime_preflight gradle=unavailable") + if java_line: + logger.info("sandbox_runtime_preflight java=%s", java_line.replace("\n", " | ")) + else: + logger.info("sandbox_runtime_preflight java=unavailable") + + _RUNTIME_PREFLIGHT_DONE = True + + +async def _run_gradle_wrapper_prewarm_once(workdir: str) -> None: + global _WRAPPER_PREWARM_DONE + if _WRAPPER_PREWARM_DONE: + return + if not _env_flag("SANDBOX_GRADLE_WRAPPER_PREWARM", True): + return + + gradlew = Path(workdir) / "gradlew" + if not gradlew.is_file(): + return + + timeout_seconds = _env_int("SANDBOX_GRADLE_WRAPPER_PREWARM_TIMEOUT_SECONDS", 180) + cmd = _wrap_with_timeout(f"cd {shlex.quote(workdir)} && ./gradlew --version", timeout_seconds) + + async with _WRAPPER_PREWARM_LOCK: + if _WRAPPER_PREWARM_DONE: + return + + started = time.monotonic() + proc = await asyncio.create_subprocess_exec( + "sh", "-lc", cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + ) + try: + out, _ = await proc.communicate() + elapsed_ms = int((time.monotonic() - started) * 1000) + output = (out or b"").decode("utf-8", errors="ignore") + if proc.returncode == 0: + _WRAPPER_PREWARM_DONE = True + logger.info( + "gradle_wrapper_prewarm status=success workdir=%s elapsed_ms=%d gradle_user_home=%s", + workdir, + elapsed_ms, + os.environ.get("GRADLE_USER_HOME", ""), + ) + return + + logger.warning( + "gradle_wrapper_prewarm status=failed workdir=%s elapsed_ms=%d exit_code=%s output=%s", + workdir, + elapsed_ms, + proc.returncode, + output[-600:].replace("\n", " | "), + ) + except Exception as exc: + elapsed_ms = int((time.monotonic() - started) * 1000) + logger.warning( + "gradle_wrapper_prewarm status=error workdir=%s elapsed_ms=%d error=%s", + workdir, + elapsed_ms, + exc, + ) + + def _extract_gemini_thought_signatures_from_response(response_obj: Any) -> dict[str, str]: """Extract OpenAI-compat thought signatures keyed by tool_call id.""" if hasattr(response_obj, "model_dump"): @@ -463,6 +620,46 @@ async def _execute_tool_base(self, tool_call, on_output=None) -> str: async def _execute_tool(self, tool_call, on_output=None): return await self._execute_tool_with_policy(tool_call, on_output=on_output) + def _preprocess_validated_tool_call( + self, + *, + tool_name: str, + args: dict[str, Any], + tool_call: dict[str, Any], + ) -> tuple[dict[str, Any], dict[str, Any], str | None, str | None]: + if tool_name != "execute": + return super()._preprocess_validated_tool_call( + tool_name=tool_name, + args=args, + tool_call=tool_call, + ) + + command = str(args.get("command") or "").strip() + if not command or not _GRADLE_ANY_CMD_RE.search(command): + return super()._preprocess_validated_tool_call( + tool_name=tool_name, + args=args, + tool_call=tool_call, + ) + + timeout_seconds = _env_int("SANDBOX_GRADLE_CMD_TIMEOUT_SECONDS", 480) + rewritten = command + if timeout_seconds > 0: + rewritten = _wrap_with_timeout(rewritten, timeout_seconds) + + if rewritten != command: + logger.info( + "gradle_command_wrap strategy=wrapper rewritten=true original_command=%s rewritten_command=%s", + command, + rewritten, + ) + + updated_args = dict(args) + updated_args["command"] = rewritten + normalized_tool_call = dict(tool_call) + normalized_tool_call["arguments"] = json.dumps(updated_args, ensure_ascii=False) + return normalized_tool_call, updated_args, None, None + def _on_tool_validation_error( self, *, @@ -585,6 +782,58 @@ def _create_runner(parsed: dict[str, Any]) -> ContainerAgentRunner: return runner +def _build_restart_prompt( + user_prompt: str, + text_content: str, + tool_calls: list[dict[str, Any]], + *, + reason: str, +) -> str: + prompt_excerpt = (user_prompt or "").strip() + if len(prompt_excerpt) > 1200: + prompt_excerpt = prompt_excerpt[:1200] + "\n...(任务上下文已截断)" + partial_output = (text_content or "").replace( + "[Max turns reached. Please continue the conversation.]", + "", + ).strip() + if len(partial_output) > 1200: + partial_output = partial_output[:1200] + "\n...(已有输出已截断)" + + digest_lines: list[str] = [] + for item in tool_calls[-4:]: + status = str(item.get("status") or "success").upper() + tool_name = str(item.get("tool_name") or "tool") + preview = str(item.get("result_preview") or "").strip() + if len(preview) > 240: + preview = preview[:240] + "...[truncated]" + line = f"- [{status}] {tool_name}" + if preview: + line += f"\n 结果: {preview}" + digest_lines.append(line) + + action_prompt = ( + "请停止继续广泛探索,直接完成最小必要工作。" + if reason == "forced_convergence" + else "请不要重复整段历史,只基于当前状态继续完成剩余必要内容。" + ) + + parts = [ + "## 原始任务摘要", + prompt_excerpt or "(无)", + ] + if partial_output: + parts.extend(["\n## 当前阶段已有部分输出", partial_output]) + if digest_lines: + parts.extend(["\n## 最近关键工具结果", "\n".join(digest_lines)]) + parts.extend( + [ + "\n## 下一步要求", + action_prompt, + ] + ) + return "\n".join(parts).strip() + + async def _run_stage_chat( runner: ContainerAgentRunner, *, @@ -610,8 +859,14 @@ async def _run_stage_chat( max_continuations, ) try: + restart_prompt = _build_restart_prompt( + user_prompt, + text_content, + runner.tool_calls_log, + reason="truncation", + ) cont = await asyncio.wait_for( - runner.chat("请继续完成上面的输出,从你停下的地方继续。", reset=False), + runner.chat(restart_prompt, reset=True), timeout=timeout, ) cont_text = cont.text_content or "" @@ -747,7 +1002,9 @@ async def handle_execute(request: web.Request) -> web.Response: parsed["workdir"], parsed["timeout"], ) + await _run_runtime_preflight_once() _configure_java_runtime_for_workspace(parsed["workdir"]) + await _run_gradle_wrapper_prewarm_once(parsed["workdir"]) try: runner = _create_runner(parsed) @@ -818,7 +1075,9 @@ async def handle_execute_stream(request: web.Request) -> web.StreamResponse: parsed["workdir"], parsed["timeout"], ) + await _run_runtime_preflight_once() _configure_java_runtime_for_workspace(parsed["workdir"]) + await _run_gradle_wrapper_prewarm_once(parsed["workdir"]) try: runner = _create_runner(parsed) diff --git a/platform/sandbox/scripts/prewarm_gradle_cache.sh b/platform/sandbox/scripts/prewarm_gradle_cache.sh new file mode 100644 index 0000000..ddf5193 --- /dev/null +++ b/platform/sandbox/scripts/prewarm_gradle_cache.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash +set -euo pipefail + +GRADLE_PREWARM_USER_HOME="${GRADLE_PREWARM_USER_HOME:-/opt/gradle-prewarm}" +mkdir -p "${GRADLE_PREWARM_USER_HOME}" + +run_gradle() { + local version="$1" + local install_dir="/opt/gradle-${version}" + + if [ ! -x "${install_dir}/bin/gradle" ]; then + return 0 + fi + + "${install_dir}/bin/gradle" \ + --gradle-user-home "${GRADLE_PREWARM_USER_HOME}" \ + --no-daemon \ + -v >/dev/null 2>&1 || true +} + +run_gradle "8.5" diff --git a/platform/tests/fixtures/sandbox/java17-springboot-gradle/build.gradle.kts b/platform/tests/fixtures/sandbox/java17-springboot-gradle/build.gradle.kts new file mode 100644 index 0000000..836d6da --- /dev/null +++ b/platform/tests/fixtures/sandbox/java17-springboot-gradle/build.gradle.kts @@ -0,0 +1,13 @@ +plugins { + java + id("org.springframework.boot") version "3.2.4" +} + +group = "com.example" +version = "0.0.1-SNAPSHOT" + +java { + toolchain { + languageVersion.set(JavaLanguageVersion.of(17)) + } +} diff --git a/platform/tests/fixtures/sandbox/java8-springboot-gradle/build.gradle b/platform/tests/fixtures/sandbox/java8-springboot-gradle/build.gradle new file mode 100644 index 0000000..1ea6620 --- /dev/null +++ b/platform/tests/fixtures/sandbox/java8-springboot-gradle/build.gradle @@ -0,0 +1,9 @@ +plugins { + id 'java' + id 'org.springframework.boot' version '2.7.18' +} + +group = 'com.example' +version = '0.0.1-SNAPSHOT' + +sourceCompatibility = JavaVersion.VERSION_1_8 diff --git a/platform/tests/test_agents.py b/platform/tests/test_agents.py index fbfe1fb..7b7df35 100644 --- a/platform/tests/test_agents.py +++ b/platform/tests/test_agents.py @@ -14,11 +14,27 @@ def test_role_tools_all_valid(): def test_coding_has_core_tools(): - assert {"read", "write", "edit", "execute", "execute_script", "skill"}.issubset(ROLE_TOOLS["coding"]) + assert {"read", "write", "edit", "execute", "execute_script"}.issubset(ROLE_TOOLS["coding"]) + assert "skill" not in ROLE_TOOLS["coding"] def test_test_has_core_tools(): - assert {"read", "write", "edit", "execute", "execute_script", "skill"}.issubset(ROLE_TOOLS["test"]) + assert {"read", "write", "edit", "execute", "execute_script"}.issubset(ROLE_TOOLS["test"]) + assert "skill" not in ROLE_TOOLS["test"] + + +def test_coding_skill_dirs_exclude_shared_by_default(): + dirs = agents_mod._get_skill_dirs("coding") + rendered = [p.name for p in dirs] + assert rendered == [] + assert "shared" not in rendered + + +def test_test_skill_dirs_exclude_shared_by_default(): + dirs = agents_mod._get_skill_dirs("test") + rendered = [p.name for p in dirs] + assert rendered == [] + assert "shared" not in rendered def test_spec_no_execute(): diff --git a/platform/tests/test_agents_api.py b/platform/tests/test_agents_api.py index 0cd3099..e24773b 100644 --- a/platform/tests/test_agents_api.py +++ b/platform/tests/test_agents_api.py @@ -154,6 +154,28 @@ async def list_models(self): assert data["role_defaults"]["coding"] in data["available_models"] +@pytest.mark.asyncio +async def test_get_agent_config_options_uses_lightweight_orchestrator_and_test_defaults( + client, monkeypatch +): + """Config options should keep orchestrator/test on lighter defaults while coding follows global.""" + + monkeypatch.setattr(settings, "LLM_API_KEY", "") + monkeypatch.setattr(settings, "LLM_MODEL", "gpt-5.1-codex") + monkeypatch.setattr( + settings, + "LLM_ROLE_MODEL_MAP", + '{"orchestrator":"gpt-4o-mini","test":"gpt-4o-mini"}', + ) + + resp = await client.get("/api/v1/agents/config/options") + assert resp.status_code == 200 + data = resp.json() + assert data["role_defaults"]["orchestrator"] == "gpt-4o-mini" + assert data["role_defaults"]["test"] == "gpt-4o-mini" + assert data["role_defaults"]["coding"] == "gpt-5.1-codex" + + @pytest.mark.asyncio async def test_update_agent_config_rejects_extra_skill_dirs_outside_whitelist( client, seed_agent, tmp_path, monkeypatch diff --git a/platform/tests/test_compressor.py b/platform/tests/test_compressor.py index 05ef85d..a4046cb 100644 --- a/platform/tests/test_compressor.py +++ b/platform/tests/test_compressor.py @@ -56,6 +56,22 @@ def test_compression_result_sliding_window(): assert ctx[3]["output"] == "l2_3_full_content" +def test_compression_result_caps_immediate_prior_l2(): + cr = CompressionResult() + cr.add( + CompressedOutput( + stage_name="parse", + l0="short", + l1="brief", + l2="x" * 10_000, + ) + ) + + ctx = cr.build_prior_context(1) + assert ctx[0]["output"].endswith("...(输出已截断)") + assert len(ctx[0]["output"]) < 10_000 + + @pytest.mark.asyncio async def test_compress_stage_output_fallback(): """When compression is disabled, should use fallback.""" diff --git a/platform/tests/test_engine_stage_execution.py b/platform/tests/test_engine_stage_execution.py index 53a06ed..0e4a308 100644 --- a/platform/tests/test_engine_stage_execution.py +++ b/platform/tests/test_engine_stage_execution.py @@ -164,6 +164,66 @@ async def test_execute_single_stage_reflection_disabled_uses_plain_context(monke await session.commit() +@pytest.mark.asyncio +async def test_execute_single_stage_passes_preflight_summary(monkeypatch, tmp_path): + monkeypatch.setattr(engine.settings, "SANDBOX_ENABLED", False) + monkeypatch.setattr(engine.settings, "MEMORY_ENABLED", False) + execute_stage_mock = AsyncMock(return_value="stage output") + monkeypatch.setattr(engine, "execute_stage", execute_stage_mock) + monkeypatch.setattr(engine, "execute_stage_sandboxed", execute_stage_mock) + monkeypatch.setattr(engine, "_emit_system_log", AsyncMock(return_value="log-id")) + monkeypatch.setattr(engine, "_close_started_system_log", AsyncMock()) + + (tmp_path / "build.gradle").write_text("plugins {}", encoding="utf-8") + (tmp_path / "src/main/java/demo/controller").mkdir(parents=True) + (tmp_path / "src/main/java/demo/controller/HelloController.java").write_text("class X {}", encoding="utf-8") + + task_id = "tt-exec-preflight-1" + async with async_session_factory() as session: + session.add(TaskModel(id=task_id, title="Preflight Test", status="running")) + await session.commit() + + async with async_session_factory() as session: + task = await session.get(TaskModel, task_id) + stage = SimpleNamespace( + id="stage-preflight-1", + stage_name="coding", + agent_role="coding", + error_message=None, + output_summary=None, + output_structured=None, + execution_count=0, + status="pending", + ) + + from app.worker.compressor import CompressionResult + compression = CompressionResult() + + result = await engine._execute_single_stage( + session, # type: ignore[arg-type] + task, # type: ignore[arg-type] + stage, # type: ignore[arg-type] + 0, + [], + compression, + None, + None, + {}, + str(tmp_path), + None, + ) + + assert result == "stage output" + call_kwargs = execute_stage_mock.call_args.kwargs + assert "HelloController.java" in (call_kwargs.get("preflight_summary") or "") + + async with async_session_factory() as session: + t = await session.get(TaskModel, task_id) + if t: + await session.delete(t) + await session.commit() + + @pytest.mark.asyncio async def test_execute_single_stage_uses_sandbox(monkeypatch): """sandbox_info is truthy AND agent_role='coding' → calls execute_stage_sandboxed.""" diff --git a/platform/tests/test_executor_stage_logs.py b/platform/tests/test_executor_stage_logs.py index 07d4147..8b9ae2d 100644 --- a/platform/tests/test_executor_stage_logs.py +++ b/platform/tests/test_executor_stage_logs.py @@ -136,6 +136,71 @@ async def emit_update(self, *, log_id: str, updates: dict, priority: str = 'norm return True +class _ContinuationRunner: + def __init__(self, *, response_text: str = 'done') -> None: + self.config = SimpleNamespace(model='test-model') + self.cumulative_usage = SimpleNamespace(total_tokens=11) + self.prompts: list[str] = [] + self.resets: list[bool] = [] + self.response_text = response_text + + async def chat(self, prompt: str, reset: bool = True, **_: object): + self.prompts.append(prompt) + self.resets.append(reset) + return SimpleNamespace(text_content=self.response_text) + + +class _ContinuationTracker: + def __init__(self, stage_name: str, agent_role: str = 'coding') -> None: + self.stage_name = stage_name + self.agent_role = agent_role + self.sent: list[dict[str, object]] = [] + self.received: list[dict[str, object]] = [] + self._forced_convergence_used = False + self._implementation_actions = 0 + self._exploration_actions = 0 + self._verification_failures = 0 + self._successful_verifications = 0 + self._verification_attempts = 0 + + async def emit_chat_sent(self, **kwargs): + self.sent.append(kwargs) + return f"sent-{len(self.sent)}" + + async def emit_chat_received(self, *args, **kwargs): + self.received.append({"args": args, "kwargs": kwargs}) + return True + + def should_force_convergence(self) -> bool: + if self._forced_convergence_used: + return False + if self.stage_name == 'code': + return self._implementation_actions == 0 and self._exploration_actions >= 4 + if self.stage_name == 'test': + if self._verification_failures > 0 and self._successful_verifications == 0: + return True + return self._verification_attempts == 0 and self._exploration_actions >= 3 + return False + + def mark_forced_convergence_used(self) -> None: + self._forced_convergence_used = True + + def get_completed_tool_runs(self): + return [] + + +class _DigestTracker(_ContinuationTracker): + def __init__(self, stage_name: str, agent_role: str = 'coding') -> None: + super().__init__(stage_name=stage_name, agent_role=agent_role) + self._items = [ + {"status": "success", "command": f"cmd-{i}", "result_preview": f"preview-{i}"} + for i in range(4) + ] + + def get_completed_tool_runs(self): + return self._items + + class _CancelledRunner(_FakeRunner): async def chat(self, _prompt: str, reset: bool = True, **_: object): await self.events.emit( @@ -195,6 +260,14 @@ async def execute_stage(self, info, **kwargs): return self._result +def test_output_summary_limit_is_stage_specific(): + assert executor._output_summary_limit('parse') <= 600 + assert executor._output_summary_limit('code') <= 1200 + assert executor._output_summary_limit('test') <= 1200 + assert executor._output_summary_limit('signoff') <= 1600 + assert executor._output_summary_limit('spec') > executor._output_summary_limit('parse') + + def test_is_tool_call_error_matches_gemini_thought_signature_error(): err = RuntimeError( "Error code: 400 - [{'error': {'code': 400, 'message': " @@ -203,6 +276,36 @@ def test_is_tool_call_error_matches_gemini_thought_signature_error(): assert executor._is_tool_call_error(err) is True +def test_classify_tool_activity_marks_exploration_implementation_and_verification(): + assert executor._classify_tool_activity('read', {}) == 'exploration' + assert executor._classify_tool_activity('edit', {}) == 'implementation' + assert executor._classify_tool_activity('execute', {'command': 'find src -name "*.java"'}) == 'exploration' + assert executor._classify_tool_activity('execute', {'command': './gradlew test'}) == 'verification' + + +def test_stage_event_tracker_force_convergence_budget_for_code_and_test(): + tracker = executor.StageEventTracker( + pipeline=_FakePipeline(), + task_id='task-1', + stage_id='stage-1', + stage_name='code', + agent_role='coding', + ) + for _ in range(4): + tracker.record_tool_activity('read', {}, 'success') + assert tracker.should_force_convergence() is True + + test_tracker = executor.StageEventTracker( + pipeline=_FakePipeline(), + task_id='task-2', + stage_id='stage-2', + stage_name='test', + agent_role='test', + ) + test_tracker.record_tool_activity('execute', {'command': './gradlew test'}, 'failed') + assert test_tracker.should_force_convergence() is True + + @pytest.mark.asyncio async def test_execute_stage_falls_back_to_text_only_on_thought_signature_error(monkeypatch): session = SimpleNamespace(commit=AsyncMock()) @@ -275,6 +378,66 @@ def _fallback_runner( assert len(fallback_events) == 1 +@pytest.mark.asyncio +async def test_execute_stage_uses_text_only_runner_for_signoff(monkeypatch): + session = SimpleNamespace(commit=AsyncMock()) + task = SimpleNamespace( + id='task-signoff-1', + title='task title', + description='task description', + total_tokens=0, + total_cost_rmb=0.0, + ) + stage = SimpleNamespace( + id='stage-signoff-1', + stage_name='signoff', + agent_role='orchestrator', + status='pending', + started_at=None, + completed_at=None, + duration_seconds=None, + tokens_used=0, + output_summary=None, + ) + + fake_pipeline = _FakePipeline() + text_only_called = {'value': False} + + monkeypatch.setattr(executor, 'get_task_log_pipeline', lambda: fake_pipeline) + monkeypatch.setattr(executor, '_get_agent', AsyncMock(return_value=None)) + monkeypatch.setattr(executor, '_safe_broadcast', AsyncMock()) + monkeypatch.setattr(executor, 'build_user_prompt', lambda _ctx: 'signoff prompt') + + def _unexpected_agent(*args, **kwargs): + raise AssertionError('signoff should not use tool-enabled get_agent') + + def _text_only_runner( + _role, + _task_id, + model=None, + temperature=None, + max_tokens=None, + max_turns=None, + extra_skill_dirs=None, + system_prompt_append=None, + ): + text_only_called['value'] = True + return _FakeRunner() + + monkeypatch.setattr(executor, 'get_agent', _unexpected_agent) + monkeypatch.setattr(executor, 'get_agent_text_only', _text_only_runner) + + result = await executor.execute_stage( + session=session, + task=task, + stage=stage, + prior_outputs=[], + ) + + assert result == 'stage output' + assert text_only_called['value'] is True + + def test_apply_runner_workspace_override_replaces_prompt_and_cwd(): runner = SimpleNamespace( default_cwd='/tmp/old-workspace', @@ -438,7 +601,7 @@ def _capture_runner( assert captured_params['model'] == 'gpt-5.1-codex-mini' assert captured_params['temperature_override'] == 0.2 assert captured_params['max_tokens_override'] == 1200 - assert captured_params['max_turns'] == 18 + assert captured_params['max_turns'] == 6 assert captured_params['extra_skill_dirs'] == ['/tmp/skills'] assert captured_params['system_prompt_append'] == 'extra prompt' assert captured_params['temperature'] == 0.2 @@ -572,6 +735,190 @@ async def test_execute_stage_cancellation_still_finalizes_started_logs(monkeypat assert isinstance(turn_updates[-1]['updates']['duration_ms'], float) +@pytest.mark.asyncio +async def test_handle_continuations_uses_coding_specific_prompt(): + runner = _ContinuationRunner() + tracker = _ContinuationTracker(stage_name='code', agent_role='coding') + + output, total_tokens = await executor._handle_continuations( + runner, + "[Max turns reached. Please continue the conversation.]", + {}, + tracker, + ) + + assert total_tokens == 11 + assert output == 'done' + assert runner.resets == [True] + assert '## 当前阶段\ncode' in runner.prompts[0] + assert '请停止继续广泛探索。基于已知信息直接修改代码' in runner.prompts[0] + + +@pytest.mark.asyncio +async def test_handle_continuations_uses_test_specific_prompt(): + runner = _ContinuationRunner() + tracker = _ContinuationTracker(stage_name='test', agent_role='test') + + output, total_tokens = await executor._handle_continuations( + runner, + "[Max turns reached. Please continue the conversation.]", + {}, + tracker, + ) + + assert total_tokens == 11 + assert output == 'done' + assert runner.resets == [True] + assert '## 当前阶段\ntest' in runner.prompts[0] + assert '请停止扩展测试范围。只做最小、最相关的验证' in runner.prompts[0] + + +@pytest.mark.asyncio +async def test_handle_continuations_injects_forced_convergence_for_coding_budget(): + runner = _ContinuationRunner() + tracker = _ContinuationTracker(stage_name='code', agent_role='coding') + tracker._exploration_actions = 4 + + output, total_tokens = await executor._handle_continuations( + runner, + 'partial summary', + {}, + tracker, + ) + + assert total_tokens == 11 + assert output == 'partial summary\n\ndone' + assert runner.resets == [True] + assert '禁止继续浏览仓库' in runner.prompts[0] + + +@pytest.mark.asyncio +async def test_handle_continuations_injects_forced_convergence_for_failed_test_verification(): + runner = _ContinuationRunner() + tracker = _ContinuationTracker(stage_name='test', agent_role='test') + tracker._verification_failures = 1 + + output, total_tokens = await executor._handle_continuations( + runner, + 'analysis only', + {}, + tracker, + ) + + assert total_tokens == 11 + assert output == 'analysis only\n\ndone' + assert runner.resets == [True] + assert '禁止继续扩展测试范围' in runner.prompts[0] + + +@pytest.mark.asyncio +async def test_handle_continuations_uses_checkpoint_restart_with_reset_true(): + runner = _ContinuationRunner(response_text='final answer') + tracker = _ContinuationTracker(stage_name='code', agent_role='coding') + + output, total_tokens = await executor._handle_continuations( + runner, + "[Max turns reached. Please continue the conversation.]", + {}, + tracker, + 'code', + { + 'task_title': 'Hello Task', + 'task_description': 'Implement hello endpoint', + 'stage_name': 'code', + 'preflight_summary': '- 构建文件: build.gradle', + }, + ) + + assert total_tokens == 11 + assert output == 'final answer' + assert runner.resets == [True] + assert '## 任务\n**Hello Task**' in runner.prompts[0] + assert '## 阶段预扫摘要' in runner.prompts[0] + assert '不要重新展开整段历史' in runner.prompts[0] + assert tracker.sent[0]['request_body']['restart'] == 1 + assert tracker.sent[0]['request_body']['restart_reason'] == 'truncation' + assert tracker.sent[0]['request_body']['reset'] is True + + +@pytest.mark.asyncio +async def test_handle_continuations_restarts_from_checkpoint_for_forced_convergence(): + runner = _ContinuationRunner(response_text='implemented result') + tracker = _ContinuationTracker(stage_name='code', agent_role='coding') + tracker._exploration_actions = 4 + + output, total_tokens = await executor._handle_continuations( + runner, + 'partial summary', + {}, + tracker, + 'code', + { + 'task_title': 'Hello Task', + 'task_description': 'Implement hello endpoint', + 'stage_name': 'code', + 'preflight_summary': '- 实现参考: src/main/java/demo/HelloController.java', + }, + ) + + assert total_tokens == 11 + assert output == 'partial summary\n\nimplemented result' + assert runner.resets == [True] + assert tracker.sent[0]['request_body']['restart'] == 1 + assert tracker.sent[0]['request_body']['restart_reason'] == 'forced_convergence' + assert tracker.sent[0]['request_body']['forced_convergence'] is True + + +def test_build_stage_restart_prompt_limits_tool_digest_items(): + tracker = _DigestTracker(stage_name='code', agent_role='coding') + prompt = executor._build_stage_restart_prompt( + { + 'task_title': 'Hello Task', + 'task_description': 'Implement hello endpoint', + 'stage_name': 'code', + 'preflight_summary': '- 构建文件: build.gradle', + }, + tracker, + 'partial output', + reason='truncation', + ) + + assert 'cmd-3' in prompt + assert 'cmd-2' in prompt + assert 'cmd-1' not in prompt + assert 'cmd-0' not in prompt + + +def test_resolve_stage_max_turns_caps_coding_and_test(): + assert executor._resolve_stage_max_turns('coding', None) == 6 + assert executor._resolve_stage_max_turns('coding', 18) == 6 + assert executor._resolve_stage_max_turns('test', 12) == 6 + + +def test_resolve_stage_max_turns_preserves_other_roles(): + assert executor._resolve_stage_max_turns('doc', None) == 10 + assert executor._resolve_stage_max_turns('doc', 18) == 18 + + +@pytest.mark.asyncio +async def test_handle_continuations_uses_generic_prompt_for_other_stage(): + runner = _ContinuationRunner() + tracker = _ContinuationTracker(stage_name='review', agent_role='review') + + output, total_tokens = await executor._handle_continuations( + runner, + "[Max turns reached. Please continue the conversation.]", + {}, + tracker, + ) + + assert total_tokens == 11 + assert output == 'done' + assert runner.prompts == [ + '请继续完成上面的输出,从你停下的地方继续。' + ] + + @pytest.mark.asyncio async def test_execute_stage_sandboxed_emits_standardized_pipeline_events(monkeypatch): from app.worker import agents as worker_agents diff --git a/platform/tests/test_project_service.py b/platform/tests/test_project_service.py index 4f5ccf2..a6507ad 100644 --- a/platform/tests/test_project_service.py +++ b/platform/tests/test_project_service.py @@ -302,6 +302,20 @@ async def test_create_project_minimal(): await _cleanup_project(resp.id) +@pytest.mark.asyncio +async def test_create_project_defaults_sandbox_image(): + """create_project falls back to the configured sandbox image when none is provided.""" + name = _unique_name("svc-create-default-image") + request = ProjectCreateRequest(name=name, display_name="Default Image") + async with async_session_factory() as session: + svc = ProjectService(session) + resp = await svc.create_project(request) + + assert resp.sandbox_image == "silicon-agent-sandbox:coding" + + await _cleanup_project(resp.id) + + # ── update_project ──────────────────────────────────────────────────────────── diff --git a/platform/tests/test_prompts.py b/platform/tests/test_prompts.py index a1900c6..c8f9ca4 100644 --- a/platform/tests/test_prompts.py +++ b/platform/tests/test_prompts.py @@ -32,6 +32,23 @@ def test_minimal_title_only(): assert STAGE_INSTRUCTIONS["code"] in result +def test_code_guardrail_emphasizes_convergence(): + ctx = _minimal_ctx(stage_name="code") + result = build_user_prompt(ctx) + assert "不要为了理解整个仓库而广泛探索" in result + assert "最小必要验证" in result + assert "最多再检查 3 个关键文件" in result + + +def test_test_guardrail_emphasizes_minimal_validation(): + ctx = _minimal_ctx(stage_name="test", agent_role="test") + result = build_user_prompt(ctx) + assert "最小、最相关、最快的验证路径" in result + assert "满足验收标准" in result + assert "执行 2 条验证命令" in result + assert "不要只根据代码阅读就判定测试通过" in result + + # --------------------------------------------------------------------------- # With description # --------------------------------------------------------------------------- @@ -60,6 +77,47 @@ def test_with_repo_context(): assert "Python 3.11 / FastAPI" in result +def test_code_stage_clips_large_repo_context(): + repo_context = ( + "### 技术栈\nJava 17, Spring Boot, Gradle\n\n" + "### 目录结构\n" + "build.gradle\n" + "src/main/java/demo/controller/HelloController.java\n" + "src/main/java/demo/service/HelloService.java\n" + "src/test/java/demo/controller/HelloControllerTest.java\n" + "docs/design.md\n" + ) + ctx = _minimal_ctx(stage_name="code", agent_role="coding", repo_context=repo_context) + result = build_user_prompt(ctx) + assert "## 项目代码库信息" in result + assert "- 技术栈: Java 17, Spring Boot, Gradle" in result + assert "- 构建入口: build.gradle" in result + assert "- 源码目录:" in result + assert "- 测试目录:" in result + assert "- 参考实现:" in result + assert "### 目录结构" not in result + + +def test_spec_stage_keeps_full_repo_context(): + repo_context = "STACK\n" + ("src/main/java/demo/File.java\n" * 40) + ctx = _minimal_ctx(stage_name="spec", agent_role="spec", repo_context=repo_context) + result = build_user_prompt(ctx) + assert "...(执行阶段上下文已截断)" not in result + assert repo_context in result + + +def test_code_stage_omits_repo_context_when_preflight_present(): + ctx = _minimal_ctx( + stage_name="code", + agent_role="coding", + repo_context="STACK\nsrc/main/java/demo/File.java", + preflight_summary="- 构建文件: build.gradle", + ) + result = build_user_prompt(ctx) + assert "## 项目代码库信息" not in result + assert "## 阶段预扫摘要" in result + + def test_without_repo_context(): ctx = _minimal_ctx(repo_context=None) result = build_user_prompt(ctx) @@ -77,12 +135,34 @@ def test_with_project_memory(): assert "Previous task: added auth module." in result +def test_test_stage_clips_large_project_memory(): + project_memory = "Memory line\n" * 300 + ctx = _minimal_ctx(stage_name="test", agent_role="test", project_memory=project_memory) + result = build_user_prompt(ctx) + assert "## 项目上下文(来自历史任务)" in result + assert "...(执行阶段记忆已截断)" in result + assert len(result) < len(project_memory) + 500 + + def test_without_project_memory(): ctx = _minimal_ctx(project_memory=None) result = build_user_prompt(ctx) assert "## 项目上下文(来自历史任务)" not in result +def test_with_preflight_summary(): + ctx = _minimal_ctx(preflight_summary="- 构建文件: build.gradle\n- 实现参考: src/main/java/demo/HelloController.java") + result = build_user_prompt(ctx) + assert "## 阶段预扫摘要" in result + assert "HelloController.java" in result + + +def test_without_preflight_summary(): + ctx = _minimal_ctx(preflight_summary=None) + result = build_user_prompt(ctx) + assert "## 阶段预扫摘要" not in result + + # --------------------------------------------------------------------------- # With prior_outputs (raw) # --------------------------------------------------------------------------- @@ -101,6 +181,25 @@ def test_with_prior_outputs_raw(): assert "Spec document:" in result +def test_execution_stage_clips_prior_outputs_aggressively(): + long_parse = "需求分析\n" + ("parse-line\n" * 200) + long_spec = "技术方案\n" + ("spec-line\n" * 200) + ctx = _minimal_ctx( + stage_name="code", + agent_role="coding", + prior_outputs=[ + {"stage": "parse", "output": long_parse}, + {"stage": "spec", "output": long_spec}, + ], + ) + result = build_user_prompt(ctx) + assert "## 前序阶段产出" in result + assert "...(前序阶段产出已截断)" in result + assert "parse-line\nparse-line\nparse-line" in result + assert result.count("parse-line") < 80 + assert result.count("spec-line") < 100 + + def test_with_empty_prior_outputs(): ctx = _minimal_ctx(prior_outputs=[]) result = build_user_prompt(ctx) diff --git a/platform/tests/test_sandbox_agent_server.py b/platform/tests/test_sandbox_agent_server.py index a0bb670..7676227 100644 --- a/platform/tests/test_sandbox_agent_server.py +++ b/platform/tests/test_sandbox_agent_server.py @@ -3,6 +3,7 @@ import importlib import os import sys +from pathlib import Path from types import ModuleType, SimpleNamespace @@ -161,6 +162,29 @@ def test_detect_java_version_finds_java17_markers(tmp_path): assert agent_server._detect_java_major_version(str(tmp_path)) == 17 +def test_detect_java_version_finds_gradle_toolchain_markers(tmp_path): + agent_server = _load_agent_server_with_fake_skillkit() + gradle = tmp_path / "build.gradle.kts" + gradle.write_text( + """ + java { + toolchain { + languageVersion.set(JavaLanguageVersion.of(17)) + } + } + """.strip(), + encoding="utf-8", + ) + assert agent_server._detect_java_major_version(str(tmp_path)) == 17 + + +def test_detect_java_version_returns_none_without_markers(tmp_path): + agent_server = _load_agent_server_with_fake_skillkit() + settings = tmp_path / "settings.gradle" + settings.write_text('rootProject.name = "demo"', encoding="utf-8") + assert agent_server._detect_java_major_version(str(tmp_path)) is None + + def test_configure_java_runtime_sets_java_home_and_path(tmp_path, monkeypatch): agent_server = _load_agent_server_with_fake_skillkit() gradle = tmp_path / "build.gradle" @@ -175,3 +199,120 @@ def test_configure_java_runtime_sets_java_home_and_path(tmp_path, monkeypatch): assert selected == 8 assert os.environ["JAVA_HOME"] == "/opt/jdk8" assert os.environ["PATH"].split(":")[0] == "/opt/jdk8/bin" + + +def test_configure_java_runtime_respects_explicit_override(tmp_path, monkeypatch): + agent_server = _load_agent_server_with_fake_skillkit() + monkeypatch.setenv("SANDBOX_JAVA_VERSION", "17") + monkeypatch.setenv("JAVA17_HOME", "/opt/jdk17") + monkeypatch.setenv("JAVA_HOME", "/opt/jdk8") + monkeypatch.setenv("PATH", "/opt/jdk8/bin:/usr/bin:/bin") + + selected = agent_server._configure_java_runtime_for_workspace(str(tmp_path)) + assert selected == 17 + assert os.environ["JAVA_HOME"] == "/opt/jdk17" + assert os.environ["PATH"].split(":")[0] == "/opt/jdk17/bin" + + +def test_configure_java_runtime_defaults_to_java8_without_markers(tmp_path, monkeypatch): + agent_server = _load_agent_server_with_fake_skillkit() + monkeypatch.delenv("SANDBOX_JAVA_VERSION", raising=False) + monkeypatch.setenv("SANDBOX_DEFAULT_JAVA_VERSION", "8") + monkeypatch.setenv("JAVA8_HOME", "/opt/jdk8") + monkeypatch.setenv("JAVA_HOME", "/opt/jdk17") + monkeypatch.setenv("PATH", "/opt/jdk17/bin:/usr/bin:/bin") + + selected = agent_server._configure_java_runtime_for_workspace(str(tmp_path)) + assert selected == 8 + assert os.environ["JAVA_HOME"] == "/opt/jdk8" + assert os.environ["PATH"].split(":")[0] == "/opt/jdk8/bin" + + +def test_container_runner_keeps_gradlew_and_wraps_timeout(monkeypatch): + agent_server = _load_agent_server_with_fake_skillkit() + monkeypatch.setenv("SANDBOX_GRADLE_CMD_TIMEOUT_SECONDS", "480") + runner = agent_server._create_runner( + { + "skill_dirs": [], + "system_prompt": "system", + "max_turns": 5, + "enable_tools": True, + "model": "gpt-4o", + "temperature": None, + "max_tokens": None, + "workdir": "/workspace", + "allowed_tools": {"execute"}, + } + ) + normalized, args, err, result = runner._preprocess_validated_tool_call( + tool_name="execute", + args={"command": "cd /workspace && ./gradlew test"}, + tool_call={"name": "execute", "arguments": '{"command":"cd /workspace && ./gradlew test"}'}, + ) + assert err is None + assert result is None + assert "timeout 480s bash -lc" in args["command"] + assert "./gradlew test" in args["command"] + assert "gradle test" not in args["command"] + assert normalized["arguments"] + + +def test_run_gradle_wrapper_prewarm_once_marks_done(tmp_path, monkeypatch): + agent_server = _load_agent_server_with_fake_skillkit() + gradlew = tmp_path / "gradlew" + gradlew.write_text("#!/bin/sh\nexit 0\n", encoding="utf-8") + gradlew.chmod(0o755) + monkeypatch.setenv("SANDBOX_GRADLE_WRAPPER_PREWARM", "true") + monkeypatch.setenv("SANDBOX_GRADLE_WRAPPER_PREWARM_TIMEOUT_SECONDS", "30") + agent_server._WRAPPER_PREWARM_DONE = False + import asyncio + asyncio.run(agent_server._run_gradle_wrapper_prewarm_once(str(tmp_path))) + assert agent_server._WRAPPER_PREWARM_DONE is True + + +def test_build_restart_prompt_includes_task_excerpt_and_tool_digest(): + agent_server = _load_agent_server_with_fake_skillkit() + prompt = agent_server._build_restart_prompt( + "## 任务\n实现 hello 接口", + "[Max turns reached. Please continue the conversation.]", + [ + { + "tool_name": "execute", + "result_preview": "src/main/java/demo/HelloController.java", + "status": "success", + } + ], + reason="truncation", + ) + assert "原始任务摘要" in prompt + assert "最近关键工具结果" in prompt + assert "HelloController.java" in prompt + + +def test_should_retry_with_other_java_on_version_mismatch(): + agent_server = _load_agent_server_with_fake_skillkit() + assert agent_server._should_retry_with_other_java("Unsupported class file major version 61") + assert agent_server._should_retry_with_other_java("invalid source release: 17") + assert not agent_server._should_retry_with_other_java("Execution failed for task ':test'") + + +def test_detect_java_version_from_java8_fixture(): + agent_server = _load_agent_server_with_fake_skillkit() + fixture = ( + Path(__file__).resolve().parent + / "fixtures" + / "sandbox" + / "java8-springboot-gradle" + ) + assert agent_server._detect_java_major_version(str(fixture)) == 8 + + +def test_detect_java_version_from_java17_fixture(): + agent_server = _load_agent_server_with_fake_skillkit() + fixture = ( + Path(__file__).resolve().parent + / "fixtures" + / "sandbox" + / "java17-springboot-gradle" + ) + assert agent_server._detect_java_major_version(str(fixture)) == 17 diff --git a/platform/tests/test_sandbox_env_contract.py b/platform/tests/test_sandbox_env_contract.py index a4782b9..66124a3 100644 --- a/platform/tests/test_sandbox_env_contract.py +++ b/platform/tests/test_sandbox_env_contract.py @@ -46,6 +46,13 @@ def test_build_docker_run_cmd_includes_skillkit_compat_env(monkeypatch, tmp_path "SANDBOX_MODEL_API_RAW_LOG_HOST_DIR", str(raw_log_dir), ) + monkeypatch.setattr(sandbox_mod.settings, "SANDBOX_GRADLE_CMD_TIMEOUT_SECONDS", 480) + gradle_cache_dir = tmp_path / "gradle_cache" + monkeypatch.setattr(sandbox_mod.settings, "SANDBOX_GRADLE_CACHE_HOST_DIR", str(gradle_cache_dir)) + monkeypatch.setattr(sandbox_mod.settings, "SANDBOX_GRADLE_USER_HOME", "/var/lib/silicon_agent/gradle-cache") + monkeypatch.setattr(sandbox_mod.settings, "SANDBOX_GRADLE_WRAPPER_PREWARM", True) + monkeypatch.setattr(sandbox_mod.settings, "SANDBOX_GRADLE_WRAPPER_PREWARM_TIMEOUT_SECONDS", 180) + monkeypatch.setattr(sandbox_mod.settings, "SANDBOX_DEFAULT_JAVA_VERSION", 8) backend = DockerSandboxBackend() cmd = backend._build_docker_run_cmd( @@ -66,7 +73,13 @@ def test_build_docker_run_cmd_includes_skillkit_compat_env(monkeypatch, tmp_path assert env["AGENT_PORT"] == "19090" assert env["SANDBOX_DUMP_MODEL_API_RESPONSE"] == "true" assert env["SANDBOX_MODEL_API_RAW_LOG_PATH"] == "/model_api_logs/task-123.jsonl" + assert env["SANDBOX_GRADLE_CMD_TIMEOUT_SECONDS"] == "480" + assert env["GRADLE_USER_HOME"] == "/var/lib/silicon_agent/gradle-cache" + assert env["SANDBOX_DEFAULT_JAVA_VERSION"] == "8" + assert env["SANDBOX_GRADLE_WRAPPER_PREWARM"] == "true" + assert env["SANDBOX_GRADLE_WRAPPER_PREWARM_TIMEOUT_SECONDS"] == "180" assert f"type=bind,src={raw_log_dir},dst=/model_api_logs" in mounts + assert f"type=bind,src={gradle_cache_dir},dst=/var/lib/silicon_agent/gradle-cache" in mounts def test_build_docker_run_cmd_disables_raw_model_dump_when_config_off(monkeypatch, tmp_path): @@ -143,3 +156,28 @@ def test_coding_sandbox_image_provides_java_toolchain(): assert "JAVA8_HOME" in content assert "JAVA17_HOME" in content assert "ENV JAVA_HOME=/opt/jdk17" in content + + +def test_coding_sandbox_image_prepares_offline_gradle_cache(): + dockerfile_path = Path(__file__).resolve().parents[1] / "sandbox" / "Dockerfile.coding" + content = dockerfile_path.read_text(encoding="utf-8") + + assert "sandbox/scripts/prewarm_gradle_cache.sh" in content + assert "GRADLE_PREWARM_USER_HOME" in content + assert "prewarm_gradle_cache.sh" in content + + +def test_base_sandbox_image_makes_runtime_entrypoints_world_readable(): + dockerfile_path = Path(__file__).resolve().parents[1] / "sandbox" / "Dockerfile.base" + content = dockerfile_path.read_text(encoding="utf-8") + + assert "COPY --chown=agent:agent sandbox/agent_server.py /app/agent_server.py" in content + assert "COPY --chown=agent:agent sandbox/tool_policy.py /app/tool_policy.py" in content + assert "RUN chmod -R a+rX /app /skills" in content + + +def test_base_sandbox_image_copies_skills_with_agent_ownership(): + dockerfile_path = Path(__file__).resolve().parents[1] / "sandbox" / "Dockerfile.base" + content = dockerfile_path.read_text(encoding="utf-8") + + assert "COPY --chown=agent:agent skills/ /skills/" in content diff --git a/platform/tests/test_task_service.py b/platform/tests/test_task_service.py index 0020842..219c14b 100644 --- a/platform/tests/test_task_service.py +++ b/platform/tests/test_task_service.py @@ -351,6 +351,70 @@ async def test_create_task_template_not_found(): assert result.id == "new-task-4" +# --------------------------------------------------------------------------- +# clone_task +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_clone_task_not_found(): + """clone_task returns None when source task is missing.""" + session = _make_session() + session.execute.return_value = _mock_result(scalar_one_or_none=None) + svc = TaskService(session) + + result = await svc.clone_task("missing-task") + + assert result is None + + +@pytest.mark.asyncio +async def test_clone_task_reuses_create_task_with_whitelisted_fields(): + """clone_task should create a fresh task from copy-safe source fields only.""" + session = _make_session() + source = _make_task( + id="task-source", + title="Clone Me", + description="Original task body", + status="failed", + jira_id="JIRA-1", + template_id="tmpl-1", + project_id="proj-1", + target_branch="silicon_agent/source", + yunxiao_task_id="YX-1", + branch_name="feature/source", + pr_url="https://example.com/pr/1", + ) + session.execute.return_value = _mock_result(scalar_one_or_none=source) + svc = TaskService(session) + cloned = _make_task( + id="task-clone", + title="Clone Me", + description="Original task body", + status="pending", + jira_id="JIRA-1", + template_id="tmpl-1", + project_id="proj-1", + target_branch="silicon_agent/clone", + yunxiao_task_id="YX-1", + ) + svc.create_task = AsyncMock(return_value=svc._task_to_response(cloned)) + + result = await svc.clone_task("task-source") + + assert result.id == "task-clone" + svc.create_task.assert_awaited_once() + request = svc.create_task.await_args.args[0] + assert isinstance(request, TaskCreateRequest) + assert request.title == "Clone Me" + assert request.description == "Original task body" + assert request.jira_id == "JIRA-1" + assert request.template_id == "tmpl-1" + assert request.project_id == "proj-1" + assert request.yunxiao_task_id == "YX-1" + assert request.target_branch is None + + # --------------------------------------------------------------------------- # get_task (lines 133-136) # --------------------------------------------------------------------------- diff --git a/platform/tests/test_tasks_api.py b/platform/tests/test_tasks_api.py index c1c3cc7..8558dd0 100644 --- a/platform/tests/test_tasks_api.py +++ b/platform/tests/test_tasks_api.py @@ -261,6 +261,80 @@ async def test_get_task_404(client): assert resp.status_code == 404 +@pytest.mark.asyncio +async def test_clone_task_creates_fresh_pending_copy(client, seed_template_with_stages): + """POST /api/v1/tasks/{id}/clone creates a new task without inheriting runtime state.""" + template_id = seed_template_with_stages + create_resp = await client.post("/api/v1/tasks", json={ + "title": "TT Clone Source", + "description": "Clone this task", + "template_id": template_id, + "jira_id": "TT-123", + "project_id": None, + "yunxiao_task_id": "YX-123", + }) + assert create_resp.status_code == 201 + source = create_resp.json() + + async with async_session_factory() as session: + result = await session.execute(select(TaskModel).where(TaskModel.id == source["id"])) + task = result.scalar_one() + task.status = "failed" + task.branch_name = "feature/original" + task.pr_url = "https://example.com/pr/123" + + stage_result = await session.execute( + select(TaskStageModel).where(TaskStageModel.task_id == source["id"]) + ) + stages = stage_result.scalars().all() + stages[0].status = "completed" + stages[1].status = "failed" + stages[1].retry_count = 2 + stages[1].error_message = "compile failed" + await session.commit() + + clone_resp = await client.post(f"/api/v1/tasks/{source['id']}/clone") + assert clone_resp.status_code == 201 + cloned = clone_resp.json() + + assert cloned["id"] != source["id"] + assert cloned["title"] == source["title"] + assert cloned["description"] == source["description"] + assert cloned["jira_id"] == source["jira_id"] + assert cloned["template_id"] == source["template_id"] + assert cloned["yunxiao_task_id"] == source["yunxiao_task_id"] + assert cloned["status"] == "pending" + assert cloned["branch_name"] is None + assert cloned["pr_url"] is None + assert cloned["target_branch"] == f"silicon_agent/{cloned['id'].rsplit('-', 1)[-1]}" + assert len(cloned["stages"]) == 3 + for stage in cloned["stages"]: + assert stage["status"] == "pending" + assert stage["retry_count"] == 0 + assert stage["error_message"] is None + + async with async_session_factory() as session: + stage_result = await session.execute( + select(TaskStageModel).where(TaskStageModel.task_id.in_([source["id"], cloned["id"]])) + ) + for stage in stage_result.scalars().all(): + await session.delete(stage) + + task_result = await session.execute( + select(TaskModel).where(TaskModel.id.in_([source["id"], cloned["id"]])) + ) + for task in task_result.scalars().all(): + await session.delete(task) + await session.commit() + + +@pytest.mark.asyncio +async def test_clone_task_404(client): + """POST /api/v1/tasks/{id}/clone returns 404 for nonexistent task.""" + resp = await client.post("/api/v1/tasks/tt-nonexistent-id/clone") + assert resp.status_code == 404 + + # ── List Tasks Tests ────────────────────────────────────── diff --git a/platform/tests/test_worker.py b/platform/tests/test_worker.py index c4780f5..6cc064a 100644 --- a/platform/tests/test_worker.py +++ b/platform/tests/test_worker.py @@ -1,11 +1,17 @@ """Unit tests for worker engine pure functions.""" import asyncio import json +from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import pytest -from app.worker.engine import _parse_gates, _sort_stages, _build_repo_context +from app.worker.engine import ( + _parse_gates, + _sort_stages, + _build_repo_context, + _build_stage_preflight_summary, +) from app.worker.engine import _safe_broadcast as engine_safe_broadcast from app.worker.executor import _safe_broadcast as executor_safe_broadcast @@ -108,6 +114,66 @@ def test_build_repo_context_default_branch(self): assert "main" in result +class TestBuildStagePreflightSummary: + def test_build_stage_preflight_summary_for_coding(self, tmp_path: Path): + (tmp_path / "build.gradle").write_text("plugins {}", encoding="utf-8") + (tmp_path / "src/main/java/demo/controller").mkdir(parents=True) + (tmp_path / "src/main/java/demo/service").mkdir(parents=True) + (tmp_path / "src/test/java/demo/controller").mkdir(parents=True) + (tmp_path / "src/main/java/demo/controller/HelloController.java").write_text("class X {}", encoding="utf-8") + (tmp_path / "src/main/java/demo/service/HelloService.java").write_text("class S {}", encoding="utf-8") + (tmp_path / "src/test/java/demo/controller/HelloControllerTest.java").write_text("class T {}", encoding="utf-8") + + result = _build_stage_preflight_summary("coding", str(tmp_path)) + + assert result is not None + assert "构建入口" in result + assert "推荐修改落点" in result + assert "最相关实现参考" in result + assert "最相关测试参考" in result + assert "推荐最小验证命令" in result + assert "HelloController.java" in result + assert "./gradlew test" in result + + def test_build_stage_preflight_summary_for_test(self, tmp_path: Path): + (tmp_path / "pom.xml").write_text("<project/>", encoding="utf-8") + (tmp_path / "src/test/java/demo").mkdir(parents=True) + (tmp_path / "src/test/java/demo/controller").mkdir(parents=True) + (tmp_path / "src/main/java/demo/service").mkdir(parents=True) + (tmp_path / "src/main/java/demo/controller").mkdir(parents=True) + (tmp_path / "src/test/java/demo/DemoServiceTest.java").write_text("class T {}", encoding="utf-8") + (tmp_path / "src/test/java/demo/controller/HelloControllerTest.java").write_text("class HC {}", encoding="utf-8") + (tmp_path / "src/main/java/demo/service/DemoService.java").write_text("class S {}", encoding="utf-8") + (tmp_path / "src/main/java/demo/controller/HelloController.java").write_text("class C {}", encoding="utf-8") + + result = _build_stage_preflight_summary("test", str(tmp_path)) + + assert result is not None + assert "构建入口" in result + assert "推荐验证落点" in result + assert "最相关测试参考" in result + assert "对应实现参考" in result + assert "推荐最小验证命令: ./mvnw test" in result + assert "HelloControllerTest.java" in result + + def test_build_stage_preflight_summary_prioritizes_controller_tests(self, tmp_path: Path): + (tmp_path / "build.gradle").write_text("plugins {}", encoding="utf-8") + (tmp_path / "src/test/java/demo/sdk").mkdir(parents=True) + (tmp_path / "src/test/java/demo/controller").mkdir(parents=True) + (tmp_path / "src/main/java/demo/controller").mkdir(parents=True) + (tmp_path / "src/test/java/demo/sdk/TaobaoApiTest.java").write_text("class T {}", encoding="utf-8") + (tmp_path / "src/test/java/demo/controller/HelloControllerTest.java").write_text("class C {}", encoding="utf-8") + (tmp_path / "src/main/java/demo/controller/HelloController.java").write_text("class X {}", encoding="utf-8") + + result = _build_stage_preflight_summary("test", str(tmp_path)) + + assert result is not None + assert result.index("HelloControllerTest.java") < result.index("TaobaoApiTest.java") + + def test_build_stage_preflight_summary_ignores_other_stages(self, tmp_path: Path): + assert _build_stage_preflight_summary("signoff", str(tmp_path)) is None + + class TestSafeBroadcast: @pytest.mark.asyncio async def test_engine_safe_broadcast_swallows_errors(self): diff --git a/web/src/components/ReActTimeline/index.tsx b/web/src/components/ReActTimeline/index.tsx index e976c22..8609489 100644 --- a/web/src/components/ReActTimeline/index.tsx +++ b/web/src/components/ReActTimeline/index.tsx @@ -19,6 +19,11 @@ import './styles.css'; const { Text, Paragraph } = Typography; +export interface TurnBadge { + label: string; + color: string; +} + interface ReActTurn { id: string; turnNumber: number; @@ -34,6 +39,8 @@ interface ReActViewProps { loading?: boolean; } +const MAX_TURNS_SENTINEL = '[Max turns reached. Please continue the conversation.]'; + function getLogContent(log?: TaskLogEvent): string { if (!log || !log.response_body) return ''; const raw = (log.response_body as Record<string, unknown>).content; @@ -56,6 +63,64 @@ function getLogContent(log?: TaskLogEvent): string { return ''; } +function getRecordValue(record: Record<string, unknown> | null | undefined, key: string): unknown { + if (!record) return undefined; + return record[key]; +} + +function getContinuationNumber(log?: TaskLogEvent): number | null { + const requestValue = getRecordValue(log?.request_body, 'continuation'); + const responseValue = getRecordValue(log?.response_body, 'continuation'); + const candidate = requestValue ?? responseValue; + if (typeof candidate === 'number' && Number.isFinite(candidate)) return candidate; + if (typeof candidate === 'string' && candidate.trim()) { + const parsed = Number(candidate); + if (Number.isFinite(parsed)) return parsed; + } + return null; +} + +function hasForcedConvergence(log?: TaskLogEvent): boolean { + return Boolean(getRecordValue(log?.request_body, 'forced_convergence') || getRecordValue(log?.response_body, 'forced_convergence')); +} + +export function getTurnBadges(log?: TaskLogEvent): TurnBadge[] { + if (!log) return []; + + const badges: TurnBadge[] = []; + const continuation = getContinuationNumber(log); + if (continuation != null) { + badges.push({ label: `Continuation #${continuation}`, color: 'blue' }); + } + + if (hasForcedConvergence(log)) { + badges.push({ label: 'Forced Convergence', color: 'gold' }); + } + + return badges; +} + +export function stripMaxTurnsSentinel(text: string): { text: string; truncated: boolean } { + if (!text.includes(MAX_TURNS_SENTINEL)) { + return { text, truncated: false }; + } + + return { + text: text.replace(MAX_TURNS_SENTINEL, '\n').replace(/\n{2,}/g, '\n').trim(), + truncated: true, + }; +} + +export function getThoughtDisplay(log?: TaskLogEvent): { text: string; truncated: boolean; badges: TurnBadge[] } { + const content = getLogContent(log); + const { text, truncated } = stripMaxTurnsSentinel(content); + return { + text, + truncated, + badges: getTurnBadges(log), + }; +} + function parseReActTurns(logs: TaskLogEvent[]): ReActTurn[] { const sorted = [...logs].sort((a, b) => a.event_seq - b.event_seq); const turnsMap = new Map<string, ReActTurn>(); @@ -100,7 +165,21 @@ function parseReActTurns(logs: TaskLogEvent[]): ReActTurn[] { return Array.from(turnsMap.values()); } -const ExpandablePromptBlock: React.FC<{ content: string; title: string; maxHeight?: number }> = ({ content, title, maxHeight = 150 }) => { +const BadgeRow: React.FC<{ badges: TurnBadge[] }> = ({ badges }) => { + if (badges.length === 0) return null; + + return ( + <div className="react-gemini-badge-row"> + {badges.map((badge) => ( + <Tag key={badge.label} color={badge.color} className="react-gemini-badge"> + {badge.label} + </Tag> + ))} + </div> + ); +}; + +const ExpandablePromptBlock: React.FC<{ content: string; title: string; badges?: TurnBadge[]; maxHeight?: number }> = ({ content, title, badges = [], maxHeight = 150 }) => { const [expanded, setExpanded] = useState(false); const [isOverflowing, setIsOverflowing] = useState(false); const contentRef = useRef<HTMLDivElement>(null); @@ -127,6 +206,8 @@ const ExpandablePromptBlock: React.FC<{ content: string; title: string; maxHeigh )} </div> + <BadgeRow badges={badges} /> + <div className="message-bubble user-bubble" style={{ width: '100%', boxSizing: 'border-box' }}> <div style={{ @@ -147,8 +228,10 @@ const ExpandablePromptBlock: React.FC<{ content: string; title: string; maxHeigh const CursorStyleThoughtBlock: React.FC<{ thoughtText: string; durationMs?: number | null; + badges?: TurnBadge[]; + truncated?: boolean; defaultExpanded?: boolean; -}> = ({ thoughtText, durationMs, defaultExpanded = false }) => { +}> = ({ thoughtText, durationMs, badges = [], truncated = false, defaultExpanded = false }) => { const [expanded, setExpanded] = useState(defaultExpanded); // In some cases duration_ms might be tiny or null; default to <1s or omit @@ -164,6 +247,13 @@ const CursorStyleThoughtBlock: React.FC<{ {expanded ? <DownOutlined style={{ fontSize: 10, marginRight: 8 }} /> : <RightOutlined style={{ fontSize: 10, marginRight: 8 }} />} <Text type="secondary">{label}</Text> </div> + <BadgeRow badges={badges} /> + {truncated && ( + <div className="react-gemini-truncation-note"> + <Tag color="default" className="react-gemini-badge">Max turns reached</Tag> + <Text type="secondary">系统已截断当前轮次并请求继续,下面展示的是后续收敛输出。</Text> + </div> + )} {expanded && ( <div className="cursor-thought-content markdown-body message-body"> <ReactMarkdown remarkPlugins={[remarkGfm]}>{thoughtText}</ReactMarkdown> @@ -214,11 +304,13 @@ export const ReActTimeline: React.FC<ReActViewProps> = ({ logs, loading }) => { const isRunning = turn.thought?.status === 'running' || turn.action?.status === 'running' || turn.observation?.status === 'running'; const promptContent = turn.prompt?.request_body?.prompt as string; + const promptBadges = getTurnBadges(turn.prompt); const isLLMRunning = turn.thought_sent?.status === 'running' && !turn.thought; const streamLogId = turn.thought_sent?.id; - let thoughtText = getLogContent(turn.thought); + const thoughtDisplay = getThoughtDisplay(turn.thought); + let thoughtText = thoughtDisplay.text; if (isLLMRunning && streamLogId) { const streamLines = linesByLog[streamLogId]; if (streamLines && streamLines.length > 0) { @@ -226,6 +318,9 @@ export const ReActTimeline: React.FC<ReActViewProps> = ({ logs, loading }) => { } } + const thoughtBadges = getTurnBadges(turn.thought); + const thoughtTruncated = thoughtDisplay.truncated || thoughtText.includes(MAX_TURNS_SENTINEL); + if (thoughtText && thoughtText.includes('<thought>')) { const match = thoughtText.match(/<thought>([\s\S]*?)(?:<\/thought>|$)/); if (match) { @@ -234,7 +329,7 @@ export const ReActTimeline: React.FC<ReActViewProps> = ({ logs, loading }) => { } // AI Response Segment needs to show up if it's currently running, even if text is empty yet - const hasAIActivity = thoughtText || turn.action?.command || isLLMRunning; + const hasAIActivity = thoughtText || thoughtTruncated || thoughtBadges.length > 0 || turn.action?.command || isLLMRunning; const actionCommand = turn.action?.command || turn.observation?.command; const actionArgs = turn.action?.command_args || turn.observation?.command_args; @@ -252,7 +347,7 @@ export const ReActTimeline: React.FC<ReActViewProps> = ({ logs, loading }) => { <div className="react-gemini-message user-message"> <Avatar icon={<UserOutlined />} className="message-avatar" style={{ backgroundColor: '#87d068' }} /> <div className="message-content" style={{ minWidth: 0, width: '100%' }}> - <ExpandablePromptBlock content={promptContent} title="System / User" maxHeight={120} /> + <ExpandablePromptBlock content={promptContent} title="System / User" badges={promptBadges} maxHeight={120} /> </div> </div> )} @@ -265,10 +360,12 @@ export const ReActTimeline: React.FC<ReActViewProps> = ({ logs, loading }) => { <Text strong className="message-author">Silicon Agent</Text> {/* Thought formatted with Cursor-style collapse */} - {thoughtText && ( + {(thoughtText || thoughtTruncated || thoughtBadges.length > 0) && ( <CursorStyleThoughtBlock thoughtText={thoughtText} durationMs={turn.thought?.duration_ms || turn.thought_sent?.duration_ms} + badges={thoughtBadges} + truncated={thoughtTruncated} defaultExpanded={turn.thought?.event_type === 'llm_turn_received'} /> )} diff --git a/web/src/components/ReActTimeline/styles.css b/web/src/components/ReActTimeline/styles.css index 4f61435..74981b5 100644 --- a/web/src/components/ReActTimeline/styles.css +++ b/web/src/components/ReActTimeline/styles.css @@ -31,6 +31,29 @@ color: #1f1f1f; } +.react-gemini-badge-row { + display: flex; + flex-wrap: wrap; + gap: 6px; + margin-bottom: 10px; +} + +.react-gemini-badge { + margin-inline-end: 0 !important; +} + +.react-gemini-truncation-note { + display: flex; + align-items: center; + gap: 8px; + margin: 6px 0 10px; + padding: 8px 10px; + border-radius: 8px; + background: #fafafa; + border: 1px dashed #d9d9d9; + color: #595959; +} + .message-bubble { border-radius: 8px; padding: 12px 16px; @@ -189,4 +212,4 @@ padding-left: 14px; border-left: 2px solid #e8e8e8; margin-left: 5px; -} \ No newline at end of file +} diff --git a/web/tests/ReActTimeline.test.ts b/web/tests/ReActTimeline.test.ts new file mode 100644 index 0000000..3beea40 --- /dev/null +++ b/web/tests/ReActTimeline.test.ts @@ -0,0 +1,80 @@ +import { describe, expect, it } from 'vitest'; +import type { TaskLogEvent } from '@/services/taskLogApi'; +import { getThoughtDisplay, getTurnBadges, stripMaxTurnsSentinel } from '@/components/ReActTimeline'; + +function makeLog(overrides: Partial<TaskLogEvent>): TaskLogEvent { + return { + id: 'log-1', + task_id: 'task-1', + stage_id: 'stage-1', + stage_name: 'coding', + agent_role: 'coding', + correlation_id: 'chat-1', + event_seq: 1, + event_type: 'agent_runner_chat_received', + event_source: 'llm', + status: 'success', + request_body: null, + response_body: null, + command: null, + command_args: null, + workspace: null, + execution_mode: null, + duration_ms: null, + result: null, + output_summary: null, + output_truncated: false, + missing_fields: [], + created_at: '2026-03-19T00:00:00.000Z', + ...overrides, + }; +} + +describe('ReActTimeline helper transforms', () => { + it('extracts continuation and forced convergence badges from task logs', () => { + const log = makeLog({ + request_body: { continuation: 2, forced_convergence: true }, + }); + + expect(getTurnBadges(log)).toEqual([ + { label: 'Continuation #2', color: 'blue' }, + { label: 'Forced Convergence', color: 'gold' }, + ]); + }); + + it('prefers response metadata when request body is absent', () => { + const log = makeLog({ + response_body: { continuation: '1', forced_convergence: true }, + }); + + expect(getTurnBadges(log)).toEqual([ + { label: 'Continuation #1', color: 'blue' }, + { label: 'Forced Convergence', color: 'gold' }, + ]); + }); + + it('strips the max-turn sentinel while preserving surrounding text', () => { + expect( + stripMaxTurnsSentinel( + 'alpha\n[Max turns reached. Please continue the conversation.]\nbeta' + ) + ).toEqual({ + text: 'alpha\nbeta', + truncated: true, + }); + }); + + it('treats sentinel-only thoughts as truncated system notes', () => { + const thought = getThoughtDisplay( + makeLog({ + response_body: { content: '[Max turns reached. Please continue the conversation.]' }, + }) + ); + + expect(thought).toEqual({ + text: '', + truncated: true, + badges: [], + }); + }); +});