Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/kimi_cli/acp/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ async def __call__(self, params: ShellParams) -> ToolReturnValue:
if not approval_result:
return approval_result.rejection_error()

timeout_seconds = float(params.timeout)
timeout_label = f"{timeout_seconds:g}s"
timeout_seconds = params.timeout
timeout_label = f"{timeout_seconds}s"
terminal_id: str | None = None
exit_status: (
acp.schema.WaitForTerminalExitResponse | acp.schema.TerminalExitStatus | None
Expand Down
27 changes: 27 additions & 0 deletions src/kimi_cli/tools/shell/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from collections.abc import Callable
from pathlib import Path
from typing import Any
from typing import Self, override

import kaos
Expand All @@ -20,6 +21,8 @@

MAX_FOREGROUND_TIMEOUT = 5 * 60
MAX_BACKGROUND_TIMEOUT = 24 * 60 * 60
TIMEOUT_ALIAS_FIELD = "timeout_s"
UNSUPPORTED_TIMEOUT_FIELDS = frozenset({"timeout_ms", "timeoutSeconds"})


class Params(BaseModel):
Expand All @@ -44,6 +47,30 @@ class Params(BaseModel):
),
)

@model_validator(mode="before")
@classmethod
def _normalize_timeout_fields(cls, data: Any) -> Any:
if not isinstance(data, dict):
return data

values = dict(data)

unsupported = sorted(k for k in UNSUPPORTED_TIMEOUT_FIELDS if k in values)
if unsupported:
raise ValueError(
f"Unsupported timeout field(s): {', '.join(unsupported)}. "
f"Use `timeout` or `{TIMEOUT_ALIAS_FIELD}` in seconds."
)

has_timeout = "timeout" in values
has_timeout_alias = TIMEOUT_ALIAS_FIELD in values
if has_timeout and has_timeout_alias and values["timeout"] != values[TIMEOUT_ALIAS_FIELD]:
raise ValueError("`timeout` and `timeout_s` must match when both are provided")
if not has_timeout and has_timeout_alias:
values["timeout"] = values[TIMEOUT_ALIAS_FIELD]

return values

@model_validator(mode="after")
def _validate_background_fields(self) -> Self:
if self.run_in_background and not self.description.strip():
Expand Down
90 changes: 90 additions & 0 deletions tests/acp/test_terminal_timeout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from __future__ import annotations

from types import SimpleNamespace

import pytest

from kimi_cli.acp.tools import Terminal
from kimi_cli.tools.shell import Params as ShellParams

pytestmark = pytest.mark.asyncio


class _FakeACPConn:
def __init__(self, *, timeout_error: bool = False):
self.timeout_error = timeout_error
self.killed_terminal = False
self.released_terminal = False

async def create_terminal(self, **_kwargs):
return SimpleNamespace(terminal_id="term-1")

async def session_update(self, **_kwargs):
return None

async def wait_for_terminal_exit(self, **_kwargs):
if self.timeout_error:
raise TimeoutError
return SimpleNamespace(exit_code=0, signal=None)

async def kill_terminal(self, **_kwargs):
self.killed_terminal = True
return None

async def terminal_output(self, **_kwargs):
return SimpleNamespace(
output="",
truncated=False,
exit_status=None,
)

async def release_terminal(self, **_kwargs):
self.released_terminal = True
return None


async def test_terminal_uses_effective_timeout_from_alias(shell_tool, approval, monkeypatch):
captured: dict[str, float] = {}

class _FakeTimeoutCtx:
def __init__(self, seconds: float):
captured["seconds"] = seconds

async def __aenter__(self):
return None

async def __aexit__(self, _exc_type, _exc, _tb):
return False

monkeypatch.setattr(
"kimi_cli.acp.session.get_current_acp_tool_call_id_or_none",
lambda: "turn-1/test-tool-call",
)
monkeypatch.setattr("kimi_cli.acp.tools.asyncio.timeout", _FakeTimeoutCtx)

acp_conn = _FakeACPConn()
terminal = Terminal(shell_tool, acp_conn, "session-1", approval)

result = await terminal(ShellParams.model_validate({"command": "echo ok", "timeout_s": 9}))

assert not result.is_error
assert captured["seconds"] == 9
assert acp_conn.released_terminal


async def test_terminal_timeout_message_matches_effective_timeout(shell_tool, approval, monkeypatch):
monkeypatch.setattr(
"kimi_cli.acp.session.get_current_acp_tool_call_id_or_none",
lambda: "turn-1/test-tool-call",
)

acp_conn = _FakeACPConn(timeout_error=True)
terminal = Terminal(shell_tool, acp_conn, "session-1", approval)

result = await terminal(ShellParams.model_validate({"command": "sleep 10", "timeout_s": 7}))

assert result.is_error
assert result.message == "Command killed by timeout (7s)"
assert result.brief == "Killed by timeout (7s)"
assert acp_conn.killed_terminal
assert acp_conn.released_terminal
43 changes: 42 additions & 1 deletion tests/tools/test_background_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest

from kimi_cli.background import TaskRuntime, TaskSpec, TaskStatus
from kimi_cli.background import TaskConsumerState, TaskControl, TaskRuntime, TaskSpec, TaskStatus, TaskView
from kimi_cli.tools.shell import Params


Expand Down Expand Up @@ -91,6 +91,47 @@ async def test_shell_background_starts_task(shell_tool, runtime, monkeypatch):
assert "/task list" in result.output


@pytest.mark.asyncio
async def test_shell_background_timeout_alias_propagates(shell_tool, runtime, monkeypatch):
captured: dict[str, object] = {}

def fake_create_bash_task(**kwargs):
captured.update(kwargs)
return TaskView(
spec=TaskSpec(
id="b1234567",
kind="bash",
session_id=runtime.session.id,
description="sleep task",
tool_call_id="test",
command="sleep 1",
shell_name="bash",
shell_path="/bin/bash",
cwd=str(runtime.session.work_dir),
timeout_s=kwargs["timeout_s"],
),
runtime=TaskRuntime(status="starting"),
control=TaskControl(),
consumer=TaskConsumerState(),
)

monkeypatch.setattr(runtime.background_tasks, "create_bash_task", fake_create_bash_task)

result = await shell_tool(
Params.model_validate(
{
"command": "sleep 1",
"timeout_s": 123,
"run_in_background": True,
"description": "sleep task",
}
)
)

assert not result.is_error
assert captured["timeout_s"] == 123


@pytest.mark.asyncio
async def test_shell_background_requires_description(shell_tool):
with pytest.raises(ValueError, match="description"):
Expand Down
36 changes: 36 additions & 0 deletions tests/tools/test_shell_bash.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest
from inline_snapshot import snapshot
from kaos.path import KaosPath
from kosong.tooling.error import ToolValidateError

from kimi_cli.tools.shell import Params, Shell
from kimi_cli.tools.utils import DEFAULT_MAX_CHARS
Expand Down Expand Up @@ -97,6 +98,41 @@ async def test_command_timeout_expires(shell_tool: Shell):
assert result.brief == snapshot("Killed by timeout (1s)")


async def test_timeout_alias_maps_to_timeout_value():
params = Params.model_validate({"command": "echo test", "timeout_s": 7})
assert params.timeout == 7


async def test_timeout_alias_drives_foreground_timeout_message(
shell_tool: Shell, monkeypatch: pytest.MonkeyPatch
):
async def fake_run(*_args, **_kwargs) -> int:
raise TimeoutError

monkeypatch.setattr(shell_tool, "_run_shell_command", fake_run)
result = await shell_tool(Params.model_validate({"command": "echo test", "timeout_s": 7}))

assert result.is_error
assert result.message == "Command killed by timeout (7s)"
assert result.brief == "Killed by timeout (7s)"


async def test_timeout_conflict_returns_tool_validate_error(shell_tool: Shell):
ret = await shell_tool.call({"command": "echo test", "timeout": 2, "timeout_s": 3})
assert isinstance(ret, ToolValidateError)


async def test_timeout_unknown_alias_returns_tool_validate_error(shell_tool: Shell):
ret = await shell_tool.call({"command": "echo test", "timeout_ms": 1500})
assert isinstance(ret, ToolValidateError)
assert "timeout_ms" in ret.message


async def test_timeout_invalid_type_returns_tool_validate_error(shell_tool: Shell):
ret = await shell_tool.call({"command": "echo test", "timeout": "abc"})
assert isinstance(ret, ToolValidateError)


async def test_environment_variables(shell_tool: Shell):
"""Test setting and using environment variables."""
result = await shell_tool(Params(command="export TEST_VAR='test_value' && echo $TEST_VAR"))
Expand Down
Loading