Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 46 additions & 3 deletions src/kimi_cli/web/api/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,40 @@ def _read_wire_lines(wire_file: Path) -> list[str]:
return result


def _resolve_default_work_dir(http_request: Request) -> tuple[Path, bool]:
"""Resolve default work_dir for session creation.

Prefer app startup directory; fallback to HOME when unavailable.
"""
startup_dir = getattr(http_request.app.state, "startup_dir", None)
if isinstance(startup_dir, str) and startup_dir.strip():
candidate = Path(startup_dir).expanduser()
try:
resolved = candidate.resolve()
except (OSError, RuntimeError) as exc:
logger.warning(
"Failed to resolve startup_dir, fallback to home: startup_dir={startup_dir}, error={error}",
startup_dir=startup_dir,
error=str(exc),
)
else:
if resolved.exists() and resolved.is_dir():
return resolved, False

reason = "does not exist"
if resolved.exists():
reason = "not a directory"
logger.warning(
"Invalid startup_dir, fallback to home: startup_dir={startup_dir}, reason={reason}",
startup_dir=str(resolved),
reason=reason,
)
else:
logger.warning("Missing startup_dir, fallback to home")

return Path.home(), True


async def replay_history(ws: WebSocket, session_dir: Path) -> None:
"""Replay historical wire messages from wire.jsonl to a WebSocket."""
wire_file = session_dir / "wire.jsonl"
Expand Down Expand Up @@ -296,9 +330,11 @@ async def get_session(


@router.post("/", summary="Create a new session")
async def create_session(request: CreateSessionRequest | None = None) -> Session:
async def create_session(
http_request: Request, request: CreateSessionRequest | None = None
) -> Session:
"""Create a new session."""
# Use provided work_dir or default to user's home directory
requested_work_dir = request.work_dir if request else None
if request and request.work_dir:
work_dir_path = Path(request.work_dir).expanduser().resolve()
# Validate the directory exists
Expand Down Expand Up @@ -330,7 +366,14 @@ async def create_session(request: CreateSessionRequest | None = None) -> Session
)
work_dir = KaosPath.unsafe_from_local_path(work_dir_path)
else:
work_dir = KaosPath.unsafe_from_local_path(Path.home())
effective_work_dir, used_fallback_home = _resolve_default_work_dir(http_request)
work_dir = KaosPath.unsafe_from_local_path(effective_work_dir)
logger.info(
"Resolved create_session default work_dir",
requested_work_dir=requested_work_dir,
effective_work_dir=str(effective_work_dir),
used_fallback_home=used_fallback_home,
)
kimi_cli_session = await KimiCLISession.create(work_dir=work_dir)
context_file = kimi_cli_session.dir / "context.jsonl"
invalidate_sessions_cache()
Expand Down
80 changes: 80 additions & 0 deletions tests/web/test_sessions_create_default_workdir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from __future__ import annotations

from pathlib import Path
from unittest.mock import patch

from starlette.testclient import TestClient

from kimi_cli.web.app import create_app


def _build_client(monkeypatch, tmp_path: Path, startup_dir: Path) -> TestClient:
monkeypatch.setenv("KIMI_SHARE_DIR", str(tmp_path / "share"))
monkeypatch.chdir(startup_dir)
app = create_app(session_token="test-token")
client = TestClient(app)
client.headers.update({"Authorization": "Bearer test-token"})
return client


def test_create_session_without_work_dir_uses_startup_dir(monkeypatch, tmp_path: Path) -> None:
startup_dir = tmp_path / "startup"
startup_dir.mkdir()

with _build_client(monkeypatch, tmp_path, startup_dir) as client:
# Simulate the existing web call path where request body can be omitted.
response = client.post("/api/sessions/")

assert response.status_code == 200
payload = response.json()
assert Path(payload["work_dir"]).resolve() == startup_dir.resolve()


def test_create_session_fallbacks_to_home_when_startup_dir_invalid(monkeypatch, tmp_path: Path) -> None:
startup_dir = tmp_path / "startup"
startup_dir.mkdir()

with _build_client(monkeypatch, tmp_path, startup_dir) as client:
client.app.state.startup_dir = str(tmp_path / "missing-startup-dir")
with patch("kimi_cli.web.api.sessions.logger") as mock_logger:
response = client.post("/api/sessions/")

assert response.status_code == 200
payload = response.json()
assert Path(payload["work_dir"]).resolve() == Path.home().resolve()
assert mock_logger.warning.call_count >= 1
assert any(
call.kwargs.get("used_fallback_home") is True for call in mock_logger.info.call_args_list
)


def test_create_session_explicit_work_dir_kept(monkeypatch, tmp_path: Path) -> None:
startup_dir = tmp_path / "startup"
startup_dir.mkdir()
explicit_dir = tmp_path / "explicit"
explicit_dir.mkdir()

with _build_client(monkeypatch, tmp_path, startup_dir) as client:
response = client.post("/api/sessions/", json={"work_dir": str(explicit_dir)})

assert response.status_code == 200
payload = response.json()
assert Path(payload["work_dir"]).resolve() == explicit_dir.resolve()


def test_session_file_endpoint_resolves_against_created_work_dir(monkeypatch, tmp_path: Path) -> None:
startup_dir = tmp_path / "startup"
project_dir = startup_dir / "project"
project_dir.mkdir(parents=True)
expected_content = "mention target"
(project_dir / "note.txt").write_text(expected_content, encoding="utf-8")

with _build_client(monkeypatch, tmp_path, startup_dir) as client:
create_response = client.post("/api/sessions/")
assert create_response.status_code == 200
session_id = create_response.json()["session_id"]

file_response = client.get(f"/api/sessions/{session_id}/files/project/note.txt")

assert file_response.status_code == 200
assert file_response.content.decode("utf-8") == expected_content
Loading