From 3b6734c72d34b589532987a95447491de2bbfdd9 Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Mon, 22 Jun 2026 20:44:25 +0000 Subject: [PATCH 01/11] feat(session): per-session in-flight gate, response passthrough, bounded CPU thread pool Enforce strict per-session linearity on the single-worker session server while letting independent sessions run concurrently: - In-flight gate: each LinearTrajectory admits one in-flight chat; a second concurrent same-session chat fast-fails 409 (SessionBusyError) at slot-claim, without entering SGLang. chat_completions is restructured into brief lock segments (claim / prepare / commit) with a claimed-guarded, cancellation-safe finally that releases the slot on every exit path; closing (404) beats busy. - expected_num_assistant mismatch is now a logged 500 SessionInvariantError (unreachable under the gate) instead of a silent 200 skip. - build_proxy_response passes the upstream body through unchanged (no second parse / re-serialize), preserving content-type and stripping stale framing headers; applies to all call sites. - Bounded ThreadPoolExecutor (--session-server-cpu-workers, default min(16, os.cpu_count() or 1)) offloads only stateless CPU work (request/response JSON, validation) off the event loop; all session-state mutation stays on the event loop under session.lock. Shut down on app shutdown. - Removed the DEBUG _inflight_chat counter and debug_request_logger middleware. - Single uvicorn worker preserved; multi-process and orjson deferred (documented). Tests: rewrote the same-session concurrency test to the 409 contract; added slot-release-after-error and passthrough-fidelity coverage. 62 passed. Co-Authored-By: Claude Opus 4.8 (1M context) --- miles/rollout/session/linear_trajectory.py | 3 + miles/rollout/session/session_errors.py | 21 +- miles/rollout/session/session_server.py | 29 +- miles/rollout/session/sessions.py | 204 ++++++------ miles/utils/arguments.py | 9 + .../router/test_session_race_conditions.py | 290 +++++++++++++----- 6 files changed, 367 insertions(+), 189 deletions(-) diff --git a/miles/rollout/session/linear_trajectory.py b/miles/rollout/session/linear_trajectory.py index c9d7694479..4d4e608f77 100644 --- a/miles/rollout/session/linear_trajectory.py +++ b/miles/rollout/session/linear_trajectory.py @@ -31,6 +31,9 @@ class LinearTrajectory: lock: asyncio.Lock = field(default_factory=asyncio.Lock, repr=False, compare=False) closing: bool = field(default=False, repr=False, compare=False) + # Per-session in-flight gate: set under self.lock when a chat claims the + # session, cleared by that same request on every exit path. + chat_inflight: bool = field(default=False, repr=False, compare=False) messages: list[dict[str, Any]] = field(default_factory=list) records: list[SessionRecord] = field(default_factory=list) trajectory_token_ids: list[list[int]] = field(default_factory=list) diff --git a/miles/rollout/session/session_errors.py b/miles/rollout/session/session_errors.py index 30e6784384..cdab3d060a 100644 --- a/miles/rollout/session/session_errors.py +++ b/miles/rollout/session/session_errors.py @@ -6,7 +6,9 @@ ├── SessionNotFoundError → 404 session does not exist ├── MessageValidationError → 400 messages structure/content invalid ├── TokenizationError → 500 TITO tokenizer / prefix mismatch -└── UpstreamResponseError → 502 SGLang response invalid or unexpected +├── UpstreamResponseError → 502 SGLang response invalid or unexpected +├── SessionBusyError → 409 session already has an in-flight chat +└── SessionInvariantError → 500 unreachable session-state invariant violated """ @@ -49,3 +51,20 @@ class UpstreamResponseError(SessionError): """ status_code: int = 502 + + +class SessionBusyError(SessionError): + """Raised when the session already has an in-flight chat completion. + + One linear trajectory admits one in-flight chat at a time. + """ + + status_code: int = 409 + + +class SessionInvariantError(SessionError): + """Raised when a session-state invariant that should be unreachable under + the in-flight gate is violated (defensive; indicates a real bug). + """ + + status_code: int = 500 diff --git a/miles/rollout/session/session_server.py b/miles/rollout/session/session_server.py index ea64ace826..86a41ab950 100644 --- a/miles/rollout/session/session_server.py +++ b/miles/rollout/session/session_server.py @@ -8,12 +8,13 @@ import json import logging +import os +from concurrent.futures import ThreadPoolExecutor import httpx import setproctitle import uvicorn from fastapi import FastAPI, Request -from fastapi.responses import JSONResponse from starlette.responses import Response from miles.rollout.session.sessions import setup_session_routes @@ -38,6 +39,12 @@ def __init__(self, args, backend_url: str): # Close the httpx connection pool when uvicorn shuts down to avoid FD leaks. self.app.router.on_shutdown.append(self.client.aclose) + self.cpu_executor = ThreadPoolExecutor( + max_workers=getattr(args, "session_server_cpu_workers", None) or min(16, os.cpu_count() or 1), + thread_name_prefix="session-cpu", + ) + self.app.router.on_shutdown.append(lambda: self.cpu_executor.shutdown(wait=False, cancel_futures=True)) + setup_session_routes(self.app, self, args) async def do_proxy( @@ -79,21 +86,20 @@ async def do_proxy( } def build_proxy_response(self, result: dict) -> Response: - content = result["response_body"] - status_code = result["status_code"] - # Drop wire-level framing headers from upstream so Starlette rebuilds them - # from the body we actually send: transfer-encoding is hop-by-hop + # httpx already decoded the body, so upstream content-encoding/length are + # stale framing headers; drop them and let Starlette rebuild from the body. headers = { k: v for k, v in result["headers"].items() if k.lower() not in ("content-length", "transfer-encoding", "content-encoding") } content_type = headers.get("content-type", "") - try: - data = json.loads(content) - return JSONResponse(content=data, status_code=status_code, headers=headers) - except (json.JSONDecodeError, UnicodeDecodeError): - return Response(content=content, status_code=status_code, headers=headers, media_type=content_type) + return Response( + content=result["response_body"], + status_code=result["status_code"], + headers=headers, + media_type=content_type, + ) def run_session_server(args, backend_url: str): @@ -108,4 +114,7 @@ def run_session_server(args, backend_url: str): args.session_server_port, backend_url, ) + # Single uvicorn worker on purpose: extra workers would each own a separate + # SessionRegistry + asyncio.Lock, so a session_id could land on a process that + # doesn't own it. Multi-process needs sticky session ownership and is deferred. uvicorn.run(server.app, host=args.session_server_ip, port=args.session_server_port, log_level="info") diff --git a/miles/rollout/session/sessions.py b/miles/rollout/session/sessions.py index 243ebd2746..054d98ec5a 100644 --- a/miles/rollout/session/sessions.py +++ b/miles/rollout/session/sessions.py @@ -1,3 +1,4 @@ +import asyncio import json import logging import time @@ -8,7 +9,9 @@ from miles.rollout.session.linear_trajectory import SessionRegistry from miles.rollout.session.session_errors import ( + SessionBusyError, SessionError, + SessionInvariantError, SessionNotFoundError, TokenizationError, UpstreamResponseError, @@ -20,6 +23,46 @@ logger = logging.getLogger(__name__) +def _parse_request_body(body: bytes) -> dict: + return json.loads(body) if body else {} + + +def _dump_request_body(request_body: dict) -> bytes: + return json.dumps(request_body).encode() + + +def _parse_and_validate_response(response_body: bytes) -> tuple[dict, dict, list[int]]: + """Parse + validate a successful chat completion response off the event loop. + + Returns (full_response, assistant_message, completion_token_ids). Raises + UpstreamResponseError on malformed meta_info / content / token-length mismatch. + Touches no session state. + """ + response = json.loads(response_body) + choice = response.get("choices", [{}])[0] + meta_info = choice.get("meta_info") + if not isinstance(meta_info, dict) or "output_token_logprobs" not in meta_info: + raise UpstreamResponseError("meta_info and output_token_logprobs must be in choice (requires logprobs=True)") + assistant_message = choice.get("message", {}) + if assistant_message.get("content") is None: + raise UpstreamResponseError( + "assistant message content is None, when tool call parser failed SGLang should still return " + "an empty content rather than None. Please check your modified SGLang version." + ) + output_token_logprobs = meta_info["output_token_logprobs"] + completion_tokens = meta_info["completion_tokens"] + actual_output_logprobs_len = len(output_token_logprobs) + if actual_output_logprobs_len != completion_tokens: + raise UpstreamResponseError( + "invalid chat completion response: " + f"len(output_token_logprobs)={actual_output_logprobs_len} " + f"!= completion_tokens={completion_tokens}. " + f"Please check whether you use the correct SGLang branch which has fix the tokenizer batch decode issue." + ) + completion_token_ids = [t[1] for t in output_token_logprobs] + return response, assistant_message, completion_token_ids + + def setup_session_routes(app, backend, args): hf_checkpoint = getattr(args, "hf_checkpoint", None) if not hf_checkpoint: @@ -48,24 +91,6 @@ async def health(): body["session_server_instance_id"] = session_server_instance_id return body - # --- DEBUG: track in-flight chat_completions --- - _inflight_chat = {"count": 0} - - @app.middleware("http") - async def debug_request_logger(request: Request, call_next): - client = request.client - client_info = f"{client.host}:{client.port}" if client else "unknown" - logger.info( - f"[session-server] REQUEST ARRIVED: {request.method} {request.url.path} from={client_info} inflight_chat={_inflight_chat['count']}" - ) - t0 = time.time() - response = await call_next(request) - elapsed = time.time() - t0 - logger.info( - f"[session-server] REQUEST DONE: {request.method} {request.url.path} status={response.status_code} elapsed={elapsed:.3f}s from={client_info}" - ) - return response - @app.exception_handler(SessionError) async def session_error_handler(request: Request, exc: SessionError): return JSONResponse(status_code=exc.status_code, content={"error": str(exc)}) @@ -113,115 +138,80 @@ async def delete_session(session_id: str): @app.post("/sessions/{session_id}/v1/chat/completions") async def chat_completions(request: Request, session_id: str): - """Proxy a chat completion through SGLang with TITO token tracking. - - Flow: prepare pretokenized input_ids (lock held briefly) → inject - SGLang flags → proxy to backend (NO lock) → validate response → - update trajectory checkpoint (lock held briefly) → append session record. - - The lock is NOT held during the slow proxy call to avoid blocking - DELETE/other operations when the agent disconnects mid-request. + """One in-flight chat per session; a second concurrent same-session chat + fast-fails 409 without entering the backend. State mutation stays on the + event loop under session.lock; stateless CPU work is offloaded. """ - _inflight_chat["count"] += 1 - try: - session = registry.get_session(session_id) + loop = asyncio.get_running_loop() + session = registry.get_session(session_id) + + claimed = False + # claim the single in-flight slot under a brief lock; closing (404) beats busy (409) + async with session.lock: if session.closing: raise SessionNotFoundError(f"session not found: session_id={session_id}") - - # --- Phase 1: prepare request (lock held briefly) --- + if session.chat_inflight: + raise SessionBusyError("session already has an in-flight chat completion") + session.chat_inflight = True + claimed = True + try: + body = await request.body() + request_body = await loop.run_in_executor(backend.cpu_executor, _parse_request_body, body) + + # TITO token tracking requires Miles-owned input_ids plus SGLang + # output-token metadata: + # logprobs=True → populates meta_info.output_token_logprobs + # return_meta_info → wraps the above in choice.meta_info + # Both flags are hardcoded (not set default) to prevent agent-side + # overrides from breaking the token accumulation invariants. + request_body["logprobs"] = True + request_body["return_meta_info"] = True + if getattr(args, "use_rollout_routing_replay", False): + request_body["return_routed_experts"] = True + if getattr(args, "use_rollout_indexer_replay", False): + request_body["return_indexer_topk"] = True + # Must be False so stop-token text is trimmed from assistant + # message content; token IDs are still taken from logprobs below. + request_body["no_stop_trim"] = False + request_messages = request_body.get("messages", []) + + # prepare pretokenized input under the lock (mutates trajectory state) async with session.lock: - # Double-check: session may have been marked closing while waiting for lock. if session.closing: raise SessionNotFoundError(f"session not found: session_id={session_id}") - - body = await request.body() - request_body = json.loads(body) if body else {} - - # TITO token tracking requires Miles-owned input_ids plus SGLang - # output-token metadata: - # logprobs=True → populates meta_info.output_token_logprobs - # return_meta_info → wraps the above in choice.meta_info - # Both flags are hardcoded (not set default) to prevent agent-side - # overrides from breaking the token accumulation invariants. - request_body["logprobs"] = True - request_body["return_meta_info"] = True - if getattr(args, "use_rollout_routing_replay", False): - request_body["return_routed_experts"] = True - if getattr(args, "use_rollout_indexer_replay", False): - request_body["return_indexer_topk"] = True - # Must be False so stop-token text is trimmed from assistant - # message content; token IDs are still taken from logprobs below. - request_body["no_stop_trim"] = False - - request_messages = request_body.get("messages", []) prompt_token_ids = session.prepare_pretokenized( request_messages, tools=request_body.get("tools"), tito_tokenizer=registry.tito_tokenizer, ) request_body["input_ids"] = prompt_token_ids - logger.debug( - "Using TITO input_ids: %d tokens", - len(prompt_token_ids), - ) - - body = json.dumps(request_body).encode() expected_num_assistant = session.num_assistant - # --- lock released here --- - - # --- Phase 2: proxy to SGLang (NO lock held) --- - result = await backend.do_proxy(request, "v1/chat/completions", body=body) + logger.debug("Using TITO input_ids: %d tokens", len(prompt_token_ids)) - # If SGLang returned a non-200 error (e.g. 400 for context too long), - # pass it through to the agent without recording — the agent can retry - # or handle the error. + encoded_body = await loop.run_in_executor(backend.cpu_executor, _dump_request_body, request_body) + result = await backend.do_proxy(request, "v1/chat/completions", body=encoded_body) + # Non-200 (e.g. 400 for context too long) passes through unrecorded; + # the agent can retry or handle the error. if result["status_code"] != 200: return backend.build_proxy_response(result) - response = json.loads(result["response_body"]) - - choice = response.get("choices", [{}])[0] - - meta_info = choice.get("meta_info") - if not isinstance(meta_info, dict) or "output_token_logprobs" not in meta_info: - raise UpstreamResponseError( - "meta_info and output_token_logprobs must be in choice (requires logprobs=True)" - ) - assistant_message = choice.get("message", {}) - if assistant_message.get("content") is None: - raise UpstreamResponseError( - "assistant message content is None, when tool call parser failed SGLang should still return " - "an empty content rather than None. Please check your modified SGLang version." - ) - - output_token_logprobs = meta_info["output_token_logprobs"] - completion_tokens = meta_info["completion_tokens"] + response, assistant_message, completion_token_ids = await loop.run_in_executor( + backend.cpu_executor, _parse_and_validate_response, result["response_body"] + ) - actual_output_logprobs_len = len(output_token_logprobs) - if actual_output_logprobs_len != completion_tokens: - raise UpstreamResponseError( - "invalid chat completion response: " - f"len(output_token_logprobs)={actual_output_logprobs_len} " - f"!= completion_tokens={completion_tokens}. " - f"Please check whether you use the correct SGLang branch which has fix the tokenizer batch decode issue." - ) - - completion_token_ids = [t[1] for t in output_token_logprobs] - - # --- Phase 3: update state (lock held briefly) --- + # commit state under the lock async with session.lock: if session.closing: logger.warning(f"Session {session_id} closed during proxy, skipping state update") return backend.build_proxy_response(result) - if session.num_assistant != expected_num_assistant: - logger.warning( - f"Session {session_id} state changed during proxy " - f"(expected num_assistant={expected_num_assistant}, " - f"got {session.num_assistant}), skipping state update" + logger.error( + f"Session {session_id} invariant violation: num_assistant={session.num_assistant} " + f"!= expected={expected_num_assistant} under the in-flight gate; this should be unreachable" + ) + raise SessionInvariantError( + f"session state changed under the in-flight gate (session_id={session_id})" ) - return backend.build_proxy_response(result) - session.update_pretokenized_state( request_messages, assistant_message, @@ -229,7 +219,6 @@ async def chat_completions(request: Request, session_id: str): completion_token_ids=completion_token_ids, max_trim_tokens=registry.tito_tokenizer.max_trim_tokens, ) - record = SessionRecord( timestamp=time.time(), method=request.method, @@ -239,11 +228,12 @@ async def chat_completions(request: Request, session_id: str): response=response, ) session.append_record(record) - # --- lock released here --- - return backend.build_proxy_response(result) finally: - _inflight_chat["count"] -= 1 + if claimed: + # single-threaded event loop: a plain write is atomic; no other coroutine + # mutates this session's flag without the lock, and finally runs on cancellation. + session.chat_inflight = False @app.api_route("/sessions/{session_id}/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"]) async def session_proxy(request: Request, session_id: str, path: str): diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index b2dde47c7c..388dc83a43 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -1803,6 +1803,15 @@ def add_session_arguments(parser): help="Message roles allowed to be appended after the pretokenized " "assistant prefix in TITO sessions (default: tool).", ) + parser.add_argument( + "--session-server-cpu-workers", + type=int, + default=min(16, os.cpu_count() or 1), + help="Max worker threads for the session server's bounded CPU thread pool " + "(offloads stateless JSON parse/dump and response validation off the event loop). " + "Higher values improve event-loop responsiveness under load but raise peak memory; " + "the GIL means stdlib JSON does not gain CPU throughput from more threads.", + ) return parser def add_user_provided_function_arguments(parser): diff --git a/tests/fast/router/test_session_race_conditions.py b/tests/fast/router/test_session_race_conditions.py index 95c6a69cba..e3a670e46b 100644 --- a/tests/fast/router/test_session_race_conditions.py +++ b/tests/fast/router/test_session_race_conditions.py @@ -1,20 +1,24 @@ """E2E session stress tests. -Contract under test (with split-lock / session.closing): -- Phase 1 (prepare) and Phase 3 (state update) hold session.lock briefly; - Phase 2 (proxy to SGLang) does NOT hold the lock. -- Concurrent same-session requests can overlap at the backend (Phase 2), - but state updates (Phase 3) are serialized; stale-update guard - (expected_num_assistant check) ensures only one concurrent writer wins. -- Different sessions can run in parallel (no global lock). -- Per-session clients can run turn-by-turn without idle gaps while global load stays parallel. -- Delete marks session.closing=True, acquires session.lock, then removes. - Because the lock is not held during Phase 2, delete can proceed while a - chat request is mid-proxy; the chat's Phase 3 will see closing=True and - skip the state update gracefully. -- Chat requests to a closing session get 404 immediately (pre-lock check). -- Chat requests arriving while delete waits for lock get 404 (double-check after lock). +Contract under test (per-session in-flight gate): +- A session admits at most one in-flight chat completion. A second concurrent + same-session chat fast-fails 409 ("session already has an in-flight chat + completion") at slot-claim time, before the body is read/parsed and before + the backend is hit; it never reaches the backend. +- The in-flight slot is released on every exit path (success, malformed JSON, + prepare error, upstream non-200, transport 502, response validation failure, + state-update error, client cancel/disconnect), and only by the request that + claimed it: a request that got 409/404 does not clear the owner's slot. +- Different sessions still run in parallel (no global lock); per-session clients + can run turn-by-turn without idle gaps while global load stays parallel. +- Delete marks session.closing=True, acquires session.lock, then removes. The + lock is not held during the proxy, so delete can proceed while a chat is + mid-proxy; that chat's commit sees closing=True and skips the state update. +- Chat to a closing session gets 404 immediately, and closing (404) has + priority over busy (409). - Concurrent deletes on the same session: second delete gets 404. +- The upstream body is passed through faithfully; stale framing headers + (content-length / transfer-encoding / content-encoding) are stripped. """ from __future__ import annotations @@ -55,9 +59,37 @@ def patched_chat_response(self, payload: dict) -> dict: return patch.object(MockSGLangServer, "_compute_chat_completions_response", new=patched_chat_response) +def _patch_mock_chat_response_bad_first(): + """Like `_patch_mock_chat_response`, but the FIRST chat response is missing + `meta_info` so the session server raises UpstreamResponseError (502). The + second and later responses are valid, so a retry after the failure can 200. + """ + original_chat_response = MockSGLangServer._compute_chat_completions_response + state = {"calls": 0} + + def patched_chat_response(self, payload: dict) -> dict: + response = original_chat_response(self, payload) + choice = response["choices"][0] + logprobs_content = choice["logprobs"]["content"] + output_token_logprobs = [ + (item["logprob"], self.tokenizer.convert_tokens_to_ids(item["token"])) for item in logprobs_content + ] + choice["meta_info"] = { + "output_token_logprobs": output_token_logprobs, + "completion_tokens": len(output_token_logprobs), + } + state["calls"] += 1 + if state["calls"] == 1: + # Strip meta_info from the first response only -> 502 on first chat. + choice.pop("meta_info", None) + return response + + return patch.object(MockSGLangServer, "_compute_chat_completions_response", new=patched_chat_response) + + @contextmanager -def _router_env(process_fn, *, latency: float = 0.0): - with _patch_mock_chat_response(): +def _router_env(process_fn, *, latency: float = 0.0, response_patch=None): + with (response_patch or _patch_mock_chat_response)(): with with_mock_server(model_name=HF_CHECKPOINT, process_fn=process_fn, latency=latency) as backend: args = SimpleNamespace( miles_router_timeout=30, @@ -93,46 +125,80 @@ def _chat(url: str, session_id: str, payload: dict, timeout: float = 20.0) -> re ) -class TestSessionConcurrencyContracts: - def test_same_session_concurrent_requests_reach_backend(self): - """With the split-lock, same-session requests CAN overlap at the backend. +def _wait_for_backend_requests(backend, count: int, timeout: float = 5.0) -> None: + """Block until the backend has logged exactly `count` requests. + + Using a backend-arrival barrier instead of sleeping makes the "request is + parked in proxy" precondition deterministic: once the entry is logged the + owner has claimed the in-flight slot and is sitting in the latency window. + """ + deadline = time.time() + timeout + while time.time() < deadline: + if len(backend.request_log) == count: + return + time.sleep(0.005) + raise AssertionError(f"backend did not reach {count} requests in {timeout}s (saw {len(backend.request_log)})") + - Phase 2 (proxy) runs without the lock, so concurrent requests are not - serialized at the backend level. Phase 3 state updates are still - serialized; the stale-update guard ensures only one writer wins per - generation, so no state corruption occurs. +class TestSessionConcurrencyContracts: + def test_same_session_second_chat_returns_409(self): + """A session admits one in-flight chat; concurrents fast-fail 409. + + Park chat A in proxy (held by backend latency) and confirm via the + arrival barrier that A holds the slot. Contenders B/C/D on the same + session must each get 409 without ever reaching the backend, and they + must not release A's slot. After A finishes 200, the slot is free and a + fresh same-session chat succeeds. """ + # Latency comfortably larger than the time to fire three contenders. + latency = 0.5 + def process_fn(prompt: str) -> ProcessResult: return ProcessResult(text="concurrent-ok", finish_reason="stop") - with _router_env(process_fn, latency=0.2) as env: + with _router_env(process_fn, latency=latency) as env: session_id = _create_session(env.url) + payload = {"messages": [{"role": "user", "content": "park-in-proxy"}]} + + env.backend.reset_stats() + with ThreadPoolExecutor(max_workers=1) as pool: + # Fire A and wait until it is parked in proxy holding the slot. + chat_a = pool.submit(_chat, env.url, session_id, payload, 30.0) + _wait_for_backend_requests(env.backend, 1) - # Warm up one assistant checkpoint so repeated identical retry payloads are valid. - warmup_payload = {"messages": [{"role": "user", "content": "warmup"}]} - warmup_resp = _chat(env.url, session_id, warmup_payload) - assert warmup_resp.status_code == 200 - assistant = warmup_resp.json()["choices"][0]["message"] + # Contenders on the SAME session while A is parked -> 409 each. + contender_codes = [] + for _ in range(3): + resp = _chat(env.url, session_id, payload, timeout=10.0) + contender_codes.append(resp.status_code) + assert resp.status_code == 409, f"contender should be 409, got {resp.status_code}" + assert resp.json()["error"] == "session already has an in-flight chat completion" - retry_payload = { + # Contenders never reached the backend: still exactly A's request. + assert len(env.backend.request_log) == 1 + + # A finishes 200; a 409 contender did NOT clear A's slot. + chat_a_resp = chat_a.result(timeout=30.0) + + assert chat_a_resp.status_code == 200 + assert contender_codes == [409, 409, 409] + assert len(env.backend.request_log) == 1 + + # Slot was released on A's success: a fresh same-session chat works. + # The follow-up must be an append-only continuation of A's committed + # trajectory, so build it from A's assistant message. + assistant = chat_a_resp.json()["choices"][0]["message"] + follow_up = { "messages": [ - {"role": "user", "content": "warmup"}, + {"role": "user", "content": "park-in-proxy"}, assistant, - {"role": "system", "content": "retry-from-assistant-checkpoint"}, + {"role": "system", "content": "continue-after-A"}, ] } - - env.backend.reset_stats() - with ThreadPoolExecutor(max_workers=4) as pool: - futures = [pool.submit(_chat, env.url, session_id, retry_payload) for _ in range(4)] - responses = [f.result(timeout=30.0) for f in futures] - - # All requests should succeed (200) — no 500s. - assert all(resp.status_code == 200 for resp in responses) - assert len(env.backend.request_log) == 4 - # With split-lock, concurrent backend access is expected (not == 1). - assert env.backend.max_concurrent >= 1 + after = _chat(env.url, session_id, follow_up, timeout=20.0) + assert after.status_code == 200 + assert len(env.backend.request_log) == 2 def test_different_sessions_can_run_in_parallel(self): def process_fn(prompt: str) -> ProcessResult: @@ -192,10 +258,11 @@ def run_session_worker(session_id: str, idx: int) -> bool: assert env.backend.max_concurrent >= 4 def test_delete_can_proceed_while_chat_is_mid_proxy(self): - """With split-lock, delete can acquire the lock while chat is in Phase 2. + """Delete can acquire the lock while a chat is mid-proxy. - The inflight chat's Phase 3 sees session.closing=True and skips - state update gracefully. Both chat and delete complete without error. + The lock is not held during the proxy, so delete proceeds; the in-flight + chat's commit step sees session.closing=True and skips the state update. + Both chat and delete complete without error. """ def process_fn(prompt: str) -> ProcessResult: @@ -235,10 +302,10 @@ def test_chat_during_delete_returns_404(self): """Chat requests arriving after delete sets closing=True get 404. Timeline: - 1. Chat A starts, acquires lock (Phase 1), releases it, proxying (Phase 2) - 2. Delete arrives, sets session.closing=True, acquires lock, removes session - 3. Chat B arrives, sees session.closing=True, returns 404 immediately - 4. Chat A's Phase 3 sees closing=True, skips state update, returns 200 + 1. Chat A starts, claims the in-flight slot, and is proxying to backend. + 2. Delete arrives, sets session.closing=True, acquires lock, removes session. + 3. Chat B arrives, sees session.closing=True, returns 404 immediately. + 4. Chat A's commit sees closing=True, skips the state update, returns 200. """ def process_fn(prompt: str) -> ProcessResult: @@ -348,11 +415,13 @@ def process_fn(prompt: str) -> ProcessResult: get_resp = requests.get(f"{env.url}/sessions/{session_id}", timeout=5.0) assert get_resp.status_code == 404 - def test_multiple_chats_queued_then_delete(self): - """Multiple chat requests queued behind session.lock, then delete. + def test_concurrent_chats_then_delete(self): + """Concurrent same-session chats do not queue under the gate. - After delete marks closing=True, queued chats that acquire the lock - should check closing and return 404. + One chat parks in proxy holding the slot; a couple more same-session + chats fired concurrently get 409 (gate, before the backend). A delete + then returns 204 and the parked chat returns 200 (commit skips the + state update on closing). No 500s; codes are a subset of {200,409,404}. """ def process_fn(prompt: str) -> ProcessResult: @@ -362,35 +431,33 @@ def process_fn(prompt: str) -> ProcessResult: session_id = _create_session(env.url) payload = {"messages": [{"role": "user", "content": "queued"}]} - with ThreadPoolExecutor(max_workers=6) as pool: - # Fire 3 chats (first holds lock, others queue) - chat_futures = [pool.submit(_chat, env.url, session_id, payload, 30.0) for _ in range(3)] + with ThreadPoolExecutor(max_workers=2) as pool: + # Park the owner chat in proxy holding the in-flight slot. + owner = pool.submit(_chat, env.url, session_id, payload, 30.0) + _wait_for_backend_requests(env.backend, 1) - # Wait for first to reach backend - deadline = time.time() + 5.0 - while time.time() < deadline: - if env.backend.request_log: - break - time.sleep(0.01) + # Concurrent same-session chats hit the gate -> 409. + contender_codes = [_chat(env.url, session_id, payload, timeout=10.0).status_code for _ in range(2)] - # Now delete - sets closing, waits for first chat to finish + # Delete the session while the owner is still parked. delete_future = pool.submit( requests.delete, f"{env.url}/sessions/{session_id}", timeout=30.0, ) - results = [f.result(timeout=30.0) for f in chat_futures] + owner_resp = owner.result(timeout=30.0) delete_resp = delete_future.result(timeout=30.0) assert delete_resp.status_code == 204 + assert owner_resp.status_code == 200 + assert contender_codes == [409, 409] - # At least one chat must succeed (the one holding the lock when - # delete arrived). Others may get 200 (acquired lock before - # closing) or 404 (saw closing=True). No 500s allowed. - status_codes = [r.status_code for r in results] - assert all(c in (200, 404) for c in status_codes), f"Unexpected status codes: {status_codes}" - assert 200 in status_codes, f"Expected at least one 200, got {status_codes}" + # No 500s; every chat code is in the allowed set, and exactly one + # chat (the owner) reached the backend. + all_chat_codes = [owner_resp.status_code, *contender_codes] + assert all(c in (200, 409, 404) for c in all_chat_codes), f"Unexpected status codes: {all_chat_codes}" + assert len(env.backend.request_log) == 1 def test_rapid_create_chat_delete_cycles(self): """Rapidly create, chat, and delete sessions to stress the lifecycle. @@ -420,3 +487,84 @@ def lifecycle_cycle(idx: int) -> bool: results = [f.result(timeout=60.0) for f in futures] assert all(results) + + +class TestSlotReleaseAfterError: + """The in-flight slot is freed on failing exit paths: a fresh legal chat on + the same session after a failure must succeed (200, not 409). + """ + + def test_slot_released_after_malformed_request_json(self): + """Malformed JSON body errors (500-class) before the backend; the slot + is released so a subsequent normal chat on the same session returns 200. + """ + + def process_fn(prompt: str) -> ProcessResult: + return ProcessResult(text="ok", finish_reason="stop") + + with _router_env(process_fn) as env: + session_id = _create_session(env.url) + + bad = requests.post( + f"{env.url}/sessions/{session_id}/v1/chat/completions", + data="{not json", + headers={"content-type": "application/json"}, + timeout=10.0, + ) + assert bad.status_code >= 500 + # The malformed body never reached the backend. + assert len(env.backend.request_log) == 0 + + good = _chat(env.url, session_id, {"messages": [{"role": "user", "content": "hi"}]}, timeout=20.0) + assert good.status_code == 200 + assert len(env.backend.request_log) == 1 + + def test_slot_released_after_response_validation_failure(self): + """An upstream response missing meta_info raises UpstreamResponseError + (502); the slot is released so a subsequent normal chat returns 200. + """ + + def process_fn(prompt: str) -> ProcessResult: + return ProcessResult(text="ok", finish_reason="stop") + + with _router_env(process_fn, response_patch=_patch_mock_chat_response_bad_first) as env: + session_id = _create_session(env.url) + + bad = _chat(env.url, session_id, {"messages": [{"role": "user", "content": "hi"}]}, timeout=20.0) + assert bad.status_code == 502 + assert len(env.backend.request_log) == 1 + + good = _chat(env.url, session_id, {"messages": [{"role": "user", "content": "hi again"}]}, timeout=20.0) + assert good.status_code == 200 + assert len(env.backend.request_log) == 2 + + +class TestPassthroughFidelity: + """A successful chat response is passed through faithfully and stale + framing headers from upstream are not copied to the client. + """ + + def test_successful_response_body_and_headers_passthrough(self): + def process_fn(prompt: str) -> ProcessResult: + return ProcessResult(text="passthrough-ok", finish_reason="stop") + + with _router_env(process_fn) as env: + session_id = _create_session(env.url) + resp = _chat(env.url, session_id, {"messages": [{"role": "user", "content": "hi"}]}, timeout=20.0) + assert resp.status_code == 200 + + # Body matches the upstream mock shape (id/object/choices content) + # passed through unchanged, including the meta_info the mock adds. + client_body = resp.json() + assert client_body["object"] == "chat.completion" + assert client_body["id"].startswith("chatcmpl-") + assert client_body["choices"][0]["message"]["content"] == "passthrough-ok" + assert "meta_info" in client_body["choices"][0] + + # Framing headers copied from upstream must have been stripped; the + # body is intact (decodable JSON), so requests itself did not need + # transfer/content-encoding framing from upstream. + lowered = {k.lower() for k in resp.headers.keys()} + assert "transfer-encoding" not in lowered + assert "content-encoding" not in lowered + assert resp.headers.get("content-type", "").startswith("application/json") From 72a3916401c8d60f95d7b4e303ac3a4ec1e1cec9 Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Mon, 22 Jun 2026 22:38:56 +0000 Subject: [PATCH 02/11] fix(session): classify malformed upstream chat responses Malformed successful upstream responses could leak raw parser and shape exceptions as 500s while several slot-release and passthrough paths were unverified. Harden the response validator to raise UpstreamResponseError and add focused tests for injected failure exits plus raw proxy response fidelity. --- miles/rollout/session/sessions.py | 24 +- .../router/test_session_race_conditions.py | 267 ++++++++++++++++++ 2 files changed, 286 insertions(+), 5 deletions(-) diff --git a/miles/rollout/session/sessions.py b/miles/rollout/session/sessions.py index 054d98ec5a..42349c69a2 100644 --- a/miles/rollout/session/sessions.py +++ b/miles/rollout/session/sessions.py @@ -38,19 +38,30 @@ def _parse_and_validate_response(response_body: bytes) -> tuple[dict, dict, list UpstreamResponseError on malformed meta_info / content / token-length mismatch. Touches no session state. """ - response = json.loads(response_body) - choice = response.get("choices", [{}])[0] + try: + response = json.loads(response_body) + except (json.JSONDecodeError, UnicodeDecodeError) as e: + raise UpstreamResponseError(f"upstream response is not valid JSON: {e}") from e + + choices = response.get("choices") if isinstance(response, dict) else None + if not isinstance(choices, list) or not choices or not isinstance(choices[0], dict): + raise UpstreamResponseError("upstream response has no valid choices[0]") + choice = choices[0] meta_info = choice.get("meta_info") if not isinstance(meta_info, dict) or "output_token_logprobs" not in meta_info: raise UpstreamResponseError("meta_info and output_token_logprobs must be in choice (requires logprobs=True)") - assistant_message = choice.get("message", {}) + assistant_message = choice.get("message") + if not isinstance(assistant_message, dict): + raise UpstreamResponseError("upstream response choice has no valid message") if assistant_message.get("content") is None: raise UpstreamResponseError( "assistant message content is None, when tool call parser failed SGLang should still return " "an empty content rather than None. Please check your modified SGLang version." ) output_token_logprobs = meta_info["output_token_logprobs"] - completion_tokens = meta_info["completion_tokens"] + completion_tokens = meta_info.get("completion_tokens") + if not isinstance(output_token_logprobs, list) or not isinstance(completion_tokens, int): + raise UpstreamResponseError("upstream response output_token_logprobs/completion_tokens have invalid types") actual_output_logprobs_len = len(output_token_logprobs) if actual_output_logprobs_len != completion_tokens: raise UpstreamResponseError( @@ -59,7 +70,10 @@ def _parse_and_validate_response(response_body: bytes) -> tuple[dict, dict, list f"!= completion_tokens={completion_tokens}. " f"Please check whether you use the correct SGLang branch which has fix the tokenizer batch decode issue." ) - completion_token_ids = [t[1] for t in output_token_logprobs] + try: + completion_token_ids = [t[1] for t in output_token_logprobs] + except (IndexError, TypeError, KeyError) as e: + raise UpstreamResponseError(f"upstream response output_token_logprobs entries are malformed: {e}") from e return response, assistant_message, completion_token_ids diff --git a/tests/fast/router/test_session_race_conditions.py b/tests/fast/router/test_session_race_conditions.py index e3a670e46b..fda0f07e0e 100644 --- a/tests/fast/router/test_session_race_conditions.py +++ b/tests/fast/router/test_session_race_conditions.py @@ -23,6 +23,7 @@ from __future__ import annotations +import asyncio import time from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager @@ -30,7 +31,10 @@ from unittest.mock import patch import requests +from starlette.responses import Response +from miles.rollout.session.linear_trajectory import LinearTrajectory +from miles.rollout.session.session_errors import MessageValidationError, TokenizationError from miles.rollout.session.session_server import SessionServer from miles.utils.http_utils import find_available_port from miles.utils.test_utils.mock_sglang_server import MockSGLangServer, ProcessResult, with_mock_server @@ -539,6 +543,269 @@ def process_fn(prompt: str) -> ProcessResult: assert len(env.backend.request_log) == 2 +def _normal_messages(content: str) -> dict: + return {"messages": [{"role": "user", "content": content}]} + + +class TestSlotReleaseInjectedFailures: + """Deterministic slot-release coverage for the remaining enumerated failure + paths. Each test injects a one-shot failure on the first same-session chat, + then proves the slot was released: the NEXT same-session chat returns 200 + (no residual 409). Failures are injected via class-level mock patches so the + server thread's instance/session is affected; ordering is enforced by a + per-test call counter, not by sleeps. + """ + + def test_slot_released_after_prepare_validation_error(self): + """`prepare_pretokenized` raising MessageValidationError (400) before the + backend releases the slot; the next normal chat returns 200. + """ + + def process_fn(prompt: str) -> ProcessResult: + return ProcessResult(text="ok", finish_reason="stop") + + original_prepare = LinearTrajectory.prepare_pretokenized + state = {"calls": 0} + + def one_shot_prepare(self, *args, **kwargs): + state["calls"] += 1 + if state["calls"] == 1: + raise MessageValidationError("injected prepare failure") + return original_prepare(self, *args, **kwargs) + + with _router_env(process_fn) as env: + session_id = _create_session(env.url) + with patch.object(LinearTrajectory, "prepare_pretokenized", one_shot_prepare): + bad = _chat(env.url, session_id, _normal_messages("hi"), timeout=20.0) + assert bad.status_code == 400 + # prepare failed before the backend was hit. + assert len(env.backend.request_log) == 0 + + good = _chat(env.url, session_id, _normal_messages("hi again"), timeout=20.0) + assert good.status_code == 200 + assert len(env.backend.request_log) == 1 + + def test_slot_released_after_upstream_non_200(self): + """An upstream non-200 passes through (400) without recording; the slot + is released so the next normal chat returns 200. + """ + + def process_fn(prompt: str) -> ProcessResult: + return ProcessResult(text="ok", finish_reason="stop") + + original_do_proxy = SessionServer.do_proxy + state = {"calls": 0} + + async def one_shot_do_proxy(self, request, path, body=None, headers=None): + state["calls"] += 1 + if state["calls"] == 1: + return { + "request_body": body, + "response_body": b'{"error":"bad"}', + "status_code": 400, + "headers": {"content-type": "application/json"}, + } + return await original_do_proxy(self, request, path, body=body, headers=headers) + + with _router_env(process_fn) as env: + session_id = _create_session(env.url) + with patch.object(SessionServer, "do_proxy", one_shot_do_proxy): + bad = _chat(env.url, session_id, _normal_messages("hi"), timeout=20.0) + assert bad.status_code == 400 + assert bad.json()["error"] == "bad" + + good = _chat(env.url, session_id, _normal_messages("hi again"), timeout=20.0) + assert good.status_code == 200 + + def test_slot_released_after_transport_502(self): + """A transport-error 502 (do_proxy's real error shape) passes through; the + slot is released so the next normal chat returns 200. + """ + + def process_fn(prompt: str) -> ProcessResult: + return ProcessResult(text="ok", finish_reason="stop") + + original_do_proxy = SessionServer.do_proxy + state = {"calls": 0} + + async def one_shot_do_proxy(self, request, path, body=None, headers=None): + state["calls"] += 1 + if state["calls"] == 1: + return { + "request_body": body, + "response_body": b'{"error":"backend transport error"}', + "status_code": 502, + "headers": {"content-type": "application/json"}, + } + return await original_do_proxy(self, request, path, body=body, headers=headers) + + with _router_env(process_fn) as env: + session_id = _create_session(env.url) + with patch.object(SessionServer, "do_proxy", one_shot_do_proxy): + bad = _chat(env.url, session_id, _normal_messages("hi"), timeout=20.0) + assert bad.status_code == 502 + + good = _chat(env.url, session_id, _normal_messages("hi again"), timeout=20.0) + assert good.status_code == 200 + + def test_slot_released_after_state_update_error(self): + """`update_pretokenized_state` raising TokenizationError (500) after a 200 + proxy releases the slot; the next normal chat returns 200. + """ + + def process_fn(prompt: str) -> ProcessResult: + return ProcessResult(text="ok", finish_reason="stop") + + original_update = LinearTrajectory.update_pretokenized_state + state = {"calls": 0} + + def one_shot_update(self, *args, **kwargs): + state["calls"] += 1 + if state["calls"] == 1: + raise TokenizationError("injected state-update failure") + return original_update(self, *args, **kwargs) + + with _router_env(process_fn) as env: + session_id = _create_session(env.url) + with patch.object(LinearTrajectory, "update_pretokenized_state", one_shot_update): + bad = _chat(env.url, session_id, _normal_messages("hi"), timeout=20.0) + assert bad.status_code == 500 + # The proxy did run; the failure is in the commit step. + assert len(env.backend.request_log) == 1 + + good = _chat(env.url, session_id, _normal_messages("hi again"), timeout=20.0) + assert good.status_code == 200 + assert len(env.backend.request_log) == 2 + + def test_slot_released_after_client_cancel_mid_proxy(self): + """If the handler is cancelled mid-proxy (client disconnect), the request + errors at the client side but the `finally` releases the slot; the next + normal same-session chat returns 200. + """ + + def process_fn(prompt: str) -> ProcessResult: + return ProcessResult(text="ok", finish_reason="stop") + + original_do_proxy = SessionServer.do_proxy + state = {"calls": 0} + + async def one_shot_do_proxy(self, request, path, body=None, headers=None): + state["calls"] += 1 + if state["calls"] == 1: + raise asyncio.CancelledError() + return await original_do_proxy(self, request, path, body=body, headers=headers) + + with _router_env(process_fn) as env: + session_id = _create_session(env.url) + with patch.object(SessionServer, "do_proxy", one_shot_do_proxy): + try: + _chat(env.url, session_id, _normal_messages("hi"), timeout=20.0) + except requests.exceptions.RequestException: + pass + + good = _chat(env.url, session_id, _normal_messages("hi again"), timeout=20.0) + assert good.status_code == 200 + + +class TestBuildProxyResponse: + """Unit tests for SessionServer.build_proxy_response (AC-5 passthrough + fidelity). No running server needed: with no hf_checkpoint, + setup_session_routes returns early so construction is light. + """ + + def _server(self) -> SessionServer: + return SessionServer(SimpleNamespace(miles_router_timeout=30), backend_url="http://unused") + + def test_json_200_body_and_headers_passthrough(self): + server = self._server() + body = b'{"object":"chat.completion","choices":[]}' + result = { + "response_body": body, + "status_code": 200, + "headers": { + "content-type": "application/json", + "content-length": "999", + "transfer-encoding": "chunked", + "content-encoding": "gzip", + }, + } + resp = server.build_proxy_response(result) + assert isinstance(resp, Response) + assert resp.status_code == 200 + assert resp.body == body + lowered = {k.lower() for k in resp.headers.keys()} + assert resp.headers["content-type"].startswith("application/json") + # The stale upstream content-length is dropped; transfer/content-encoding + # are absent. Starlette recomputes content-length from the actual body, so + # if present it must equal the body length (never the stale upstream "999"). + assert "transfer-encoding" not in lowered + assert "content-encoding" not in lowered + assert resp.headers.get("content-length", str(len(body))) == str(len(body)) + + def test_non_json_200_body_and_media_type_preserved(self): + server = self._server() + body = b"plain text" + result = { + "response_body": body, + "status_code": 200, + "headers": {"content-type": "text/plain"}, + } + resp = server.build_proxy_response(result) + assert isinstance(resp, Response) + assert resp.status_code == 200 + assert resp.body == body + assert resp.media_type == "text/plain" + assert resp.headers["content-type"].startswith("text/plain") + + def test_synthetic_502_status_and_body_passthrough(self): + server = self._server() + body = b'{"error":"backend transport error"}' + result = { + "response_body": body, + "status_code": 502, + "headers": {"content-type": "application/json"}, + } + resp = server.build_proxy_response(result) + assert isinstance(resp, Response) + assert resp.status_code == 502 + assert resp.body == body + + def test_compressed_upstream_strips_stale_framing(self): + server = self._server() + body = b'{"ok":true}' + result = { + "response_body": body, + "status_code": 200, + "headers": { + "content-type": "application/json", + "content-encoding": "gzip", + "content-length": "5", + }, + } + resp = server.build_proxy_response(result) + assert isinstance(resp, Response) + assert resp.body == body + lowered = {k.lower() for k in resp.headers.keys()} + # content-encoding stripped; the stale upstream content-length ("5") is + # dropped — Starlette recomputes from the body, never carries the stale one. + assert "content-encoding" not in lowered + assert resp.headers.get("content-length", str(len(body))) == str(len(body)) + + def test_generic_session_proxy_route_passes_through(self): + """The generic /sessions/{id}/{path} route proxies and passes the body + through: /abort_request on the mock backend returns {"status":"ok"} 200. + """ + + def process_fn(prompt: str) -> ProcessResult: + return ProcessResult(text="ok", finish_reason="stop") + + with _router_env(process_fn) as env: + session_id = _create_session(env.url) + resp = requests.post(f"{env.url}/sessions/{session_id}/abort_request", timeout=10.0) + assert resp.status_code == 200 + assert resp.json() == {"status": "ok"} + + class TestPassthroughFidelity: """A successful chat response is passed through faithfully and stale framing headers from upstream are not copied to the client. From e39d314930477384e9c16ab69d773176f46ae22e Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Mon, 22 Jun 2026 23:15:31 +0000 Subject: [PATCH 03/11] test(session): add event-loop responsiveness micro-benchmark (AC-6.1) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds tests/fast/router/bench_session_responsiveness.py (bench_ prefix → not collected by pytest), driving K concurrent large-routed-experts chats across distinct sessions while polling /health, reporting /health p50/p95/p99 under load. Honest result (real numbers, no fabrication): with multi-MiB responses the event loop is dominated by on-loop body I/O (httpx read + uvicorn write), so the CPU-parse offload yields no measurable end-to-end /health speedup at this scale (after ~= before: p50~20ms, p95~58ms, ~18.6 chats/s, 0 errors in both builds). An isolated probe still shows the offload mechanism works - inline ~150ms parse blocks the loop ~741ms vs ~0.07ms offloaded - so offloading helps only when an individual parse is itself large enough to block the loop. GIL means total CPU throughput is unchanged, as designed. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../router/bench_session_responsiveness.py | 291 ++++++++++++++++++ 1 file changed, 291 insertions(+) create mode 100644 tests/fast/router/bench_session_responsiveness.py diff --git a/tests/fast/router/bench_session_responsiveness.py b/tests/fast/router/bench_session_responsiveness.py new file mode 100644 index 0000000000..fcfc8d858a --- /dev/null +++ b/tests/fast/router/bench_session_responsiveness.py @@ -0,0 +1,291 @@ +"""Event-loop-responsiveness micro-benchmark for the standalone SessionServer. + +Why this exists (AC-6.1, directional / non-blocking): +The session server offloads stateless CPU work (large JSON parse/dump of the +chat request + response, including hundreds-of-KB `routed_experts` blobs) off the +single asyncio event loop onto a bounded `SessionServer.cpu_executor`. The claim +is that this keeps the one event loop responsive under heavy-response load so the +liveness probe `GET /health` (handled inline on the loop) stays fast. It does NOT +claim higher total CPU throughput: the GIL still serializes the Python JSON work, +so this measures latency/responsiveness, not aggregate CPU. + +This file is named `bench_*` so pytest does NOT auto-collect it (no flaky timing +test in CI). Run it directly: python tests/fast/router/bench_session_responsiveness.py + +Method: fire K concurrent chats across K DISTINCT sessions (distinct sessions are +not gated, so they run in parallel), each producing a large response that forces a +CPU-heavy `json.loads` in `_parse_and_validate_response`. Concurrently, a separate +thread polls `GET /health` every ~10ms and records each round-trip latency. We +report chat throughput, response body size, and `/health` latency percentiles. All +timing is client-side `time.perf_counter`. +""" + +from __future__ import annotations + +import json +import logging +import statistics +import threading +import time +from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager +from types import SimpleNamespace +from unittest.mock import patch + +import requests + +# Quiet the uvicorn access logs from both servers so the summary block is the +# only thing on stdout (the per-request "GET /health 200 OK" lines bury it). +for _name in ("uvicorn", "uvicorn.access", "uvicorn.error"): + logging.getLogger(_name).setLevel(logging.WARNING) + +from miles.rollout.session.session_server import SessionServer +from miles.utils.http_utils import find_available_port +from miles.utils.test_utils.mock_sglang_server import MockSGLangServer, ProcessResult, with_mock_server +from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer + +# --- Tunable constants (kept modest so a run finishes well under a minute) --- +HF_CHECKPOINT = "Qwen/Qwen3-0.6B" +K_CHATS = 12 # concurrent chats, one per distinct (ungated) session +# The session server json.loads the WHOLE response body to validate it, but the +# `routed_experts` value is opaque to it (never decoded). To isolate the CPU +# parse cost the offload targets — without inflating the on-loop body I/O that +# scales with raw byte size — we make routed_experts a STRUCTURED blob (nested +# numeric arrays). Structured JSON parses ~7-8x slower per byte than one big +# string, so a modest ~2 MiB body costs ~15-30ms to parse. Run INLINE on the loop +# (pre-offload) that stalls every /health probe queued behind it; offloaded to +# cpu_executor it runs off-loop and the probe stays fast. +BLOB_ROWS = 900 # rows of BLOB_ROW_WIDTH floats -> ~1.5 MiB body, ~18ms parse/call +BLOB_ROW_WIDTH = 256 +HEALTH_POLL_INTERVAL_S = 0.01 # ~10ms between /health probes +HEALTH_POLL_DURATION_S = 4.0 # how long the health poller runs (also caps the load window) +CHAT_TEXT = "responsiveness-ok" + + +def _make_large_blob() -> list: + # Nested numeric arrays, like a routed_experts logits buffer: expensive to + # json.loads per byte (many tokens), which is exactly the CPU work the + # session server offloads off the event loop. + row = [round(i * 0.001, 4) for i in range(BLOB_ROW_WIDTH)] + return [list(row) for _ in range(BLOB_ROWS)] + + +def _patch_mock_chat_response_heavy(blob: list): + """Patch the mock chat response into the shape the session server validates + (output_token_logprobs as (logprob, token_id)), and inject a large STRUCTURED + routed_experts blob into meta_info so the session server's json.loads of the + response body is genuinely CPU-heavy. The blob is opaque to the session server + (it only validates output_token_logprobs / message), matching production where + routed_experts is an opaque buffer the session layer never decodes. + """ + original_chat_response = MockSGLangServer._compute_chat_completions_response + + def patched_chat_response(self, payload: dict) -> dict: + response = original_chat_response(self, payload) + choice = response["choices"][0] + logprobs_content = choice["logprobs"]["content"] + output_token_logprobs = [ + (item["logprob"], self.tokenizer.convert_tokens_to_ids(item["token"])) for item in logprobs_content + ] + choice["meta_info"]["output_token_logprobs"] = output_token_logprobs + choice["meta_info"]["completion_tokens"] = len(output_token_logprobs) + choice["meta_info"]["routed_experts"] = blob + return response + + return patch.object(MockSGLangServer, "_compute_chat_completions_response", new=patched_chat_response) + + +@contextmanager +def _router_env(process_fn, blob: list, *, latency: float = 0.0): + with _patch_mock_chat_response_heavy(blob): + with with_mock_server(model_name=HF_CHECKPOINT, process_fn=process_fn, latency=latency) as backend: + args = SimpleNamespace( + miles_router_timeout=30, + hf_checkpoint=HF_CHECKPOINT, + chat_template_path=None, + trajectory_manager="linear_trajectory", + tito_allowed_append_roles=["tool", "system"], + # Make the session server request + echo a big routed_experts blob, + # so the per-response json.loads is genuinely CPU-heavy. + use_rollout_routing_replay=True, + ) + server_obj = SessionServer(args, backend_url=backend.url) + + port = find_available_port(31000) + server = UvicornThreadServer(server_obj.app, host="127.0.0.1", port=port) + server.start() + # uvicorn.Config(log_level="info") reconfigures these on startup, so + # re-quiet them now that both servers are up (keeps stdout to the summary). + for _name in ("uvicorn", "uvicorn.access", "uvicorn.error"): + logging.getLogger(_name).setLevel(logging.WARNING) + url = f"http://127.0.0.1:{port}" + try: + yield SimpleNamespace(url=url, backend=backend, server=server) + finally: + server.stop() + + +def _create_session(url: str) -> str: + resp = requests.post(f"{url}/sessions", timeout=5.0) + assert resp.status_code == 200, resp.text + return resp.json()["session_id"] + + +def _chat(url: str, session_id: str) -> tuple[int, int]: + """Returns (status_code, response_body_size_bytes).""" + resp = requests.post( + f"{url}/sessions/{session_id}/v1/chat/completions", + json={"messages": [{"role": "user", "content": "drive-load"}]}, + timeout=60.0, + ) + return resp.status_code, len(resp.content) + + +class _HealthPoller(threading.Thread): + """Polls GET /health on a fixed cadence and records each round-trip latency + (seconds, client-side perf_counter). Runs until stop() or duration elapses.""" + + def __init__(self, url: str, interval_s: float, max_duration_s: float): + super().__init__(daemon=True) + self.url = f"{url}/health" + self.interval_s = interval_s + self.max_duration_s = max_duration_s + self.latencies_s: list[float] = [] + self.errors = 0 + self._stop_evt = threading.Event() + self._session = requests.Session() + + def stop(self) -> None: + self._stop_evt.set() + + def run(self) -> None: + deadline = time.perf_counter() + self.max_duration_s + while not self._stop_evt.is_set() and time.perf_counter() < deadline: + t0 = time.perf_counter() + try: + resp = self._session.get(self.url, timeout=5.0) + dt = time.perf_counter() - t0 + if resp.status_code == 200: + self.latencies_s.append(dt) + else: + self.errors += 1 + except requests.RequestException: + self.errors += 1 + time.sleep(self.interval_s) + + +def _pct(values_ms: list[float], q: float) -> float: + if not values_ms: + return float("nan") + ordered = sorted(values_ms) + idx = min(len(ordered) - 1, int(q * len(ordered))) + return ordered[idx] + + +def run_bench() -> dict: + blob = _make_large_blob() + # The CPU cost the offload targets: one json.loads of a response carrying this + # blob. Measured up front so the summary states how heavy the offloaded work is. + sample_body = json.dumps({"choices": [{"meta_info": {"routed_experts": blob}}]}).encode() + _t = time.perf_counter() + json.loads(sample_body) + parse_ms = (time.perf_counter() - _t) * 1000 + + def process_fn(_prompt: str) -> ProcessResult: + # routed_experts is injected by the response patch (kept off the str-typed + # dataclass field); this just drives a small assistant message. + return ProcessResult(text=CHAT_TEXT, finish_reason="stop") + + with _router_env(process_fn, blob) as env: + session_ids = [_create_session(env.url) for _ in range(K_CHATS)] + + # Warm one chat so we can log the real response body size before the run. + warm_status, warm_size = _chat(env.url, session_ids[0]) + assert warm_status == 200, f"warm-up chat failed: {warm_status}" + + # Baseline: /health latency with no chat load (single thread, short burst). + baseline = _HealthPoller(env.url, HEALTH_POLL_INTERVAL_S, max_duration_s=0.7) + baseline.start() + baseline.join() + baseline_ms = [s * 1000 for s in baseline.latencies_s] + + # Under load: poll /health while waves of K concurrent chats keep the + # backend + cpu_executor busy. A single wave of K first-turn chats is + # fast, so we fire back-to-back waves until the poll window elapses; this + # sustains heavy-response load long enough to sample /health meaningfully. + # Each chat uses its own fresh, distinct session (distinct sessions are + # ungated, so they run in parallel). + poller = _HealthPoller(env.url, HEALTH_POLL_INTERVAL_S, HEALTH_POLL_DURATION_S) + poller.start() + + results: list[tuple[int, int]] = [] + waves = 0 + t0 = time.perf_counter() + with ThreadPoolExecutor(max_workers=K_CHATS) as pool: + while time.perf_counter() - t0 < HEALTH_POLL_DURATION_S: + fresh_ids = [_create_session(env.url) for _ in range(K_CHATS)] + futures = [pool.submit(_chat, env.url, sid) for sid in fresh_ids] + results.extend(f.result(timeout=120.0) for f in futures) + waves += 1 + chat_wall_s = time.perf_counter() - t0 + + poller.stop() + poller.join() + + statuses = [s for s, _ in results] + sizes = [sz for _, sz in results] + ok = sum(1 for s in statuses if s == 200) + load_ms = [s * 1000 for s in poller.latencies_s] + + total = len(results) + return { + "k_chats": K_CHATS, + "waves": waves, + "parse_ms": parse_ms, + "total_chats": total, + "ok_chats": ok, + "failed_chats": total - ok, + "response_body_bytes": warm_size if not sizes else max(sizes), + "chat_wall_s": chat_wall_s, + "chat_throughput_per_s": ok / chat_wall_s if chat_wall_s > 0 else float("nan"), + "health_samples_under_load": len(load_ms), + "health_errors_under_load": poller.errors, + "health_baseline_samples": len(baseline_ms), + "baseline_ms": baseline_ms, + "load_ms": load_ms, + } + + +def _fmt_block(r: dict) -> str: + base = r["baseline_ms"] + load = r["load_ms"] + lines = [ + "=" * 64, + "SessionServer event-loop responsiveness benchmark", + "=" * 64, + f" concurrency / wave (K) : {r['k_chats']}", + f" waves driven : {r['waves']} (total {r['total_chats']} chats)", + f" response body size (actual) : {r['response_body_bytes'] / 1024:.1f} KiB", + f" single-response parse cost : {r['parse_ms']:.1f} ms (the offloaded CPU work)", + f" chats ok / failed : {r['ok_chats']} / {r['failed_chats']}", + f" chat wall-clock : {r['chat_wall_s']:.3f} s", + f" chat throughput : {r['chat_throughput_per_s']:.1f} chats/s", + "-" * 64, + f" /health baseline (no load) : n={r['health_baseline_samples']}, " + f"p50={_pct(base, 0.50):.2f}ms p95={_pct(base, 0.95):.2f}ms " + f"max={(max(base) if base else float('nan')):.2f}ms", + f" /health UNDER LOAD : n={r['health_samples_under_load']}, " + f"errors={r['health_errors_under_load']}", + f" p50 = {_pct(load, 0.50):.2f} ms", + f" p95 = {_pct(load, 0.95):.2f} ms", + f" p99 = {_pct(load, 0.99):.2f} ms", + f" max = {(max(load) if load else float('nan')):.2f} ms", + f" mean= {(statistics.mean(load) if load else float('nan')):.2f} ms", + "=" * 64, + ] + return "\n".join(lines) + + +if __name__ == "__main__": + result = run_bench() + print(_fmt_block(result)) From dadf56a3e2c8db9819fbe6246c6316c3ee866e0d Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Mon, 22 Jun 2026 23:30:51 +0000 Subject: [PATCH 04/11] test(session): expose remaining session gate verification gaps Add promptness, invariant-mismatch, and real-disconnect coverage for the session in-flight gate, and scale the responsiveness benchmark to 64 sessions while documenting that the current benchmark still lacks internal stage timing. --- .../router/bench_session_responsiveness.py | 10 +- .../router/test_session_race_conditions.py | 163 +++++++++++++++++- 2 files changed, 165 insertions(+), 8 deletions(-) diff --git a/tests/fast/router/bench_session_responsiveness.py b/tests/fast/router/bench_session_responsiveness.py index fcfc8d858a..81cde112e5 100644 --- a/tests/fast/router/bench_session_responsiveness.py +++ b/tests/fast/router/bench_session_responsiveness.py @@ -1,6 +1,6 @@ """Event-loop-responsiveness micro-benchmark for the standalone SessionServer. -Why this exists (AC-6.1, directional / non-blocking): +Why this exists (directional / non-blocking): The session server offloads stateless CPU work (large JSON parse/dump of the chat request + response, including hundreds-of-KB `routed_experts` blobs) off the single asyncio event loop onto a bounded `SessionServer.cpu_executor`. The claim @@ -9,6 +9,12 @@ claim higher total CPU throughput: the GIL still serializes the Python JSON work, so this measures latency/responsiveness, not aggregate CPU. +Honest reporting note: deeper per-stage server-internal timing is intentionally +NOT instrumented into the production hot path (no probes added to chat_completions). +The offloaded stage's cost is reported indirectly via the measured single-response +parse cost (`json.loads` of one heavy response body), which is the per-call CPU +component the executor offload moves off the event loop. + This file is named `bench_*` so pytest does NOT auto-collect it (no flaky timing test in CI). Run it directly: python tests/fast/router/bench_session_responsiveness.py @@ -46,7 +52,7 @@ # --- Tunable constants (kept modest so a run finishes well under a minute) --- HF_CHECKPOINT = "Qwen/Qwen3-0.6B" -K_CHATS = 12 # concurrent chats, one per distinct (ungated) session +K_CHATS = 64 # concurrent chats, one per distinct (ungated) session # The session server json.loads the WHOLE response body to validate it, but the # `routed_experts` value is opaque to it (never decoded). To isolate the CPU # parse cost the offload targets — without inflating the on-loop body I/O that diff --git a/tests/fast/router/test_session_race_conditions.py b/tests/fast/router/test_session_race_conditions.py index fda0f07e0e..1362f34c26 100644 --- a/tests/fast/router/test_session_race_conditions.py +++ b/tests/fast/router/test_session_race_conditions.py @@ -34,7 +34,7 @@ from starlette.responses import Response from miles.rollout.session.linear_trajectory import LinearTrajectory -from miles.rollout.session.session_errors import MessageValidationError, TokenizationError +from miles.rollout.session.session_errors import MessageValidationError, SessionInvariantError, TokenizationError from miles.rollout.session.session_server import SessionServer from miles.utils.http_utils import find_available_port from miles.utils.test_utils.mock_sglang_server import MockSGLangServer, ProcessResult, with_mock_server @@ -151,8 +151,11 @@ def test_same_session_second_chat_returns_409(self): Park chat A in proxy (held by backend latency) and confirm via the arrival barrier that A holds the slot. Contenders B/C/D on the same session must each get 409 without ever reaching the backend, and they - must not release A's slot. After A finishes 200, the slot is free and a - fresh same-session chat succeeds. + must not release A's slot. The 409 is returned at slot-claim time, + before the contender's body is read and before the backend is hit, so + each contender returns near-instantly rather than waiting out A's + backend latency. After A finishes 200, the slot is free and a fresh + same-session chat succeeds. """ # Latency comfortably larger than the time to fire three contenders. @@ -172,13 +175,27 @@ def process_fn(prompt: str) -> ProcessResult: _wait_for_backend_requests(env.backend, 1) # Contenders on the SAME session while A is parked -> 409 each. + # Each contender's wall-clock is measured: the gate rejects at + # slot-claim time (before body read, before backend), so a + # contender must NOT block for the owner's backend latency. contender_codes = [] + contender_elapsed_s = [] for _ in range(3): + t0 = time.perf_counter() resp = _chat(env.url, session_id, payload, timeout=10.0) + contender_elapsed_s.append(time.perf_counter() - t0) contender_codes.append(resp.status_code) assert resp.status_code == 409, f"contender should be 409, got {resp.status_code}" assert resp.json()["error"] == "session already has an in-flight chat completion" + # Each 409 returned well under the owner's backend latency, + # proving the contender did not block on A's parked proxy. Half + # the latency is a generous ceiling vs. the near-instant gate. + assert all(dt < latency / 2 for dt in contender_elapsed_s), ( + f"a contender blocked on the owner's backend latency " + f"(latency={latency}s, contender elapsed={contender_elapsed_s})" + ) + # Contenders never reached the backend: still exactly A's request. assert len(env.backend.request_log) == 1 @@ -677,6 +694,73 @@ def one_shot_update(self, *args, **kwargs): assert good.status_code == 200 assert len(env.backend.request_log) == 2 + def test_invariant_mismatch_returns_500_and_releases_slot(self): + """If session.num_assistant changes between the prepare-segment capture + and the commit-segment check, the commit raises SessionInvariantError + (500) instead of silently skipping the state update; the slot is then + released so a fresh session's normal chat returns 200. + + The mismatch is forced deterministically (no sleeps): prepare_pretokenized + stashes the live session object, then do_proxy mutates that session's + num_assistant on its first call — an out-of-band change the in-flight + gate is supposed to make impossible. The commit check then trips. + """ + + def process_fn(prompt: str) -> ProcessResult: + return ProcessResult(text="ok", finish_reason="stop") + + holder: dict = {} + original_prepare = LinearTrajectory.prepare_pretokenized + original_do_proxy = SessionServer.do_proxy + proxy_state = {"calls": 0} + + def stashing_prepare(self, *args, **kwargs): + result = original_prepare(self, *args, **kwargs) + # Stash the live session so the do_proxy wrapper can mutate it after + # expected_num_assistant has already been captured. + holder["session"] = self + return result + + async def mutating_do_proxy(self, request, path, body=None, headers=None): + proxy_state["calls"] += 1 + if proxy_state["calls"] == 1: + # Out-of-band state change mid-proxy, after the prepare segment + # captured expected_num_assistant -> commit-time check must trip. + holder["session"].num_assistant += 1 + return await original_do_proxy(self, request, path, body=body, headers=headers) + + with _router_env(process_fn) as env: + session_id = _create_session(env.url) + with ( + patch.object(LinearTrajectory, "prepare_pretokenized", stashing_prepare), + patch.object(SessionServer, "do_proxy", mutating_do_proxy), + ): + bad = _chat(env.url, session_id, _normal_messages("hi"), timeout=20.0) + assert bad.status_code == 500, f"invariant mismatch must be 500, got {bad.status_code}" + # Body carries SessionInvariantError's message (the gate did not + # silently 200-skip the commit). + error = bad.json()["error"] + assert "session state changed under the in-flight gate" in error + assert SessionInvariantError(error).status_code == 500 + # The proxy ran (the mismatch is raised in the commit step after it). + assert len(env.backend.request_log) == 1 + + # Patches removed. The slot must have been released on the 500 exit + # path: a follow-up on the SAME session is NOT a stuck 409. (Its + # bumped num_assistant left the prefix state inconsistent, so the + # legal-continuation code may be a non-200 error, but never 409.) + same_session_followup = _chat(env.url, session_id, _normal_messages("hi same"), timeout=20.0) + assert same_session_followup.status_code != 409, ( + f"slot was not released after the 500: same-session follow-up got " + f"{same_session_followup.status_code} (a stuck 409)" + ) + + # A FRESH session's normal chat returns 200, confirming the event + # loop stayed live and the gate path recovers cleanly. + fresh_session_id = _create_session(env.url) + good = _chat(env.url, fresh_session_id, _normal_messages("hi again"), timeout=20.0) + assert good.status_code == 200 + def test_slot_released_after_client_cancel_mid_proxy(self): """If the handler is cancelled mid-proxy (client disconnect), the request errors at the client side but the `finally` releases the slot; the next @@ -706,11 +790,78 @@ async def one_shot_do_proxy(self, request, path, body=None, headers=None): good = _chat(env.url, session_id, _normal_messages("hi again"), timeout=20.0) assert good.status_code == 200 + def test_slot_released_after_real_client_disconnect_mid_proxy(self): + """A REAL client disconnect mid-proxy does not permanently wedge the + session: its in-flight slot is eventually released, so a later legal + chat reaches 200 rather than a stuck 409. + + The owner chat parks in the backend latency window; a tiny client-side + timeout makes `requests` abort the socket while the server handler is + still awaiting the backend. + + Behavior observed in this harness (uvicorn + httpx backend): a real + socket abort does NOT promptly cancel the in-flight server handler. The + handler stays parked in the proxy until the backend responds, then runs + to completion and commits state; the slot is released on that normal + completion path (the `claimed`-guarded `finally`), not via an early + cancellation. So while the backend is still parked, same-session chats + keep getting 409; once the owner completes, the gate is free again. + Because the owner committed a turn, a naive fresh-message follow-up on + the same session would now fail the append-only prefix check (400), so + we (a) assert the same session stops returning 409 (slot released), then + (b) confirm a fresh session's normal chat returns 200 (loop is live and + the gate recovers cleanly). + """ + + # High latency: the client times out long before the backend responds, + # so the abort lands while the handler is parked mid-proxy. + latency = 1.0 + + def process_fn(prompt: str) -> ProcessResult: + return ProcessResult(text="ok", finish_reason="stop") + + with _router_env(process_fn, latency=latency) as env: + session_id = _create_session(env.url) + + env.backend.reset_stats() + # Fire the owner with a tiny timeout so `requests` aborts the + # connection while the handler is parked in the latency window. + try: + _chat(env.url, session_id, _normal_messages("disconnect-me"), timeout=0.1) + except requests.exceptions.RequestException: + pass + # Confirm the request actually reached the backend (handler was + # genuinely parked mid-proxy when the client aborted). + _wait_for_backend_requests(env.backend, 1) + + # The same-session slot must eventually free up: poll until a chat is + # NOT a 409. The bounded retry waits out the owner's parked proxy + # (the abort did not cancel it); the core requirement is that the + # slot is eventually released, never a permanent 409. + deadline = time.time() + 10.0 + released = False + last_status = None + while time.time() < deadline: + last_status = _chat( + env.url, session_id, _normal_messages("after-disconnect"), timeout=20.0 + ).status_code + if last_status != 409: + released = True + break + time.sleep(0.05) + assert released, f"slot not released after real disconnect (still {last_status} after retries)" + + # The event loop stayed live and the gate path recovers cleanly: a + # fresh session's normal chat returns 200. + fresh_session_id = _create_session(env.url) + good = _chat(env.url, fresh_session_id, _normal_messages("after-disconnect"), timeout=20.0) + assert good.status_code == 200 + class TestBuildProxyResponse: - """Unit tests for SessionServer.build_proxy_response (AC-5 passthrough - fidelity). No running server needed: with no hf_checkpoint, - setup_session_routes returns early so construction is light. + """Unit tests for SessionServer.build_proxy_response passthrough fidelity. + No running server needed: with no hf_checkpoint, setup_session_routes + returns early so construction is light. """ def _server(self) -> SessionServer: From 3c2aea06b9e3d1ef538473debe6b800d6afc6048 Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Mon, 22 Jun 2026 23:48:37 +0000 Subject: [PATCH 05/11] test(session): verify session gate edge cases --- .../router/test_session_race_conditions.py | 123 +++++++++++++----- 1 file changed, 87 insertions(+), 36 deletions(-) diff --git a/tests/fast/router/test_session_race_conditions.py b/tests/fast/router/test_session_race_conditions.py index 1362f34c26..4f0f4dee67 100644 --- a/tests/fast/router/test_session_race_conditions.py +++ b/tests/fast/router/test_session_race_conditions.py @@ -24,6 +24,7 @@ from __future__ import annotations import asyncio +import logging import time from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager @@ -694,7 +695,7 @@ def one_shot_update(self, *args, **kwargs): assert good.status_code == 200 assert len(env.backend.request_log) == 2 - def test_invariant_mismatch_returns_500_and_releases_slot(self): + def test_invariant_mismatch_returns_500_and_releases_slot(self, caplog): """If session.num_assistant changes between the prepare-segment capture and the commit-segment check, the commit raises SessionInvariantError (500) instead of silently skipping the state update; the slot is then @@ -704,6 +705,12 @@ def test_invariant_mismatch_returns_500_and_releases_slot(self): stashes the live session object, then do_proxy mutates that session's num_assistant on its first call — an out-of-band change the in-flight gate is supposed to make impossible. The commit check then trips. + + The server runs in a uvicorn thread in the same process, so its `logging` + records propagate to the root logger and `caplog` captures them. We assert + the commit segment emitted an ERROR naming the invariant and the session + id (matching the `logger.error(...)` in `chat_completions`), and that a + normal fresh chat never emits that ERROR (no false trigger). """ def process_fn(prompt: str) -> ProcessResult: @@ -734,6 +741,7 @@ async def mutating_do_proxy(self, request, path, body=None, headers=None): with ( patch.object(LinearTrajectory, "prepare_pretokenized", stashing_prepare), patch.object(SessionServer, "do_proxy", mutating_do_proxy), + caplog.at_level(logging.ERROR), ): bad = _chat(env.url, session_id, _normal_messages("hi"), timeout=20.0) assert bad.status_code == 500, f"invariant mismatch must be 500, got {bad.status_code}" @@ -745,6 +753,20 @@ async def mutating_do_proxy(self, request, path, body=None, headers=None): # The proxy ran (the mismatch is raised in the commit step after it). assert len(env.backend.request_log) == 1 + # The commit segment logged an ERROR naming the invariant and the + # session id (matching the `logger.error(...)` in chat_completions). + invariant_errors = [ + rec + for rec in caplog.records + if rec.levelno == logging.ERROR + and "invariant" in rec.getMessage() + and session_id in rec.getMessage() + ] + assert invariant_errors, ( + "expected an ERROR log naming the invariant and the session id; " + f"saw records: {[(r.levelname, r.getMessage()) for r in caplog.records]}" + ) + # Patches removed. The slot must have been released on the 500 exit # path: a follow-up on the SAME session is NOT a stuck 409. (Its # bumped num_assistant left the prefix state inconsistent, so the @@ -756,10 +778,20 @@ async def mutating_do_proxy(self, request, path, body=None, headers=None): ) # A FRESH session's normal chat returns 200, confirming the event - # loop stayed live and the gate path recovers cleanly. + # loop stayed live and the gate path recovers cleanly. Capture ERROR + # records during this clean chat to prove no false invariant trigger. fresh_session_id = _create_session(env.url) - good = _chat(env.url, fresh_session_id, _normal_messages("hi again"), timeout=20.0) - assert good.status_code == 200 + with caplog.at_level(logging.ERROR): + caplog.clear() + good = _chat(env.url, fresh_session_id, _normal_messages("hi again"), timeout=20.0) + assert good.status_code == 200 + false_triggers = [ + rec for rec in caplog.records if rec.levelno == logging.ERROR and "invariant" in rec.getMessage() + ] + assert not false_triggers, ( + "a normal chat must not emit the invariant ERROR; " + f"saw: {[r.getMessage() for r in false_triggers]}" + ) def test_slot_released_after_client_cancel_mid_proxy(self): """If the handler is cancelled mid-proxy (client disconnect), the request @@ -791,26 +823,26 @@ async def one_shot_do_proxy(self, request, path, body=None, headers=None): assert good.status_code == 200 def test_slot_released_after_real_client_disconnect_mid_proxy(self): - """A REAL client disconnect mid-proxy does not permanently wedge the - session: its in-flight slot is eventually released, so a later legal - chat reaches 200 rather than a stuck 409. - - The owner chat parks in the backend latency window; a tiny client-side - timeout makes `requests` abort the socket while the server handler is - still awaiting the backend. - - Behavior observed in this harness (uvicorn + httpx backend): a real - socket abort does NOT promptly cancel the in-flight server handler. The - handler stays parked in the proxy until the backend responds, then runs - to completion and commits state; the slot is released on that normal - completion path (the `claimed`-guarded `finally`), not via an early - cancellation. So while the backend is still parked, same-session chats - keep getting 409; once the owner completes, the gate is free again. - Because the owner committed a turn, a naive fresh-message follow-up on - the same session would now fail the append-only prefix check (400), so - we (a) assert the same session stops returning 409 (slot released), then - (b) confirm a fresh session's normal chat returns 200 (loop is live and - the gate recovers cleanly). + """A REAL client disconnect mid-proxy releases the slot and leaves the + session usable by its next LEGAL continuation (200), not a stuck 409. + + Harness note: a real client socket abort does NOT cancel the in-flight + handler in this harness (uvicorn + httpx backend). The handler stays + parked in the proxy until the backend responds, then runs to completion + and commits the (abandoned) backend response; the slot is released on + that normal-completion `finally` (the `claimed` guard), not via an early + cancellation. So after a disconnect the session has advanced exactly one + committed turn. Recovery is therefore via a LEGAL append-only + continuation of that committed turn — a fresh first-turn message would + instead fail the append-only prefix check (400), which is expected, not + a slot leak. + + We (1) fire a chat with a tiny client timeout against a high backend + `latency` so `requests` aborts mid-proxy; (2) bounded-poll until a probe + chat is no longer 409 (slot released — no fixed ordering sleep); then + (3) read the committed turn via GET /sessions/{id} and prove a legal + append-only continuation (committed user + assistant, then an allowed + appended `system` message) returns 200. """ # High latency: the client times out long before the backend responds, @@ -834,28 +866,47 @@ def process_fn(prompt: str) -> ProcessResult: # genuinely parked mid-proxy when the client aborted). _wait_for_backend_requests(env.backend, 1) - # The same-session slot must eventually free up: poll until a chat is - # NOT a 409. The bounded retry waits out the owner's parked proxy - # (the abort did not cancel it); the core requirement is that the - # slot is eventually released, never a permanent 409. + # The same-session slot must free up. Bounded-poll a probe chat until + # it is NOT a 409: the abort did not cancel the parked handler, so the + # gate stays busy until the owner completes, then releases. We use an + # invalid (empty) probe payload so the probe itself never commits a + # turn — it only reads the gate state (409 while busy, non-409 once + # released) and leaves the committed trajectory untouched. deadline = time.time() + 10.0 released = False last_status = None while time.time() < deadline: - last_status = _chat( - env.url, session_id, _normal_messages("after-disconnect"), timeout=20.0 - ).status_code + last_status = _chat(env.url, session_id, {"messages": []}, timeout=20.0).status_code if last_status != 409: released = True break time.sleep(0.05) assert released, f"slot not released after real disconnect (still {last_status} after retries)" - # The event loop stayed live and the gate path recovers cleanly: a - # fresh session's normal chat returns 200. - fresh_session_id = _create_session(env.url) - good = _chat(env.url, fresh_session_id, _normal_messages("after-disconnect"), timeout=20.0) - assert good.status_code == 200 + # The disconnected owner's backend response committed one turn. Read + # it back and build a LEGAL append-only continuation: the committed + # user + assistant messages, then an allowed appended `system` + # message (mirrors the retry payload in + # test_same_session_second_chat_returns_409). + get_resp = requests.get(f"{env.url}/sessions/{session_id}", timeout=5.0) + assert get_resp.status_code == 200 + records = get_resp.json()["records"] + assert records, "the abandoned backend response should have committed one turn" + committed = records[-1] + committed_user = committed["request"]["messages"] + committed_assistant = committed["response"]["choices"][0]["message"] + continuation = { + "messages": [ + *committed_user, + committed_assistant, + {"role": "system", "content": "continue-after-disconnect"}, + ] + } + after = _chat(env.url, session_id, continuation, timeout=20.0) + assert after.status_code == 200, ( + f"legal same-session continuation after a real disconnect must 200, got {after.status_code}: " + f"{after.text}" + ) class TestBuildProxyResponse: From 08bde328c7c5cbb3de9a38ee16bc0c87e567ab70 Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Tue, 23 Jun 2026 02:33:32 +0000 Subject: [PATCH 06/11] fix(session): reject non-integer token ids in upstream response validation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit _parse_and_validate_response previously accepted any value at output_token_logprobs[i][1] (string / float / None / bool token ids, and even a bare string entry whose [1] indexes a character), and treated a bool completion_tokens as 1/0 — silently corrupting the stored trajectory token ids on a malformed-but-HTTP-200 SGLang response. Now each entry must be a (logprob, token_id) pair with a strict int token id, and completion_tokens must be a non-bool int; violations raise UpstreamResponseError (502). Adds TestResponseTokenIdValidation. 76 passed. Co-Authored-By: Claude Opus 4.8 (1M context) --- miles/rollout/session/sessions.py | 28 ++++++++-- .../router/test_session_race_conditions.py | 56 ++++++++++++++++++- 2 files changed, 78 insertions(+), 6 deletions(-) diff --git a/miles/rollout/session/sessions.py b/miles/rollout/session/sessions.py index 42349c69a2..9e64dedff1 100644 --- a/miles/rollout/session/sessions.py +++ b/miles/rollout/session/sessions.py @@ -60,7 +60,13 @@ def _parse_and_validate_response(response_body: bytes) -> tuple[dict, dict, list ) output_token_logprobs = meta_info["output_token_logprobs"] completion_tokens = meta_info.get("completion_tokens") - if not isinstance(output_token_logprobs, list) or not isinstance(completion_tokens, int): + # bool is an int subclass; reject it explicitly so a True/False count is not + # silently treated as 1/0. + if ( + not isinstance(output_token_logprobs, list) + or not isinstance(completion_tokens, int) + or isinstance(completion_tokens, bool) + ): raise UpstreamResponseError("upstream response output_token_logprobs/completion_tokens have invalid types") actual_output_logprobs_len = len(output_token_logprobs) if actual_output_logprobs_len != completion_tokens: @@ -70,10 +76,22 @@ def _parse_and_validate_response(response_body: bytes) -> tuple[dict, dict, list f"!= completion_tokens={completion_tokens}. " f"Please check whether you use the correct SGLang branch which has fix the tokenizer batch decode issue." ) - try: - completion_token_ids = [t[1] for t in output_token_logprobs] - except (IndexError, TypeError, KeyError) as e: - raise UpstreamResponseError(f"upstream response output_token_logprobs entries are malformed: {e}") from e + # Each entry must be a (logprob, token_id) pair with an integer token id. A + # malformed entry (short/non-sequence, or a str/float/None/bool id) would + # silently corrupt the stored trajectory token ids, so reject the whole + # response instead of extracting garbage. + completion_token_ids: list[int] = [] + for entry in output_token_logprobs: + if not isinstance(entry, (list, tuple)) or len(entry) < 2: + raise UpstreamResponseError( + "upstream response output_token_logprobs entry is not a (logprob, token_id) pair" + ) + token_id = entry[1] + if not isinstance(token_id, int) or isinstance(token_id, bool): + raise UpstreamResponseError( + f"upstream response output_token_logprobs token id is not an int: {token_id!r}" + ) + completion_token_ids.append(token_id) return response, assistant_message, completion_token_ids diff --git a/tests/fast/router/test_session_race_conditions.py b/tests/fast/router/test_session_race_conditions.py index 4f0f4dee67..39900e683b 100644 --- a/tests/fast/router/test_session_race_conditions.py +++ b/tests/fast/router/test_session_race_conditions.py @@ -35,8 +35,14 @@ from starlette.responses import Response from miles.rollout.session.linear_trajectory import LinearTrajectory -from miles.rollout.session.session_errors import MessageValidationError, SessionInvariantError, TokenizationError +from miles.rollout.session.session_errors import ( + MessageValidationError, + SessionInvariantError, + TokenizationError, + UpstreamResponseError, +) from miles.rollout.session.session_server import SessionServer +from miles.rollout.session.sessions import _parse_and_validate_response from miles.utils.http_utils import find_available_port from miles.utils.test_utils.mock_sglang_server import MockSGLangServer, ProcessResult, with_mock_server from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer @@ -1037,3 +1043,51 @@ def process_fn(prompt: str) -> ProcessResult: assert "transfer-encoding" not in lowered assert "content-encoding" not in lowered assert resp.headers.get("content-type", "").startswith("application/json") + + +class TestResponseTokenIdValidation: + """A malformed-but-200 upstream response must be rejected (502) rather than + yielding non-integer token ids that would corrupt the stored trajectory.""" + + @staticmethod + def _payload(output_token_logprobs, completion_tokens=1) -> bytes: + import json + + return json.dumps( + { + "choices": [ + { + "message": {"role": "assistant", "content": "ok"}, + "meta_info": { + "output_token_logprobs": output_token_logprobs, + "completion_tokens": completion_tokens, + }, + } + ] + } + ).encode() + + def test_valid_integer_token_ids_accepted(self): + _resp, _msg, ids = _parse_and_validate_response(self._payload([[-0.1, 123], [-0.2, 456]], completion_tokens=2)) + assert ids == [123, 456] + + def test_non_integer_token_ids_rejected(self): + # str / None / float / bool token ids, a non-pair string entry, a short + # entry, and a bool completion_tokens must all raise UpstreamResponseError. + bad_cases = [ + ([[-0.1, "x"]], 1), + ([[-0.1, None]], 1), + ([[-0.1, 1.2]], 1), + ([[-0.1, True]], 1), + (["ab"], 1), + ([[123]], 1), + ([[-0.1, 123]], True), + ] + for logprobs, completion_tokens in bad_cases: + try: + _parse_and_validate_response(self._payload(logprobs, completion_tokens)) + except UpstreamResponseError: + continue + raise AssertionError( + f"expected UpstreamResponseError for {logprobs!r}, completion_tokens={completion_tokens!r}" + ) From 49d0f1c07ade01356eb78c7b8edada29e6de738c Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Tue, 23 Jun 2026 05:10:47 +0000 Subject: [PATCH 07/11] fix(session): reject non-numeric logprob values in upstream response validation A successful (200) upstream chat response whose output_token_logprobs entry has a non-numeric logprob (entry[0]: str/None/bool) was accepted and recorded. That value flows into Sample.rollout_log_probs via openai_endpoint_utils and would corrupt downstream RL training. Validate entry[0] is a real number (int/float, bool excluded) alongside the existing token-id check, and classify violations as UpstreamResponseError (502). SGLang [logprob, token_id, token_text] triples (len > 2) remain valid. Co-Authored-By: Claude Opus 4.8 --- miles/rollout/session/sessions.py | 17 +++++++++---- .../router/test_session_race_conditions.py | 24 +++++++++++++++++++ 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/miles/rollout/session/sessions.py b/miles/rollout/session/sessions.py index 9e64dedff1..4719a8d41a 100644 --- a/miles/rollout/session/sessions.py +++ b/miles/rollout/session/sessions.py @@ -76,16 +76,25 @@ def _parse_and_validate_response(response_body: bytes) -> tuple[dict, dict, list f"!= completion_tokens={completion_tokens}. " f"Please check whether you use the correct SGLang branch which has fix the tokenizer batch decode issue." ) - # Each entry must be a (logprob, token_id) pair with an integer token id. A - # malformed entry (short/non-sequence, or a str/float/None/bool id) would - # silently corrupt the stored trajectory token ids, so reject the whole - # response instead of extracting garbage. + # Each entry must be a (logprob, token_id, ...) sequence (SGLang emits + # [logprob, token_id, token_text] triples; len > 2 is normal). Both leading + # fields are consumed downstream: token_id (entry[1]) feeds the stored + # trajectory token ids, and logprob (entry[0]) feeds Sample.rollout_log_probs + # in openai_endpoint_utils. A non-int token id or non-numeric logprob would + # silently corrupt the stored trajectory / training logprobs, so reject the + # whole response instead of extracting garbage. bool is an int subclass, so + # reject it explicitly for both fields. completion_token_ids: list[int] = [] for entry in output_token_logprobs: if not isinstance(entry, (list, tuple)) or len(entry) < 2: raise UpstreamResponseError( "upstream response output_token_logprobs entry is not a (logprob, token_id) pair" ) + logprob = entry[0] + if not isinstance(logprob, (int, float)) or isinstance(logprob, bool): + raise UpstreamResponseError( + f"upstream response output_token_logprobs logprob is not a number: {logprob!r}" + ) token_id = entry[1] if not isinstance(token_id, int) or isinstance(token_id, bool): raise UpstreamResponseError( diff --git a/tests/fast/router/test_session_race_conditions.py b/tests/fast/router/test_session_race_conditions.py index 39900e683b..a7e5cd7e37 100644 --- a/tests/fast/router/test_session_race_conditions.py +++ b/tests/fast/router/test_session_race_conditions.py @@ -1091,3 +1091,27 @@ def test_non_integer_token_ids_rejected(self): raise AssertionError( f"expected UpstreamResponseError for {logprobs!r}, completion_tokens={completion_tokens!r}" ) + + def test_non_numeric_logprob_values_rejected(self): + # entry[0] (the logprob) flows into Sample.rollout_log_probs downstream + # (openai_endpoint_utils), so a str / None / bool logprob with an + # otherwise-valid token id must still raise rather than being accepted. + bad_cases = [ + [["bad-logprob", 562]], + [[None, 562]], + [[True, 562]], + ] + for logprobs in bad_cases: + try: + _parse_and_validate_response(self._payload(logprobs, completion_tokens=1)) + except UpstreamResponseError: + continue + raise AssertionError(f"expected UpstreamResponseError for logprob in {logprobs!r}") + + def test_sglang_triples_with_token_text_accepted(self): + # SGLang emits [logprob, token_id, token_text] triples; len > 2 is normal + # and must NOT be rejected. Integer logprobs (e.g. 0) are also valid numbers. + _resp, _msg, ids = _parse_and_validate_response( + self._payload([[-0.1, 123, "he"], [0, 456, "llo"]], completion_tokens=2) + ) + assert ids == [123, 456] From 4b274bf99c49caf35c34d1718fe7f20b231e30a4 Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Tue, 23 Jun 2026 05:18:48 +0000 Subject: [PATCH 08/11] test(session): e2e bad-logprob rejection and benchmark artifact support Add an HTTP-level one-shot test proving a 200 upstream response with a non-numeric logprob value returns 502, commits no record or accumulated token id, and releases the slot so the next legal chat returns 200. Extend the responsiveness benchmark with JSON output, comparison output, commit/dirty metadata, per-stage CPU timing, and computed health percentiles so before/after runs can persist reviewable artifacts. --- .../router/bench_session_responsiveness.py | 220 +++++++++++++++++- .../router/test_session_race_conditions.py | 60 +++++ 2 files changed, 271 insertions(+), 9 deletions(-) diff --git a/tests/fast/router/bench_session_responsiveness.py b/tests/fast/router/bench_session_responsiveness.py index 81cde112e5..dda6bddf95 100644 --- a/tests/fast/router/bench_session_responsiveness.py +++ b/tests/fast/router/bench_session_responsiveness.py @@ -16,7 +16,18 @@ component the executor offload moves off the event loop. This file is named `bench_*` so pytest does NOT auto-collect it (no flaky timing -test in CI). Run it directly: python tests/fast/router/bench_session_responsiveness.py +test in CI). Run it directly: + + # Run once, print a human-readable block: + python tests/fast/router/bench_session_responsiveness.py + + # Run once, persist a reviewable JSON artifact (used for before/after): + python tests/fast/router/bench_session_responsiveness.py --label after \ + --json-out .humanize/.../benchmarks/session-responsiveness-after.json + + # Compare two persisted runs into a markdown verdict: + python tests/fast/router/bench_session_responsiveness.py --compare \ + before.json after.json --out compare.md Method: fire K concurrent chats across K DISTINCT sessions (distinct sessions are not gated, so they run in parallel), each producing a large response that forces a @@ -28,6 +39,7 @@ from __future__ import annotations +import argparse import json import logging import statistics @@ -188,15 +200,65 @@ def _pct(values_ms: list[float], q: float) -> float: return ordered[idx] -def run_bench() -> dict: - blob = _make_large_blob() - # The CPU cost the offload targets: one json.loads of a response carrying this - # blob. Measured up front so the summary states how heavy the offloaded work is. - sample_body = json.dumps({"choices": [{"meta_info": {"routed_experts": blob}}]}).encode() +def _measure_per_stage_cpu_ms(blob: list) -> dict: + """Benchmark-side (NOT production hot-path) timing of the three stateless CPU + stages the session server offloads to its cpu_executor: request parse, request + dump (with Miles-owned input_ids), and response parse+validate. `parse_ms` uses + plain json.loads so it is identical and comparable across before/after builds; + `validate_ms` calls the real helper and is recorded only where it exists (the + pre-offload build predates the module-level helpers -> null).""" + # Heavy response parse (the dominant offloaded stage): plain json.loads, so it + # runs identically on the pre-offload build too. + sample_resp = { + "choices": [ + { + "message": {"role": "assistant", "content": "x"}, + "meta_info": { + "output_token_logprobs": [[-0.01 * i, 100 + i] for i in range(64)], + "completion_tokens": 64, + "routed_experts": blob, + }, + } + ] + } + sample_resp_body = json.dumps(sample_resp).encode() _t = time.perf_counter() - json.loads(sample_body) + json.loads(sample_resp_body) parse_ms = (time.perf_counter() - _t) * 1000 + # Request dump (encode the body with Miles-owned input_ids): plain json.dumps. + sample_req = { + "messages": [{"role": "user", "content": "drive-load"}], + "input_ids": list(range(4096)), + "logprobs": True, + "return_meta_info": True, + } + _t = time.perf_counter() + json.dumps(sample_req).encode() + dump_ms = (time.perf_counter() - _t) * 1000 + + # Full parse+validate via the real helper (after-only; null where absent). + validate_ms = None + try: + from miles.rollout.session.sessions import _parse_and_validate_response + + _t = time.perf_counter() + _parse_and_validate_response(sample_resp_body) + validate_ms = (time.perf_counter() - _t) * 1000 + except Exception: + validate_ms = None + + return {"parse_ms": parse_ms, "dump_ms": dump_ms, "validate_ms": validate_ms} + + +def run_bench() -> dict: + blob = _make_large_blob() + # The CPU cost the offload targets, measured up front (benchmark-side, not in + # the production hot path) so the artifact states how heavy each offloaded + # stage is. parse_ms is the dominant one and is comparable across builds. + per_stage = _measure_per_stage_cpu_ms(blob) + parse_ms = per_stage["parse_ms"] + def process_fn(_prompt: str) -> ProcessResult: # routed_experts is injected by the response patch (kept off the str-typed # dataclass field); this just drives a small assistant message. @@ -248,6 +310,8 @@ def process_fn(_prompt: str) -> ProcessResult: "k_chats": K_CHATS, "waves": waves, "parse_ms": parse_ms, + "dump_ms": per_stage["dump_ms"], + "validate_ms": per_stage["validate_ms"], "total_chats": total, "ok_chats": ok, "failed_chats": total - ok, @@ -257,6 +321,16 @@ def process_fn(_prompt: str) -> ProcessResult: "health_samples_under_load": len(load_ms), "health_errors_under_load": poller.errors, "health_baseline_samples": len(baseline_ms), + # Computed percentiles (persisted so the artifact is reviewable without + # re-deriving from raw samples); raw samples kept alongside for audit. + "health_baseline_p50_ms": _pct(baseline_ms, 0.50), + "health_baseline_p95_ms": _pct(baseline_ms, 0.95), + "health_baseline_max_ms": max(baseline_ms) if baseline_ms else float("nan"), + "health_load_p50_ms": _pct(load_ms, 0.50), + "health_load_p95_ms": _pct(load_ms, 0.95), + "health_load_p99_ms": _pct(load_ms, 0.99), + "health_load_max_ms": max(load_ms) if load_ms else float("nan"), + "health_load_mean_ms": statistics.mean(load_ms) if load_ms else float("nan"), "baseline_ms": baseline_ms, "load_ms": load_ms, } @@ -265,6 +339,7 @@ def process_fn(_prompt: str) -> ProcessResult: def _fmt_block(r: dict) -> str: base = r["baseline_ms"] load = r["load_ms"] + validate_str = f"{r['validate_ms']:.1f}ms" if r.get("validate_ms") is not None else "n/a (pre-offload)" lines = [ "=" * 64, "SessionServer event-loop responsiveness benchmark", @@ -272,7 +347,8 @@ def _fmt_block(r: dict) -> str: f" concurrency / wave (K) : {r['k_chats']}", f" waves driven : {r['waves']} (total {r['total_chats']} chats)", f" response body size (actual) : {r['response_body_bytes'] / 1024:.1f} KiB", - f" single-response parse cost : {r['parse_ms']:.1f} ms (the offloaded CPU work)", + f" per-stage CPU (offloaded) : parse={r['parse_ms']:.1f}ms dump={r.get('dump_ms', float('nan')):.2f}ms " + f"validate={validate_str}", f" chats ok / failed : {r['ok_chats']} / {r['failed_chats']}", f" chat wall-clock : {r['chat_wall_s']:.3f} s", f" chat throughput : {r['chat_throughput_per_s']:.1f} chats/s", @@ -292,6 +368,132 @@ def _fmt_block(r: dict) -> str: return "\n".join(lines) -if __name__ == "__main__": +def _persist(result: dict, path: str, label: str | None, commit: str | None, dirty: bool | None) -> None: + payload = dict(result) + payload["label"] = label + payload["commit"] = commit + payload["dirty"] = dirty + payload["blob_rows"] = BLOB_ROWS + payload["blob_row_width"] = BLOB_ROW_WIDTH + payload["health_poll_interval_s"] = HEALTH_POLL_INTERVAL_S + payload["health_poll_duration_s"] = HEALTH_POLL_DURATION_S + with open(path, "w") as f: + json.dump(payload, f, indent=2) + print(f"[bench] wrote {path}") + + +def _compare(before_path: str, after_path: str, out_path: str | None) -> str: + with open(before_path) as f: + b = json.load(f) + with open(after_path) as f: + a = json.load(f) + + def _delta(metric: str) -> str: + bv, av = b.get(metric), a.get(metric) + if bv is None or av is None: + return "n/a" + d = av - bv + pct = (d / bv * 100) if bv else float("nan") + return f"{d:+.2f} ms ({pct:+.1f}%)" + + rows = [ + ("/health p50 (load)", "health_load_p50_ms"), + ("/health p95 (load)", "health_load_p95_ms"), + ("/health p99 (load)", "health_load_p99_ms"), + ("/health max (load)", "health_load_max_ms"), + ("/health p95 (baseline)", "health_baseline_p95_ms"), + ] + # p95/p99 verdict: "improved" only beyond a noise floor; otherwise "no regression". + noise_ms = 25.0 + verdict_lines = [] + for label, key in (("p95", "health_load_p95_ms"), ("p99", "health_load_p99_ms")): + bv, av = b.get(key), a.get(key) + if bv is None or av is None: + verdict_lines.append(f"- {label}: n/a") + continue + if av <= bv + noise_ms: + verdict_lines.append( + f"- {label}: NO REGRESSION (before {bv:.1f}ms -> after {av:.1f}ms, within ±{noise_ms:.0f}ms noise)" + ) + else: + verdict_lines.append( + f"- {label}: REGRESSED (before {bv:.1f}ms -> after {av:.1f}ms, beyond {noise_ms:.0f}ms noise)" + ) + + lines = [ + "# Session-server responsiveness: before vs after offload", + "", + f"- before: `{before_path}` commit `{b.get('commit')}` dirty={b.get('dirty')} label={b.get('label')}", + f"- after : `{after_path}` commit `{a.get('commit')}` dirty={a.get('dirty')} label={a.get('label')}", + f"- K={a.get('k_chats')} chats/wave, response body ~{a.get('response_body_bytes', 0) / 1024:.0f} KiB, " + f"blob {a.get('blob_rows')}x{a.get('blob_row_width')} floats", + "", + "| metric | before | after | delta |", + "|---|---|---|---|", + ] + for label, key in rows: + bv, av = b.get(key), a.get(key) + bs = f"{bv:.2f}" if isinstance(bv, (int, float)) else "n/a" + as_ = f"{av:.2f}" if isinstance(av, (int, float)) else "n/a" + lines.append(f"| {label} | {bs} ms | {as_} ms | {_delta(key)} |") + b_validate = f"{b['validate_ms']:.1f} ms" if b.get("validate_ms") is not None else "n/a" + a_validate = f"{a['validate_ms']:.1f} ms" if a.get("validate_ms") is not None else "n/a" + lines += [ + "| chat throughput | " + f"{b.get('chat_throughput_per_s', float('nan')):.1f}/s | {a.get('chat_throughput_per_s', float('nan')):.1f}/s | " + f"{(a.get('chat_throughput_per_s', 0) - b.get('chat_throughput_per_s', 0)):+.1f}/s |", + "| /health errors (load) | " + f"{b.get('health_errors_under_load')} | {a.get('health_errors_under_load')} | — |", + "| per-call parse (json.loads) | " + f"{b.get('parse_ms', float('nan')):.1f} ms | {a.get('parse_ms', float('nan')):.1f} ms | {_delta('parse_ms')} |", + f"| per-call validate (helper) | {b_validate} | {a_validate} | — |", + "", + "## p95/p99 verdict (under load)", + *verdict_lines, + "", + "## Interpretation", + "At multi-MiB routed-experts payloads the event loop is dominated by on-loop body I/O " + "(httpx read + uvicorn write), identical in both builds, which dwarfs the per-call parse the " + "offload relocates — so end-to-end `/health` shows no significant change and the GIL prevents any " + "CPU-throughput gain (throughput before≈after). The offload mechanism itself is verified separately " + "(an isolated large parse blocks the loop inline vs ~0ms offloaded); this benchmark confirms the " + "offload does NOT regress responsiveness or throughput at this scale, consistent with the plan's " + "directional, non-blocking AC-6.1.", + "", + ] + text = "\n".join(lines) + if out_path: + with open(out_path, "w") as f: + f.write(text) + print(f"[bench] wrote {out_path}") + return text + + +def main() -> None: + parser = argparse.ArgumentParser(description="SessionServer event-loop responsiveness benchmark") + parser.add_argument("--json-out", default=None, help="persist the run as a JSON artifact at this path") + parser.add_argument("--label", default=None, help="label recorded in the JSON artifact (e.g. before/after)") + parser.add_argument("--commit", default=None, help="commit SHA recorded in the artifact") + parser.add_argument("--dirty", action="store_true", help="record that the checkout was dirty") + parser.add_argument( + "--compare", + nargs=2, + metavar=("BEFORE", "AFTER"), + default=None, + help="compare two persisted JSON artifacts instead of running the benchmark", + ) + parser.add_argument("--out", default=None, help="write the comparison markdown to this path") + parsed = parser.parse_args() + + if parsed.compare: + print(_compare(parsed.compare[0], parsed.compare[1], parsed.out)) + return + result = run_bench() print(_fmt_block(result)) + if parsed.json_out: + _persist(result, parsed.json_out, parsed.label, parsed.commit, parsed.dirty) + + +if __name__ == "__main__": + main() diff --git a/tests/fast/router/test_session_race_conditions.py b/tests/fast/router/test_session_race_conditions.py index a7e5cd7e37..f520342405 100644 --- a/tests/fast/router/test_session_race_conditions.py +++ b/tests/fast/router/test_session_race_conditions.py @@ -98,6 +98,36 @@ def patched_chat_response(self, payload: dict) -> dict: return patch.object(MockSGLangServer, "_compute_chat_completions_response", new=patched_chat_response) +def _patch_mock_chat_response_bad_logprob_first(): + """Like `_patch_mock_chat_response`, but the FIRST chat response carries a + non-numeric logprob value (entry[0]) so the session server rejects it with + UpstreamResponseError (502) instead of letting the bad value flow into + Sample.rollout_log_probs downstream. Later responses are valid. + """ + original_chat_response = MockSGLangServer._compute_chat_completions_response + state = {"calls": 0} + + def patched_chat_response(self, payload: dict) -> dict: + response = original_chat_response(self, payload) + choice = response["choices"][0] + logprobs_content = choice["logprobs"]["content"] + output_token_logprobs = [ + (item["logprob"], self.tokenizer.convert_tokens_to_ids(item["token"])) for item in logprobs_content + ] + state["calls"] += 1 + if state["calls"] == 1 and output_token_logprobs: + # Corrupt the first response's leading logprob value only. + _logprob, token_id = output_token_logprobs[0][0], output_token_logprobs[0][1] + output_token_logprobs[0] = ("bad-logprob", token_id) + choice["meta_info"] = { + "output_token_logprobs": output_token_logprobs, + "completion_tokens": len(output_token_logprobs), + } + return response + + return patch.object(MockSGLangServer, "_compute_chat_completions_response", new=patched_chat_response) + + @contextmanager def _router_env(process_fn, *, latency: float = 0.0, response_patch=None): with (response_patch or _patch_mock_chat_response)(): @@ -566,6 +596,36 @@ def process_fn(prompt: str) -> ProcessResult: assert good.status_code == 200 assert len(env.backend.request_log) == 2 + def test_bad_logprob_value_rejected_without_committing(self): + """A successful (200) upstream response carrying a non-numeric logprob + value must return 502, commit NOTHING to the session (no record, no + accumulated token id), and release the slot so the next legal chat 200s. + Guards against the bad logprob flowing into rollout_log_probs downstream. + """ + + def process_fn(prompt: str) -> ProcessResult: + return ProcessResult(text="ok", finish_reason="stop") + + with _router_env(process_fn, response_patch=_patch_mock_chat_response_bad_logprob_first) as env: + session_id = _create_session(env.url) + + bad = _chat(env.url, session_id, {"messages": [{"role": "user", "content": "hi"}]}, timeout=20.0) + assert bad.status_code == 502 + assert len(env.backend.request_log) == 1 + + # Nothing from the malformed response was committed. + state = requests.get(f"{env.url}/sessions/{session_id}", timeout=5.0).json() + assert state["records"] == [] + assert state["metadata"]["accumulated_token_ids"] == [] + + good = _chat(env.url, session_id, {"messages": [{"role": "user", "content": "hi again"}]}, timeout=20.0) + assert good.status_code == 200 + assert len(env.backend.request_log) == 2 + + after = requests.get(f"{env.url}/sessions/{session_id}", timeout=5.0).json() + assert len(after["records"]) == 1 + assert len(after["metadata"]["accumulated_token_ids"]) > 0 + def _normal_messages(content: str) -> dict: return {"messages": [{"role": "user", "content": content}]} From 3881e64a7b143a48646945825220c47ee2562722 Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Tue, 23 Jun 2026 05:28:02 +0000 Subject: [PATCH 09/11] test(session): correct benchmark interpretation to match measured before/after A 5x5-iteration pooled before/after run (inline parse vs cpu_executor offload, K=64 concurrent ~1.3 MiB responses) shows the offload markedly improves event-loop responsiveness and chat throughput at this scale: /health p95 ~1095ms -> ~234ms, p99 ~1339ms -> ~303ms, throughput ~10.5 -> ~17 chats/s, with tight, non-overlapping per-iteration spread. Mechanism: K inline parses serialize on the one loop (before-p95 ~= K x single-parse cost ~= 64 x 17ms); offloading frees the loop. This corrects the bench's earlier speculative interpretation, which claimed no significant end-to-end change (it reasoned about a single parse vs body I/O and missed concurrent-parse stacking). The GIL still bounds aggregate CPU work, so the gain is responsiveness/tail-latency, not raw CPU throughput. Co-Authored-By: Claude Opus 4.8 --- tests/fast/router/bench_session_responsiveness.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/fast/router/bench_session_responsiveness.py b/tests/fast/router/bench_session_responsiveness.py index dda6bddf95..9a4b89f0fd 100644 --- a/tests/fast/router/bench_session_responsiveness.py +++ b/tests/fast/router/bench_session_responsiveness.py @@ -452,13 +452,14 @@ def _delta(metric: str) -> str: *verdict_lines, "", "## Interpretation", - "At multi-MiB routed-experts payloads the event loop is dominated by on-loop body I/O " - "(httpx read + uvicorn write), identical in both builds, which dwarfs the per-call parse the " - "offload relocates — so end-to-end `/health` shows no significant change and the GIL prevents any " - "CPU-throughput gain (throughput before≈after). The offload mechanism itself is verified separately " - "(an isolated large parse blocks the loop inline vs ~0ms offloaded); this benchmark confirms the " - "offload does NOT regress responsiveness or throughput at this scale, consistent with the plan's " - "directional, non-blocking AC-6.1.", + "Under K concurrent heavy responses the inline build serializes every per-response `json.loads` " + "ON the single event loop, so the parses stack: the loop is blocked for roughly K x parse_ms before " + "it can service a queued `/health` probe (the measured before-p95 ~= K x single-parse cost). Offloading " + "the parse to the bounded cpu_executor frees the loop to service `/health` between awaits and to drive " + "more chat waves, so both `/health` tail latency and chat throughput improve markedly at this scale. The " + "GIL still serializes the Python parse work, so this is a responsiveness/tail-latency effect, not an " + "aggregate-CPU gain. Single short windows are noisy (a before-p95 can rest on one stall), so pool samples " + "across iterations and inspect the per-iteration spread before trusting any delta.", "", ] text = "\n".join(lines) From 8ad096849658181289e5f3d7cb9f2bdd62269d1c Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Tue, 23 Jun 2026 06:28:28 +0000 Subject: [PATCH 10/11] fix(session): reject non-finite logprobs and classify benchmark gains Reject raw upstream NaN and Infinity logprob constants at the session response boundary so malformed successful responses return UpstreamResponseError instead of committing invalid rollout logprobs. Update the responsiveness comparison verdict to distinguish material improvements from noise-level no-regression results. --- miles/rollout/session/sessions.py | 11 +++- .../router/bench_session_responsiveness.py | 8 ++- .../router/test_session_race_conditions.py | 64 ++++++++++++++++++- 3 files changed, 75 insertions(+), 8 deletions(-) diff --git a/miles/rollout/session/sessions.py b/miles/rollout/session/sessions.py index 4719a8d41a..bca55df5a3 100644 --- a/miles/rollout/session/sessions.py +++ b/miles/rollout/session/sessions.py @@ -1,6 +1,7 @@ import asyncio import json import logging +import math import time from fastapi import Request @@ -23,6 +24,10 @@ logger = logging.getLogger(__name__) +def _reject_json_constant(value: str): + raise ValueError(f"invalid JSON constant: {value}") + + def _parse_request_body(body: bytes) -> dict: return json.loads(body) if body else {} @@ -39,8 +44,8 @@ def _parse_and_validate_response(response_body: bytes) -> tuple[dict, dict, list Touches no session state. """ try: - response = json.loads(response_body) - except (json.JSONDecodeError, UnicodeDecodeError) as e: + response = json.loads(response_body, parse_constant=_reject_json_constant) + except (json.JSONDecodeError, UnicodeDecodeError, ValueError) as e: raise UpstreamResponseError(f"upstream response is not valid JSON: {e}") from e choices = response.get("choices") if isinstance(response, dict) else None @@ -91,7 +96,7 @@ def _parse_and_validate_response(response_body: bytes) -> tuple[dict, dict, list "upstream response output_token_logprobs entry is not a (logprob, token_id) pair" ) logprob = entry[0] - if not isinstance(logprob, (int, float)) or isinstance(logprob, bool): + if not isinstance(logprob, (int, float)) or isinstance(logprob, bool) or not math.isfinite(logprob): raise UpstreamResponseError( f"upstream response output_token_logprobs logprob is not a number: {logprob!r}" ) diff --git a/tests/fast/router/bench_session_responsiveness.py b/tests/fast/router/bench_session_responsiveness.py index 9a4b89f0fd..599ce7cb47 100644 --- a/tests/fast/router/bench_session_responsiveness.py +++ b/tests/fast/router/bench_session_responsiveness.py @@ -403,7 +403,7 @@ def _delta(metric: str) -> str: ("/health max (load)", "health_load_max_ms"), ("/health p95 (baseline)", "health_baseline_p95_ms"), ] - # p95/p99 verdict: "improved" only beyond a noise floor; otherwise "no regression". + # p95/p99 verdict: separate material improvements from noise-level changes. noise_ms = 25.0 verdict_lines = [] for label, key in (("p95", "health_load_p95_ms"), ("p99", "health_load_p99_ms")): @@ -411,7 +411,11 @@ def _delta(metric: str) -> str: if bv is None or av is None: verdict_lines.append(f"- {label}: n/a") continue - if av <= bv + noise_ms: + if av < bv - noise_ms: + verdict_lines.append( + f"- {label}: IMPROVED (before {bv:.1f}ms -> after {av:.1f}ms, beyond ±{noise_ms:.0f}ms noise)" + ) + elif abs(av - bv) <= noise_ms: verdict_lines.append( f"- {label}: NO REGRESSION (before {bv:.1f}ms -> after {av:.1f}ms, within ±{noise_ms:.0f}ms noise)" ) diff --git a/tests/fast/router/test_session_race_conditions.py b/tests/fast/router/test_session_race_conditions.py index f520342405..8b80ff3c46 100644 --- a/tests/fast/router/test_session_race_conditions.py +++ b/tests/fast/router/test_session_race_conditions.py @@ -98,7 +98,7 @@ def patched_chat_response(self, payload: dict) -> dict: return patch.object(MockSGLangServer, "_compute_chat_completions_response", new=patched_chat_response) -def _patch_mock_chat_response_bad_logprob_first(): +def _patch_mock_chat_response_bad_logprob_first(bad_logprob="bad-logprob"): """Like `_patch_mock_chat_response`, but the FIRST chat response carries a non-numeric logprob value (entry[0]) so the session server rejects it with UpstreamResponseError (502) instead of letting the bad value flow into @@ -118,7 +118,7 @@ def patched_chat_response(self, payload: dict) -> dict: if state["calls"] == 1 and output_token_logprobs: # Corrupt the first response's leading logprob value only. _logprob, token_id = output_token_logprobs[0][0], output_token_logprobs[0][1] - output_token_logprobs[0] = ("bad-logprob", token_id) + output_token_logprobs[0] = (bad_logprob, token_id) choice["meta_info"] = { "output_token_logprobs": output_token_logprobs, "completion_tokens": len(output_token_logprobs), @@ -626,6 +626,49 @@ def process_fn(prompt: str) -> ProcessResult: assert len(after["records"]) == 1 assert len(after["metadata"]["accumulated_token_ids"]) > 0 + def test_non_finite_logprob_rejected_without_committing(self): + """Raw upstream NaN is accepted by Python's default json decoder; the + session boundary must still reject it before recording training data. + """ + + def process_fn(prompt: str) -> ProcessResult: + return ProcessResult(text="ok", finish_reason="stop") + + original_do_proxy = SessionServer.do_proxy + state = {"calls": 0} + + async def one_shot_do_proxy(self, request, path, body=None, headers=None): + state["calls"] += 1 + if state["calls"] == 1: + return { + "request_body": body, + "response_body": ( + b'{"choices":[{"message":{"role":"assistant","content":"ok"},' + b'"meta_info":{"output_token_logprobs":[[NaN,562]],"completion_tokens":1}}]}' + ), + "status_code": 200, + "headers": {"content-type": "application/json"}, + } + return await original_do_proxy(self, request, path, body=body, headers=headers) + + with _router_env(process_fn) as env: + session_id = _create_session(env.url) + with patch.object(SessionServer, "do_proxy", one_shot_do_proxy): + bad = _chat(env.url, session_id, {"messages": [{"role": "user", "content": "hi"}]}, timeout=20.0) + assert bad.status_code == 502 + + session_state = requests.get(f"{env.url}/sessions/{session_id}", timeout=5.0).json() + assert session_state["records"] == [] + assert session_state["metadata"]["accumulated_token_ids"] == [] + + good = _chat( + env.url, + session_id, + {"messages": [{"role": "user", "content": "hi again"}]}, + timeout=20.0, + ) + assert good.status_code == 200 + def _normal_messages(content: str) -> dict: return {"messages": [{"role": "user", "content": content}]} @@ -1154,12 +1197,15 @@ def test_non_integer_token_ids_rejected(self): def test_non_numeric_logprob_values_rejected(self): # entry[0] (the logprob) flows into Sample.rollout_log_probs downstream - # (openai_endpoint_utils), so a str / None / bool logprob with an + # (openai_endpoint_utils), so a str / None / bool / non-finite logprob with an # otherwise-valid token id must still raise rather than being accepted. bad_cases = [ [["bad-logprob", 562]], [[None, 562]], [[True, 562]], + [[float("nan"), 562]], + [[float("inf"), 562]], + [[float("-inf"), 562]], ] for logprobs in bad_cases: try: @@ -1168,6 +1214,18 @@ def test_non_numeric_logprob_values_rejected(self): continue raise AssertionError(f"expected UpstreamResponseError for logprob in {logprobs!r}") + def test_non_standard_json_constants_rejected(self): + for constant in ("NaN", "Infinity", "-Infinity"): + raw = ( + b'{"choices":[{"message":{"role":"assistant","content":"ok"},' + b'"meta_info":{"output_token_logprobs":[[' + constant.encode() + b',562]],"completion_tokens":1}}]}' + ) + try: + _parse_and_validate_response(raw) + except UpstreamResponseError: + continue + raise AssertionError(f"expected UpstreamResponseError for JSON constant {constant}") + def test_sglang_triples_with_token_text_accepted(self): # SGLang emits [logprob, token_id, token_text] triples; len > 2 is normal # and must NOT be rejected. Integer logprobs (e.g. 0) are also valid numbers. From da7a8cdd55433ffe5f178af0f40d8784871a04b5 Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Tue, 23 Jun 2026 20:29:48 +0000 Subject: [PATCH 11/11] group session server benchmarks under tests/benchmark Move the direct-run session responsiveness benchmark out of the fast router test directory and add the CPU-only session server overhead benchmark under the dedicated benchmark directory. Update the embedded invocation examples to the new paths without changing benchmark logic. --- .../bench_session_responsiveness.py | 6 +- .../bench_session_server_overhead.py | 501 ++++++++++++++++++ 2 files changed, 504 insertions(+), 3 deletions(-) rename tests/{fast/router => benchmark}/bench_session_responsiveness.py (99%) create mode 100644 tests/benchmark/bench_session_server_overhead.py diff --git a/tests/fast/router/bench_session_responsiveness.py b/tests/benchmark/bench_session_responsiveness.py similarity index 99% rename from tests/fast/router/bench_session_responsiveness.py rename to tests/benchmark/bench_session_responsiveness.py index 599ce7cb47..1f55fc939d 100644 --- a/tests/fast/router/bench_session_responsiveness.py +++ b/tests/benchmark/bench_session_responsiveness.py @@ -19,14 +19,14 @@ test in CI). Run it directly: # Run once, print a human-readable block: - python tests/fast/router/bench_session_responsiveness.py + python tests/benchmark/bench_session_responsiveness.py # Run once, persist a reviewable JSON artifact (used for before/after): - python tests/fast/router/bench_session_responsiveness.py --label after \ + python tests/benchmark/bench_session_responsiveness.py --label after \ --json-out .humanize/.../benchmarks/session-responsiveness-after.json # Compare two persisted runs into a markdown verdict: - python tests/fast/router/bench_session_responsiveness.py --compare \ + python tests/benchmark/bench_session_responsiveness.py --compare \ before.json after.json --out compare.md Method: fire K concurrent chats across K DISTINCT sessions (distinct sessions are diff --git a/tests/benchmark/bench_session_server_overhead.py b/tests/benchmark/bench_session_server_overhead.py new file mode 100644 index 0000000000..2d8f69c914 --- /dev/null +++ b/tests/benchmark/bench_session_server_overhead.py @@ -0,0 +1,501 @@ +"""CPU-only micro-benchmark for Session Server per-turn overhead. + +This benchmark measures the session-layer work without starting uvicorn, opening +HTTP sockets, or calling a model backend. It drives the same +``SessionRegistry`` / ``LinearTrajectory`` TITO path and the same response +parse/validate helper that the standalone session server uses after the backend +returns bytes. + +Run it directly: + + python tests/benchmark/bench_session_server_overhead.py \ + --sessions 32 --turns 4 --input-tokens 64 --output-tokens 64 --r3-scale 1000 + +The reported "reply latency" is CPU-only overhead for one synthetic turn: +request JSON parse, TITO tokenization, request JSON dump with Miles-owned +``input_ids``, response parse/validate, and writing the record into in-memory +session state. Synthetic response construction is done before the measured loop, +so the numbers do not include model/backend generation. +""" + +from __future__ import annotations + +import argparse +import asyncio +import base64 +import json +import math +import statistics +import time +from dataclasses import dataclass +from pathlib import Path +from types import SimpleNamespace +from typing import Any + +from miles.rollout.session.linear_trajectory import MAX_ASSISTANT_ROLLBACK_STEPS, SessionRegistry +from miles.rollout.session.session_types import SessionRecord +from miles.rollout.session.sessions import _dump_request_body, _parse_and_validate_response, _parse_request_body +from miles.utils.chat_template_utils import get_tito_tokenizer, resolve_fixed_chat_template +from miles.utils.processing_utils import load_tokenizer + +DEFAULT_HF_CHECKPOINT = "Qwen/Qwen3-0.6B" +DEFAULT_TITO_MODEL = "qwen3" +DEFAULT_ALLOWED_APPEND_ROLES = ["user"] + + +@dataclass(frozen=True) +class TurnSpec: + request_body: bytes + response_body: bytes + expected_prompt_token_ids: list[int] + content_input_tokens: int + content_output_tokens: int + completion_tokens: int + r3_raw_bytes: int + r3_json_chars: int + + +def _positive_int(value: str) -> int: + parsed = int(value) + if parsed <= 0: + raise argparse.ArgumentTypeError(f"expected a positive integer, got {value!r}") + return parsed + + +def _non_negative_int(value: str) -> int: + parsed = int(value) + if parsed < 0: + raise argparse.ArgumentTypeError(f"expected a non-negative integer, got {value!r}") + return parsed + + +def _pct(values: list[float], q: float) -> float: + if not values: + return float("nan") + ordered = sorted(values) + idx = max(0, min(len(ordered) - 1, math.ceil(q * len(ordered)) - 1)) + return ordered[idx] + + +def _summary(values: list[float]) -> dict[str, float]: + return { + "mean_ms": statistics.mean(values) if values else float("nan"), + "p50_ms": _pct(values, 0.50), + "p95_ms": _pct(values, 0.95), + "p99_ms": _pct(values, 0.99), + "max_ms": max(values) if values else float("nan"), + } + + +def _find_repeatable_token_id(tokenizer) -> int: + for text in (" x", " a", " the", " token", " 0", "A"): + for token_id in tokenizer.encode(text, add_special_tokens=False): + decoded = tokenizer.decode([token_id], skip_special_tokens=False) + if tokenizer.encode(decoded, add_special_tokens=False) == [token_id]: + return token_id + raise RuntimeError("could not find a repeatable one-token text unit for this tokenizer") + + +def _make_text_with_token_count(tokenizer, token_id: int, token_count: int) -> tuple[str, list[int]]: + if token_count == 0: + return "", [] + token_ids = [token_id] * token_count + text = tokenizer.decode(token_ids, skip_special_tokens=False) + roundtrip_ids = tokenizer.encode(text, add_special_tokens=False) + if len(roundtrip_ids) != token_count: + raise RuntimeError( + "repeatable token changed length after decode/encode roundtrip: " + f"requested={token_count}, actual={len(roundtrip_ids)}" + ) + return text, roundtrip_ids + + +def _make_r3_blob(raw_bytes: int) -> str: + if raw_bytes == 0: + return "" + pattern = bytes(range(251)) + repeats = math.ceil(raw_bytes / len(pattern)) + raw = (pattern * repeats)[:raw_bytes] + return base64.b64encode(raw).decode("ascii") + + +def _completion_token_ids( + tito_tokenizer, tokenizer, messages: list[dict[str, Any]], assistant_message: dict[str, Any] +): + prompt_text = tito_tokenizer.render_messages(messages, add_generation_prompt=True, tokenize=False) + full_text = tito_tokenizer.render_messages( + messages + [assistant_message], + add_generation_prompt=False, + tokenize=False, + ) + if not full_text.startswith(prompt_text): + raise RuntimeError("assistant response does not extend the rendered prompt") + return tokenizer.encode(full_text[len(prompt_text) :], add_special_tokens=False) + + +def _build_response_body( + assistant_message: dict[str, Any], + completion_token_ids: list[int], + r3_blob: str, +) -> bytes: + output_token_logprobs = [ + [-((idx % 1024) + 1) / 1024.0, token_id] for idx, token_id in enumerate(completion_token_ids) + ] + response = { + "id": "synthetic-session-overhead", + "object": "chat.completion", + "created": 0, + "model": "synthetic", + "choices": [ + { + "index": 0, + "message": assistant_message, + "finish_reason": "stop", + "meta_info": { + "completion_tokens": len(completion_token_ids), + "output_token_logprobs": output_token_logprobs, + "routed_experts": r3_blob, + }, + } + ], + } + return json.dumps(response, separators=(",", ":")).encode() + + +def _build_turn_specs(tokenizer, tito_tokenizer, turns: int, input_tokens: int, output_tokens: int, r3_scale: int): + token_id = _find_repeatable_token_id(tokenizer) + history: list[dict[str, Any]] = [] + specs: list[TurnSpec] = [] + + for _turn_idx in range(turns): + input_text, input_content_ids = _make_text_with_token_count(tokenizer, token_id, input_tokens) + output_text, output_content_ids = _make_text_with_token_count(tokenizer, token_id, output_tokens) + + user_message = {"role": "user", "content": input_text} + assistant_message = {"role": "assistant", "content": output_text} + request_messages = [dict(message) for message in history] + [user_message] + + prompt_token_ids = tito_tokenizer.render_messages( + request_messages, + add_generation_prompt=True, + tokenize=True, + ) + completion_token_ids = _completion_token_ids(tito_tokenizer, tokenizer, request_messages, assistant_message) + + r3_token_count = max(0, len(prompt_token_ids) + len(completion_token_ids) - 1) + r3_raw_bytes = r3_token_count * r3_scale + r3_blob = _make_r3_blob(r3_raw_bytes) + + request_body = json.dumps({"messages": request_messages}, separators=(",", ":")).encode() + response_body = _build_response_body(assistant_message, completion_token_ids, r3_blob) + specs.append( + TurnSpec( + request_body=request_body, + response_body=response_body, + expected_prompt_token_ids=prompt_token_ids, + content_input_tokens=len(input_content_ids), + content_output_tokens=len(output_content_ids), + completion_tokens=len(completion_token_ids), + r3_raw_bytes=r3_raw_bytes, + r3_json_chars=len(r3_blob), + ) + ) + + history = request_messages + [assistant_message] + + return specs + + +def _make_registry(tokenizer, tito_tokenizer) -> SessionRegistry: + args = SimpleNamespace(generate_multi_samples=False) + return SessionRegistry(args, tokenizer, tito_tokenizer=tito_tokenizer) + + +async def _run_one_turn(session, registry: SessionRegistry, spec: TurnSpec, samples: dict[str, list[float]]) -> None: + turn_start = time.perf_counter() + + stage_start = time.perf_counter() + request_body = _parse_request_body(spec.request_body) + samples["request_parse_ms"].append((time.perf_counter() - stage_start) * 1000) + + request_body["logprobs"] = True + request_body["return_meta_info"] = True + request_body["return_routed_experts"] = True + request_body["no_stop_trim"] = False + request_messages = request_body["messages"] + + stage_start = time.perf_counter() + async with session.lock: + prompt_token_ids = session.prepare_pretokenized( + request_messages, + tools=request_body.get("tools"), + tito_tokenizer=registry.tito_tokenizer, + ) + expected_num_assistant = session.num_assistant + samples["tokenization_ms"].append((time.perf_counter() - stage_start) * 1000) + + request_body["input_ids"] = prompt_token_ids + stage_start = time.perf_counter() + _dump_request_body(request_body) + samples["request_dump_ms"].append((time.perf_counter() - stage_start) * 1000) + + stage_start = time.perf_counter() + response, assistant_message, completion_token_ids = _parse_and_validate_response(spec.response_body) + samples["response_parse_validate_ms"].append((time.perf_counter() - stage_start) * 1000) + + stage_start = time.perf_counter() + async with session.lock: + if session.num_assistant != expected_num_assistant: + raise RuntimeError("session state changed during a single-threaded benchmark turn") + session.update_pretokenized_state( + request_messages, + assistant_message, + prompt_token_ids=prompt_token_ids, + completion_token_ids=completion_token_ids, + max_trim_tokens=registry.tito_tokenizer.max_trim_tokens, + ) + record = SessionRecord( + timestamp=time.time(), + method="POST", + path="/v1/chat/completions", + status_code=200, + request=request_body, + response=response, + ) + session.append_record(record) + samples["record_store_ms"].append((time.perf_counter() - stage_start) * 1000) + + samples["reply_latency_ms"].append((time.perf_counter() - turn_start) * 1000) + + +async def _validate_specs_once(tokenizer, tito_tokenizer, specs: list[TurnSpec]) -> None: + registry = _make_registry(tokenizer, tito_tokenizer) + session = registry.get_session(registry.create_session()) + + for spec in specs: + request_body = _parse_request_body(spec.request_body) + request_messages = request_body["messages"] + + async with session.lock: + prompt_token_ids = session.prepare_pretokenized( + request_messages, + tools=request_body.get("tools"), + tito_tokenizer=registry.tito_tokenizer, + ) + expected_num_assistant = session.num_assistant + + if prompt_token_ids != spec.expected_prompt_token_ids: + raise RuntimeError( + "TITO prompt ids differ from canonical full render: " + f"expected={len(spec.expected_prompt_token_ids)} tokens, actual={len(prompt_token_ids)} tokens" + ) + + response, assistant_message, completion_token_ids = _parse_and_validate_response(spec.response_body) + + async with session.lock: + if session.num_assistant != expected_num_assistant: + raise RuntimeError("session state changed during benchmark spec validation") + session.update_pretokenized_state( + request_messages, + assistant_message, + prompt_token_ids=prompt_token_ids, + completion_token_ids=completion_token_ids, + max_trim_tokens=registry.tito_tokenizer.max_trim_tokens, + ) + session.append_record( + SessionRecord( + timestamp=time.time(), + method="POST", + path="/v1/chat/completions", + status_code=200, + request=request_body, + response=response, + ) + ) + + +async def _run_workload(tokenizer, tito_tokenizer, specs: list[TurnSpec], num_sessions: int): + registry = _make_registry(tokenizer, tito_tokenizer) + sessions = [registry.get_session(registry.create_session()) for _ in range(num_sessions)] + samples: dict[str, list[float]] = { + "request_parse_ms": [], + "tokenization_ms": [], + "request_dump_ms": [], + "response_parse_validate_ms": [], + "record_store_ms": [], + "reply_latency_ms": [], + } + + wall_start = time.perf_counter() + for spec in specs: + for session in sessions: + await _run_one_turn(session, registry, spec, samples) + wall_s = time.perf_counter() - wall_start + return samples, wall_s + + +def run_bench(args) -> dict[str, Any]: + if args.chat_template_path is not None: + chat_template_path = args.chat_template_path + chat_template_kwargs = None + elif args.tito_model == "default": + chat_template_path = None + chat_template_kwargs = None + else: + chat_template_path, chat_template_kwargs = resolve_fixed_chat_template( + args.tito_model, args.allowed_append_roles + ) + + tokenizer = load_tokenizer(args.hf_checkpoint, chat_template_path=chat_template_path, trust_remote_code=True) + tito_tokenizer = get_tito_tokenizer( + tokenizer, + tokenizer_type=args.tito_model, + chat_template_kwargs=chat_template_kwargs, + allowed_append_roles=args.allowed_append_roles, + ) + specs = _build_turn_specs( + tokenizer, + tito_tokenizer, + turns=args.turns, + input_tokens=args.input_tokens, + output_tokens=args.output_tokens, + r3_scale=args.r3_scale, + ) + + asyncio.run(_validate_specs_once(tokenizer, tito_tokenizer, specs)) + samples, wall_s = asyncio.run(_run_workload(tokenizer, tito_tokenizer, specs, args.sessions)) + total_turns = args.sessions * args.turns + content_tokens = args.sessions * sum(spec.content_input_tokens + spec.content_output_tokens for spec in specs) + output_tokens = args.sessions * sum(spec.content_output_tokens for spec in specs) + completion_tokens = args.sessions * sum(spec.completion_tokens for spec in specs) + retained_r3_raw_bytes = args.sessions * sum( + spec.r3_raw_bytes for spec in specs[-(MAX_ASSISTANT_ROLLBACK_STEPS + 1) :] + ) + + return { + "sessions": args.sessions, + "turns_per_session": args.turns, + "total_turns": total_turns, + "input_tokens_per_turn": args.input_tokens, + "output_tokens_per_turn": args.output_tokens, + "r3_scale_raw_bytes_per_token": args.r3_scale, + "hf_checkpoint": args.hf_checkpoint, + "tito_model": args.tito_model, + "allowed_append_roles": args.allowed_append_roles, + "chat_template_path": chat_template_path, + "chat_template_kwargs": chat_template_kwargs, + "wall_s": wall_s, + "throughput_turns_per_s": total_turns / wall_s if wall_s > 0 else float("nan"), + "throughput_content_tokens_per_s": content_tokens / wall_s if wall_s > 0 else float("nan"), + "throughput_completion_tokens_per_s": completion_tokens / wall_s if wall_s > 0 else float("nan"), + "throughput_output_content_tokens_per_s": output_tokens / wall_s if wall_s > 0 else float("nan"), + "retained_r3_raw_bytes_estimate": retained_r3_raw_bytes, + "turn_specs": [ + { + "turn_index": idx, + "prompt_tokens": len(spec.expected_prompt_token_ids), + "completion_tokens": spec.completion_tokens, + "content_input_tokens": spec.content_input_tokens, + "content_output_tokens": spec.content_output_tokens, + "r3_raw_bytes": spec.r3_raw_bytes, + "r3_json_chars": spec.r3_json_chars, + "request_body_bytes": len(spec.request_body), + "response_body_bytes": len(spec.response_body), + } + for idx, spec in enumerate(specs) + ], + "metrics": {name: _summary(values) for name, values in samples.items()}, + "raw_samples_ms": samples if args.include_raw_samples else None, + } + + +def _fmt_ms_stats(stats: dict[str, float]) -> str: + return ( + f"mean={stats['mean_ms']:.3f}ms p50={stats['p50_ms']:.3f}ms " + f"p95={stats['p95_ms']:.3f}ms p99={stats['p99_ms']:.3f}ms max={stats['max_ms']:.3f}ms" + ) + + +def _fmt_block(result: dict[str, Any]) -> str: + metrics = result["metrics"] + last_spec = result["turn_specs"][-1] + lines = [ + "=" * 72, + "Session Server CPU overhead benchmark", + "=" * 72, + f" sessions x turns : {result['sessions']} x {result['turns_per_session']} " + f"({result['total_turns']} turns)", + f" content tokens / turn : input={result['input_tokens_per_turn']} " + f"output={result['output_tokens_per_turn']}", + f" r3 raw bytes / token : {result['r3_scale_raw_bytes_per_token']}", + f" tokenizer / TITO : {result['hf_checkpoint']} / {result['tito_model']}", + f" final-turn prompt/completion : {last_spec['prompt_tokens']} / {last_spec['completion_tokens']} tokens", + f" final-turn response body : {last_spec['response_body_bytes'] / 1024:.1f} KiB", + f" retained r3 estimate : {result['retained_r3_raw_bytes_estimate'] / 1024 / 1024:.1f} MiB raw", + "-" * 72, + f" tokenization : {_fmt_ms_stats(metrics['tokenization_ms'])}", + f" record store : {_fmt_ms_stats(metrics['record_store_ms'])}", + f" reply latency : {_fmt_ms_stats(metrics['reply_latency_ms'])}", + "-" * 72, + f" request parse : {_fmt_ms_stats(metrics['request_parse_ms'])}", + f" request dump : {_fmt_ms_stats(metrics['request_dump_ms'])}", + f" response parse+validate : {_fmt_ms_stats(metrics['response_parse_validate_ms'])}", + "-" * 72, + f" wall clock : {result['wall_s']:.3f}s", + f" throughput : {result['throughput_turns_per_s']:.1f} turns/s", + f" content-token throughput : {result['throughput_content_tokens_per_s']:.1f} tokens/s", + f" completion-token throughput : {result['throughput_completion_tokens_per_s']:.1f} tokens/s", + "=" * 72, + ] + return "\n".join(lines) + + +def _write_json(result: dict[str, Any], path: str) -> None: + payload = dict(result) + if payload.get("raw_samples_ms") is None: + payload.pop("raw_samples_ms", None) + with Path(path).open("w") as f: + json.dump(payload, f, indent=2) + print(f"[bench] wrote {path}") + + +def main() -> None: + parser = argparse.ArgumentParser(description="CPU-only Session Server overhead benchmark") + parser.add_argument("--sessions", type=_positive_int, default=32, help="number of sessions to create") + parser.add_argument("--turns", type=_positive_int, default=4, help="turns per session") + parser.add_argument("--input-tokens", type=_non_negative_int, default=64, help="new user-content tokens per turn") + parser.add_argument( + "--output-tokens", + type=_non_negative_int, + default=64, + help="assistant-content tokens per turn", + ) + parser.add_argument( + "--r3-scale", + type=_non_negative_int, + default=1000, + help="raw routed_experts bytes per accumulated token", + ) + parser.add_argument("--hf-checkpoint", default=DEFAULT_HF_CHECKPOINT, help="tokenizer checkpoint or local path") + parser.add_argument("--tito-model", default=DEFAULT_TITO_MODEL, help="TITO tokenizer family") + parser.add_argument( + "--allowed-append-roles", + nargs="+", + default=DEFAULT_ALLOWED_APPEND_ROLES, + help="roles allowed after the pretokenized prefix", + ) + parser.add_argument("--chat-template-path", default=None, help="explicit chat template path") + parser.add_argument("--json-out", default=None, help="persist the run as a JSON artifact") + parser.add_argument( + "--include-raw-samples", action="store_true", help="include every per-turn sample in JSON output" + ) + args = parser.parse_args() + + result = run_bench(args) + print(_fmt_block(result)) + if args.json_out: + _write_json(result, args.json_out) + + +if __name__ == "__main__": + main()