diff --git a/miles/rollout/session/linear_trajectory.py b/miles/rollout/session/linear_trajectory.py index c9d7694479..4d4e608f77 100644 --- a/miles/rollout/session/linear_trajectory.py +++ b/miles/rollout/session/linear_trajectory.py @@ -31,6 +31,9 @@ class LinearTrajectory: lock: asyncio.Lock = field(default_factory=asyncio.Lock, repr=False, compare=False) closing: bool = field(default=False, repr=False, compare=False) + # Per-session in-flight gate: set under self.lock when a chat claims the + # session, cleared by that same request on every exit path. + chat_inflight: bool = field(default=False, repr=False, compare=False) messages: list[dict[str, Any]] = field(default_factory=list) records: list[SessionRecord] = field(default_factory=list) trajectory_token_ids: list[list[int]] = field(default_factory=list) diff --git a/miles/rollout/session/session_errors.py b/miles/rollout/session/session_errors.py index 30e6784384..cdab3d060a 100644 --- a/miles/rollout/session/session_errors.py +++ b/miles/rollout/session/session_errors.py @@ -6,7 +6,9 @@ ├── SessionNotFoundError → 404 session does not exist ├── MessageValidationError → 400 messages structure/content invalid ├── TokenizationError → 500 TITO tokenizer / prefix mismatch -└── UpstreamResponseError → 502 SGLang response invalid or unexpected +├── UpstreamResponseError → 502 SGLang response invalid or unexpected +├── SessionBusyError → 409 session already has an in-flight chat +└── SessionInvariantError → 500 unreachable session-state invariant violated """ @@ -49,3 +51,20 @@ class UpstreamResponseError(SessionError): """ status_code: int = 502 + + +class SessionBusyError(SessionError): + """Raised when the session already has an in-flight chat completion. + + One linear trajectory admits one in-flight chat at a time. + """ + + status_code: int = 409 + + +class SessionInvariantError(SessionError): + """Raised when a session-state invariant that should be unreachable under + the in-flight gate is violated (defensive; indicates a real bug). + """ + + status_code: int = 500 diff --git a/miles/rollout/session/session_server.py b/miles/rollout/session/session_server.py index ea64ace826..86a41ab950 100644 --- a/miles/rollout/session/session_server.py +++ b/miles/rollout/session/session_server.py @@ -8,12 +8,13 @@ import json import logging +import os +from concurrent.futures import ThreadPoolExecutor import httpx import setproctitle import uvicorn from fastapi import FastAPI, Request -from fastapi.responses import JSONResponse from starlette.responses import Response from miles.rollout.session.sessions import setup_session_routes @@ -38,6 +39,12 @@ def __init__(self, args, backend_url: str): # Close the httpx connection pool when uvicorn shuts down to avoid FD leaks. self.app.router.on_shutdown.append(self.client.aclose) + self.cpu_executor = ThreadPoolExecutor( + max_workers=getattr(args, "session_server_cpu_workers", None) or min(16, os.cpu_count() or 1), + thread_name_prefix="session-cpu", + ) + self.app.router.on_shutdown.append(lambda: self.cpu_executor.shutdown(wait=False, cancel_futures=True)) + setup_session_routes(self.app, self, args) async def do_proxy( @@ -79,21 +86,20 @@ async def do_proxy( } def build_proxy_response(self, result: dict) -> Response: - content = result["response_body"] - status_code = result["status_code"] - # Drop wire-level framing headers from upstream so Starlette rebuilds them - # from the body we actually send: transfer-encoding is hop-by-hop + # httpx already decoded the body, so upstream content-encoding/length are + # stale framing headers; drop them and let Starlette rebuild from the body. headers = { k: v for k, v in result["headers"].items() if k.lower() not in ("content-length", "transfer-encoding", "content-encoding") } content_type = headers.get("content-type", "") - try: - data = json.loads(content) - return JSONResponse(content=data, status_code=status_code, headers=headers) - except (json.JSONDecodeError, UnicodeDecodeError): - return Response(content=content, status_code=status_code, headers=headers, media_type=content_type) + return Response( + content=result["response_body"], + status_code=result["status_code"], + headers=headers, + media_type=content_type, + ) def run_session_server(args, backend_url: str): @@ -108,4 +114,7 @@ def run_session_server(args, backend_url: str): args.session_server_port, backend_url, ) + # Single uvicorn worker on purpose: extra workers would each own a separate + # SessionRegistry + asyncio.Lock, so a session_id could land on a process that + # doesn't own it. Multi-process needs sticky session ownership and is deferred. uvicorn.run(server.app, host=args.session_server_ip, port=args.session_server_port, log_level="info") diff --git a/miles/rollout/session/sessions.py b/miles/rollout/session/sessions.py index 243ebd2746..bca55df5a3 100644 --- a/miles/rollout/session/sessions.py +++ b/miles/rollout/session/sessions.py @@ -1,5 +1,7 @@ +import asyncio import json import logging +import math import time from fastapi import Request @@ -8,7 +10,9 @@ from miles.rollout.session.linear_trajectory import SessionRegistry from miles.rollout.session.session_errors import ( + SessionBusyError, SessionError, + SessionInvariantError, SessionNotFoundError, TokenizationError, UpstreamResponseError, @@ -20,6 +24,91 @@ logger = logging.getLogger(__name__) +def _reject_json_constant(value: str): + raise ValueError(f"invalid JSON constant: {value}") + + +def _parse_request_body(body: bytes) -> dict: + return json.loads(body) if body else {} + + +def _dump_request_body(request_body: dict) -> bytes: + return json.dumps(request_body).encode() + + +def _parse_and_validate_response(response_body: bytes) -> tuple[dict, dict, list[int]]: + """Parse + validate a successful chat completion response off the event loop. + + Returns (full_response, assistant_message, completion_token_ids). Raises + UpstreamResponseError on malformed meta_info / content / token-length mismatch. + Touches no session state. + """ + try: + response = json.loads(response_body, parse_constant=_reject_json_constant) + except (json.JSONDecodeError, UnicodeDecodeError, ValueError) as e: + raise UpstreamResponseError(f"upstream response is not valid JSON: {e}") from e + + choices = response.get("choices") if isinstance(response, dict) else None + if not isinstance(choices, list) or not choices or not isinstance(choices[0], dict): + raise UpstreamResponseError("upstream response has no valid choices[0]") + choice = choices[0] + meta_info = choice.get("meta_info") + if not isinstance(meta_info, dict) or "output_token_logprobs" not in meta_info: + raise UpstreamResponseError("meta_info and output_token_logprobs must be in choice (requires logprobs=True)") + assistant_message = choice.get("message") + if not isinstance(assistant_message, dict): + raise UpstreamResponseError("upstream response choice has no valid message") + if assistant_message.get("content") is None: + raise UpstreamResponseError( + "assistant message content is None, when tool call parser failed SGLang should still return " + "an empty content rather than None. Please check your modified SGLang version." + ) + output_token_logprobs = meta_info["output_token_logprobs"] + completion_tokens = meta_info.get("completion_tokens") + # bool is an int subclass; reject it explicitly so a True/False count is not + # silently treated as 1/0. + if ( + not isinstance(output_token_logprobs, list) + or not isinstance(completion_tokens, int) + or isinstance(completion_tokens, bool) + ): + raise UpstreamResponseError("upstream response output_token_logprobs/completion_tokens have invalid types") + actual_output_logprobs_len = len(output_token_logprobs) + if actual_output_logprobs_len != completion_tokens: + raise UpstreamResponseError( + "invalid chat completion response: " + f"len(output_token_logprobs)={actual_output_logprobs_len} " + f"!= completion_tokens={completion_tokens}. " + f"Please check whether you use the correct SGLang branch which has fix the tokenizer batch decode issue." + ) + # Each entry must be a (logprob, token_id, ...) sequence (SGLang emits + # [logprob, token_id, token_text] triples; len > 2 is normal). Both leading + # fields are consumed downstream: token_id (entry[1]) feeds the stored + # trajectory token ids, and logprob (entry[0]) feeds Sample.rollout_log_probs + # in openai_endpoint_utils. A non-int token id or non-numeric logprob would + # silently corrupt the stored trajectory / training logprobs, so reject the + # whole response instead of extracting garbage. bool is an int subclass, so + # reject it explicitly for both fields. + completion_token_ids: list[int] = [] + for entry in output_token_logprobs: + if not isinstance(entry, (list, tuple)) or len(entry) < 2: + raise UpstreamResponseError( + "upstream response output_token_logprobs entry is not a (logprob, token_id) pair" + ) + logprob = entry[0] + if not isinstance(logprob, (int, float)) or isinstance(logprob, bool) or not math.isfinite(logprob): + raise UpstreamResponseError( + f"upstream response output_token_logprobs logprob is not a number: {logprob!r}" + ) + token_id = entry[1] + if not isinstance(token_id, int) or isinstance(token_id, bool): + raise UpstreamResponseError( + f"upstream response output_token_logprobs token id is not an int: {token_id!r}" + ) + completion_token_ids.append(token_id) + return response, assistant_message, completion_token_ids + + def setup_session_routes(app, backend, args): hf_checkpoint = getattr(args, "hf_checkpoint", None) if not hf_checkpoint: @@ -48,24 +137,6 @@ async def health(): body["session_server_instance_id"] = session_server_instance_id return body - # --- DEBUG: track in-flight chat_completions --- - _inflight_chat = {"count": 0} - - @app.middleware("http") - async def debug_request_logger(request: Request, call_next): - client = request.client - client_info = f"{client.host}:{client.port}" if client else "unknown" - logger.info( - f"[session-server] REQUEST ARRIVED: {request.method} {request.url.path} from={client_info} inflight_chat={_inflight_chat['count']}" - ) - t0 = time.time() - response = await call_next(request) - elapsed = time.time() - t0 - logger.info( - f"[session-server] REQUEST DONE: {request.method} {request.url.path} status={response.status_code} elapsed={elapsed:.3f}s from={client_info}" - ) - return response - @app.exception_handler(SessionError) async def session_error_handler(request: Request, exc: SessionError): return JSONResponse(status_code=exc.status_code, content={"error": str(exc)}) @@ -113,115 +184,80 @@ async def delete_session(session_id: str): @app.post("/sessions/{session_id}/v1/chat/completions") async def chat_completions(request: Request, session_id: str): - """Proxy a chat completion through SGLang with TITO token tracking. - - Flow: prepare pretokenized input_ids (lock held briefly) → inject - SGLang flags → proxy to backend (NO lock) → validate response → - update trajectory checkpoint (lock held briefly) → append session record. - - The lock is NOT held during the slow proxy call to avoid blocking - DELETE/other operations when the agent disconnects mid-request. + """One in-flight chat per session; a second concurrent same-session chat + fast-fails 409 without entering the backend. State mutation stays on the + event loop under session.lock; stateless CPU work is offloaded. """ - _inflight_chat["count"] += 1 - try: - session = registry.get_session(session_id) + loop = asyncio.get_running_loop() + session = registry.get_session(session_id) + + claimed = False + # claim the single in-flight slot under a brief lock; closing (404) beats busy (409) + async with session.lock: if session.closing: raise SessionNotFoundError(f"session not found: session_id={session_id}") - - # --- Phase 1: prepare request (lock held briefly) --- + if session.chat_inflight: + raise SessionBusyError("session already has an in-flight chat completion") + session.chat_inflight = True + claimed = True + try: + body = await request.body() + request_body = await loop.run_in_executor(backend.cpu_executor, _parse_request_body, body) + + # TITO token tracking requires Miles-owned input_ids plus SGLang + # output-token metadata: + # logprobs=True → populates meta_info.output_token_logprobs + # return_meta_info → wraps the above in choice.meta_info + # Both flags are hardcoded (not set default) to prevent agent-side + # overrides from breaking the token accumulation invariants. + request_body["logprobs"] = True + request_body["return_meta_info"] = True + if getattr(args, "use_rollout_routing_replay", False): + request_body["return_routed_experts"] = True + if getattr(args, "use_rollout_indexer_replay", False): + request_body["return_indexer_topk"] = True + # Must be False so stop-token text is trimmed from assistant + # message content; token IDs are still taken from logprobs below. + request_body["no_stop_trim"] = False + request_messages = request_body.get("messages", []) + + # prepare pretokenized input under the lock (mutates trajectory state) async with session.lock: - # Double-check: session may have been marked closing while waiting for lock. if session.closing: raise SessionNotFoundError(f"session not found: session_id={session_id}") - - body = await request.body() - request_body = json.loads(body) if body else {} - - # TITO token tracking requires Miles-owned input_ids plus SGLang - # output-token metadata: - # logprobs=True → populates meta_info.output_token_logprobs - # return_meta_info → wraps the above in choice.meta_info - # Both flags are hardcoded (not set default) to prevent agent-side - # overrides from breaking the token accumulation invariants. - request_body["logprobs"] = True - request_body["return_meta_info"] = True - if getattr(args, "use_rollout_routing_replay", False): - request_body["return_routed_experts"] = True - if getattr(args, "use_rollout_indexer_replay", False): - request_body["return_indexer_topk"] = True - # Must be False so stop-token text is trimmed from assistant - # message content; token IDs are still taken from logprobs below. - request_body["no_stop_trim"] = False - - request_messages = request_body.get("messages", []) prompt_token_ids = session.prepare_pretokenized( request_messages, tools=request_body.get("tools"), tito_tokenizer=registry.tito_tokenizer, ) request_body["input_ids"] = prompt_token_ids - logger.debug( - "Using TITO input_ids: %d tokens", - len(prompt_token_ids), - ) - - body = json.dumps(request_body).encode() expected_num_assistant = session.num_assistant - # --- lock released here --- + logger.debug("Using TITO input_ids: %d tokens", len(prompt_token_ids)) - # --- Phase 2: proxy to SGLang (NO lock held) --- - result = await backend.do_proxy(request, "v1/chat/completions", body=body) - - # If SGLang returned a non-200 error (e.g. 400 for context too long), - # pass it through to the agent without recording — the agent can retry - # or handle the error. + encoded_body = await loop.run_in_executor(backend.cpu_executor, _dump_request_body, request_body) + result = await backend.do_proxy(request, "v1/chat/completions", body=encoded_body) + # Non-200 (e.g. 400 for context too long) passes through unrecorded; + # the agent can retry or handle the error. if result["status_code"] != 200: return backend.build_proxy_response(result) - response = json.loads(result["response_body"]) - - choice = response.get("choices", [{}])[0] - - meta_info = choice.get("meta_info") - if not isinstance(meta_info, dict) or "output_token_logprobs" not in meta_info: - raise UpstreamResponseError( - "meta_info and output_token_logprobs must be in choice (requires logprobs=True)" - ) - assistant_message = choice.get("message", {}) - if assistant_message.get("content") is None: - raise UpstreamResponseError( - "assistant message content is None, when tool call parser failed SGLang should still return " - "an empty content rather than None. Please check your modified SGLang version." - ) - - output_token_logprobs = meta_info["output_token_logprobs"] - completion_tokens = meta_info["completion_tokens"] - - actual_output_logprobs_len = len(output_token_logprobs) - if actual_output_logprobs_len != completion_tokens: - raise UpstreamResponseError( - "invalid chat completion response: " - f"len(output_token_logprobs)={actual_output_logprobs_len} " - f"!= completion_tokens={completion_tokens}. " - f"Please check whether you use the correct SGLang branch which has fix the tokenizer batch decode issue." - ) + response, assistant_message, completion_token_ids = await loop.run_in_executor( + backend.cpu_executor, _parse_and_validate_response, result["response_body"] + ) - completion_token_ids = [t[1] for t in output_token_logprobs] - - # --- Phase 3: update state (lock held briefly) --- + # commit state under the lock async with session.lock: if session.closing: logger.warning(f"Session {session_id} closed during proxy, skipping state update") return backend.build_proxy_response(result) - if session.num_assistant != expected_num_assistant: - logger.warning( - f"Session {session_id} state changed during proxy " - f"(expected num_assistant={expected_num_assistant}, " - f"got {session.num_assistant}), skipping state update" + logger.error( + f"Session {session_id} invariant violation: num_assistant={session.num_assistant} " + f"!= expected={expected_num_assistant} under the in-flight gate; this should be unreachable" + ) + raise SessionInvariantError( + f"session state changed under the in-flight gate (session_id={session_id})" ) - return backend.build_proxy_response(result) - session.update_pretokenized_state( request_messages, assistant_message, @@ -229,7 +265,6 @@ async def chat_completions(request: Request, session_id: str): completion_token_ids=completion_token_ids, max_trim_tokens=registry.tito_tokenizer.max_trim_tokens, ) - record = SessionRecord( timestamp=time.time(), method=request.method, @@ -239,11 +274,12 @@ async def chat_completions(request: Request, session_id: str): response=response, ) session.append_record(record) - # --- lock released here --- - return backend.build_proxy_response(result) finally: - _inflight_chat["count"] -= 1 + if claimed: + # single-threaded event loop: a plain write is atomic; no other coroutine + # mutates this session's flag without the lock, and finally runs on cancellation. + session.chat_inflight = False @app.api_route("/sessions/{session_id}/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"]) async def session_proxy(request: Request, session_id: str, path: str): diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index b2dde47c7c..388dc83a43 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -1803,6 +1803,15 @@ def add_session_arguments(parser): help="Message roles allowed to be appended after the pretokenized " "assistant prefix in TITO sessions (default: tool).", ) + parser.add_argument( + "--session-server-cpu-workers", + type=int, + default=min(16, os.cpu_count() or 1), + help="Max worker threads for the session server's bounded CPU thread pool " + "(offloads stateless JSON parse/dump and response validation off the event loop). " + "Higher values improve event-loop responsiveness under load but raise peak memory; " + "the GIL means stdlib JSON does not gain CPU throughput from more threads.", + ) return parser def add_user_provided_function_arguments(parser): diff --git a/tests/benchmark/bench_session_responsiveness.py b/tests/benchmark/bench_session_responsiveness.py new file mode 100644 index 0000000000..1f55fc939d --- /dev/null +++ b/tests/benchmark/bench_session_responsiveness.py @@ -0,0 +1,504 @@ +"""Event-loop-responsiveness micro-benchmark for the standalone SessionServer. + +Why this exists (directional / non-blocking): +The session server offloads stateless CPU work (large JSON parse/dump of the +chat request + response, including hundreds-of-KB `routed_experts` blobs) off the +single asyncio event loop onto a bounded `SessionServer.cpu_executor`. The claim +is that this keeps the one event loop responsive under heavy-response load so the +liveness probe `GET /health` (handled inline on the loop) stays fast. It does NOT +claim higher total CPU throughput: the GIL still serializes the Python JSON work, +so this measures latency/responsiveness, not aggregate CPU. + +Honest reporting note: deeper per-stage server-internal timing is intentionally +NOT instrumented into the production hot path (no probes added to chat_completions). +The offloaded stage's cost is reported indirectly via the measured single-response +parse cost (`json.loads` of one heavy response body), which is the per-call CPU +component the executor offload moves off the event loop. + +This file is named `bench_*` so pytest does NOT auto-collect it (no flaky timing +test in CI). Run it directly: + + # Run once, print a human-readable block: + python tests/benchmark/bench_session_responsiveness.py + + # Run once, persist a reviewable JSON artifact (used for before/after): + python tests/benchmark/bench_session_responsiveness.py --label after \ + --json-out .humanize/.../benchmarks/session-responsiveness-after.json + + # Compare two persisted runs into a markdown verdict: + python tests/benchmark/bench_session_responsiveness.py --compare \ + before.json after.json --out compare.md + +Method: fire K concurrent chats across K DISTINCT sessions (distinct sessions are +not gated, so they run in parallel), each producing a large response that forces a +CPU-heavy `json.loads` in `_parse_and_validate_response`. Concurrently, a separate +thread polls `GET /health` every ~10ms and records each round-trip latency. We +report chat throughput, response body size, and `/health` latency percentiles. All +timing is client-side `time.perf_counter`. +""" + +from __future__ import annotations + +import argparse +import json +import logging +import statistics +import threading +import time +from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager +from types import SimpleNamespace +from unittest.mock import patch + +import requests + +# Quiet the uvicorn access logs from both servers so the summary block is the +# only thing on stdout (the per-request "GET /health 200 OK" lines bury it). +for _name in ("uvicorn", "uvicorn.access", "uvicorn.error"): + logging.getLogger(_name).setLevel(logging.WARNING) + +from miles.rollout.session.session_server import SessionServer +from miles.utils.http_utils import find_available_port +from miles.utils.test_utils.mock_sglang_server import MockSGLangServer, ProcessResult, with_mock_server +from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer + +# --- Tunable constants (kept modest so a run finishes well under a minute) --- +HF_CHECKPOINT = "Qwen/Qwen3-0.6B" +K_CHATS = 64 # concurrent chats, one per distinct (ungated) session +# The session server json.loads the WHOLE response body to validate it, but the +# `routed_experts` value is opaque to it (never decoded). To isolate the CPU +# parse cost the offload targets — without inflating the on-loop body I/O that +# scales with raw byte size — we make routed_experts a STRUCTURED blob (nested +# numeric arrays). Structured JSON parses ~7-8x slower per byte than one big +# string, so a modest ~2 MiB body costs ~15-30ms to parse. Run INLINE on the loop +# (pre-offload) that stalls every /health probe queued behind it; offloaded to +# cpu_executor it runs off-loop and the probe stays fast. +BLOB_ROWS = 900 # rows of BLOB_ROW_WIDTH floats -> ~1.5 MiB body, ~18ms parse/call +BLOB_ROW_WIDTH = 256 +HEALTH_POLL_INTERVAL_S = 0.01 # ~10ms between /health probes +HEALTH_POLL_DURATION_S = 4.0 # how long the health poller runs (also caps the load window) +CHAT_TEXT = "responsiveness-ok" + + +def _make_large_blob() -> list: + # Nested numeric arrays, like a routed_experts logits buffer: expensive to + # json.loads per byte (many tokens), which is exactly the CPU work the + # session server offloads off the event loop. + row = [round(i * 0.001, 4) for i in range(BLOB_ROW_WIDTH)] + return [list(row) for _ in range(BLOB_ROWS)] + + +def _patch_mock_chat_response_heavy(blob: list): + """Patch the mock chat response into the shape the session server validates + (output_token_logprobs as (logprob, token_id)), and inject a large STRUCTURED + routed_experts blob into meta_info so the session server's json.loads of the + response body is genuinely CPU-heavy. The blob is opaque to the session server + (it only validates output_token_logprobs / message), matching production where + routed_experts is an opaque buffer the session layer never decodes. + """ + original_chat_response = MockSGLangServer._compute_chat_completions_response + + def patched_chat_response(self, payload: dict) -> dict: + response = original_chat_response(self, payload) + choice = response["choices"][0] + logprobs_content = choice["logprobs"]["content"] + output_token_logprobs = [ + (item["logprob"], self.tokenizer.convert_tokens_to_ids(item["token"])) for item in logprobs_content + ] + choice["meta_info"]["output_token_logprobs"] = output_token_logprobs + choice["meta_info"]["completion_tokens"] = len(output_token_logprobs) + choice["meta_info"]["routed_experts"] = blob + return response + + return patch.object(MockSGLangServer, "_compute_chat_completions_response", new=patched_chat_response) + + +@contextmanager +def _router_env(process_fn, blob: list, *, latency: float = 0.0): + with _patch_mock_chat_response_heavy(blob): + with with_mock_server(model_name=HF_CHECKPOINT, process_fn=process_fn, latency=latency) as backend: + args = SimpleNamespace( + miles_router_timeout=30, + hf_checkpoint=HF_CHECKPOINT, + chat_template_path=None, + trajectory_manager="linear_trajectory", + tito_allowed_append_roles=["tool", "system"], + # Make the session server request + echo a big routed_experts blob, + # so the per-response json.loads is genuinely CPU-heavy. + use_rollout_routing_replay=True, + ) + server_obj = SessionServer(args, backend_url=backend.url) + + port = find_available_port(31000) + server = UvicornThreadServer(server_obj.app, host="127.0.0.1", port=port) + server.start() + # uvicorn.Config(log_level="info") reconfigures these on startup, so + # re-quiet them now that both servers are up (keeps stdout to the summary). + for _name in ("uvicorn", "uvicorn.access", "uvicorn.error"): + logging.getLogger(_name).setLevel(logging.WARNING) + url = f"http://127.0.0.1:{port}" + try: + yield SimpleNamespace(url=url, backend=backend, server=server) + finally: + server.stop() + + +def _create_session(url: str) -> str: + resp = requests.post(f"{url}/sessions", timeout=5.0) + assert resp.status_code == 200, resp.text + return resp.json()["session_id"] + + +def _chat(url: str, session_id: str) -> tuple[int, int]: + """Returns (status_code, response_body_size_bytes).""" + resp = requests.post( + f"{url}/sessions/{session_id}/v1/chat/completions", + json={"messages": [{"role": "user", "content": "drive-load"}]}, + timeout=60.0, + ) + return resp.status_code, len(resp.content) + + +class _HealthPoller(threading.Thread): + """Polls GET /health on a fixed cadence and records each round-trip latency + (seconds, client-side perf_counter). Runs until stop() or duration elapses.""" + + def __init__(self, url: str, interval_s: float, max_duration_s: float): + super().__init__(daemon=True) + self.url = f"{url}/health" + self.interval_s = interval_s + self.max_duration_s = max_duration_s + self.latencies_s: list[float] = [] + self.errors = 0 + self._stop_evt = threading.Event() + self._session = requests.Session() + + def stop(self) -> None: + self._stop_evt.set() + + def run(self) -> None: + deadline = time.perf_counter() + self.max_duration_s + while not self._stop_evt.is_set() and time.perf_counter() < deadline: + t0 = time.perf_counter() + try: + resp = self._session.get(self.url, timeout=5.0) + dt = time.perf_counter() - t0 + if resp.status_code == 200: + self.latencies_s.append(dt) + else: + self.errors += 1 + except requests.RequestException: + self.errors += 1 + time.sleep(self.interval_s) + + +def _pct(values_ms: list[float], q: float) -> float: + if not values_ms: + return float("nan") + ordered = sorted(values_ms) + idx = min(len(ordered) - 1, int(q * len(ordered))) + return ordered[idx] + + +def _measure_per_stage_cpu_ms(blob: list) -> dict: + """Benchmark-side (NOT production hot-path) timing of the three stateless CPU + stages the session server offloads to its cpu_executor: request parse, request + dump (with Miles-owned input_ids), and response parse+validate. `parse_ms` uses + plain json.loads so it is identical and comparable across before/after builds; + `validate_ms` calls the real helper and is recorded only where it exists (the + pre-offload build predates the module-level helpers -> null).""" + # Heavy response parse (the dominant offloaded stage): plain json.loads, so it + # runs identically on the pre-offload build too. + sample_resp = { + "choices": [ + { + "message": {"role": "assistant", "content": "x"}, + "meta_info": { + "output_token_logprobs": [[-0.01 * i, 100 + i] for i in range(64)], + "completion_tokens": 64, + "routed_experts": blob, + }, + } + ] + } + sample_resp_body = json.dumps(sample_resp).encode() + _t = time.perf_counter() + json.loads(sample_resp_body) + parse_ms = (time.perf_counter() - _t) * 1000 + + # Request dump (encode the body with Miles-owned input_ids): plain json.dumps. + sample_req = { + "messages": [{"role": "user", "content": "drive-load"}], + "input_ids": list(range(4096)), + "logprobs": True, + "return_meta_info": True, + } + _t = time.perf_counter() + json.dumps(sample_req).encode() + dump_ms = (time.perf_counter() - _t) * 1000 + + # Full parse+validate via the real helper (after-only; null where absent). + validate_ms = None + try: + from miles.rollout.session.sessions import _parse_and_validate_response + + _t = time.perf_counter() + _parse_and_validate_response(sample_resp_body) + validate_ms = (time.perf_counter() - _t) * 1000 + except Exception: + validate_ms = None + + return {"parse_ms": parse_ms, "dump_ms": dump_ms, "validate_ms": validate_ms} + + +def run_bench() -> dict: + blob = _make_large_blob() + # The CPU cost the offload targets, measured up front (benchmark-side, not in + # the production hot path) so the artifact states how heavy each offloaded + # stage is. parse_ms is the dominant one and is comparable across builds. + per_stage = _measure_per_stage_cpu_ms(blob) + parse_ms = per_stage["parse_ms"] + + def process_fn(_prompt: str) -> ProcessResult: + # routed_experts is injected by the response patch (kept off the str-typed + # dataclass field); this just drives a small assistant message. + return ProcessResult(text=CHAT_TEXT, finish_reason="stop") + + with _router_env(process_fn, blob) as env: + session_ids = [_create_session(env.url) for _ in range(K_CHATS)] + + # Warm one chat so we can log the real response body size before the run. + warm_status, warm_size = _chat(env.url, session_ids[0]) + assert warm_status == 200, f"warm-up chat failed: {warm_status}" + + # Baseline: /health latency with no chat load (single thread, short burst). + baseline = _HealthPoller(env.url, HEALTH_POLL_INTERVAL_S, max_duration_s=0.7) + baseline.start() + baseline.join() + baseline_ms = [s * 1000 for s in baseline.latencies_s] + + # Under load: poll /health while waves of K concurrent chats keep the + # backend + cpu_executor busy. A single wave of K first-turn chats is + # fast, so we fire back-to-back waves until the poll window elapses; this + # sustains heavy-response load long enough to sample /health meaningfully. + # Each chat uses its own fresh, distinct session (distinct sessions are + # ungated, so they run in parallel). + poller = _HealthPoller(env.url, HEALTH_POLL_INTERVAL_S, HEALTH_POLL_DURATION_S) + poller.start() + + results: list[tuple[int, int]] = [] + waves = 0 + t0 = time.perf_counter() + with ThreadPoolExecutor(max_workers=K_CHATS) as pool: + while time.perf_counter() - t0 < HEALTH_POLL_DURATION_S: + fresh_ids = [_create_session(env.url) for _ in range(K_CHATS)] + futures = [pool.submit(_chat, env.url, sid) for sid in fresh_ids] + results.extend(f.result(timeout=120.0) for f in futures) + waves += 1 + chat_wall_s = time.perf_counter() - t0 + + poller.stop() + poller.join() + + statuses = [s for s, _ in results] + sizes = [sz for _, sz in results] + ok = sum(1 for s in statuses if s == 200) + load_ms = [s * 1000 for s in poller.latencies_s] + + total = len(results) + return { + "k_chats": K_CHATS, + "waves": waves, + "parse_ms": parse_ms, + "dump_ms": per_stage["dump_ms"], + "validate_ms": per_stage["validate_ms"], + "total_chats": total, + "ok_chats": ok, + "failed_chats": total - ok, + "response_body_bytes": warm_size if not sizes else max(sizes), + "chat_wall_s": chat_wall_s, + "chat_throughput_per_s": ok / chat_wall_s if chat_wall_s > 0 else float("nan"), + "health_samples_under_load": len(load_ms), + "health_errors_under_load": poller.errors, + "health_baseline_samples": len(baseline_ms), + # Computed percentiles (persisted so the artifact is reviewable without + # re-deriving from raw samples); raw samples kept alongside for audit. + "health_baseline_p50_ms": _pct(baseline_ms, 0.50), + "health_baseline_p95_ms": _pct(baseline_ms, 0.95), + "health_baseline_max_ms": max(baseline_ms) if baseline_ms else float("nan"), + "health_load_p50_ms": _pct(load_ms, 0.50), + "health_load_p95_ms": _pct(load_ms, 0.95), + "health_load_p99_ms": _pct(load_ms, 0.99), + "health_load_max_ms": max(load_ms) if load_ms else float("nan"), + "health_load_mean_ms": statistics.mean(load_ms) if load_ms else float("nan"), + "baseline_ms": baseline_ms, + "load_ms": load_ms, + } + + +def _fmt_block(r: dict) -> str: + base = r["baseline_ms"] + load = r["load_ms"] + validate_str = f"{r['validate_ms']:.1f}ms" if r.get("validate_ms") is not None else "n/a (pre-offload)" + lines = [ + "=" * 64, + "SessionServer event-loop responsiveness benchmark", + "=" * 64, + f" concurrency / wave (K) : {r['k_chats']}", + f" waves driven : {r['waves']} (total {r['total_chats']} chats)", + f" response body size (actual) : {r['response_body_bytes'] / 1024:.1f} KiB", + f" per-stage CPU (offloaded) : parse={r['parse_ms']:.1f}ms dump={r.get('dump_ms', float('nan')):.2f}ms " + f"validate={validate_str}", + f" chats ok / failed : {r['ok_chats']} / {r['failed_chats']}", + f" chat wall-clock : {r['chat_wall_s']:.3f} s", + f" chat throughput : {r['chat_throughput_per_s']:.1f} chats/s", + "-" * 64, + f" /health baseline (no load) : n={r['health_baseline_samples']}, " + f"p50={_pct(base, 0.50):.2f}ms p95={_pct(base, 0.95):.2f}ms " + f"max={(max(base) if base else float('nan')):.2f}ms", + f" /health UNDER LOAD : n={r['health_samples_under_load']}, " + f"errors={r['health_errors_under_load']}", + f" p50 = {_pct(load, 0.50):.2f} ms", + f" p95 = {_pct(load, 0.95):.2f} ms", + f" p99 = {_pct(load, 0.99):.2f} ms", + f" max = {(max(load) if load else float('nan')):.2f} ms", + f" mean= {(statistics.mean(load) if load else float('nan')):.2f} ms", + "=" * 64, + ] + return "\n".join(lines) + + +def _persist(result: dict, path: str, label: str | None, commit: str | None, dirty: bool | None) -> None: + payload = dict(result) + payload["label"] = label + payload["commit"] = commit + payload["dirty"] = dirty + payload["blob_rows"] = BLOB_ROWS + payload["blob_row_width"] = BLOB_ROW_WIDTH + payload["health_poll_interval_s"] = HEALTH_POLL_INTERVAL_S + payload["health_poll_duration_s"] = HEALTH_POLL_DURATION_S + with open(path, "w") as f: + json.dump(payload, f, indent=2) + print(f"[bench] wrote {path}") + + +def _compare(before_path: str, after_path: str, out_path: str | None) -> str: + with open(before_path) as f: + b = json.load(f) + with open(after_path) as f: + a = json.load(f) + + def _delta(metric: str) -> str: + bv, av = b.get(metric), a.get(metric) + if bv is None or av is None: + return "n/a" + d = av - bv + pct = (d / bv * 100) if bv else float("nan") + return f"{d:+.2f} ms ({pct:+.1f}%)" + + rows = [ + ("/health p50 (load)", "health_load_p50_ms"), + ("/health p95 (load)", "health_load_p95_ms"), + ("/health p99 (load)", "health_load_p99_ms"), + ("/health max (load)", "health_load_max_ms"), + ("/health p95 (baseline)", "health_baseline_p95_ms"), + ] + # p95/p99 verdict: separate material improvements from noise-level changes. + noise_ms = 25.0 + verdict_lines = [] + for label, key in (("p95", "health_load_p95_ms"), ("p99", "health_load_p99_ms")): + bv, av = b.get(key), a.get(key) + if bv is None or av is None: + verdict_lines.append(f"- {label}: n/a") + continue + if av < bv - noise_ms: + verdict_lines.append( + f"- {label}: IMPROVED (before {bv:.1f}ms -> after {av:.1f}ms, beyond ±{noise_ms:.0f}ms noise)" + ) + elif abs(av - bv) <= noise_ms: + verdict_lines.append( + f"- {label}: NO REGRESSION (before {bv:.1f}ms -> after {av:.1f}ms, within ±{noise_ms:.0f}ms noise)" + ) + else: + verdict_lines.append( + f"- {label}: REGRESSED (before {bv:.1f}ms -> after {av:.1f}ms, beyond {noise_ms:.0f}ms noise)" + ) + + lines = [ + "# Session-server responsiveness: before vs after offload", + "", + f"- before: `{before_path}` commit `{b.get('commit')}` dirty={b.get('dirty')} label={b.get('label')}", + f"- after : `{after_path}` commit `{a.get('commit')}` dirty={a.get('dirty')} label={a.get('label')}", + f"- K={a.get('k_chats')} chats/wave, response body ~{a.get('response_body_bytes', 0) / 1024:.0f} KiB, " + f"blob {a.get('blob_rows')}x{a.get('blob_row_width')} floats", + "", + "| metric | before | after | delta |", + "|---|---|---|---|", + ] + for label, key in rows: + bv, av = b.get(key), a.get(key) + bs = f"{bv:.2f}" if isinstance(bv, (int, float)) else "n/a" + as_ = f"{av:.2f}" if isinstance(av, (int, float)) else "n/a" + lines.append(f"| {label} | {bs} ms | {as_} ms | {_delta(key)} |") + b_validate = f"{b['validate_ms']:.1f} ms" if b.get("validate_ms") is not None else "n/a" + a_validate = f"{a['validate_ms']:.1f} ms" if a.get("validate_ms") is not None else "n/a" + lines += [ + "| chat throughput | " + f"{b.get('chat_throughput_per_s', float('nan')):.1f}/s | {a.get('chat_throughput_per_s', float('nan')):.1f}/s | " + f"{(a.get('chat_throughput_per_s', 0) - b.get('chat_throughput_per_s', 0)):+.1f}/s |", + "| /health errors (load) | " + f"{b.get('health_errors_under_load')} | {a.get('health_errors_under_load')} | — |", + "| per-call parse (json.loads) | " + f"{b.get('parse_ms', float('nan')):.1f} ms | {a.get('parse_ms', float('nan')):.1f} ms | {_delta('parse_ms')} |", + f"| per-call validate (helper) | {b_validate} | {a_validate} | — |", + "", + "## p95/p99 verdict (under load)", + *verdict_lines, + "", + "## Interpretation", + "Under K concurrent heavy responses the inline build serializes every per-response `json.loads` " + "ON the single event loop, so the parses stack: the loop is blocked for roughly K x parse_ms before " + "it can service a queued `/health` probe (the measured before-p95 ~= K x single-parse cost). Offloading " + "the parse to the bounded cpu_executor frees the loop to service `/health` between awaits and to drive " + "more chat waves, so both `/health` tail latency and chat throughput improve markedly at this scale. The " + "GIL still serializes the Python parse work, so this is a responsiveness/tail-latency effect, not an " + "aggregate-CPU gain. Single short windows are noisy (a before-p95 can rest on one stall), so pool samples " + "across iterations and inspect the per-iteration spread before trusting any delta.", + "", + ] + text = "\n".join(lines) + if out_path: + with open(out_path, "w") as f: + f.write(text) + print(f"[bench] wrote {out_path}") + return text + + +def main() -> None: + parser = argparse.ArgumentParser(description="SessionServer event-loop responsiveness benchmark") + parser.add_argument("--json-out", default=None, help="persist the run as a JSON artifact at this path") + parser.add_argument("--label", default=None, help="label recorded in the JSON artifact (e.g. before/after)") + parser.add_argument("--commit", default=None, help="commit SHA recorded in the artifact") + parser.add_argument("--dirty", action="store_true", help="record that the checkout was dirty") + parser.add_argument( + "--compare", + nargs=2, + metavar=("BEFORE", "AFTER"), + default=None, + help="compare two persisted JSON artifacts instead of running the benchmark", + ) + parser.add_argument("--out", default=None, help="write the comparison markdown to this path") + parsed = parser.parse_args() + + if parsed.compare: + print(_compare(parsed.compare[0], parsed.compare[1], parsed.out)) + return + + result = run_bench() + print(_fmt_block(result)) + if parsed.json_out: + _persist(result, parsed.json_out, parsed.label, parsed.commit, parsed.dirty) + + +if __name__ == "__main__": + main() diff --git a/tests/benchmark/bench_session_server_overhead.py b/tests/benchmark/bench_session_server_overhead.py new file mode 100644 index 0000000000..2d8f69c914 --- /dev/null +++ b/tests/benchmark/bench_session_server_overhead.py @@ -0,0 +1,501 @@ +"""CPU-only micro-benchmark for Session Server per-turn overhead. + +This benchmark measures the session-layer work without starting uvicorn, opening +HTTP sockets, or calling a model backend. It drives the same +``SessionRegistry`` / ``LinearTrajectory`` TITO path and the same response +parse/validate helper that the standalone session server uses after the backend +returns bytes. + +Run it directly: + + python tests/benchmark/bench_session_server_overhead.py \ + --sessions 32 --turns 4 --input-tokens 64 --output-tokens 64 --r3-scale 1000 + +The reported "reply latency" is CPU-only overhead for one synthetic turn: +request JSON parse, TITO tokenization, request JSON dump with Miles-owned +``input_ids``, response parse/validate, and writing the record into in-memory +session state. Synthetic response construction is done before the measured loop, +so the numbers do not include model/backend generation. +""" + +from __future__ import annotations + +import argparse +import asyncio +import base64 +import json +import math +import statistics +import time +from dataclasses import dataclass +from pathlib import Path +from types import SimpleNamespace +from typing import Any + +from miles.rollout.session.linear_trajectory import MAX_ASSISTANT_ROLLBACK_STEPS, SessionRegistry +from miles.rollout.session.session_types import SessionRecord +from miles.rollout.session.sessions import _dump_request_body, _parse_and_validate_response, _parse_request_body +from miles.utils.chat_template_utils import get_tito_tokenizer, resolve_fixed_chat_template +from miles.utils.processing_utils import load_tokenizer + +DEFAULT_HF_CHECKPOINT = "Qwen/Qwen3-0.6B" +DEFAULT_TITO_MODEL = "qwen3" +DEFAULT_ALLOWED_APPEND_ROLES = ["user"] + + +@dataclass(frozen=True) +class TurnSpec: + request_body: bytes + response_body: bytes + expected_prompt_token_ids: list[int] + content_input_tokens: int + content_output_tokens: int + completion_tokens: int + r3_raw_bytes: int + r3_json_chars: int + + +def _positive_int(value: str) -> int: + parsed = int(value) + if parsed <= 0: + raise argparse.ArgumentTypeError(f"expected a positive integer, got {value!r}") + return parsed + + +def _non_negative_int(value: str) -> int: + parsed = int(value) + if parsed < 0: + raise argparse.ArgumentTypeError(f"expected a non-negative integer, got {value!r}") + return parsed + + +def _pct(values: list[float], q: float) -> float: + if not values: + return float("nan") + ordered = sorted(values) + idx = max(0, min(len(ordered) - 1, math.ceil(q * len(ordered)) - 1)) + return ordered[idx] + + +def _summary(values: list[float]) -> dict[str, float]: + return { + "mean_ms": statistics.mean(values) if values else float("nan"), + "p50_ms": _pct(values, 0.50), + "p95_ms": _pct(values, 0.95), + "p99_ms": _pct(values, 0.99), + "max_ms": max(values) if values else float("nan"), + } + + +def _find_repeatable_token_id(tokenizer) -> int: + for text in (" x", " a", " the", " token", " 0", "A"): + for token_id in tokenizer.encode(text, add_special_tokens=False): + decoded = tokenizer.decode([token_id], skip_special_tokens=False) + if tokenizer.encode(decoded, add_special_tokens=False) == [token_id]: + return token_id + raise RuntimeError("could not find a repeatable one-token text unit for this tokenizer") + + +def _make_text_with_token_count(tokenizer, token_id: int, token_count: int) -> tuple[str, list[int]]: + if token_count == 0: + return "", [] + token_ids = [token_id] * token_count + text = tokenizer.decode(token_ids, skip_special_tokens=False) + roundtrip_ids = tokenizer.encode(text, add_special_tokens=False) + if len(roundtrip_ids) != token_count: + raise RuntimeError( + "repeatable token changed length after decode/encode roundtrip: " + f"requested={token_count}, actual={len(roundtrip_ids)}" + ) + return text, roundtrip_ids + + +def _make_r3_blob(raw_bytes: int) -> str: + if raw_bytes == 0: + return "" + pattern = bytes(range(251)) + repeats = math.ceil(raw_bytes / len(pattern)) + raw = (pattern * repeats)[:raw_bytes] + return base64.b64encode(raw).decode("ascii") + + +def _completion_token_ids( + tito_tokenizer, tokenizer, messages: list[dict[str, Any]], assistant_message: dict[str, Any] +): + prompt_text = tito_tokenizer.render_messages(messages, add_generation_prompt=True, tokenize=False) + full_text = tito_tokenizer.render_messages( + messages + [assistant_message], + add_generation_prompt=False, + tokenize=False, + ) + if not full_text.startswith(prompt_text): + raise RuntimeError("assistant response does not extend the rendered prompt") + return tokenizer.encode(full_text[len(prompt_text) :], add_special_tokens=False) + + +def _build_response_body( + assistant_message: dict[str, Any], + completion_token_ids: list[int], + r3_blob: str, +) -> bytes: + output_token_logprobs = [ + [-((idx % 1024) + 1) / 1024.0, token_id] for idx, token_id in enumerate(completion_token_ids) + ] + response = { + "id": "synthetic-session-overhead", + "object": "chat.completion", + "created": 0, + "model": "synthetic", + "choices": [ + { + "index": 0, + "message": assistant_message, + "finish_reason": "stop", + "meta_info": { + "completion_tokens": len(completion_token_ids), + "output_token_logprobs": output_token_logprobs, + "routed_experts": r3_blob, + }, + } + ], + } + return json.dumps(response, separators=(",", ":")).encode() + + +def _build_turn_specs(tokenizer, tito_tokenizer, turns: int, input_tokens: int, output_tokens: int, r3_scale: int): + token_id = _find_repeatable_token_id(tokenizer) + history: list[dict[str, Any]] = [] + specs: list[TurnSpec] = [] + + for _turn_idx in range(turns): + input_text, input_content_ids = _make_text_with_token_count(tokenizer, token_id, input_tokens) + output_text, output_content_ids = _make_text_with_token_count(tokenizer, token_id, output_tokens) + + user_message = {"role": "user", "content": input_text} + assistant_message = {"role": "assistant", "content": output_text} + request_messages = [dict(message) for message in history] + [user_message] + + prompt_token_ids = tito_tokenizer.render_messages( + request_messages, + add_generation_prompt=True, + tokenize=True, + ) + completion_token_ids = _completion_token_ids(tito_tokenizer, tokenizer, request_messages, assistant_message) + + r3_token_count = max(0, len(prompt_token_ids) + len(completion_token_ids) - 1) + r3_raw_bytes = r3_token_count * r3_scale + r3_blob = _make_r3_blob(r3_raw_bytes) + + request_body = json.dumps({"messages": request_messages}, separators=(",", ":")).encode() + response_body = _build_response_body(assistant_message, completion_token_ids, r3_blob) + specs.append( + TurnSpec( + request_body=request_body, + response_body=response_body, + expected_prompt_token_ids=prompt_token_ids, + content_input_tokens=len(input_content_ids), + content_output_tokens=len(output_content_ids), + completion_tokens=len(completion_token_ids), + r3_raw_bytes=r3_raw_bytes, + r3_json_chars=len(r3_blob), + ) + ) + + history = request_messages + [assistant_message] + + return specs + + +def _make_registry(tokenizer, tito_tokenizer) -> SessionRegistry: + args = SimpleNamespace(generate_multi_samples=False) + return SessionRegistry(args, tokenizer, tito_tokenizer=tito_tokenizer) + + +async def _run_one_turn(session, registry: SessionRegistry, spec: TurnSpec, samples: dict[str, list[float]]) -> None: + turn_start = time.perf_counter() + + stage_start = time.perf_counter() + request_body = _parse_request_body(spec.request_body) + samples["request_parse_ms"].append((time.perf_counter() - stage_start) * 1000) + + request_body["logprobs"] = True + request_body["return_meta_info"] = True + request_body["return_routed_experts"] = True + request_body["no_stop_trim"] = False + request_messages = request_body["messages"] + + stage_start = time.perf_counter() + async with session.lock: + prompt_token_ids = session.prepare_pretokenized( + request_messages, + tools=request_body.get("tools"), + tito_tokenizer=registry.tito_tokenizer, + ) + expected_num_assistant = session.num_assistant + samples["tokenization_ms"].append((time.perf_counter() - stage_start) * 1000) + + request_body["input_ids"] = prompt_token_ids + stage_start = time.perf_counter() + _dump_request_body(request_body) + samples["request_dump_ms"].append((time.perf_counter() - stage_start) * 1000) + + stage_start = time.perf_counter() + response, assistant_message, completion_token_ids = _parse_and_validate_response(spec.response_body) + samples["response_parse_validate_ms"].append((time.perf_counter() - stage_start) * 1000) + + stage_start = time.perf_counter() + async with session.lock: + if session.num_assistant != expected_num_assistant: + raise RuntimeError("session state changed during a single-threaded benchmark turn") + session.update_pretokenized_state( + request_messages, + assistant_message, + prompt_token_ids=prompt_token_ids, + completion_token_ids=completion_token_ids, + max_trim_tokens=registry.tito_tokenizer.max_trim_tokens, + ) + record = SessionRecord( + timestamp=time.time(), + method="POST", + path="/v1/chat/completions", + status_code=200, + request=request_body, + response=response, + ) + session.append_record(record) + samples["record_store_ms"].append((time.perf_counter() - stage_start) * 1000) + + samples["reply_latency_ms"].append((time.perf_counter() - turn_start) * 1000) + + +async def _validate_specs_once(tokenizer, tito_tokenizer, specs: list[TurnSpec]) -> None: + registry = _make_registry(tokenizer, tito_tokenizer) + session = registry.get_session(registry.create_session()) + + for spec in specs: + request_body = _parse_request_body(spec.request_body) + request_messages = request_body["messages"] + + async with session.lock: + prompt_token_ids = session.prepare_pretokenized( + request_messages, + tools=request_body.get("tools"), + tito_tokenizer=registry.tito_tokenizer, + ) + expected_num_assistant = session.num_assistant + + if prompt_token_ids != spec.expected_prompt_token_ids: + raise RuntimeError( + "TITO prompt ids differ from canonical full render: " + f"expected={len(spec.expected_prompt_token_ids)} tokens, actual={len(prompt_token_ids)} tokens" + ) + + response, assistant_message, completion_token_ids = _parse_and_validate_response(spec.response_body) + + async with session.lock: + if session.num_assistant != expected_num_assistant: + raise RuntimeError("session state changed during benchmark spec validation") + session.update_pretokenized_state( + request_messages, + assistant_message, + prompt_token_ids=prompt_token_ids, + completion_token_ids=completion_token_ids, + max_trim_tokens=registry.tito_tokenizer.max_trim_tokens, + ) + session.append_record( + SessionRecord( + timestamp=time.time(), + method="POST", + path="/v1/chat/completions", + status_code=200, + request=request_body, + response=response, + ) + ) + + +async def _run_workload(tokenizer, tito_tokenizer, specs: list[TurnSpec], num_sessions: int): + registry = _make_registry(tokenizer, tito_tokenizer) + sessions = [registry.get_session(registry.create_session()) for _ in range(num_sessions)] + samples: dict[str, list[float]] = { + "request_parse_ms": [], + "tokenization_ms": [], + "request_dump_ms": [], + "response_parse_validate_ms": [], + "record_store_ms": [], + "reply_latency_ms": [], + } + + wall_start = time.perf_counter() + for spec in specs: + for session in sessions: + await _run_one_turn(session, registry, spec, samples) + wall_s = time.perf_counter() - wall_start + return samples, wall_s + + +def run_bench(args) -> dict[str, Any]: + if args.chat_template_path is not None: + chat_template_path = args.chat_template_path + chat_template_kwargs = None + elif args.tito_model == "default": + chat_template_path = None + chat_template_kwargs = None + else: + chat_template_path, chat_template_kwargs = resolve_fixed_chat_template( + args.tito_model, args.allowed_append_roles + ) + + tokenizer = load_tokenizer(args.hf_checkpoint, chat_template_path=chat_template_path, trust_remote_code=True) + tito_tokenizer = get_tito_tokenizer( + tokenizer, + tokenizer_type=args.tito_model, + chat_template_kwargs=chat_template_kwargs, + allowed_append_roles=args.allowed_append_roles, + ) + specs = _build_turn_specs( + tokenizer, + tito_tokenizer, + turns=args.turns, + input_tokens=args.input_tokens, + output_tokens=args.output_tokens, + r3_scale=args.r3_scale, + ) + + asyncio.run(_validate_specs_once(tokenizer, tito_tokenizer, specs)) + samples, wall_s = asyncio.run(_run_workload(tokenizer, tito_tokenizer, specs, args.sessions)) + total_turns = args.sessions * args.turns + content_tokens = args.sessions * sum(spec.content_input_tokens + spec.content_output_tokens for spec in specs) + output_tokens = args.sessions * sum(spec.content_output_tokens for spec in specs) + completion_tokens = args.sessions * sum(spec.completion_tokens for spec in specs) + retained_r3_raw_bytes = args.sessions * sum( + spec.r3_raw_bytes for spec in specs[-(MAX_ASSISTANT_ROLLBACK_STEPS + 1) :] + ) + + return { + "sessions": args.sessions, + "turns_per_session": args.turns, + "total_turns": total_turns, + "input_tokens_per_turn": args.input_tokens, + "output_tokens_per_turn": args.output_tokens, + "r3_scale_raw_bytes_per_token": args.r3_scale, + "hf_checkpoint": args.hf_checkpoint, + "tito_model": args.tito_model, + "allowed_append_roles": args.allowed_append_roles, + "chat_template_path": chat_template_path, + "chat_template_kwargs": chat_template_kwargs, + "wall_s": wall_s, + "throughput_turns_per_s": total_turns / wall_s if wall_s > 0 else float("nan"), + "throughput_content_tokens_per_s": content_tokens / wall_s if wall_s > 0 else float("nan"), + "throughput_completion_tokens_per_s": completion_tokens / wall_s if wall_s > 0 else float("nan"), + "throughput_output_content_tokens_per_s": output_tokens / wall_s if wall_s > 0 else float("nan"), + "retained_r3_raw_bytes_estimate": retained_r3_raw_bytes, + "turn_specs": [ + { + "turn_index": idx, + "prompt_tokens": len(spec.expected_prompt_token_ids), + "completion_tokens": spec.completion_tokens, + "content_input_tokens": spec.content_input_tokens, + "content_output_tokens": spec.content_output_tokens, + "r3_raw_bytes": spec.r3_raw_bytes, + "r3_json_chars": spec.r3_json_chars, + "request_body_bytes": len(spec.request_body), + "response_body_bytes": len(spec.response_body), + } + for idx, spec in enumerate(specs) + ], + "metrics": {name: _summary(values) for name, values in samples.items()}, + "raw_samples_ms": samples if args.include_raw_samples else None, + } + + +def _fmt_ms_stats(stats: dict[str, float]) -> str: + return ( + f"mean={stats['mean_ms']:.3f}ms p50={stats['p50_ms']:.3f}ms " + f"p95={stats['p95_ms']:.3f}ms p99={stats['p99_ms']:.3f}ms max={stats['max_ms']:.3f}ms" + ) + + +def _fmt_block(result: dict[str, Any]) -> str: + metrics = result["metrics"] + last_spec = result["turn_specs"][-1] + lines = [ + "=" * 72, + "Session Server CPU overhead benchmark", + "=" * 72, + f" sessions x turns : {result['sessions']} x {result['turns_per_session']} " + f"({result['total_turns']} turns)", + f" content tokens / turn : input={result['input_tokens_per_turn']} " + f"output={result['output_tokens_per_turn']}", + f" r3 raw bytes / token : {result['r3_scale_raw_bytes_per_token']}", + f" tokenizer / TITO : {result['hf_checkpoint']} / {result['tito_model']}", + f" final-turn prompt/completion : {last_spec['prompt_tokens']} / {last_spec['completion_tokens']} tokens", + f" final-turn response body : {last_spec['response_body_bytes'] / 1024:.1f} KiB", + f" retained r3 estimate : {result['retained_r3_raw_bytes_estimate'] / 1024 / 1024:.1f} MiB raw", + "-" * 72, + f" tokenization : {_fmt_ms_stats(metrics['tokenization_ms'])}", + f" record store : {_fmt_ms_stats(metrics['record_store_ms'])}", + f" reply latency : {_fmt_ms_stats(metrics['reply_latency_ms'])}", + "-" * 72, + f" request parse : {_fmt_ms_stats(metrics['request_parse_ms'])}", + f" request dump : {_fmt_ms_stats(metrics['request_dump_ms'])}", + f" response parse+validate : {_fmt_ms_stats(metrics['response_parse_validate_ms'])}", + "-" * 72, + f" wall clock : {result['wall_s']:.3f}s", + f" throughput : {result['throughput_turns_per_s']:.1f} turns/s", + f" content-token throughput : {result['throughput_content_tokens_per_s']:.1f} tokens/s", + f" completion-token throughput : {result['throughput_completion_tokens_per_s']:.1f} tokens/s", + "=" * 72, + ] + return "\n".join(lines) + + +def _write_json(result: dict[str, Any], path: str) -> None: + payload = dict(result) + if payload.get("raw_samples_ms") is None: + payload.pop("raw_samples_ms", None) + with Path(path).open("w") as f: + json.dump(payload, f, indent=2) + print(f"[bench] wrote {path}") + + +def main() -> None: + parser = argparse.ArgumentParser(description="CPU-only Session Server overhead benchmark") + parser.add_argument("--sessions", type=_positive_int, default=32, help="number of sessions to create") + parser.add_argument("--turns", type=_positive_int, default=4, help="turns per session") + parser.add_argument("--input-tokens", type=_non_negative_int, default=64, help="new user-content tokens per turn") + parser.add_argument( + "--output-tokens", + type=_non_negative_int, + default=64, + help="assistant-content tokens per turn", + ) + parser.add_argument( + "--r3-scale", + type=_non_negative_int, + default=1000, + help="raw routed_experts bytes per accumulated token", + ) + parser.add_argument("--hf-checkpoint", default=DEFAULT_HF_CHECKPOINT, help="tokenizer checkpoint or local path") + parser.add_argument("--tito-model", default=DEFAULT_TITO_MODEL, help="TITO tokenizer family") + parser.add_argument( + "--allowed-append-roles", + nargs="+", + default=DEFAULT_ALLOWED_APPEND_ROLES, + help="roles allowed after the pretokenized prefix", + ) + parser.add_argument("--chat-template-path", default=None, help="explicit chat template path") + parser.add_argument("--json-out", default=None, help="persist the run as a JSON artifact") + parser.add_argument( + "--include-raw-samples", action="store_true", help="include every per-turn sample in JSON output" + ) + args = parser.parse_args() + + result = run_bench(args) + print(_fmt_block(result)) + if args.json_out: + _write_json(result, args.json_out) + + +if __name__ == "__main__": + main() diff --git a/tests/fast/router/test_session_race_conditions.py b/tests/fast/router/test_session_race_conditions.py index 95c6a69cba..8b80ff3c46 100644 --- a/tests/fast/router/test_session_race_conditions.py +++ b/tests/fast/router/test_session_race_conditions.py @@ -1,24 +1,30 @@ """E2E session stress tests. -Contract under test (with split-lock / session.closing): -- Phase 1 (prepare) and Phase 3 (state update) hold session.lock briefly; - Phase 2 (proxy to SGLang) does NOT hold the lock. -- Concurrent same-session requests can overlap at the backend (Phase 2), - but state updates (Phase 3) are serialized; stale-update guard - (expected_num_assistant check) ensures only one concurrent writer wins. -- Different sessions can run in parallel (no global lock). -- Per-session clients can run turn-by-turn without idle gaps while global load stays parallel. -- Delete marks session.closing=True, acquires session.lock, then removes. - Because the lock is not held during Phase 2, delete can proceed while a - chat request is mid-proxy; the chat's Phase 3 will see closing=True and - skip the state update gracefully. -- Chat requests to a closing session get 404 immediately (pre-lock check). -- Chat requests arriving while delete waits for lock get 404 (double-check after lock). +Contract under test (per-session in-flight gate): +- A session admits at most one in-flight chat completion. A second concurrent + same-session chat fast-fails 409 ("session already has an in-flight chat + completion") at slot-claim time, before the body is read/parsed and before + the backend is hit; it never reaches the backend. +- The in-flight slot is released on every exit path (success, malformed JSON, + prepare error, upstream non-200, transport 502, response validation failure, + state-update error, client cancel/disconnect), and only by the request that + claimed it: a request that got 409/404 does not clear the owner's slot. +- Different sessions still run in parallel (no global lock); per-session clients + can run turn-by-turn without idle gaps while global load stays parallel. +- Delete marks session.closing=True, acquires session.lock, then removes. The + lock is not held during the proxy, so delete can proceed while a chat is + mid-proxy; that chat's commit sees closing=True and skips the state update. +- Chat to a closing session gets 404 immediately, and closing (404) has + priority over busy (409). - Concurrent deletes on the same session: second delete gets 404. +- The upstream body is passed through faithfully; stale framing headers + (content-length / transfer-encoding / content-encoding) are stripped. """ from __future__ import annotations +import asyncio +import logging import time from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager @@ -26,8 +32,17 @@ from unittest.mock import patch import requests - +from starlette.responses import Response + +from miles.rollout.session.linear_trajectory import LinearTrajectory +from miles.rollout.session.session_errors import ( + MessageValidationError, + SessionInvariantError, + TokenizationError, + UpstreamResponseError, +) from miles.rollout.session.session_server import SessionServer +from miles.rollout.session.sessions import _parse_and_validate_response from miles.utils.http_utils import find_available_port from miles.utils.test_utils.mock_sglang_server import MockSGLangServer, ProcessResult, with_mock_server from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer @@ -55,9 +70,67 @@ def patched_chat_response(self, payload: dict) -> dict: return patch.object(MockSGLangServer, "_compute_chat_completions_response", new=patched_chat_response) +def _patch_mock_chat_response_bad_first(): + """Like `_patch_mock_chat_response`, but the FIRST chat response is missing + `meta_info` so the session server raises UpstreamResponseError (502). The + second and later responses are valid, so a retry after the failure can 200. + """ + original_chat_response = MockSGLangServer._compute_chat_completions_response + state = {"calls": 0} + + def patched_chat_response(self, payload: dict) -> dict: + response = original_chat_response(self, payload) + choice = response["choices"][0] + logprobs_content = choice["logprobs"]["content"] + output_token_logprobs = [ + (item["logprob"], self.tokenizer.convert_tokens_to_ids(item["token"])) for item in logprobs_content + ] + choice["meta_info"] = { + "output_token_logprobs": output_token_logprobs, + "completion_tokens": len(output_token_logprobs), + } + state["calls"] += 1 + if state["calls"] == 1: + # Strip meta_info from the first response only -> 502 on first chat. + choice.pop("meta_info", None) + return response + + return patch.object(MockSGLangServer, "_compute_chat_completions_response", new=patched_chat_response) + + +def _patch_mock_chat_response_bad_logprob_first(bad_logprob="bad-logprob"): + """Like `_patch_mock_chat_response`, but the FIRST chat response carries a + non-numeric logprob value (entry[0]) so the session server rejects it with + UpstreamResponseError (502) instead of letting the bad value flow into + Sample.rollout_log_probs downstream. Later responses are valid. + """ + original_chat_response = MockSGLangServer._compute_chat_completions_response + state = {"calls": 0} + + def patched_chat_response(self, payload: dict) -> dict: + response = original_chat_response(self, payload) + choice = response["choices"][0] + logprobs_content = choice["logprobs"]["content"] + output_token_logprobs = [ + (item["logprob"], self.tokenizer.convert_tokens_to_ids(item["token"])) for item in logprobs_content + ] + state["calls"] += 1 + if state["calls"] == 1 and output_token_logprobs: + # Corrupt the first response's leading logprob value only. + _logprob, token_id = output_token_logprobs[0][0], output_token_logprobs[0][1] + output_token_logprobs[0] = (bad_logprob, token_id) + choice["meta_info"] = { + "output_token_logprobs": output_token_logprobs, + "completion_tokens": len(output_token_logprobs), + } + return response + + return patch.object(MockSGLangServer, "_compute_chat_completions_response", new=patched_chat_response) + + @contextmanager -def _router_env(process_fn, *, latency: float = 0.0): - with _patch_mock_chat_response(): +def _router_env(process_fn, *, latency: float = 0.0, response_patch=None): + with (response_patch or _patch_mock_chat_response)(): with with_mock_server(model_name=HF_CHECKPOINT, process_fn=process_fn, latency=latency) as backend: args = SimpleNamespace( miles_router_timeout=30, @@ -93,46 +166,97 @@ def _chat(url: str, session_id: str, payload: dict, timeout: float = 20.0) -> re ) -class TestSessionConcurrencyContracts: - def test_same_session_concurrent_requests_reach_backend(self): - """With the split-lock, same-session requests CAN overlap at the backend. +def _wait_for_backend_requests(backend, count: int, timeout: float = 5.0) -> None: + """Block until the backend has logged exactly `count` requests. + + Using a backend-arrival barrier instead of sleeping makes the "request is + parked in proxy" precondition deterministic: once the entry is logged the + owner has claimed the in-flight slot and is sitting in the latency window. + """ + deadline = time.time() + timeout + while time.time() < deadline: + if len(backend.request_log) == count: + return + time.sleep(0.005) + raise AssertionError(f"backend did not reach {count} requests in {timeout}s (saw {len(backend.request_log)})") - Phase 2 (proxy) runs without the lock, so concurrent requests are not - serialized at the backend level. Phase 3 state updates are still - serialized; the stale-update guard ensures only one writer wins per - generation, so no state corruption occurs. + +class TestSessionConcurrencyContracts: + def test_same_session_second_chat_returns_409(self): + """A session admits one in-flight chat; concurrents fast-fail 409. + + Park chat A in proxy (held by backend latency) and confirm via the + arrival barrier that A holds the slot. Contenders B/C/D on the same + session must each get 409 without ever reaching the backend, and they + must not release A's slot. The 409 is returned at slot-claim time, + before the contender's body is read and before the backend is hit, so + each contender returns near-instantly rather than waiting out A's + backend latency. After A finishes 200, the slot is free and a fresh + same-session chat succeeds. """ + # Latency comfortably larger than the time to fire three contenders. + latency = 0.5 + def process_fn(prompt: str) -> ProcessResult: return ProcessResult(text="concurrent-ok", finish_reason="stop") - with _router_env(process_fn, latency=0.2) as env: + with _router_env(process_fn, latency=latency) as env: session_id = _create_session(env.url) + payload = {"messages": [{"role": "user", "content": "park-in-proxy"}]} - # Warm up one assistant checkpoint so repeated identical retry payloads are valid. - warmup_payload = {"messages": [{"role": "user", "content": "warmup"}]} - warmup_resp = _chat(env.url, session_id, warmup_payload) - assert warmup_resp.status_code == 200 - assistant = warmup_resp.json()["choices"][0]["message"] + env.backend.reset_stats() + with ThreadPoolExecutor(max_workers=1) as pool: + # Fire A and wait until it is parked in proxy holding the slot. + chat_a = pool.submit(_chat, env.url, session_id, payload, 30.0) + _wait_for_backend_requests(env.backend, 1) + + # Contenders on the SAME session while A is parked -> 409 each. + # Each contender's wall-clock is measured: the gate rejects at + # slot-claim time (before body read, before backend), so a + # contender must NOT block for the owner's backend latency. + contender_codes = [] + contender_elapsed_s = [] + for _ in range(3): + t0 = time.perf_counter() + resp = _chat(env.url, session_id, payload, timeout=10.0) + contender_elapsed_s.append(time.perf_counter() - t0) + contender_codes.append(resp.status_code) + assert resp.status_code == 409, f"contender should be 409, got {resp.status_code}" + assert resp.json()["error"] == "session already has an in-flight chat completion" + + # Each 409 returned well under the owner's backend latency, + # proving the contender did not block on A's parked proxy. Half + # the latency is a generous ceiling vs. the near-instant gate. + assert all(dt < latency / 2 for dt in contender_elapsed_s), ( + f"a contender blocked on the owner's backend latency " + f"(latency={latency}s, contender elapsed={contender_elapsed_s})" + ) + + # Contenders never reached the backend: still exactly A's request. + assert len(env.backend.request_log) == 1 + + # A finishes 200; a 409 contender did NOT clear A's slot. + chat_a_resp = chat_a.result(timeout=30.0) - retry_payload = { + assert chat_a_resp.status_code == 200 + assert contender_codes == [409, 409, 409] + assert len(env.backend.request_log) == 1 + + # Slot was released on A's success: a fresh same-session chat works. + # The follow-up must be an append-only continuation of A's committed + # trajectory, so build it from A's assistant message. + assistant = chat_a_resp.json()["choices"][0]["message"] + follow_up = { "messages": [ - {"role": "user", "content": "warmup"}, + {"role": "user", "content": "park-in-proxy"}, assistant, - {"role": "system", "content": "retry-from-assistant-checkpoint"}, + {"role": "system", "content": "continue-after-A"}, ] } - - env.backend.reset_stats() - with ThreadPoolExecutor(max_workers=4) as pool: - futures = [pool.submit(_chat, env.url, session_id, retry_payload) for _ in range(4)] - responses = [f.result(timeout=30.0) for f in futures] - - # All requests should succeed (200) — no 500s. - assert all(resp.status_code == 200 for resp in responses) - assert len(env.backend.request_log) == 4 - # With split-lock, concurrent backend access is expected (not == 1). - assert env.backend.max_concurrent >= 1 + after = _chat(env.url, session_id, follow_up, timeout=20.0) + assert after.status_code == 200 + assert len(env.backend.request_log) == 2 def test_different_sessions_can_run_in_parallel(self): def process_fn(prompt: str) -> ProcessResult: @@ -192,10 +316,11 @@ def run_session_worker(session_id: str, idx: int) -> bool: assert env.backend.max_concurrent >= 4 def test_delete_can_proceed_while_chat_is_mid_proxy(self): - """With split-lock, delete can acquire the lock while chat is in Phase 2. + """Delete can acquire the lock while a chat is mid-proxy. - The inflight chat's Phase 3 sees session.closing=True and skips - state update gracefully. Both chat and delete complete without error. + The lock is not held during the proxy, so delete proceeds; the in-flight + chat's commit step sees session.closing=True and skips the state update. + Both chat and delete complete without error. """ def process_fn(prompt: str) -> ProcessResult: @@ -235,10 +360,10 @@ def test_chat_during_delete_returns_404(self): """Chat requests arriving after delete sets closing=True get 404. Timeline: - 1. Chat A starts, acquires lock (Phase 1), releases it, proxying (Phase 2) - 2. Delete arrives, sets session.closing=True, acquires lock, removes session - 3. Chat B arrives, sees session.closing=True, returns 404 immediately - 4. Chat A's Phase 3 sees closing=True, skips state update, returns 200 + 1. Chat A starts, claims the in-flight slot, and is proxying to backend. + 2. Delete arrives, sets session.closing=True, acquires lock, removes session. + 3. Chat B arrives, sees session.closing=True, returns 404 immediately. + 4. Chat A's commit sees closing=True, skips the state update, returns 200. """ def process_fn(prompt: str) -> ProcessResult: @@ -348,11 +473,13 @@ def process_fn(prompt: str) -> ProcessResult: get_resp = requests.get(f"{env.url}/sessions/{session_id}", timeout=5.0) assert get_resp.status_code == 404 - def test_multiple_chats_queued_then_delete(self): - """Multiple chat requests queued behind session.lock, then delete. + def test_concurrent_chats_then_delete(self): + """Concurrent same-session chats do not queue under the gate. - After delete marks closing=True, queued chats that acquire the lock - should check closing and return 404. + One chat parks in proxy holding the slot; a couple more same-session + chats fired concurrently get 409 (gate, before the backend). A delete + then returns 204 and the parked chat returns 200 (commit skips the + state update on closing). No 500s; codes are a subset of {200,409,404}. """ def process_fn(prompt: str) -> ProcessResult: @@ -362,35 +489,33 @@ def process_fn(prompt: str) -> ProcessResult: session_id = _create_session(env.url) payload = {"messages": [{"role": "user", "content": "queued"}]} - with ThreadPoolExecutor(max_workers=6) as pool: - # Fire 3 chats (first holds lock, others queue) - chat_futures = [pool.submit(_chat, env.url, session_id, payload, 30.0) for _ in range(3)] + with ThreadPoolExecutor(max_workers=2) as pool: + # Park the owner chat in proxy holding the in-flight slot. + owner = pool.submit(_chat, env.url, session_id, payload, 30.0) + _wait_for_backend_requests(env.backend, 1) - # Wait for first to reach backend - deadline = time.time() + 5.0 - while time.time() < deadline: - if env.backend.request_log: - break - time.sleep(0.01) + # Concurrent same-session chats hit the gate -> 409. + contender_codes = [_chat(env.url, session_id, payload, timeout=10.0).status_code for _ in range(2)] - # Now delete - sets closing, waits for first chat to finish + # Delete the session while the owner is still parked. delete_future = pool.submit( requests.delete, f"{env.url}/sessions/{session_id}", timeout=30.0, ) - results = [f.result(timeout=30.0) for f in chat_futures] + owner_resp = owner.result(timeout=30.0) delete_resp = delete_future.result(timeout=30.0) assert delete_resp.status_code == 204 + assert owner_resp.status_code == 200 + assert contender_codes == [409, 409] - # At least one chat must succeed (the one holding the lock when - # delete arrived). Others may get 200 (acquired lock before - # closing) or 404 (saw closing=True). No 500s allowed. - status_codes = [r.status_code for r in results] - assert all(c in (200, 404) for c in status_codes), f"Unexpected status codes: {status_codes}" - assert 200 in status_codes, f"Expected at least one 200, got {status_codes}" + # No 500s; every chat code is in the allowed set, and exactly one + # chat (the owner) reached the backend. + all_chat_codes = [owner_resp.status_code, *contender_codes] + assert all(c in (200, 409, 404) for c in all_chat_codes), f"Unexpected status codes: {all_chat_codes}" + assert len(env.backend.request_log) == 1 def test_rapid_create_chat_delete_cycles(self): """Rapidly create, chat, and delete sessions to stress the lifecycle. @@ -420,3 +545,691 @@ def lifecycle_cycle(idx: int) -> bool: results = [f.result(timeout=60.0) for f in futures] assert all(results) + + +class TestSlotReleaseAfterError: + """The in-flight slot is freed on failing exit paths: a fresh legal chat on + the same session after a failure must succeed (200, not 409). + """ + + def test_slot_released_after_malformed_request_json(self): + """Malformed JSON body errors (500-class) before the backend; the slot + is released so a subsequent normal chat on the same session returns 200. + """ + + def process_fn(prompt: str) -> ProcessResult: + return ProcessResult(text="ok", finish_reason="stop") + + with _router_env(process_fn) as env: + session_id = _create_session(env.url) + + bad = requests.post( + f"{env.url}/sessions/{session_id}/v1/chat/completions", + data="{not json", + headers={"content-type": "application/json"}, + timeout=10.0, + ) + assert bad.status_code >= 500 + # The malformed body never reached the backend. + assert len(env.backend.request_log) == 0 + + good = _chat(env.url, session_id, {"messages": [{"role": "user", "content": "hi"}]}, timeout=20.0) + assert good.status_code == 200 + assert len(env.backend.request_log) == 1 + + def test_slot_released_after_response_validation_failure(self): + """An upstream response missing meta_info raises UpstreamResponseError + (502); the slot is released so a subsequent normal chat returns 200. + """ + + def process_fn(prompt: str) -> ProcessResult: + return ProcessResult(text="ok", finish_reason="stop") + + with _router_env(process_fn, response_patch=_patch_mock_chat_response_bad_first) as env: + session_id = _create_session(env.url) + + bad = _chat(env.url, session_id, {"messages": [{"role": "user", "content": "hi"}]}, timeout=20.0) + assert bad.status_code == 502 + assert len(env.backend.request_log) == 1 + + good = _chat(env.url, session_id, {"messages": [{"role": "user", "content": "hi again"}]}, timeout=20.0) + assert good.status_code == 200 + assert len(env.backend.request_log) == 2 + + def test_bad_logprob_value_rejected_without_committing(self): + """A successful (200) upstream response carrying a non-numeric logprob + value must return 502, commit NOTHING to the session (no record, no + accumulated token id), and release the slot so the next legal chat 200s. + Guards against the bad logprob flowing into rollout_log_probs downstream. + """ + + def process_fn(prompt: str) -> ProcessResult: + return ProcessResult(text="ok", finish_reason="stop") + + with _router_env(process_fn, response_patch=_patch_mock_chat_response_bad_logprob_first) as env: + session_id = _create_session(env.url) + + bad = _chat(env.url, session_id, {"messages": [{"role": "user", "content": "hi"}]}, timeout=20.0) + assert bad.status_code == 502 + assert len(env.backend.request_log) == 1 + + # Nothing from the malformed response was committed. + state = requests.get(f"{env.url}/sessions/{session_id}", timeout=5.0).json() + assert state["records"] == [] + assert state["metadata"]["accumulated_token_ids"] == [] + + good = _chat(env.url, session_id, {"messages": [{"role": "user", "content": "hi again"}]}, timeout=20.0) + assert good.status_code == 200 + assert len(env.backend.request_log) == 2 + + after = requests.get(f"{env.url}/sessions/{session_id}", timeout=5.0).json() + assert len(after["records"]) == 1 + assert len(after["metadata"]["accumulated_token_ids"]) > 0 + + def test_non_finite_logprob_rejected_without_committing(self): + """Raw upstream NaN is accepted by Python's default json decoder; the + session boundary must still reject it before recording training data. + """ + + def process_fn(prompt: str) -> ProcessResult: + return ProcessResult(text="ok", finish_reason="stop") + + original_do_proxy = SessionServer.do_proxy + state = {"calls": 0} + + async def one_shot_do_proxy(self, request, path, body=None, headers=None): + state["calls"] += 1 + if state["calls"] == 1: + return { + "request_body": body, + "response_body": ( + b'{"choices":[{"message":{"role":"assistant","content":"ok"},' + b'"meta_info":{"output_token_logprobs":[[NaN,562]],"completion_tokens":1}}]}' + ), + "status_code": 200, + "headers": {"content-type": "application/json"}, + } + return await original_do_proxy(self, request, path, body=body, headers=headers) + + with _router_env(process_fn) as env: + session_id = _create_session(env.url) + with patch.object(SessionServer, "do_proxy", one_shot_do_proxy): + bad = _chat(env.url, session_id, {"messages": [{"role": "user", "content": "hi"}]}, timeout=20.0) + assert bad.status_code == 502 + + session_state = requests.get(f"{env.url}/sessions/{session_id}", timeout=5.0).json() + assert session_state["records"] == [] + assert session_state["metadata"]["accumulated_token_ids"] == [] + + good = _chat( + env.url, + session_id, + {"messages": [{"role": "user", "content": "hi again"}]}, + timeout=20.0, + ) + assert good.status_code == 200 + + +def _normal_messages(content: str) -> dict: + return {"messages": [{"role": "user", "content": content}]} + + +class TestSlotReleaseInjectedFailures: + """Deterministic slot-release coverage for the remaining enumerated failure + paths. Each test injects a one-shot failure on the first same-session chat, + then proves the slot was released: the NEXT same-session chat returns 200 + (no residual 409). Failures are injected via class-level mock patches so the + server thread's instance/session is affected; ordering is enforced by a + per-test call counter, not by sleeps. + """ + + def test_slot_released_after_prepare_validation_error(self): + """`prepare_pretokenized` raising MessageValidationError (400) before the + backend releases the slot; the next normal chat returns 200. + """ + + def process_fn(prompt: str) -> ProcessResult: + return ProcessResult(text="ok", finish_reason="stop") + + original_prepare = LinearTrajectory.prepare_pretokenized + state = {"calls": 0} + + def one_shot_prepare(self, *args, **kwargs): + state["calls"] += 1 + if state["calls"] == 1: + raise MessageValidationError("injected prepare failure") + return original_prepare(self, *args, **kwargs) + + with _router_env(process_fn) as env: + session_id = _create_session(env.url) + with patch.object(LinearTrajectory, "prepare_pretokenized", one_shot_prepare): + bad = _chat(env.url, session_id, _normal_messages("hi"), timeout=20.0) + assert bad.status_code == 400 + # prepare failed before the backend was hit. + assert len(env.backend.request_log) == 0 + + good = _chat(env.url, session_id, _normal_messages("hi again"), timeout=20.0) + assert good.status_code == 200 + assert len(env.backend.request_log) == 1 + + def test_slot_released_after_upstream_non_200(self): + """An upstream non-200 passes through (400) without recording; the slot + is released so the next normal chat returns 200. + """ + + def process_fn(prompt: str) -> ProcessResult: + return ProcessResult(text="ok", finish_reason="stop") + + original_do_proxy = SessionServer.do_proxy + state = {"calls": 0} + + async def one_shot_do_proxy(self, request, path, body=None, headers=None): + state["calls"] += 1 + if state["calls"] == 1: + return { + "request_body": body, + "response_body": b'{"error":"bad"}', + "status_code": 400, + "headers": {"content-type": "application/json"}, + } + return await original_do_proxy(self, request, path, body=body, headers=headers) + + with _router_env(process_fn) as env: + session_id = _create_session(env.url) + with patch.object(SessionServer, "do_proxy", one_shot_do_proxy): + bad = _chat(env.url, session_id, _normal_messages("hi"), timeout=20.0) + assert bad.status_code == 400 + assert bad.json()["error"] == "bad" + + good = _chat(env.url, session_id, _normal_messages("hi again"), timeout=20.0) + assert good.status_code == 200 + + def test_slot_released_after_transport_502(self): + """A transport-error 502 (do_proxy's real error shape) passes through; the + slot is released so the next normal chat returns 200. + """ + + def process_fn(prompt: str) -> ProcessResult: + return ProcessResult(text="ok", finish_reason="stop") + + original_do_proxy = SessionServer.do_proxy + state = {"calls": 0} + + async def one_shot_do_proxy(self, request, path, body=None, headers=None): + state["calls"] += 1 + if state["calls"] == 1: + return { + "request_body": body, + "response_body": b'{"error":"backend transport error"}', + "status_code": 502, + "headers": {"content-type": "application/json"}, + } + return await original_do_proxy(self, request, path, body=body, headers=headers) + + with _router_env(process_fn) as env: + session_id = _create_session(env.url) + with patch.object(SessionServer, "do_proxy", one_shot_do_proxy): + bad = _chat(env.url, session_id, _normal_messages("hi"), timeout=20.0) + assert bad.status_code == 502 + + good = _chat(env.url, session_id, _normal_messages("hi again"), timeout=20.0) + assert good.status_code == 200 + + def test_slot_released_after_state_update_error(self): + """`update_pretokenized_state` raising TokenizationError (500) after a 200 + proxy releases the slot; the next normal chat returns 200. + """ + + def process_fn(prompt: str) -> ProcessResult: + return ProcessResult(text="ok", finish_reason="stop") + + original_update = LinearTrajectory.update_pretokenized_state + state = {"calls": 0} + + def one_shot_update(self, *args, **kwargs): + state["calls"] += 1 + if state["calls"] == 1: + raise TokenizationError("injected state-update failure") + return original_update(self, *args, **kwargs) + + with _router_env(process_fn) as env: + session_id = _create_session(env.url) + with patch.object(LinearTrajectory, "update_pretokenized_state", one_shot_update): + bad = _chat(env.url, session_id, _normal_messages("hi"), timeout=20.0) + assert bad.status_code == 500 + # The proxy did run; the failure is in the commit step. + assert len(env.backend.request_log) == 1 + + good = _chat(env.url, session_id, _normal_messages("hi again"), timeout=20.0) + assert good.status_code == 200 + assert len(env.backend.request_log) == 2 + + def test_invariant_mismatch_returns_500_and_releases_slot(self, caplog): + """If session.num_assistant changes between the prepare-segment capture + and the commit-segment check, the commit raises SessionInvariantError + (500) instead of silently skipping the state update; the slot is then + released so a fresh session's normal chat returns 200. + + The mismatch is forced deterministically (no sleeps): prepare_pretokenized + stashes the live session object, then do_proxy mutates that session's + num_assistant on its first call — an out-of-band change the in-flight + gate is supposed to make impossible. The commit check then trips. + + The server runs in a uvicorn thread in the same process, so its `logging` + records propagate to the root logger and `caplog` captures them. We assert + the commit segment emitted an ERROR naming the invariant and the session + id (matching the `logger.error(...)` in `chat_completions`), and that a + normal fresh chat never emits that ERROR (no false trigger). + """ + + def process_fn(prompt: str) -> ProcessResult: + return ProcessResult(text="ok", finish_reason="stop") + + holder: dict = {} + original_prepare = LinearTrajectory.prepare_pretokenized + original_do_proxy = SessionServer.do_proxy + proxy_state = {"calls": 0} + + def stashing_prepare(self, *args, **kwargs): + result = original_prepare(self, *args, **kwargs) + # Stash the live session so the do_proxy wrapper can mutate it after + # expected_num_assistant has already been captured. + holder["session"] = self + return result + + async def mutating_do_proxy(self, request, path, body=None, headers=None): + proxy_state["calls"] += 1 + if proxy_state["calls"] == 1: + # Out-of-band state change mid-proxy, after the prepare segment + # captured expected_num_assistant -> commit-time check must trip. + holder["session"].num_assistant += 1 + return await original_do_proxy(self, request, path, body=body, headers=headers) + + with _router_env(process_fn) as env: + session_id = _create_session(env.url) + with ( + patch.object(LinearTrajectory, "prepare_pretokenized", stashing_prepare), + patch.object(SessionServer, "do_proxy", mutating_do_proxy), + caplog.at_level(logging.ERROR), + ): + bad = _chat(env.url, session_id, _normal_messages("hi"), timeout=20.0) + assert bad.status_code == 500, f"invariant mismatch must be 500, got {bad.status_code}" + # Body carries SessionInvariantError's message (the gate did not + # silently 200-skip the commit). + error = bad.json()["error"] + assert "session state changed under the in-flight gate" in error + assert SessionInvariantError(error).status_code == 500 + # The proxy ran (the mismatch is raised in the commit step after it). + assert len(env.backend.request_log) == 1 + + # The commit segment logged an ERROR naming the invariant and the + # session id (matching the `logger.error(...)` in chat_completions). + invariant_errors = [ + rec + for rec in caplog.records + if rec.levelno == logging.ERROR + and "invariant" in rec.getMessage() + and session_id in rec.getMessage() + ] + assert invariant_errors, ( + "expected an ERROR log naming the invariant and the session id; " + f"saw records: {[(r.levelname, r.getMessage()) for r in caplog.records]}" + ) + + # Patches removed. The slot must have been released on the 500 exit + # path: a follow-up on the SAME session is NOT a stuck 409. (Its + # bumped num_assistant left the prefix state inconsistent, so the + # legal-continuation code may be a non-200 error, but never 409.) + same_session_followup = _chat(env.url, session_id, _normal_messages("hi same"), timeout=20.0) + assert same_session_followup.status_code != 409, ( + f"slot was not released after the 500: same-session follow-up got " + f"{same_session_followup.status_code} (a stuck 409)" + ) + + # A FRESH session's normal chat returns 200, confirming the event + # loop stayed live and the gate path recovers cleanly. Capture ERROR + # records during this clean chat to prove no false invariant trigger. + fresh_session_id = _create_session(env.url) + with caplog.at_level(logging.ERROR): + caplog.clear() + good = _chat(env.url, fresh_session_id, _normal_messages("hi again"), timeout=20.0) + assert good.status_code == 200 + false_triggers = [ + rec for rec in caplog.records if rec.levelno == logging.ERROR and "invariant" in rec.getMessage() + ] + assert not false_triggers, ( + "a normal chat must not emit the invariant ERROR; " + f"saw: {[r.getMessage() for r in false_triggers]}" + ) + + def test_slot_released_after_client_cancel_mid_proxy(self): + """If the handler is cancelled mid-proxy (client disconnect), the request + errors at the client side but the `finally` releases the slot; the next + normal same-session chat returns 200. + """ + + def process_fn(prompt: str) -> ProcessResult: + return ProcessResult(text="ok", finish_reason="stop") + + original_do_proxy = SessionServer.do_proxy + state = {"calls": 0} + + async def one_shot_do_proxy(self, request, path, body=None, headers=None): + state["calls"] += 1 + if state["calls"] == 1: + raise asyncio.CancelledError() + return await original_do_proxy(self, request, path, body=body, headers=headers) + + with _router_env(process_fn) as env: + session_id = _create_session(env.url) + with patch.object(SessionServer, "do_proxy", one_shot_do_proxy): + try: + _chat(env.url, session_id, _normal_messages("hi"), timeout=20.0) + except requests.exceptions.RequestException: + pass + + good = _chat(env.url, session_id, _normal_messages("hi again"), timeout=20.0) + assert good.status_code == 200 + + def test_slot_released_after_real_client_disconnect_mid_proxy(self): + """A REAL client disconnect mid-proxy releases the slot and leaves the + session usable by its next LEGAL continuation (200), not a stuck 409. + + Harness note: a real client socket abort does NOT cancel the in-flight + handler in this harness (uvicorn + httpx backend). The handler stays + parked in the proxy until the backend responds, then runs to completion + and commits the (abandoned) backend response; the slot is released on + that normal-completion `finally` (the `claimed` guard), not via an early + cancellation. So after a disconnect the session has advanced exactly one + committed turn. Recovery is therefore via a LEGAL append-only + continuation of that committed turn — a fresh first-turn message would + instead fail the append-only prefix check (400), which is expected, not + a slot leak. + + We (1) fire a chat with a tiny client timeout against a high backend + `latency` so `requests` aborts mid-proxy; (2) bounded-poll until a probe + chat is no longer 409 (slot released — no fixed ordering sleep); then + (3) read the committed turn via GET /sessions/{id} and prove a legal + append-only continuation (committed user + assistant, then an allowed + appended `system` message) returns 200. + """ + + # High latency: the client times out long before the backend responds, + # so the abort lands while the handler is parked mid-proxy. + latency = 1.0 + + def process_fn(prompt: str) -> ProcessResult: + return ProcessResult(text="ok", finish_reason="stop") + + with _router_env(process_fn, latency=latency) as env: + session_id = _create_session(env.url) + + env.backend.reset_stats() + # Fire the owner with a tiny timeout so `requests` aborts the + # connection while the handler is parked in the latency window. + try: + _chat(env.url, session_id, _normal_messages("disconnect-me"), timeout=0.1) + except requests.exceptions.RequestException: + pass + # Confirm the request actually reached the backend (handler was + # genuinely parked mid-proxy when the client aborted). + _wait_for_backend_requests(env.backend, 1) + + # The same-session slot must free up. Bounded-poll a probe chat until + # it is NOT a 409: the abort did not cancel the parked handler, so the + # gate stays busy until the owner completes, then releases. We use an + # invalid (empty) probe payload so the probe itself never commits a + # turn — it only reads the gate state (409 while busy, non-409 once + # released) and leaves the committed trajectory untouched. + deadline = time.time() + 10.0 + released = False + last_status = None + while time.time() < deadline: + last_status = _chat(env.url, session_id, {"messages": []}, timeout=20.0).status_code + if last_status != 409: + released = True + break + time.sleep(0.05) + assert released, f"slot not released after real disconnect (still {last_status} after retries)" + + # The disconnected owner's backend response committed one turn. Read + # it back and build a LEGAL append-only continuation: the committed + # user + assistant messages, then an allowed appended `system` + # message (mirrors the retry payload in + # test_same_session_second_chat_returns_409). + get_resp = requests.get(f"{env.url}/sessions/{session_id}", timeout=5.0) + assert get_resp.status_code == 200 + records = get_resp.json()["records"] + assert records, "the abandoned backend response should have committed one turn" + committed = records[-1] + committed_user = committed["request"]["messages"] + committed_assistant = committed["response"]["choices"][0]["message"] + continuation = { + "messages": [ + *committed_user, + committed_assistant, + {"role": "system", "content": "continue-after-disconnect"}, + ] + } + after = _chat(env.url, session_id, continuation, timeout=20.0) + assert after.status_code == 200, ( + f"legal same-session continuation after a real disconnect must 200, got {after.status_code}: " + f"{after.text}" + ) + + +class TestBuildProxyResponse: + """Unit tests for SessionServer.build_proxy_response passthrough fidelity. + No running server needed: with no hf_checkpoint, setup_session_routes + returns early so construction is light. + """ + + def _server(self) -> SessionServer: + return SessionServer(SimpleNamespace(miles_router_timeout=30), backend_url="http://unused") + + def test_json_200_body_and_headers_passthrough(self): + server = self._server() + body = b'{"object":"chat.completion","choices":[]}' + result = { + "response_body": body, + "status_code": 200, + "headers": { + "content-type": "application/json", + "content-length": "999", + "transfer-encoding": "chunked", + "content-encoding": "gzip", + }, + } + resp = server.build_proxy_response(result) + assert isinstance(resp, Response) + assert resp.status_code == 200 + assert resp.body == body + lowered = {k.lower() for k in resp.headers.keys()} + assert resp.headers["content-type"].startswith("application/json") + # The stale upstream content-length is dropped; transfer/content-encoding + # are absent. Starlette recomputes content-length from the actual body, so + # if present it must equal the body length (never the stale upstream "999"). + assert "transfer-encoding" not in lowered + assert "content-encoding" not in lowered + assert resp.headers.get("content-length", str(len(body))) == str(len(body)) + + def test_non_json_200_body_and_media_type_preserved(self): + server = self._server() + body = b"plain text" + result = { + "response_body": body, + "status_code": 200, + "headers": {"content-type": "text/plain"}, + } + resp = server.build_proxy_response(result) + assert isinstance(resp, Response) + assert resp.status_code == 200 + assert resp.body == body + assert resp.media_type == "text/plain" + assert resp.headers["content-type"].startswith("text/plain") + + def test_synthetic_502_status_and_body_passthrough(self): + server = self._server() + body = b'{"error":"backend transport error"}' + result = { + "response_body": body, + "status_code": 502, + "headers": {"content-type": "application/json"}, + } + resp = server.build_proxy_response(result) + assert isinstance(resp, Response) + assert resp.status_code == 502 + assert resp.body == body + + def test_compressed_upstream_strips_stale_framing(self): + server = self._server() + body = b'{"ok":true}' + result = { + "response_body": body, + "status_code": 200, + "headers": { + "content-type": "application/json", + "content-encoding": "gzip", + "content-length": "5", + }, + } + resp = server.build_proxy_response(result) + assert isinstance(resp, Response) + assert resp.body == body + lowered = {k.lower() for k in resp.headers.keys()} + # content-encoding stripped; the stale upstream content-length ("5") is + # dropped — Starlette recomputes from the body, never carries the stale one. + assert "content-encoding" not in lowered + assert resp.headers.get("content-length", str(len(body))) == str(len(body)) + + def test_generic_session_proxy_route_passes_through(self): + """The generic /sessions/{id}/{path} route proxies and passes the body + through: /abort_request on the mock backend returns {"status":"ok"} 200. + """ + + def process_fn(prompt: str) -> ProcessResult: + return ProcessResult(text="ok", finish_reason="stop") + + with _router_env(process_fn) as env: + session_id = _create_session(env.url) + resp = requests.post(f"{env.url}/sessions/{session_id}/abort_request", timeout=10.0) + assert resp.status_code == 200 + assert resp.json() == {"status": "ok"} + + +class TestPassthroughFidelity: + """A successful chat response is passed through faithfully and stale + framing headers from upstream are not copied to the client. + """ + + def test_successful_response_body_and_headers_passthrough(self): + def process_fn(prompt: str) -> ProcessResult: + return ProcessResult(text="passthrough-ok", finish_reason="stop") + + with _router_env(process_fn) as env: + session_id = _create_session(env.url) + resp = _chat(env.url, session_id, {"messages": [{"role": "user", "content": "hi"}]}, timeout=20.0) + assert resp.status_code == 200 + + # Body matches the upstream mock shape (id/object/choices content) + # passed through unchanged, including the meta_info the mock adds. + client_body = resp.json() + assert client_body["object"] == "chat.completion" + assert client_body["id"].startswith("chatcmpl-") + assert client_body["choices"][0]["message"]["content"] == "passthrough-ok" + assert "meta_info" in client_body["choices"][0] + + # Framing headers copied from upstream must have been stripped; the + # body is intact (decodable JSON), so requests itself did not need + # transfer/content-encoding framing from upstream. + lowered = {k.lower() for k in resp.headers.keys()} + assert "transfer-encoding" not in lowered + assert "content-encoding" not in lowered + assert resp.headers.get("content-type", "").startswith("application/json") + + +class TestResponseTokenIdValidation: + """A malformed-but-200 upstream response must be rejected (502) rather than + yielding non-integer token ids that would corrupt the stored trajectory.""" + + @staticmethod + def _payload(output_token_logprobs, completion_tokens=1) -> bytes: + import json + + return json.dumps( + { + "choices": [ + { + "message": {"role": "assistant", "content": "ok"}, + "meta_info": { + "output_token_logprobs": output_token_logprobs, + "completion_tokens": completion_tokens, + }, + } + ] + } + ).encode() + + def test_valid_integer_token_ids_accepted(self): + _resp, _msg, ids = _parse_and_validate_response(self._payload([[-0.1, 123], [-0.2, 456]], completion_tokens=2)) + assert ids == [123, 456] + + def test_non_integer_token_ids_rejected(self): + # str / None / float / bool token ids, a non-pair string entry, a short + # entry, and a bool completion_tokens must all raise UpstreamResponseError. + bad_cases = [ + ([[-0.1, "x"]], 1), + ([[-0.1, None]], 1), + ([[-0.1, 1.2]], 1), + ([[-0.1, True]], 1), + (["ab"], 1), + ([[123]], 1), + ([[-0.1, 123]], True), + ] + for logprobs, completion_tokens in bad_cases: + try: + _parse_and_validate_response(self._payload(logprobs, completion_tokens)) + except UpstreamResponseError: + continue + raise AssertionError( + f"expected UpstreamResponseError for {logprobs!r}, completion_tokens={completion_tokens!r}" + ) + + def test_non_numeric_logprob_values_rejected(self): + # entry[0] (the logprob) flows into Sample.rollout_log_probs downstream + # (openai_endpoint_utils), so a str / None / bool / non-finite logprob with an + # otherwise-valid token id must still raise rather than being accepted. + bad_cases = [ + [["bad-logprob", 562]], + [[None, 562]], + [[True, 562]], + [[float("nan"), 562]], + [[float("inf"), 562]], + [[float("-inf"), 562]], + ] + for logprobs in bad_cases: + try: + _parse_and_validate_response(self._payload(logprobs, completion_tokens=1)) + except UpstreamResponseError: + continue + raise AssertionError(f"expected UpstreamResponseError for logprob in {logprobs!r}") + + def test_non_standard_json_constants_rejected(self): + for constant in ("NaN", "Infinity", "-Infinity"): + raw = ( + b'{"choices":[{"message":{"role":"assistant","content":"ok"},' + b'"meta_info":{"output_token_logprobs":[[' + constant.encode() + b',562]],"completion_tokens":1}}]}' + ) + try: + _parse_and_validate_response(raw) + except UpstreamResponseError: + continue + raise AssertionError(f"expected UpstreamResponseError for JSON constant {constant}") + + def test_sglang_triples_with_token_text_accepted(self): + # SGLang emits [logprob, token_id, token_text] triples; len > 2 is normal + # and must NOT be rejected. Integer logprobs (e.g. 0) are also valid numbers. + _resp, _msg, ids = _parse_and_validate_response( + self._payload([[-0.1, 123, "he"], [0, 456, "llo"]], completion_tokens=2) + ) + assert ids == [123, 456]