From f7985c26210b0642df167acf3f3cd922ea2c073c Mon Sep 17 00:00:00 2001 From: bloodycoder Date: Sun, 12 Apr 2026 01:21:02 +0800 Subject: [PATCH] fix(shell): normalize timeout_s alias for shell/acp timeouts --- src/kimi_cli/acp/tools.py | 4 +- src/kimi_cli/tools/shell/__init__.py | 27 +++++++++ tests/acp/test_terminal_timeout.py | 90 ++++++++++++++++++++++++++++ tests/tools/test_background_tools.py | 43 ++++++++++++- tests/tools/test_shell_bash.py | 36 +++++++++++ 5 files changed, 197 insertions(+), 3 deletions(-) create mode 100644 tests/acp/test_terminal_timeout.py diff --git a/src/kimi_cli/acp/tools.py b/src/kimi_cli/acp/tools.py index 055c9edb3..235411b2f 100644 --- a/src/kimi_cli/acp/tools.py +++ b/src/kimi_cli/acp/tools.py @@ -79,8 +79,8 @@ async def __call__(self, params: ShellParams) -> ToolReturnValue: if not approval_result: return approval_result.rejection_error() - timeout_seconds = float(params.timeout) - timeout_label = f"{timeout_seconds:g}s" + timeout_seconds = params.timeout + timeout_label = f"{timeout_seconds}s" terminal_id: str | None = None exit_status: ( acp.schema.WaitForTerminalExitResponse | acp.schema.TerminalExitStatus | None diff --git a/src/kimi_cli/tools/shell/__init__.py b/src/kimi_cli/tools/shell/__init__.py index 6a7f6b687..f4b50db59 100644 --- a/src/kimi_cli/tools/shell/__init__.py +++ b/src/kimi_cli/tools/shell/__init__.py @@ -1,6 +1,7 @@ import asyncio from collections.abc import Callable from pathlib import Path +from typing import Any from typing import Self, override import kaos @@ -20,6 +21,8 @@ MAX_FOREGROUND_TIMEOUT = 5 * 60 MAX_BACKGROUND_TIMEOUT = 24 * 60 * 60 +TIMEOUT_ALIAS_FIELD = "timeout_s" +UNSUPPORTED_TIMEOUT_FIELDS = frozenset({"timeout_ms", "timeoutSeconds"}) class Params(BaseModel): @@ -44,6 +47,30 @@ class Params(BaseModel): ), ) + @model_validator(mode="before") + @classmethod + def _normalize_timeout_fields(cls, data: Any) -> Any: + if not isinstance(data, dict): + return data + + values = dict(data) + + unsupported = sorted(k for k in UNSUPPORTED_TIMEOUT_FIELDS if k in values) + if unsupported: + raise ValueError( + f"Unsupported timeout field(s): {', '.join(unsupported)}. " + f"Use `timeout` or `{TIMEOUT_ALIAS_FIELD}` in seconds." + ) + + has_timeout = "timeout" in values + has_timeout_alias = TIMEOUT_ALIAS_FIELD in values + if has_timeout and has_timeout_alias and values["timeout"] != values[TIMEOUT_ALIAS_FIELD]: + raise ValueError("`timeout` and `timeout_s` must match when both are provided") + if not has_timeout and has_timeout_alias: + values["timeout"] = values[TIMEOUT_ALIAS_FIELD] + + return values + @model_validator(mode="after") def _validate_background_fields(self) -> Self: if self.run_in_background and not self.description.strip(): diff --git a/tests/acp/test_terminal_timeout.py b/tests/acp/test_terminal_timeout.py new file mode 100644 index 000000000..00dd17b00 --- /dev/null +++ b/tests/acp/test_terminal_timeout.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +from kimi_cli.acp.tools import Terminal +from kimi_cli.tools.shell import Params as ShellParams + +pytestmark = pytest.mark.asyncio + + +class _FakeACPConn: + def __init__(self, *, timeout_error: bool = False): + self.timeout_error = timeout_error + self.killed_terminal = False + self.released_terminal = False + + async def create_terminal(self, **_kwargs): + return SimpleNamespace(terminal_id="term-1") + + async def session_update(self, **_kwargs): + return None + + async def wait_for_terminal_exit(self, **_kwargs): + if self.timeout_error: + raise TimeoutError + return SimpleNamespace(exit_code=0, signal=None) + + async def kill_terminal(self, **_kwargs): + self.killed_terminal = True + return None + + async def terminal_output(self, **_kwargs): + return SimpleNamespace( + output="", + truncated=False, + exit_status=None, + ) + + async def release_terminal(self, **_kwargs): + self.released_terminal = True + return None + + +async def test_terminal_uses_effective_timeout_from_alias(shell_tool, approval, monkeypatch): + captured: dict[str, float] = {} + + class _FakeTimeoutCtx: + def __init__(self, seconds: float): + captured["seconds"] = seconds + + async def __aenter__(self): + return None + + async def __aexit__(self, _exc_type, _exc, _tb): + return False + + monkeypatch.setattr( + "kimi_cli.acp.session.get_current_acp_tool_call_id_or_none", + lambda: "turn-1/test-tool-call", + ) + monkeypatch.setattr("kimi_cli.acp.tools.asyncio.timeout", _FakeTimeoutCtx) + + acp_conn = _FakeACPConn() + terminal = Terminal(shell_tool, acp_conn, "session-1", approval) + + result = await terminal(ShellParams.model_validate({"command": "echo ok", "timeout_s": 9})) + + assert not result.is_error + assert captured["seconds"] == 9 + assert acp_conn.released_terminal + + +async def test_terminal_timeout_message_matches_effective_timeout(shell_tool, approval, monkeypatch): + monkeypatch.setattr( + "kimi_cli.acp.session.get_current_acp_tool_call_id_or_none", + lambda: "turn-1/test-tool-call", + ) + + acp_conn = _FakeACPConn(timeout_error=True) + terminal = Terminal(shell_tool, acp_conn, "session-1", approval) + + result = await terminal(ShellParams.model_validate({"command": "sleep 10", "timeout_s": 7})) + + assert result.is_error + assert result.message == "Command killed by timeout (7s)" + assert result.brief == "Killed by timeout (7s)" + assert acp_conn.killed_terminal + assert acp_conn.released_terminal diff --git a/tests/tools/test_background_tools.py b/tests/tools/test_background_tools.py index 75cc79b55..29dbb5e52 100644 --- a/tests/tools/test_background_tools.py +++ b/tests/tools/test_background_tools.py @@ -4,7 +4,7 @@ import pytest -from kimi_cli.background import TaskRuntime, TaskSpec, TaskStatus +from kimi_cli.background import TaskConsumerState, TaskControl, TaskRuntime, TaskSpec, TaskStatus, TaskView from kimi_cli.tools.shell import Params @@ -91,6 +91,47 @@ async def test_shell_background_starts_task(shell_tool, runtime, monkeypatch): assert "/task list" in result.output +@pytest.mark.asyncio +async def test_shell_background_timeout_alias_propagates(shell_tool, runtime, monkeypatch): + captured: dict[str, object] = {} + + def fake_create_bash_task(**kwargs): + captured.update(kwargs) + return TaskView( + spec=TaskSpec( + id="b1234567", + kind="bash", + session_id=runtime.session.id, + description="sleep task", + tool_call_id="test", + command="sleep 1", + shell_name="bash", + shell_path="/bin/bash", + cwd=str(runtime.session.work_dir), + timeout_s=kwargs["timeout_s"], + ), + runtime=TaskRuntime(status="starting"), + control=TaskControl(), + consumer=TaskConsumerState(), + ) + + monkeypatch.setattr(runtime.background_tasks, "create_bash_task", fake_create_bash_task) + + result = await shell_tool( + Params.model_validate( + { + "command": "sleep 1", + "timeout_s": 123, + "run_in_background": True, + "description": "sleep task", + } + ) + ) + + assert not result.is_error + assert captured["timeout_s"] == 123 + + @pytest.mark.asyncio async def test_shell_background_requires_description(shell_tool): with pytest.raises(ValueError, match="description"): diff --git a/tests/tools/test_shell_bash.py b/tests/tools/test_shell_bash.py index aa91cad01..a9c56f811 100644 --- a/tests/tools/test_shell_bash.py +++ b/tests/tools/test_shell_bash.py @@ -8,6 +8,7 @@ import pytest from inline_snapshot import snapshot from kaos.path import KaosPath +from kosong.tooling.error import ToolValidateError from kimi_cli.tools.shell import Params, Shell from kimi_cli.tools.utils import DEFAULT_MAX_CHARS @@ -97,6 +98,41 @@ async def test_command_timeout_expires(shell_tool: Shell): assert result.brief == snapshot("Killed by timeout (1s)") +async def test_timeout_alias_maps_to_timeout_value(): + params = Params.model_validate({"command": "echo test", "timeout_s": 7}) + assert params.timeout == 7 + + +async def test_timeout_alias_drives_foreground_timeout_message( + shell_tool: Shell, monkeypatch: pytest.MonkeyPatch +): + async def fake_run(*_args, **_kwargs) -> int: + raise TimeoutError + + monkeypatch.setattr(shell_tool, "_run_shell_command", fake_run) + result = await shell_tool(Params.model_validate({"command": "echo test", "timeout_s": 7})) + + assert result.is_error + assert result.message == "Command killed by timeout (7s)" + assert result.brief == "Killed by timeout (7s)" + + +async def test_timeout_conflict_returns_tool_validate_error(shell_tool: Shell): + ret = await shell_tool.call({"command": "echo test", "timeout": 2, "timeout_s": 3}) + assert isinstance(ret, ToolValidateError) + + +async def test_timeout_unknown_alias_returns_tool_validate_error(shell_tool: Shell): + ret = await shell_tool.call({"command": "echo test", "timeout_ms": 1500}) + assert isinstance(ret, ToolValidateError) + assert "timeout_ms" in ret.message + + +async def test_timeout_invalid_type_returns_tool_validate_error(shell_tool: Shell): + ret = await shell_tool.call({"command": "echo test", "timeout": "abc"}) + assert isinstance(ret, ToolValidateError) + + async def test_environment_variables(shell_tool: Shell): """Test setting and using environment variables.""" result = await shell_tool(Params(command="export TEST_VAR='test_value' && echo $TEST_VAR"))