Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions miles/rollout/session/linear_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ class LinearTrajectory:

lock: asyncio.Lock = field(default_factory=asyncio.Lock, repr=False, compare=False)
closing: bool = field(default=False, repr=False, compare=False)
# Per-session in-flight gate: set under self.lock when a chat claims the
# session, cleared by that same request on every exit path.
chat_inflight: bool = field(default=False, repr=False, compare=False)
messages: list[dict[str, Any]] = field(default_factory=list)
records: list[SessionRecord] = field(default_factory=list)
trajectory_token_ids: list[list[int]] = field(default_factory=list)
Expand Down
21 changes: 20 additions & 1 deletion miles/rollout/session/session_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
├── SessionNotFoundError → 404 session does not exist
├── MessageValidationError → 400 messages structure/content invalid
├── TokenizationError → 500 TITO tokenizer / prefix mismatch
└── UpstreamResponseError → 502 SGLang response invalid or unexpected
├── UpstreamResponseError → 502 SGLang response invalid or unexpected
├── SessionBusyError → 409 session already has an in-flight chat
└── SessionInvariantError → 500 unreachable session-state invariant violated
"""


Expand Down Expand Up @@ -49,3 +51,20 @@ class UpstreamResponseError(SessionError):
"""

status_code: int = 502


class SessionBusyError(SessionError):
"""Raised when the session already has an in-flight chat completion.

One linear trajectory admits one in-flight chat at a time.
"""

status_code: int = 409


class SessionInvariantError(SessionError):
"""Raised when a session-state invariant that should be unreachable under
the in-flight gate is violated (defensive; indicates a real bug).
"""

status_code: int = 500
29 changes: 19 additions & 10 deletions miles/rollout/session/session_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@

import json
import logging
import os
from concurrent.futures import ThreadPoolExecutor

import httpx
import setproctitle
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from starlette.responses import Response

from miles.rollout.session.sessions import setup_session_routes
Expand All @@ -38,6 +39,12 @@ def __init__(self, args, backend_url: str):
# Close the httpx connection pool when uvicorn shuts down to avoid FD leaks.
self.app.router.on_shutdown.append(self.client.aclose)

self.cpu_executor = ThreadPoolExecutor(
max_workers=getattr(args, "session_server_cpu_workers", None) or min(16, os.cpu_count() or 1),
thread_name_prefix="session-cpu",
)
self.app.router.on_shutdown.append(lambda: self.cpu_executor.shutdown(wait=False, cancel_futures=True))

setup_session_routes(self.app, self, args)

async def do_proxy(
Expand Down Expand Up @@ -79,21 +86,20 @@ async def do_proxy(
}

def build_proxy_response(self, result: dict) -> Response:
content = result["response_body"]
status_code = result["status_code"]
# Drop wire-level framing headers from upstream so Starlette rebuilds them
# from the body we actually send: transfer-encoding is hop-by-hop
# httpx already decoded the body, so upstream content-encoding/length are
# stale framing headers; drop them and let Starlette rebuild from the body.
headers = {
k: v
for k, v in result["headers"].items()
if k.lower() not in ("content-length", "transfer-encoding", "content-encoding")
}
content_type = headers.get("content-type", "")
try:
data = json.loads(content)
return JSONResponse(content=data, status_code=status_code, headers=headers)
except (json.JSONDecodeError, UnicodeDecodeError):
return Response(content=content, status_code=status_code, headers=headers, media_type=content_type)
return Response(
content=result["response_body"],
status_code=result["status_code"],
headers=headers,
media_type=content_type,
)


def run_session_server(args, backend_url: str):
Expand All @@ -108,4 +114,7 @@ def run_session_server(args, backend_url: str):
args.session_server_port,
backend_url,
)
# Single uvicorn worker on purpose: extra workers would each own a separate
# SessionRegistry + asyncio.Lock, so a session_id could land on a process that
# doesn't own it. Multi-process needs sticky session ownership and is deferred.
uvicorn.run(server.app, host=args.session_server_ip, port=args.session_server_port, log_level="info")
Loading
Loading