diff --git a/src/kimi_cli/web/api/sessions.py b/src/kimi_cli/web/api/sessions.py index 11573715d..6e247d386 100644 --- a/src/kimi_cli/web/api/sessions.py +++ b/src/kimi_cli/web/api/sessions.py @@ -231,6 +231,40 @@ def _read_wire_lines(wire_file: Path) -> list[str]: return result +def _resolve_default_work_dir(http_request: Request) -> tuple[Path, bool]: + """Resolve default work_dir for session creation. + + Prefer app startup directory; fallback to HOME when unavailable. + """ + startup_dir = getattr(http_request.app.state, "startup_dir", None) + if isinstance(startup_dir, str) and startup_dir.strip(): + candidate = Path(startup_dir).expanduser() + try: + resolved = candidate.resolve() + except (OSError, RuntimeError) as exc: + logger.warning( + "Failed to resolve startup_dir, fallback to home: startup_dir={startup_dir}, error={error}", + startup_dir=startup_dir, + error=str(exc), + ) + else: + if resolved.exists() and resolved.is_dir(): + return resolved, False + + reason = "does not exist" + if resolved.exists(): + reason = "not a directory" + logger.warning( + "Invalid startup_dir, fallback to home: startup_dir={startup_dir}, reason={reason}", + startup_dir=str(resolved), + reason=reason, + ) + else: + logger.warning("Missing startup_dir, fallback to home") + + return Path.home(), True + + async def replay_history(ws: WebSocket, session_dir: Path) -> None: """Replay historical wire messages from wire.jsonl to a WebSocket.""" wire_file = session_dir / "wire.jsonl" @@ -296,9 +330,11 @@ async def get_session( @router.post("/", summary="Create a new session") -async def create_session(request: CreateSessionRequest | None = None) -> Session: +async def create_session( + http_request: Request, request: CreateSessionRequest | None = None +) -> Session: """Create a new session.""" - # Use provided work_dir or default to user's home directory + requested_work_dir = request.work_dir if request else None if request and request.work_dir: work_dir_path = Path(request.work_dir).expanduser().resolve() # Validate the directory exists @@ -330,7 +366,14 @@ async def create_session(request: CreateSessionRequest | None = None) -> Session ) work_dir = KaosPath.unsafe_from_local_path(work_dir_path) else: - work_dir = KaosPath.unsafe_from_local_path(Path.home()) + effective_work_dir, used_fallback_home = _resolve_default_work_dir(http_request) + work_dir = KaosPath.unsafe_from_local_path(effective_work_dir) + logger.info( + "Resolved create_session default work_dir: requested_work_dir={requested_work_dir}, effective_work_dir={effective_work_dir}, used_fallback_home={used_fallback_home}", + requested_work_dir=requested_work_dir, + effective_work_dir=str(effective_work_dir), + used_fallback_home=used_fallback_home, + ) kimi_cli_session = await KimiCLISession.create(work_dir=work_dir) context_file = kimi_cli_session.dir / "context.jsonl" invalidate_sessions_cache() diff --git a/tests/web/test_sessions_create_default_workdir.py b/tests/web/test_sessions_create_default_workdir.py new file mode 100644 index 000000000..66ba009bc --- /dev/null +++ b/tests/web/test_sessions_create_default_workdir.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +from pathlib import Path +from unittest.mock import patch + +from starlette.testclient import TestClient + +from kimi_cli.web.app import create_app + + +def _build_client(monkeypatch, tmp_path: Path, startup_dir: Path) -> TestClient: + monkeypatch.setenv("KIMI_SHARE_DIR", str(tmp_path / "share")) + monkeypatch.chdir(startup_dir) + app = create_app(session_token="test-token") + client = TestClient(app) + client.headers.update({"Authorization": "Bearer test-token"}) + return client + + +def test_create_session_without_work_dir_uses_startup_dir(monkeypatch, tmp_path: Path) -> None: + startup_dir = tmp_path / "startup" + startup_dir.mkdir() + + with _build_client(monkeypatch, tmp_path, startup_dir) as client: + # Simulate the existing web call path where request body can be omitted. + response = client.post("/api/sessions/") + + assert response.status_code == 200 + payload = response.json() + assert Path(payload["work_dir"]).resolve() == startup_dir.resolve() + + +def test_create_session_fallbacks_to_home_when_startup_dir_invalid(monkeypatch, tmp_path: Path) -> None: + startup_dir = tmp_path / "startup" + startup_dir.mkdir() + + with _build_client(monkeypatch, tmp_path, startup_dir) as client: + client.app.state.startup_dir = str(tmp_path / "missing-startup-dir") + with patch("kimi_cli.web.api.sessions.logger") as mock_logger: + response = client.post("/api/sessions/") + + assert response.status_code == 200 + payload = response.json() + assert Path(payload["work_dir"]).resolve() == Path.home().resolve() + assert mock_logger.warning.call_count >= 1 + assert any( + call.kwargs.get("used_fallback_home") is True for call in mock_logger.info.call_args_list + ) + + +def test_create_session_explicit_work_dir_kept(monkeypatch, tmp_path: Path) -> None: + startup_dir = tmp_path / "startup" + startup_dir.mkdir() + explicit_dir = tmp_path / "explicit" + explicit_dir.mkdir() + + with _build_client(monkeypatch, tmp_path, startup_dir) as client: + response = client.post("/api/sessions/", json={"work_dir": str(explicit_dir)}) + + assert response.status_code == 200 + payload = response.json() + assert Path(payload["work_dir"]).resolve() == explicit_dir.resolve() + + +def test_session_file_endpoint_resolves_against_created_work_dir(monkeypatch, tmp_path: Path) -> None: + startup_dir = tmp_path / "startup" + project_dir = startup_dir / "project" + project_dir.mkdir(parents=True) + expected_content = "mention target" + (project_dir / "note.txt").write_text(expected_content, encoding="utf-8") + + with _build_client(monkeypatch, tmp_path, startup_dir) as client: + create_response = client.post("/api/sessions/") + assert create_response.status_code == 200 + session_id = create_response.json()["session_id"] + + file_response = client.get(f"/api/sessions/{session_id}/files/project/note.txt") + + assert file_response.status_code == 200 + assert file_response.content.decode("utf-8") == expected_content