diff --git a/grpc_servicer/pyproject.toml b/grpc_servicer/pyproject.toml index bea1e10f9..893255b98 100644 --- a/grpc_servicer/pyproject.toml +++ b/grpc_servicer/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta" [project] name = "smg-grpc-servicer" version = "0.5.2" -description = "SMG gRPC servicer implementations for LLM inference engines (vLLM, SGLang, MLX)" +description = "SMG gRPC servicer implementations for LLM inference engines (vLLM, SGLang, MLX, TokenSpeed)" requires-python = ">=3.10" dependencies = [ "smg-grpc-proto>=0.4.6", @@ -36,6 +36,23 @@ sglang = ["sglang>=0.5.10"] # without this floor, installing [mlx] against an older proto build would # crash at import time when smg_grpc_servicer.mlx.server runs. mlx = ["smg-grpc-proto>=0.4.7", "mlx>=0.22.0", "mlx-lm>=0.22.0"] +# Note: there is intentionally no ``tokenspeed`` extra. TokenSpeed is not +# published to PyPI; it is installed out-of-tree from the lightseekorg +# checkout via ``scripts/ci_install_tokenspeed.sh`` (CI) or a manual +# ``pip install -e ./tokenspeed/python`` (local dev). An extra named +# ``tokenspeed`` would imply ``pip install smg-grpc-servicer[tokenspeed]`` +# yields a working tokenspeed setup; it does not. +test = [ + "pytest>=8.0.0", + "pytest-asyncio>=0.23.0", +] + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] +markers = [ + "tokenspeed: tests that require TokenSpeed", +] [project.urls] Homepage = "https://github.com/lightseekorg/smg" diff --git a/grpc_servicer/smg_grpc_servicer/tokenspeed/__init__.py b/grpc_servicer/smg_grpc_servicer/tokenspeed/__init__.py new file mode 100644 index 000000000..d5ced6c52 --- /dev/null +++ b/grpc_servicer/smg_grpc_servicer/tokenspeed/__init__.py @@ -0,0 +1,11 @@ +"""TokenSpeed gRPC servicer implementation. + +Mirrors smg_grpc_servicer.vllm / smg_grpc_servicer.sglang. Wraps TokenSpeed's +AsyncLLM (main-process async frontend) behind the SGLang gRPC service so the +existing Rust router (which auto-detects the SGLang proto) can route traffic +to TokenSpeed without needing a new client. +""" + +from smg_grpc_servicer.tokenspeed.servicer import TokenSpeedSchedulerServicer + +__all__ = ["TokenSpeedSchedulerServicer"] diff --git a/grpc_servicer/smg_grpc_servicer/tokenspeed/__main__.py b/grpc_servicer/smg_grpc_servicer/tokenspeed/__main__.py new file mode 100644 index 000000000..b4e6fb0e6 --- /dev/null +++ b/grpc_servicer/smg_grpc_servicer/tokenspeed/__main__.py @@ -0,0 +1,71 @@ +"""CLI entrypoint for the TokenSpeed gRPC server. + +Usage:: + + python -m smg_grpc_servicer.tokenspeed --model --host 127.0.0.1 --port 50051 + +All :class:`tokenspeed.runtime.utils.server_args.ServerArgs` flags are accepted +verbatim (we reuse TokenSpeed's own ``prepare_server_args`` so there is no +flag drift between the HTTP and gRPC frontends). +""" + +from __future__ import annotations + +import asyncio +import logging +import sys + +import uvloop +from tokenspeed.runtime.utils.server_args import prepare_server_args + +from smg_grpc_servicer.tokenspeed.server import serve_grpc + + +def main(argv: list[str] | None = None) -> None: + if argv is None: + argv = sys.argv[1:] + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(name)s] %(levelname)s %(message)s", + ) + + # TokenSpeed's ``ServerArgs.resolve_kernel_backends`` defaults + # ``sampling_backend`` to ``"greedy"`` when the user doesn't pass + # ``--sampling-backend``. The greedy backend is argmax-only and + # ignores per-request ``temperature``/``top_p``/``top_k`` — fine for + # the legacy CLI where users opt in to sampling explicitly, but + # disastrous for a gateway-fronted gRPC servicer where per-request + # sampling params arrive on every call. With Llama-3.2-1B the + # always-argmax behavior collapses into single-token loops + # (\\n×N, ' ('×N, "no"×N) within a few hundred steps and + # generation runs to ``max_new_tokens`` — the smg e2e function-calling + # suite makes this directly observable. Force a sampling-respecting + # default unless the operator explicitly chose one. + if not any(a == "--sampling-backend" or a.startswith("--sampling-backend=") for a in argv): + argv = [*argv, "--sampling-backend", "flashinfer"] + + # TokenSpeed's logprob computation is gated by ``--enable-output-logprobs`` + # (default OFF, see ``ServerArgs.enable_output_logprobs``); without the + # flag, requests asking for logprobs receive empty arrays rather than an + # error. The smg gateway's OpenAI-compat path expects per-token logprobs + # whenever ``logprobs=True`` is set, so flip the flag on by default for a + # gateway-fronted gRPC servicer. Operators who want the smaller CUDA-graph + # footprint can pass ``--enable-output-logprobs=False`` explicitly. + # ``--enable-top-logprobs`` is intentionally NOT injected: TokenSpeed + # raises at startup when it's set (the path is not yet implemented). + if not any( + a == "--enable-output-logprobs" or a.startswith("--enable-output-logprobs=") for a in argv + ): + argv = [*argv, "--enable-output-logprobs"] + + server_args = prepare_server_args(argv) + # The scheduler processes will read these env vars; make sure we ran + # through TokenSpeed's shared env/resource setup path instead of + # duplicating it here. + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + asyncio.run(serve_grpc(server_args)) + + +if __name__ == "__main__": + main() diff --git a/grpc_servicer/smg_grpc_servicer/tokenspeed/health_servicer.py b/grpc_servicer/smg_grpc_servicer/tokenspeed/health_servicer.py new file mode 100644 index 000000000..d6b04a62a --- /dev/null +++ b/grpc_servicer/smg_grpc_servicer/tokenspeed/health_servicer.py @@ -0,0 +1,130 @@ +"""Standard ``grpc.health.v1.Health`` servicer for the TokenSpeed backend. + +Mirrors ``smg_grpc_servicer.sglang.health_servicer.SGLangHealthServicer`` — +same service-name semantics, same lifecycle (NOT_SERVING → SERVING → NOT_SERVING), +same ``check/watch`` contract — but sources liveness signals from a TokenSpeed +:class:`AsyncLLM` instead of an SGLang ``GrpcRequestManager``. + +The Rust router uses this health check to auto-detect the backend runtime. +TokenSpeed ships its own ``tokenspeed.grpc.scheduler.TokenSpeedScheduler`` +service identity (see ``proto/tokenspeed_scheduler.proto``) so the probe +distinguishes TokenSpeed workers from real SGLang workers regardless of any +wire-level message-type sharing between the two backends. +""" + +from __future__ import annotations + +import logging +import time +from collections.abc import AsyncIterator +from typing import TYPE_CHECKING + +import grpc +from grpc_health.v1 import health_pb2, health_pb2_grpc +from smg_grpc_proto.generated import tokenspeed_scheduler_pb2 + +if TYPE_CHECKING: + from tokenspeed.runtime.engine.async_llm import AsyncLLM + +logger = logging.getLogger(__name__) + +# Seconds of scheduler silence — with pending requests — before we report +# NOT_SERVING. Matches the SGLang equivalent so oncall dashboards are aligned. +STUCK_SCHEDULER_THRESHOLD_SEC = 30.0 + +# Source the advertised service name from the proto descriptor so a future +# ``package`` or ``service`` rename in tokenspeed_scheduler.proto stays in +# sync without a hand-edited string here. +TOKENSPEED_SCHEDULER_SERVICE_NAME = tokenspeed_scheduler_pb2.DESCRIPTOR.services_by_name[ + "TokenSpeedScheduler" +].full_name + + +class TokenSpeedHealthServicer(health_pb2_grpc.HealthServicer): + """Health servicer that tracks TokenSpeed's AsyncLLM liveness. + + Advertises two service levels: + + * ``""`` (empty) — overall server health, flipped to SERVING once the + warmup request succeeds and back to NOT_SERVING on shutdown. + * ``tokenspeed.grpc.scheduler.TokenSpeedScheduler`` — readiness: the + base status, plus a scheduler-responsiveness check (if there are + pending requests but the scheduler hasn't pushed output for >30s, + report NOT_SERVING). + """ + + OVERALL_SERVER = "" + TOKENSPEED_SERVICE = TOKENSPEED_SCHEDULER_SERVICE_NAME + + def __init__(self, async_llm: AsyncLLM, scheduler_info: dict): + self.async_llm = async_llm + self.scheduler_info = scheduler_info + self._serving_status: dict[str, int] = { + self.OVERALL_SERVER: health_pb2.HealthCheckResponse.NOT_SERVING, + self.TOKENSPEED_SERVICE: health_pb2.HealthCheckResponse.NOT_SERVING, + } + logger.info("TokenSpeed gRPC health service initialized") + + def set_serving(self) -> None: + """Flip both services to SERVING (call after successful warmup).""" + self._serving_status[self.OVERALL_SERVER] = health_pb2.HealthCheckResponse.SERVING + self._serving_status[self.TOKENSPEED_SERVICE] = health_pb2.HealthCheckResponse.SERVING + logger.info("TokenSpeed gRPC health status -> SERVING") + + def set_not_serving(self) -> None: + """Flip both services to NOT_SERVING (call on shutdown).""" + self._serving_status[self.OVERALL_SERVER] = health_pb2.HealthCheckResponse.NOT_SERVING + self._serving_status[self.TOKENSPEED_SERVICE] = health_pb2.HealthCheckResponse.NOT_SERVING + logger.info("TokenSpeed gRPC health status -> NOT_SERVING") + + async def Check( + self, + request: health_pb2.HealthCheckRequest, + context: grpc.aio.ServicerContext, + ) -> health_pb2.HealthCheckResponse: + service_name = request.service + logger.debug("Health check request for service=%r", service_name) + + if self.async_llm.gracefully_exit: + return health_pb2.HealthCheckResponse(status=health_pb2.HealthCheckResponse.NOT_SERVING) + + if service_name == self.OVERALL_SERVER: + return health_pb2.HealthCheckResponse( + status=self._serving_status.get( + self.OVERALL_SERVER, health_pb2.HealthCheckResponse.NOT_SERVING + ) + ) + + if service_name == self.TOKENSPEED_SERVICE: + base = self._serving_status.get( + self.TOKENSPEED_SERVICE, health_pb2.HealthCheckResponse.NOT_SERVING + ) + if base != health_pb2.HealthCheckResponse.SERVING: + return health_pb2.HealthCheckResponse(status=base) + + # Scheduler-stuck check: pending work but no recent output. + time_since_last_receive = time.time() - self.async_llm.last_receive_tstamp + pending = len(self.async_llm.rid_to_state) + if time_since_last_receive > STUCK_SCHEDULER_THRESHOLD_SEC and pending > 0: + logger.warning( + "Scheduler appears stuck: %.1fs since last receive, %d pending requests", + time_since_last_receive, + pending, + ) + return health_pb2.HealthCheckResponse( + status=health_pb2.HealthCheckResponse.NOT_SERVING + ) + + return health_pb2.HealthCheckResponse(status=health_pb2.HealthCheckResponse.SERVING) + + context.set_code(grpc.StatusCode.NOT_FOUND) + context.set_details(f"Unknown service: {service_name}") + return health_pb2.HealthCheckResponse(status=health_pb2.HealthCheckResponse.SERVICE_UNKNOWN) + + async def Watch( + self, + request: health_pb2.HealthCheckRequest, + context: grpc.aio.ServicerContext, + ) -> AsyncIterator[health_pb2.HealthCheckResponse]: + # K8s probes use Check, not Watch — we emit the current status once. + yield await self.Check(request, context) diff --git a/grpc_servicer/smg_grpc_servicer/tokenspeed/scheduler_launcher.py b/grpc_servicer/smg_grpc_servicer/tokenspeed/scheduler_launcher.py new file mode 100644 index 000000000..64acb18fa --- /dev/null +++ b/grpc_servicer/smg_grpc_servicer/tokenspeed/scheduler_launcher.py @@ -0,0 +1,60 @@ +"""Scheduler subprocess launcher for the TokenSpeed gRPC server. + +Mirrors ``smg_grpc_servicer.sglang.scheduler_launcher`` but delegates to +TokenSpeed's ``_launch_subprocesses``: we get back a fully-initialised +``AsyncLLM`` along with the scheduler info dict. All scheduler/DP-controller +spawning, multiprocessing start-method, and env priming already live inside +``_launch_subprocesses`` — we only wrap it to return what the gRPC server +cares about and to keep the call site symmetric with the sibling backends. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from tokenspeed.runtime.engine.async_llm import AsyncLLM +from tokenspeed.runtime.entrypoints.engine import _launch_subprocesses +from tokenspeed.runtime.utils.server_args import PortArgs, ServerArgs + +logger = logging.getLogger(__name__) + + +def launch_engine( + server_args: ServerArgs, + port_args: PortArgs | None = None, +) -> tuple[AsyncLLM, dict[str, Any]]: + """Launch TokenSpeed scheduler subprocess(es) and the main-process AsyncLLM. + + Returns: + A tuple ``(async_llm, scheduler_info)``. ``async_llm`` is the live + :class:`AsyncLLM` that the gRPC servicer will drive. ``scheduler_info`` + is the dict rank-0 sent back once its scheduler was ready (contains + e.g. ``max_total_num_tokens``, ``max_req_input_len``, ...). + + Raises: + RuntimeError: If rank-0 scheduler fails to initialize. The original + ``_launch_subprocesses`` surfaces this by re-raising the EOF/assertion + error — we propagate it unchanged. + """ + async_llm, _template_manager, scheduler_info = _launch_subprocesses( + server_args=server_args, + port_args=port_args, + ) + + # Non-zero rank nodes return (None, None, None) from _launch_subprocesses + # and block forever on the dummy health server — they never reach the gRPC + # server. Guard against callers relying on this return on secondary nodes. + if async_llm is None: + raise RuntimeError( + "launch_engine() returned no AsyncLLM. This means the current node " + "is not rank 0 in a multi-node deployment, or the scheduler died " + "during initialization. Only rank 0 may serve gRPC traffic." + ) + + logger.info( + "TokenSpeed engine ready: max_total_num_tokens=%s max_req_input_len=%s", + scheduler_info.get("max_total_num_tokens"), + scheduler_info.get("max_req_input_len"), + ) + return async_llm, scheduler_info diff --git a/grpc_servicer/smg_grpc_servicer/tokenspeed/server.py b/grpc_servicer/smg_grpc_servicer/tokenspeed/server.py new file mode 100644 index 000000000..bbe67e69a --- /dev/null +++ b/grpc_servicer/smg_grpc_servicer/tokenspeed/server.py @@ -0,0 +1,195 @@ +"""Standalone TokenSpeed gRPC server — mirrors ``smg_grpc_servicer.sglang.server``.""" + +from __future__ import annotations + +import asyncio +import logging +import os +import signal +import threading +import time +from concurrent import futures + +import grpc +from grpc_health.v1 import health_pb2_grpc +from grpc_reflection.v1alpha import reflection +from smg_grpc_proto import tokenspeed_scheduler_pb2_grpc +from smg_grpc_proto.generated import tokenspeed_scheduler_pb2 +from tokenspeed.runtime.utils.server_args import ServerArgs + +from smg_grpc_servicer.tokenspeed.health_servicer import TokenSpeedHealthServicer +from smg_grpc_servicer.tokenspeed.scheduler_launcher import launch_engine +from smg_grpc_servicer.tokenspeed.servicer import TokenSpeedSchedulerServicer + +logger = logging.getLogger(__name__) + + +async def serve_grpc(server_args: ServerArgs) -> None: + """Run the TokenSpeed gRPC server until a shutdown signal is received.""" + + logger.info("Launching TokenSpeed scheduler + AsyncLLM...") + async_llm, scheduler_info = launch_engine(server_args) + + server = grpc.aio.server( + futures.ThreadPoolExecutor(max_workers=10), + options=[ + ("grpc.max_send_message_length", 1024 * 1024 * 256), + ("grpc.max_receive_message_length", 1024 * 1024 * 256), + # Match SGLang's more-permissive keepalive defaults so long + # prefill stalls don't trip GOAWAY in the Rust client. + ("grpc.http2.min_recv_ping_interval_without_data_ms", 10000), + ("grpc.keepalive_permit_without_calls", True), + ], + ) + + health_servicer = TokenSpeedHealthServicer( + async_llm=async_llm, + scheduler_info=scheduler_info, + ) + health_pb2_grpc.add_HealthServicer_to_server(health_servicer, server) + + servicer = TokenSpeedSchedulerServicer( + async_llm=async_llm, + server_args=server_args, + scheduler_info=scheduler_info, + health_servicer=health_servicer, + ) + tokenspeed_scheduler_pb2_grpc.add_TokenSpeedSchedulerServicer_to_server(servicer, server) + + service_names = ( + tokenspeed_scheduler_pb2.DESCRIPTOR.services_by_name["TokenSpeedScheduler"].full_name, + "grpc.health.v1.Health", + reflection.SERVICE_NAME, + ) + reflection.enable_server_reflection(service_names, server) + + listen_addr = f"{server_args.host}:{server_args.port}" + server.add_insecure_port(listen_addr) + logger.info("TokenSpeed gRPC server listening on %s", listen_addr) + + await server.start() + + # Warmup on a background thread so the async server can handle the probe. + warmup_thread = threading.Thread( + target=_wait_and_warmup, + args=(server_args, health_servicer), + daemon=True, + ) + warmup_thread.start() + + loop = asyncio.get_running_loop() + stop_event = asyncio.Event() + + def _signal_handler() -> None: + logger.info("Received shutdown signal") + stop_event.set() + + for sig in (signal.SIGTERM, signal.SIGINT): + try: + loop.add_signal_handler(sig, _signal_handler) + except NotImplementedError: + # Windows and some exotic envs don't support loop.add_signal_handler. + pass + + try: + await stop_event.wait() + finally: + logger.info("Shutting down TokenSpeed gRPC server") + try: + await servicer.shutdown() + except Exception: # noqa: BLE001 + logger.exception("servicer.shutdown() raised") + await server.stop(5.0) + if warmup_thread.is_alive(): + warmup_thread.join(timeout=5.0) + + +def _wait_and_warmup( + server_args: ServerArgs, + health_servicer: TokenSpeedHealthServicer, +) -> None: + """Probe the gRPC server until it can generate one token, then set SERVING. + + We hit the external port (not the in-process servicer) so the warmup + exercises the same code path a production caller would — including the + gRPC transport, proto codec, and scheduler IPC. + """ + if os.getenv("TOKENSPEED_SKIP_GRPC_WARMUP", "0").lower() in ("1", "true", "yes"): + logger.info("TOKENSPEED_SKIP_GRPC_WARMUP=1 — skipping warmup") + health_servicer.set_serving() + return + + grpc_url = f"{server_args.host}:{server_args.port}" + channel = grpc.insecure_channel( + grpc_url, + options=[ + ("grpc.max_send_message_length", 1024 * 1024 * 256), + ("grpc.max_receive_message_length", 1024 * 1024 * 256), + ], + ) + stub = tokenspeed_scheduler_pb2_grpc.TokenSpeedSchedulerStub(channel) + + # Wait until GetModelInfo round-trips — that's the quickest confirmation + # that the gRPC server is both bound and has a live AsyncLLM behind it. + deadline = time.time() + 180 + connected = False + while time.time() < deadline: + try: + stub.GetModelInfo( + tokenspeed_scheduler_pb2.GetModelInfoRequest(), + timeout=5, + ) + connected = True + break + except Exception as e: # noqa: BLE001 + logger.debug("Warmup: GetModelInfo not ready yet: %s", e) + time.sleep(1) + + if not connected: + logger.error("TokenSpeed gRPC warmup failed: GetModelInfo never succeeded") + channel.close() + return + + # TokenSpeed serves generative LLMs only (the proto has no Embed RPC), so + # the warmup is always a 1-token generate. + warmup_ok = False + try: + warmup = tokenspeed_scheduler_pb2.GenerateRequest( + request_id=f"WARMUP_{time.time()}", + tokenized=tokenspeed_scheduler_pb2.TokenizedInput( + input_ids=[0], + original_text="warmup", + ), + sampling_params=tokenspeed_scheduler_pb2.SamplingParams( + temperature=0.0, + max_new_tokens=1, + ), + stream=False, + ) + final = None + for resp in stub.Generate(warmup, timeout=600): + final = resp + if final is None or not final.HasField("complete"): + logger.warning( + "Warmup Generate returned no Complete frame (last=%r)", + final, + ) + else: + logger.info("Warmup generation succeeded") + warmup_ok = True + except Exception as e: # noqa: BLE001 + logger.warning("TokenSpeed warmup failed: %s", e) + finally: + channel.close() + + # NOT_SERVING keeps the pod out of K8s readiness rotation when warmup + # never produced a Complete frame. + if warmup_ok: + health_servicer.set_serving() + logger.info("TokenSpeed gRPC server is ready to serve") + else: + logger.error( + "TokenSpeed gRPC warmup did not produce a complete frame; " + "health stays NOT_SERVING. K8s readiness will keep this " + "worker out of rotation until manually restarted." + ) diff --git a/grpc_servicer/smg_grpc_servicer/tokenspeed/servicer.py b/grpc_servicer/smg_grpc_servicer/tokenspeed/servicer.py new file mode 100644 index 000000000..8d9387f1b --- /dev/null +++ b/grpc_servicer/smg_grpc_servicer/tokenspeed/servicer.py @@ -0,0 +1,937 @@ +"""TokenSpeed gRPC servicer. + +Implements the ``tokenspeed.grpc.scheduler.TokenSpeedScheduler`` gRPC service +on top of TokenSpeed's :class:`tokenspeed.runtime.engine.async_llm.AsyncLLM` — +the main-process async frontend that replaced ``TokenizerManager`` in the +AsyncLLM refactor. + +Wire identity & message catalog +------------------------------- +TokenSpeed ships a fully independent proto (``proto/tokenspeed_scheduler.proto``) +with a distinct package, service, and message catalog. The Rust gateway's +``DetectBackendStep`` identifies the worker natively from the service name — +no SGLang-look-alike hack, no runtime marker probe. The proto's field set is +intentionally minimal (top-tier LLM serving only): no Embed, no +GetTokenizer, no SubscribeKvEvents, no multimodal, no PD-disaggregated +serving, no LoRA, no hidden-state forwarding, no classifier outputs. +Anything in that list has to be added to the proto first; it doesn't ride +on a shared SGLang message anymore. +""" + +from __future__ import annotations + +import asyncio +import dataclasses +import json +import logging +import os +import re +import time +from collections.abc import AsyncIterator +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any + +import grpc +from google.protobuf.struct_pb2 import Struct +from google.protobuf.timestamp_pb2 import Timestamp +from smg_grpc_proto import tokenspeed_scheduler_pb2_grpc +from smg_grpc_proto.generated import tokenspeed_scheduler_pb2 + +from smg_grpc_servicer.tokenspeed.health_servicer import TokenSpeedHealthServicer + +if TYPE_CHECKING: + # Type-only imports — not resolved at module load so the servicer is + # importable in test environments that stub AsyncLLM / ServerArgs. + from tokenspeed.runtime.engine.async_llm import AsyncLLM + from tokenspeed.runtime.utils.server_args import ServerArgs + +logger = logging.getLogger(__name__) + +HEALTH_CHECK_TIMEOUT = int(os.getenv("TOKENSPEED_HEALTH_CHECK_TIMEOUT", "20")) + + +def _lazy_generate_req_input(): + """Late import for ``tokenspeed.runtime.engine.io_struct.GenerateReqInput``. + + Kept lazy so the top of this module loads in test environments that stub + the TokenSpeed engine surface (unit tests don't need a fully-working + TokenSpeed install to exercise proto ↔ request-input conversion). + """ + from tokenspeed.runtime.engine.io_struct import GenerateReqInput + + return GenerateReqInput + + +def _finish_reason_to_dict(reason: Any) -> dict | None: + """Normalise a TokenSpeed finish reason into a dict. + + TokenSpeed emits ``BaseFinishReason``-style objects (or an already- + normalised dict) in ``meta_info["finish_reason"]``; downstream code + expects a dict with at minimum ``{"type": ...}`` and optionally + ``{"matched": int|str}``. ``None`` means "still running". + + We duck-type on ``to_json()`` so the servicer module loads without + pulling in TokenSpeed's full request-processing graph. Unknown shapes + raise ``TypeError`` rather than silently flipping ``length`` / ``abort`` + to ``stop`` — the caller maps that to ``StatusCode.INTERNAL``. + """ + if reason is None or isinstance(reason, dict): + return reason + to_json = getattr(reason, "to_json", None) + if callable(to_json): + result = to_json() + if isinstance(result, dict): + return result + raise TypeError( + f"finish_reason {type(reason).__name__!r}.to_json() returned " + f"{type(result).__name__!r}; expected dict with at least 'type'." + ) + raise TypeError( + f"Unknown finish_reason shape {type(reason).__name__!r}; expected " + f"a dict or an object with a to_json() method." + ) + + +class TokenSpeedSchedulerServicer(tokenspeed_scheduler_pb2_grpc.TokenSpeedSchedulerServicer): + """gRPC servicer exposing TokenSpeed's AsyncLLM over the dedicated TokenSpeed proto.""" + + def __init__( + self, + async_llm: AsyncLLM, + server_args: ServerArgs, + scheduler_info: dict, + health_servicer: TokenSpeedHealthServicer | None = None, + ): + self.async_llm = async_llm + self.server_args = server_args + self.scheduler_info = scheduler_info + self.health_servicer = health_servicer + self.start_time = time.time() + + # Drive AsyncLLM's output-dispatch loop. This is idempotent — the + # first caller creates the handle loop; subsequent callers (including + # the HealthCheck RPC) are no-ops thanks to ``no_create_loop``. + self.async_llm.auto_create_handle_loop() + + logger.info("TokenSpeedSchedulerServicer initialized") + + # ------------------------------------------------------------------ + # Generate (server-streaming) + # ------------------------------------------------------------------ + + async def Generate( + self, + request: tokenspeed_scheduler_pb2.GenerateRequest, + context: grpc.aio.ServicerContext, + ) -> AsyncIterator[tokenspeed_scheduler_pb2.GenerateResponse]: + rid = request.request_id + logger.info("Generate request %s (stream=%s)", rid, request.stream) + + try: + req_obj = self._build_generate_req(request) + except ValueError as e: + await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e)) + return + + # For n>1, tokenspeed's batch handler generates fresh UUIDs per + # sub-request and tags each streamed dict with a sequential + # ``index`` (see tokenizer_manager.py::_handle_batch_request). + # Non-streaming n>1 yields a *list* of final dicts instead. We + # handle both shapes below. + expanded_rid = getattr(req_obj, "rid", None) + + # When the client sets ``no_stop_trim``, the matched stop token must + # remain in the proto's ``output_ids`` so the gateway-side detokenizer + # can render it (relevant when ``skip_special_tokens=False`` is also + # set). Capture once and thread through the response builders. + no_stop_trim = bool(request.sampling_params.no_stop_trim) + + aborted = False + try: + async for output in self.async_llm.generate_request(req_obj): + # Non-streaming n>1 emits a list of final dicts in one yield. + if isinstance(output, list): + for idx, item in enumerate(output): + item_reason = _finish_reason_to_dict( + item.get("meta_info", {}).get("finish_reason") + ) + if item_reason and item_reason.get("type") == "abort": + code = _abort_status_code(item_reason) + await context.abort(code, item_reason.get("message") or "aborted") + return + ci = int(item.get("index", idx)) + yield self._complete_response( + rid, item, item_reason, ci, no_stop_trim=no_stop_trim + ) + continue + + meta = output.get("meta_info", {}) + reason_dict = _finish_reason_to_dict(meta.get("finish_reason")) + is_finished = reason_dict is not None + + if reason_dict is not None and reason_dict.get("type") == "abort": + code = _abort_status_code(reason_dict) + await context.abort(code, reason_dict.get("message") or "aborted") + return + + choice_index = int(output.get("index", 0)) + + if request.stream: + yield self._chunk_response( + rid, output, reason_dict, choice_index, no_stop_trim=no_stop_trim + ) + if is_finished: + yield self._complete_response( + rid, output, reason_dict, choice_index, no_stop_trim=no_stop_trim + ) + elif is_finished: + yield self._complete_response( + rid, output, reason_dict, choice_index, no_stop_trim=no_stop_trim + ) + + except ValueError as e: + logger.warning("Generate invalid request %s: %s", rid, e) + await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e)) + except asyncio.CancelledError: + # Client disconnected — sweep every scheduler-side rid we minted + # (including the per-choice ``{rid}-n{i}`` children n>1 creates) + # so abandoned requests don't keep consuming GPU work. + aborted = True + if isinstance(expanded_rid, list): + for r in expanded_rid: + self.async_llm.abort_request(r) + else: + self.async_llm.abort_request(rid) + raise + except grpc.aio.AbortError: + raise + except Exception as e: + logger.exception("Generate failed for request %s", rid) + await context.abort(grpc.StatusCode.INTERNAL, str(e)) + finally: + # Defensive cleanup — the scheduler owns rid_to_state, but if the + # stream was torn down before finish we need to notify it. When + # n>1 we expanded rid to a list of per-choice ids, so walk them. + if not aborted: + rids_to_check = ( + list(expanded_rid) + if isinstance(expanded_rid, list) + else ([expanded_rid] if isinstance(expanded_rid, str) else []) + ) + for r in rids_to_check: + state = self.async_llm.rid_to_state.get(r) + if state is not None and not getattr(state, "finished", False): + self.async_llm.abort_request(r) + + # ------------------------------------------------------------------ + # HealthCheck (unary) + # ------------------------------------------------------------------ + + async def HealthCheck( + self, + request: tokenspeed_scheduler_pb2.HealthCheckRequest, + context: grpc.aio.ServicerContext, + ) -> tokenspeed_scheduler_pb2.HealthCheckResponse: + """Deep health probe — sends a 1-token generation to the scheduler. + + Mirrors SGLang's contract exactly: if the scheduler pushes *any* + output within ``HEALTH_CHECK_TIMEOUT`` seconds, we consider it alive. + We bypass the normal AsyncLLM lock/metrics by crafting a dedicated + request with ``log_metrics=False`` so health checks don't skew + Prometheus counters. + """ + rid = f"HEALTH_CHECK_{time.time()}" + + if self.async_llm.gracefully_exit: + return tokenspeed_scheduler_pb2.HealthCheckResponse( + healthy=False, message="Server is shutting down" + ) + + # TokenSpeed only serves generative LLMs at this layer (the proto + # has no Embed RPC), so the probe is always a 1-token generate. + GenerateReqInput = _lazy_generate_req_input() + probe = GenerateReqInput( + input_ids=[0], + sampling_params={"max_new_tokens": 1, "temperature": 0.0}, + log_metrics=False, + ) + probe.rid = rid + + tic = time.time() + + async def _drive_probe() -> bool: + try: + async for _ in self.async_llm.generate_request(probe): + return True + except Exception as e: # noqa: BLE001 — the probe is best-effort. + logger.warning("Health probe failed: %s", e) + return False + return False + + task = asyncio.create_task(_drive_probe()) + try: + while time.time() - tic < HEALTH_CHECK_TIMEOUT: + await asyncio.sleep(0.5) + # Any scheduler push after we started counts as healthy. + if self.async_llm.last_receive_tstamp > tic: + return tokenspeed_scheduler_pb2.HealthCheckResponse( + healthy=True, + message="Health check passed", + ) + if task.done(): + return tokenspeed_scheduler_pb2.HealthCheckResponse( + healthy=bool(task.result()), + message=( + "Health check passed" + if task.result() + else "Scheduler returned no output" + ), + ) + finally: + if not task.done(): + task.cancel() + # Best-effort cleanup: the probe rid shouldn't linger. + self.async_llm.abort_request(rid) + + return tokenspeed_scheduler_pb2.HealthCheckResponse( + healthy=False, + message=f"Health check timeout after {HEALTH_CHECK_TIMEOUT}s", + ) + + # ------------------------------------------------------------------ + # Abort (unary) + # ------------------------------------------------------------------ + + async def Abort( + self, + request: tokenspeed_scheduler_pb2.AbortRequest, + _context: grpc.aio.ServicerContext, + ) -> tokenspeed_scheduler_pb2.AbortResponse: + """Abort the request + any per-choice expansions from n>1. + + Generate rewrites ``n>1`` requests into a list of rids + ``[{request_id}-n0, {request_id}-n1, ...]`` so TokenSpeed's batch + path sees unique rids. Aborting only the original ``request_id`` + would leave those children running — we sweep them all. + """ + rid = request.request_id + logger.info("Abort request %s", rid) + state_map = self.async_llm.rid_to_state + + # Anchored regex avoids matching unrelated rids like "{rid}-name". + child_pattern = re.compile(rf"^{re.escape(rid)}-n\d+$") + targets = [r for r in state_map if r == rid or child_pattern.match(r)] + + try: + for r in targets: + self.async_llm.abort_request(r) + known = bool(targets) + return tokenspeed_scheduler_pb2.AbortResponse( + success=known, + message=( + f"Aborted {len(targets)} request(s) for {rid}" + if known + else f"Request {rid} not found" + ), + ) + except Exception as e: + logger.exception("Abort failed for %s", rid) + return tokenspeed_scheduler_pb2.AbortResponse(success=False, message=str(e)) + + # ------------------------------------------------------------------ + # GetModelInfo (unary) + # ------------------------------------------------------------------ + + async def GetModelInfo( + self, + _request: tokenspeed_scheduler_pb2.GetModelInfoRequest, + _context: grpc.aio.ServicerContext, + ) -> tokenspeed_scheduler_pb2.GetModelInfoResponse: + model_config = self.async_llm.model_config + hf_config = getattr(model_config, "hf_config", None) + + eos = getattr(hf_config, "eos_token_id", None) if hf_config else None + if isinstance(eos, int): + eos_token_ids = [eos] + elif isinstance(eos, list): + eos_token_ids = list(eos) + else: + eos_token_ids = [] + + max_req_input_len = self.scheduler_info.get("max_req_input_len") or ( + self.async_llm.max_req_input_len or 0 + ) + + # TokenSpeed's GetModelInfoResponse intentionally drops + # ``is_generation`` (always true), ``supports_vision`` (always false), + # and ``id2label_json`` / ``num_labels`` (not a classifier serving + # path). The Rust client fills those slots back in when translating + # to its SGLang-shaped wrapper. + # Upstream renamed ``ServerArgs.model_path`` → ``ServerArgs.model`` + # and ``ServerArgs.tokenizer_path`` → ``ServerArgs.tokenizer`` + # alongside the ``--model-path`` → ``--model`` flag rename. Old + # versions still set the ``_path`` form; new ones set the bare + # form. Pick whichever is populated so the servicer works against + # both. + model_path = getattr(self.server_args, "model", None) or getattr( + self.server_args, "model_path", "" + ) + tokenizer_path = getattr(self.server_args, "tokenizer", None) or getattr( + self.server_args, "tokenizer_path", "" + ) + return tokenspeed_scheduler_pb2.GetModelInfoResponse( + model_path=model_path, + tokenizer_path=tokenizer_path or "", + preferred_sampling_params=self.server_args.preferred_sampling_params or "", + weight_version="", + served_model_name=(self.server_args.served_model_name or model_path), + max_context_length=int(self.async_llm.context_len), + vocab_size=int(model_config.vocab_size), + model_type=(getattr(hf_config, "model_type", "") or "") if hf_config else "", + architectures=(getattr(hf_config, "architectures", []) or []) if hf_config else [], + eos_token_ids=eos_token_ids, + pad_token_id=(getattr(hf_config, "pad_token_id", 0) or 0) if hf_config else 0, + bos_token_id=(getattr(hf_config, "bos_token_id", 0) or 0) if hf_config else 0, + max_req_input_len=int(max_req_input_len), + ) + + # ------------------------------------------------------------------ + # GetServerInfo (unary) + # ------------------------------------------------------------------ + + async def GetServerInfo( + self, + _request: tokenspeed_scheduler_pb2.GetServerInfoRequest, + _context: grpc.aio.ServicerContext, + ) -> tokenspeed_scheduler_pb2.GetServerInfoResponse: + # TokenSpeed's ``ServerArgs`` is a dataclass, but tests sometimes pass + # a plain namespace. Fall back to ``__dict__`` so both shapes work. + if dataclasses.is_dataclass(self.server_args) and not isinstance(self.server_args, type): + server_args_dict = dataclasses.asdict(self.server_args) + else: + server_args_dict = dict(getattr(self.server_args, "__dict__", {})) + server_args_struct = Struct() + server_args_struct.update(_make_json_serializable(server_args_dict)) + + scheduler_info_struct = Struct() + scheduler_info_struct.update(_make_json_serializable(dict(self.scheduler_info))) + + uptime = time.time() - self.start_time + start_timestamp = Timestamp() + start_timestamp.FromSeconds(int(self.start_time)) + + try: + import tokenspeed # local import: avoid module-load-time dependency + + version = getattr(tokenspeed, "__version__", "unknown") + except Exception: # noqa: BLE001 — fall back gracefully. + version = "unknown" + + return tokenspeed_scheduler_pb2.GetServerInfoResponse( + server_args=server_args_struct, + scheduler_info=scheduler_info_struct, + active_requests=len(self.async_llm.rid_to_state), + is_paused=False, + uptime_seconds=float(uptime), + tokenspeed_version=version, + start_time=start_timestamp, + max_total_num_tokens=int(self.scheduler_info.get("max_total_num_tokens", 0)), + ) + + # ------------------------------------------------------------------ + # GetLoads (unary) — bridges to TokenSpeed's scheduler-side load metrics + # ------------------------------------------------------------------ + + async def GetLoads( + self, + _request: tokenspeed_scheduler_pb2.GetLoadsRequest, + context: grpc.aio.ServicerContext, + ) -> tokenspeed_scheduler_pb2.GetLoadsResponse: + """Return per-DP-rank scheduler load by RPC-ing the scheduler subprocess. + + ``AsyncLLM`` inherits ``SchedulerControlClient.get_load`` which sends + ``GetLoadReqInput`` over the engine_core_client zmq channel and awaits + a ``List[GetLoadReqOutput]`` reply (one per DP rank). Each reply carries + the live counts the scheduler computes in ``event_loop._get_load``: + ``num_reqs`` (running + waiting), ``num_waiting_reqs``, and + ``num_pages`` (KV pages currently in use). We map those to the + ``SchedulerLoad`` proto plus a coarse aggregate so the router-side + consumer matches what it gets from SGLang. + """ + try: + load_outputs = await asyncio.wait_for( + self.async_llm.get_load(), timeout=HEALTH_CHECK_TIMEOUT + ) + except TimeoutError: + await context.abort( + grpc.StatusCode.DEADLINE_EXCEEDED, + f"tokenspeed scheduler did not respond to GetLoad within {HEALTH_CHECK_TIMEOUT}s", + ) + return + except Exception as e: # noqa: BLE001 + logger.exception("GetLoads failed") + await context.abort(grpc.StatusCode.INTERNAL, str(e)) + return + + page_size = int(getattr(self.async_llm.server_args, "page_size", 1) or 1) + # ``max_total_num_tokens`` lives on the scheduler-side ``scheduler_info`` + # dict that ``launch_engine`` plumbed through at boot — not directly on + # AsyncLLM. Fall back to ``server_args.max_total_num_tokens`` (used in + # tests' SimpleNamespace stubs). + max_total_num_tokens = int( + (self.scheduler_info.get("max_total_num_tokens") if self.scheduler_info else None) + or getattr(self.async_llm.server_args, "max_total_num_tokens", 0) + or 0 + ) + + scheduler_loads: list[tokenspeed_scheduler_pb2.SchedulerLoad] = [] + total_running = 0 + total_waiting = 0 + token_usages: list[float] = [] + for lo in load_outputs: + num_running = max(0, int(lo.num_reqs) - int(lo.num_waiting_reqs)) + num_used_tokens = int(lo.num_pages) * page_size + token_usage = ( + num_used_tokens / max_total_num_tokens if max_total_num_tokens > 0 else 0.0 + ) + scheduler_loads.append( + tokenspeed_scheduler_pb2.SchedulerLoad( + dp_rank=int(lo.dp_rank), + num_running_reqs=num_running, + num_waiting_reqs=int(lo.num_waiting_reqs), + num_total_reqs=int(lo.num_reqs), + num_used_tokens=num_used_tokens, + max_total_num_tokens=max_total_num_tokens, + token_usage=token_usage, + ) + ) + total_running += num_running + total_waiting += int(lo.num_waiting_reqs) + token_usages.append(token_usage) + + aggregate = tokenspeed_scheduler_pb2.AggregateMetrics( + total_running_reqs=total_running, + total_waiting_reqs=total_waiting, + total_reqs=total_running + total_waiting, + avg_token_usage=(sum(token_usages) / len(token_usages)) if token_usages else 0.0, + ) + + return tokenspeed_scheduler_pb2.GetLoadsResponse( + timestamp=datetime.now(timezone.utc).isoformat(), + version="tokenspeed", + dp_rank_count=len(scheduler_loads), + loads=scheduler_loads, + aggregate=aggregate, + ) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + async def shutdown(self, drain_timeout_secs: float = 30.0) -> None: + """Graceful shutdown — drain in-flight requests, then kill scheduler children. + + AsyncLLM's ``sigterm_watchdog`` polls ``gracefully_exit`` every 5s, + drains ``rid_to_state`` and finally calls + ``kill_process_tree(getpid, include_parent=True)``. That works in + steady-state but the gRPC server's main coroutine may unwind before + the watchdog ticks again, in which case the scheduler subprocesses + outlive the parent and end up orphaned. To avoid that, we: + + 1. Flag ``gracefully_exit`` so AsyncLLM stops accepting work and + the watchdog will eventually run its own cleanup. + 2. Wait up to ``drain_timeout_secs`` for ``rid_to_state`` to empty. + 3. Forcibly kill the subprocess tree (``include_parent=False``) so + the scheduler children are reaped regardless of whether the + watchdog tick fires before this coroutine returns. Idempotent + with the watchdog's own ``kill_process_tree`` call. + """ + self.async_llm.gracefully_exit = True + if self.health_servicer: + self.health_servicer.set_not_serving() + + deadline = time.monotonic() + drain_timeout_secs + while time.monotonic() < deadline: + if not getattr(self.async_llm, "rid_to_state", None): + break + await asyncio.sleep(0.5) + else: + logger.warning( + "shutdown drain timed out after %.1fs with %d in-flight requests; " + "killing scheduler children anyway", + drain_timeout_secs, + len(getattr(self.async_llm, "rid_to_state", {}) or {}), + ) + + # Reap the scheduler subprocesses without taking down our own PID; + # server.py's stop sequence still needs us alive to finish gRPC drain. + try: + from tokenspeed.runtime.utils.process import kill_process_tree + except ImportError: + logger.exception( + "Could not import tokenspeed.runtime.utils.process.kill_process_tree; " + "scheduler subprocesses may be orphaned" + ) + return + kill_process_tree(os.getpid(), include_parent=False) + + def _build_generate_req(self, request: tokenspeed_scheduler_pb2.GenerateRequest): + """Translate proto GenerateRequest → TokenSpeed GenerateReqInput. + + Keeps the router's pre-tokenized inputs intact (``input_ids`` set, + ``text`` left blank) so the TokenSpeed InputProcessor skips its own + tokenizer pass. + """ + if not request.HasField("tokenized"): + raise ValueError("GenerateRequest.tokenized is required") + + input_ids = list(request.tokenized.input_ids) + if not input_ids: + raise ValueError("GenerateRequest.tokenized.input_ids is empty") + + sampling = self._sampling_params_from_proto( + request.sampling_params, + reasoning_parser=getattr(self.server_args, "reasoning_parser", None), + ) + + GenerateReqInput = _lazy_generate_req_input() + obj = GenerateReqInput( + input_ids=input_ids, + sampling_params=sampling, + stream=bool(request.stream), + return_logprob=bool(request.return_logprob), + # ``logprob_start_len`` is ``optional int32`` on the wire — use + # presence-tracking, not the proto3 zero-default, to distinguish + # "client omitted" (→ SGLang's ``-1`` = no input logprobs) from + # an explicit ``0`` (→ start input logprobs at position 0). + logprob_start_len=( + request.logprob_start_len if request.HasField("logprob_start_len") else -1 + ), + top_logprobs_num=int(request.top_logprobs_num or 0), + token_ids_logprob=( + list(request.token_ids_logprob) if request.token_ids_logprob else None + ), + # Hidden-state forwarding, multimodal inputs, PD-disaggregated + # serving, LoRA hot-swap and ``log_metrics`` are intentionally + # absent from TokenSpeed's wire — leaving the engine defaults in + # place keeps the call shape simple. + ) + # Older tokenspeed's ``normalize_batch_and_arguments`` treats n>1 as + # batched and asserts ``rid`` is a list in that case. One gRPC + # request carries one rid; expand it to a list of deterministic + # per-choice rids when the caller asked for multiple samples so the + # assert doesn't fire (and the scheduler can still deduplicate). + n = sampling.get("n", 1) or 1 + if n > 1: + obj.rid = [f"{request.request_id}-n{i}" for i in range(n)] + else: + obj.rid = request.request_id + + # NOTE: We deliberately do NOT set ``obj.text`` even when the proto + # carries ``original_text``. TokenSpeed's HTTP serving_chat passes + # ``input_ids=[...], text=None`` to the engine; setting both fields + # has been observed to perturb the engine's input-processor path + # (some validators and normalizers branch on whether text is + # populated). Matching the HTTP shape — ids only, text=None — + # eliminates one source of HTTP-vs-gRPC divergence. + + return obj + + @staticmethod + def _sampling_params_from_proto( + params: tokenspeed_scheduler_pb2.SamplingParams, + *, + reasoning_parser: str | None = None, + ) -> dict[str, Any]: + """Build the dict that ``GenerateReqInput.sampling_params`` expects. + + TokenSpeed's :class:`SamplingParams` consumes this dict via + ``SamplingParams(**obj.sampling_params)``, so field names must match + the Python class (``max_new_tokens``, ``stop``, ``stop_token_ids``, ...). + """ + out: dict[str, Any] = {} + + # All sampling scalars in tokenspeed_scheduler.proto are declared + # ``optional`` (matching ``vllm_engine.proto``). We use + # ``HasField()`` to forward only the values the client explicitly + # set; absent fields fall through to the engine's own + # ``SamplingParams.__init__`` defaults. This eliminates the old + # truthy-check pitfall that silently dropped ``temperature=0`` + # (BFCL's intent for greedy decoding) AND the warmup-default-zero + # crash where invalid ``top_p=0.0`` / ``repetition_penalty=0.0`` + # would reach the engine from internal probe paths. + # + # When ``temperature=0`` does reach the engine (HasField=True for + # an explicitly-sent ``0.0``), the engine + # (``sampling_params.py:104-107``) sets ``top_k=1`` to engage + # greedy decoding. That's the path BFCL relies on. + for _field in ( + "max_new_tokens", + "temperature", + "top_p", + "top_k", + "min_p", + "frequency_penalty", + "presence_penalty", + "repetition_penalty", + ): + if params.HasField(_field): + out[_field] = getattr(params, _field) + + if params.min_new_tokens: + # ``min_new_tokens`` is non-optional; 0 is the "no minimum" sentinel. + out["min_new_tokens"] = params.min_new_tokens + + # Lists + if params.stop: + out["stop"] = list(params.stop) + if params.stop_token_ids: + out["stop_token_ids"] = list(params.stop_token_ids) + + # Bools (always forwarded) + out["skip_special_tokens"] = bool(params.skip_special_tokens) + out["spaces_between_special_tokens"] = bool(params.spaces_between_special_tokens) + out["ignore_eos"] = bool(params.ignore_eos) + # When set, tokenspeed's detokenizer keeps the matched stop token in + # the rendered text (see ``runtime/engine/detokenizer.py``); we also + # suppress the servicer-side ``output_ids`` strip in + # ``_generated_output_ids`` so the EOS reaches the gateway's + # detokenizer when ``skip_special_tokens=False``. + out["no_stop_trim"] = bool(params.no_stop_trim) + + # n (OpenAI-compat, passthrough) + if params.n: + out["n"] = params.n + if params.logit_bias: + out["logit_bias"] = dict(params.logit_bias) + + # Constraint types — exactly one may be set. + if params.HasField("regex"): + out["regex"] = params.regex + elif params.HasField("json_schema"): + # Mirror tokenspeed serving_chat.py: when the engine is + # running with a reasoning parser that has an xgrammar + # template (e.g. ``gpt-oss`` → ``harmony``), wrap the user's + # JSON schema as a structural tag so the grammar only + # activates inside the response channel. Without this, + # xgrammar fights the Harmony channel preamble + # (``<|channel|>analysis<|message|>…``) and the model stalls + # until ``max_tokens``. + wrapped: str | None = None + if reasoning_parser: + try: + from tokenspeed.runtime.grammar.reasoning_structural_tag import ( + structural_tag_for_reasoning_json_schema, + ) + + wrapped = structural_tag_for_reasoning_json_schema( + reasoning_parser, json.loads(params.json_schema) + ) + except ImportError: + wrapped = None + if wrapped is not None: + out["structural_tag"] = wrapped + else: + out["json_schema"] = params.json_schema + elif params.HasField("ebnf_grammar"): + out["ebnf"] = params.ebnf_grammar + elif params.HasField("structural_tag"): + out["structural_tag"] = params.structural_tag + + return out + + def _generated_output_ids( + self, + output: dict, + reason_dict: dict | None, + *, + no_stop_trim: bool = False, + ) -> list[int]: + """Return just the newly-generated tokens from a TokenSpeed output dict. + + TokenSpeed's AsyncLLM has two quirks that the SGLang gRPC proto contract + doesn't expect, both of which break the smg gateway's detokenization + layer and downstream tool-call parsing: + + 1. ``output_ids`` is prefixed with the Llama-3 chat-template assistant + header: ``[<|eot_id|>, <|start_header_id|>, "assistant", + <|end_header_id|>, "\\n\\n", ...generated..., ]``. The + ``skip_special_tokens=True`` detokenization strips the 128xxx + control tokens but keeps the word tokens ``"assistant"`` (78191) + and ``"\\n\\n"`` (271), so the final text looks like + ``assistant\\n\\n{"name": ...}``. The ``llama`` tool parser's + ``serde_json::from_str`` can't handle leading non-JSON prefix and + silently returns zero tool calls. + 2. The trailing stop token (e.g. ``<|eom_id|>`` = 128008) is included + in ``output_ids``; SGLang excludes it. If the gateway ever runs + with ``skip_special_tokens=False`` the stop leaks into the decoded + text and breaks JSON parsing for the same reason. + + Slicing the last ``meta_info.completion_tokens`` tokens gives us the + bare generated sequence that SGLang's ``token_ids`` would carry, and + we then defensively drop any trailing matched stop token. The + per-choice ``matched_stop`` fires in a separate proto field, so no + information is lost. + """ + raw = list(output.get("output_ids") or []) + if not raw: + return raw + completion = output.get("meta_info", {}).get("completion_tokens") + if isinstance(completion, int) and 0 < completion <= len(raw): + token_ids = raw[-completion:] + else: + token_ids = raw + if not no_stop_trim and reason_dict and reason_dict.get("type") == "stop": + matched = reason_dict.get("matched") + if isinstance(matched, int) and token_ids and token_ids[-1] == matched: + token_ids = token_ids[:-1] + return token_ids + + def _chunk_response( + self, + rid: str, + output: dict, + reason_dict: dict | None, + choice_index: int = 0, + *, + no_stop_trim: bool = False, + ) -> tokenspeed_scheduler_pb2.GenerateResponse: + meta = output.get("meta_info", {}) + token_ids = self._generated_output_ids(output, reason_dict, no_stop_trim=no_stop_trim) + return tokenspeed_scheduler_pb2.GenerateResponse( + request_id=rid, + chunk=tokenspeed_scheduler_pb2.GenerateStreamChunk( + token_ids=token_ids, + prompt_tokens=int(meta.get("prompt_tokens", 0)), + completion_tokens=int(meta.get("completion_tokens", len(token_ids))), + cached_tokens=int(meta.get("cached_tokens", 0)), + output_logprobs=self._convert_output_logprobs_to_proto(output, len(token_ids)), + index=choice_index, + ), + ) + + def _complete_response( + self, + rid: str, + output: dict, + reason_dict: dict | None, + choice_index: int = 0, + *, + no_stop_trim: bool = False, + ) -> tokenspeed_scheduler_pb2.GenerateResponse: + meta = output.get("meta_info", {}) + token_ids = self._generated_output_ids(output, reason_dict, no_stop_trim=no_stop_trim) + + finish_reason = "stop" + matched_kwargs: dict[str, Any] = {} + if reason_dict: + kind = reason_dict.get("type") + if kind == "length": + finish_reason = "length" + elif kind == "abort": + finish_reason = "abort" + matched = reason_dict.get("matched") + if isinstance(matched, int): + matched_kwargs["matched_token_id"] = matched + elif isinstance(matched, str): + matched_kwargs["matched_stop_str"] = matched + + return tokenspeed_scheduler_pb2.GenerateResponse( + request_id=rid, + complete=tokenspeed_scheduler_pb2.GenerateComplete( + output_ids=token_ids, + finish_reason=finish_reason, + prompt_tokens=int(meta.get("prompt_tokens", 0)), + completion_tokens=int(meta.get("completion_tokens", len(token_ids))), + cached_tokens=int(meta.get("cached_tokens", 0)), + output_logprobs=self._convert_output_logprobs_to_proto(output, len(token_ids)), + index=choice_index, + **matched_kwargs, + ), + ) + + @staticmethod + def _convert_output_logprobs_to_proto( + output: dict, n_keep: int + ) -> tokenspeed_scheduler_pb2.OutputLogProbs | None: + """Build an ``OutputLogProbs`` proto from a tokenspeed output dict. + + TokenSpeed accumulates the request's logprobs in per-request state + across chunks; ``meta_info["output_token_logprobs"]`` is therefore the + running cumulative list of detokenized + ``(logprob: float, token_id: int, text: Optional[str])`` tuples, and + ``meta_info["output_top_logprobs"]`` is the parallel list of top-K + alternatives per position (each entry is ``None`` or a list of the + same tuple shape). + + We slice the cumulative list down to just **this frame's tokens** by + taking the last ``len(output["output_ids"])`` entries — that's how + many new tokens this frame emitted — and then keep only the first + ``n_keep`` of those, so the alignment matches whatever + ``_generated_output_ids`` returned (it strips a trailing stop token + when the finish reason is ``stop``, leaving the last logprob entry + with no corresponding output id). + + Returns ``None`` when there are no logprobs to emit — either the + client did not request them, or the server was started without + ``--enable-output-logprobs`` (in which case TokenSpeed silently + leaves these meta_info lists empty rather than raising). + """ + if n_keep <= 0: + return None + meta = output.get("meta_info", {}) or {} + raw_token = meta.get("output_token_logprobs") or [] + if not raw_token: + return None + n_chunk = len(output.get("output_ids", []) or []) + if n_chunk <= 0: + return None + + raw_top = meta.get("output_top_logprobs") or [] + chunk_token = raw_token[-n_chunk:] if len(raw_token) >= n_chunk else raw_token + chunk_top = raw_top[-n_chunk:] if len(raw_top) >= n_chunk else raw_top + delta_token = chunk_token[:n_keep] + delta_top = chunk_top[:n_keep] + + top_proto = [] + for entry in delta_top: + if entry: + top_proto.append( + tokenspeed_scheduler_pb2.TopLogProbs( + values=[t[0] for t in entry], + token_ids=[t[1] for t in entry], + ) + ) + else: + # Position with no top-K data (e.g. ``--enable-top-logprobs`` + # is not yet implemented in TokenSpeed; we still emit a + # placeholder per position so the gateway can align indices). + top_proto.append(tokenspeed_scheduler_pb2.TopLogProbs()) + + return tokenspeed_scheduler_pb2.OutputLogProbs( + token_logprobs=[t[0] for t in delta_token], + token_ids=[t[1] for t in delta_token], + top_logprobs=top_proto, + ) + + +def _abort_status_code(reason: dict) -> grpc.StatusCode: + status_code = reason.get("status_code") + if status_code == 400: + return grpc.StatusCode.INVALID_ARGUMENT + if status_code in (408, 504): + return grpc.StatusCode.DEADLINE_EXCEEDED + if status_code == 429: + return grpc.StatusCode.RESOURCE_EXHAUSTED + return grpc.StatusCode.INTERNAL + + +def _make_json_serializable(obj: Any) -> Any: + """Flatten an arbitrary dataclass/config graph into JSON-safe primitives.""" + if obj is None or isinstance(obj, str | int | float | bool): + return obj + if isinstance(obj, list | tuple | set): + return [_make_json_serializable(x) for x in obj] + if isinstance(obj, dict): + return {str(k): _make_json_serializable(v) for k, v in obj.items()} + return str(obj) diff --git a/grpc_servicer/tests/__init__.py b/grpc_servicer/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/grpc_servicer/tests/conftest.py b/grpc_servicer/tests/conftest.py new file mode 100644 index 000000000..3ceadba4f --- /dev/null +++ b/grpc_servicer/tests/conftest.py @@ -0,0 +1,22 @@ +"""Pytest configuration for smg-grpc-servicer unit tests. + +Adds the parent directory to ``sys.path`` so editable installs work +without needing ``pip install -e``, and declares an asyncio-mode default. +""" + +from __future__ import annotations + +import pathlib +import sys + +import pytest + +_HERE = pathlib.Path(__file__).resolve().parent +_PKG_ROOT = _HERE.parent + +if str(_PKG_ROOT) not in sys.path: + sys.path.insert(0, str(_PKG_ROOT)) + + +def pytest_configure(config: pytest.Config) -> None: + config.addinivalue_line("markers", "tokenspeed: tests that require TokenSpeed") diff --git a/grpc_servicer/tests/test_tokenspeed_health_servicer.py b/grpc_servicer/tests/test_tokenspeed_health_servicer.py new file mode 100644 index 000000000..df4856af1 --- /dev/null +++ b/grpc_servicer/tests/test_tokenspeed_health_servicer.py @@ -0,0 +1,98 @@ +"""Unit tests for ``smg_grpc_servicer.tokenspeed.health_servicer``.""" + +from __future__ import annotations + +import time +from dataclasses import dataclass, field +from typing import Any +from unittest.mock import MagicMock + +import grpc +import pytest +from grpc_health.v1 import health_pb2 # noqa: E402 +from smg_grpc_servicer.tokenspeed.health_servicer import ( # noqa: E402 + TokenSpeedHealthServicer, +) + + +@dataclass +class FakeEngine: + gracefully_exit: bool = False + last_receive_tstamp: float = 0.0 + rid_to_state: dict[str, Any] = field(default_factory=dict) + + +@pytest.fixture +def servicer() -> TokenSpeedHealthServicer: + return TokenSpeedHealthServicer( + async_llm=FakeEngine(), + scheduler_info={}, + ) + + +@pytest.mark.asyncio +async def test_initial_state_is_not_serving(servicer: TokenSpeedHealthServicer): + ctx = MagicMock(spec=grpc.aio.ServicerContext) + resp = await servicer.Check(health_pb2.HealthCheckRequest(service=""), ctx) + assert resp.status == health_pb2.HealthCheckResponse.NOT_SERVING + + +@pytest.mark.asyncio +async def test_set_serving_flips_both_levels(servicer: TokenSpeedHealthServicer): + ctx = MagicMock(spec=grpc.aio.ServicerContext) + servicer.set_serving() + + # overall + resp = await servicer.Check(health_pb2.HealthCheckRequest(service=""), ctx) + assert resp.status == health_pb2.HealthCheckResponse.SERVING + + # specific + resp = await servicer.Check( + health_pb2.HealthCheckRequest(service=servicer.TOKENSPEED_SERVICE), ctx + ) + assert resp.status == health_pb2.HealthCheckResponse.SERVING + + +@pytest.mark.asyncio +async def test_shutdown_flips_back(servicer: TokenSpeedHealthServicer): + ctx = MagicMock(spec=grpc.aio.ServicerContext) + servicer.set_serving() + servicer.async_llm.gracefully_exit = True + resp = await servicer.Check(health_pb2.HealthCheckRequest(service=""), ctx) + assert resp.status == health_pb2.HealthCheckResponse.NOT_SERVING + + +@pytest.mark.asyncio +async def test_unknown_service_returns_unknown(servicer: TokenSpeedHealthServicer): + ctx = MagicMock(spec=grpc.aio.ServicerContext) + resp = await servicer.Check(health_pb2.HealthCheckRequest(service="bogus.Service"), ctx) + assert resp.status == health_pb2.HealthCheckResponse.SERVICE_UNKNOWN + ctx.set_code.assert_called_once_with(grpc.StatusCode.NOT_FOUND) + + +@pytest.mark.asyncio +async def test_stuck_scheduler_flips_to_not_serving( + servicer: TokenSpeedHealthServicer, +): + ctx = MagicMock(spec=grpc.aio.ServicerContext) + servicer.set_serving() + # Simulate "pending requests, but scheduler hasn't pushed output for 45s" + servicer.async_llm.last_receive_tstamp = time.time() - 45 + servicer.async_llm.rid_to_state["rid-1"] = object() + + resp = await servicer.Check( + health_pb2.HealthCheckRequest(service=servicer.TOKENSPEED_SERVICE), ctx + ) + assert resp.status == health_pb2.HealthCheckResponse.NOT_SERVING + + +@pytest.mark.asyncio +async def test_recent_activity_keeps_serving(servicer: TokenSpeedHealthServicer): + ctx = MagicMock(spec=grpc.aio.ServicerContext) + servicer.set_serving() + servicer.async_llm.last_receive_tstamp = time.time() - 1 + servicer.async_llm.rid_to_state["rid-1"] = object() + resp = await servicer.Check( + health_pb2.HealthCheckRequest(service=servicer.TOKENSPEED_SERVICE), ctx + ) + assert resp.status == health_pb2.HealthCheckResponse.SERVING diff --git a/grpc_servicer/tests/test_tokenspeed_servicer.py b/grpc_servicer/tests/test_tokenspeed_servicer.py new file mode 100644 index 000000000..89ed5c549 --- /dev/null +++ b/grpc_servicer/tests/test_tokenspeed_servicer.py @@ -0,0 +1,1103 @@ +"""Unit tests for ``smg_grpc_servicer.tokenspeed.servicer``. + +Runs against a minimal ``FakeAsyncLLM`` that implements only the AsyncLLM +surface the servicer actually touches. We *do* require TokenSpeed to be +importable (the servicer takes real request classes from ``tokenspeed.*``), +so the whole module is skipped when TokenSpeed is not installed. +""" + +from __future__ import annotations + +import asyncio +from collections.abc import Callable +from dataclasses import dataclass, field +from types import SimpleNamespace +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import grpc +import pytest + +pytest.importorskip( + "smg_grpc_proto", + reason="smg-grpc-proto must be installed to test the servicer", +) + +from smg_grpc_proto.generated import tokenspeed_scheduler_pb2 # noqa: E402 +from smg_grpc_servicer.tokenspeed import servicer as _servicer_module # noqa: E402 +from smg_grpc_servicer.tokenspeed.servicer import ( # noqa: E402 + TokenSpeedSchedulerServicer, + _abort_status_code, + _finish_reason_to_dict, + _make_json_serializable, +) + +# --------------------------------------------------------------------------- +# Stub request class. The servicer lazily imports ``GenerateReqInput`` so +# tests can substitute a minimal local stand-in without pulling in +# TokenSpeed's full scheduler graph. (No ``EmbeddingReqInput`` — the slim +# TokenSpeed proto removed the Embed RPC.) +# --------------------------------------------------------------------------- + + +class _StubReq: + """Minimal stand-in with the attributes the servicer sets on req objects.""" + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + # Allow later attribute assignment for rid / text. + self.rid = None + self.text = None + + +class StubGenerateReqInput(_StubReq): + pass + + +@pytest.fixture(autouse=True) +def _stub_request_inputs(monkeypatch): + """Redirect the servicer's lazy GenerateReqInput import to a local stub.""" + monkeypatch.setattr(_servicer_module, "_lazy_generate_req_input", lambda: StubGenerateReqInput) + yield + + +# --------------------------------------------------------------------------- +# Local fake finish-reason classes. The servicer duck-types on ``.to_json()`` +# so tests don't need to import TokenSpeed's request_types module (which +# pulls in the full scheduler graph and breaks in minimal test envs). +# --------------------------------------------------------------------------- + + +class FINISH_MATCHED_TOKEN: + def __init__(self, matched): + self.matched = matched + + def to_json(self): + return {"type": "stop", "matched": self.matched} + + +class FINISH_MATCHED_STR: + def __init__(self, matched): + self.matched = matched + + def to_json(self): + return {"type": "stop", "matched": self.matched} + + +class FINISH_LENGTH: + def __init__(self, length): + self.length = length + + def to_json(self): + return {"type": "length", "length": self.length} + + +class FINISH_ABORT: + def __init__(self, message="Unknown error"): + self.message = message + + def to_json(self): + return {"type": "abort", "message": self.message} + + +# --------------------------------------------------------------------------- +# FakeAsyncLLM — minimal stand-in for TokenSpeed's AsyncLLM in unit tests. +# --------------------------------------------------------------------------- + + +@dataclass +class _FakeState: + finished: bool = False + + +@dataclass +class FakeAsyncLLM: + """Implements just enough AsyncLLM surface to drive the servicer.""" + + outputs: list[dict] = field(default_factory=list) + is_generation: bool = True + context_len: int = 8192 + max_req_input_len: int | None = 4096 + # Captured state — the servicer mutates/inspects these. + rid_to_state: dict[str, _FakeState] = field(default_factory=dict) + gracefully_exit: bool = False + last_receive_tstamp: float = 0.0 + handle_loop_started: bool = False + aborted_rids: list[str] = field(default_factory=list) + # Override hook: a callable producing outputs per request, used for + # tests that need dynamic yields (e.g. cancel mid-stream). + generate_fn: Callable[[Any], Any] | None = None + + # Default load-fixture: single DP rank, 1 running request, no waiting, + # 100 used pages out of (max_total_num_tokens / page_size). Tests can + # override ``load_outputs`` directly to assert proto-mapping semantics. + load_outputs: list[Any] = field(default_factory=list) + max_total_num_tokens: int = 8192 + + server_args: Any = field( + default_factory=lambda: SimpleNamespace( + model_path="fake-model", + tokenizer_path="fake-model", + served_model_name="fake-model", + preferred_sampling_params=None, + page_size=16, + ) + ) + model_config: Any = field( + default_factory=lambda: SimpleNamespace( + vocab_size=32000, + is_multimodal=False, + hf_config=SimpleNamespace( + eos_token_id=2, + pad_token_id=0, + bos_token_id=1, + model_type="llama", + architectures=["LlamaForCausalLM"], + ), + ) + ) + + def auto_create_handle_loop(self) -> None: + self.handle_loop_started = True + + def abort_request(self, rid: str) -> None: + self.aborted_rids.append(rid) + self.rid_to_state.pop(rid, None) + + async def get_load(self): + # Mirror SchedulerControlClient.get_load — returns the configured + # ``load_outputs`` so tests can drive proto-mapping assertions. + return list(self.load_outputs) + + async def generate_request(self, obj): + # Record the request so tests can assert on what was forwarded. + # ``_build_generate_req`` rewrites ``rid`` to a list of per-choice ids + # when n>1; register state for each so the cancel sweep can abort them + # individually (and so dict assignment doesn't crash on a list key). + rid_attr = getattr(obj, "rid", None) or "no-rid" + rids = list(rid_attr) if isinstance(rid_attr, list) else [rid_attr] + for r in rids: + self.rid_to_state[r] = _FakeState() + if self.generate_fn is not None: + async for out in self.generate_fn(obj): + self.last_receive_tstamp = 9999.0 # anything > tic + yield out + return + for out in self.outputs: + self.last_receive_tstamp = 9999.0 + yield out + for r in rids: + self.rid_to_state[r].finished = True + + +@pytest.fixture +def fake_engine() -> FakeAsyncLLM: + return FakeAsyncLLM() + + +@pytest.fixture +def servicer(fake_engine: FakeAsyncLLM) -> TokenSpeedSchedulerServicer: + return TokenSpeedSchedulerServicer( + async_llm=fake_engine, + server_args=fake_engine.server_args, + scheduler_info={ + "max_total_num_tokens": 100000, + "max_req_input_len": 4096, + }, + ) + + +class _FakeAbortError(grpc.aio.AbortError): + """Stand-in for grpc.aio.AbortError raised by our mock context.abort().""" + + def __init__(self, code: grpc.StatusCode, details: str): + super().__init__() + self.code = code + self.details = details + + def __str__(self) -> str: # makes pytest.raises(match=...) useful + return f"ABORT({self.code.name}, {self.details})" + + +def _make_context() -> MagicMock: + """Build a grpc.aio.ServicerContext whose ``abort()`` raises AbortError. + + Real gRPC servicer contexts raise ``grpc.aio.AbortError`` from + ``context.abort()``. The servicer has a dedicated ``except + grpc.aio.AbortError: raise`` branch to let that propagate cleanly, so + the mock reproduces that behaviour. + """ + ctx = MagicMock(spec=grpc.aio.ServicerContext) + + async def _abort(code, details): + raise _FakeAbortError(code, details) + + ctx.abort = AsyncMock(side_effect=_abort) + ctx.set_code = MagicMock() + ctx.set_details = MagicMock() + return ctx + + +# --------------------------------------------------------------------------- +# Pure-helper tests +# --------------------------------------------------------------------------- + + +class TestFinishReasonToDict: + def test_none(self): + assert _finish_reason_to_dict(None) is None + + def test_length(self): + assert _finish_reason_to_dict(FINISH_LENGTH(length=42)) == { + "type": "length", + "length": 42, + } + + def test_matched_token(self): + assert _finish_reason_to_dict(FINISH_MATCHED_TOKEN(matched=7)) == { + "type": "stop", + "matched": 7, + } + + def test_matched_str(self): + assert _finish_reason_to_dict(FINISH_MATCHED_STR(matched="")) == { + "type": "stop", + "matched": "", + } + + def test_abort(self): + out = _finish_reason_to_dict(FINISH_ABORT(message="boom")) + assert out["type"] == "abort" + assert out["message"] == "boom" + + def test_passthrough_dict(self): + d = {"type": "stop", "matched": "foo"} + assert _finish_reason_to_dict(d) is d + + def test_unknown_raises_typeerror(self): + # Unknown shapes raise TypeError rather than coercing to a fake + # ``stop`` dict: silently flipping length/abort to stop and leaking + # repr() into the user-facing matched_stop_str field would corrupt + # the OpenAI ``finish_reason`` semantics. The Generate handler's + # ``except Exception`` turns the TypeError into INTERNAL. + with pytest.raises(TypeError, match="Unknown finish_reason shape"): + _finish_reason_to_dict("weird") + with pytest.raises(TypeError, match="Unknown finish_reason shape"): + _finish_reason_to_dict(42) + + +class TestAbortStatusCode: + @pytest.mark.parametrize( + "status_code, expected", + [ + (400, grpc.StatusCode.INVALID_ARGUMENT), + (408, grpc.StatusCode.DEADLINE_EXCEEDED), + (504, grpc.StatusCode.DEADLINE_EXCEEDED), + (429, grpc.StatusCode.RESOURCE_EXHAUSTED), + (500, grpc.StatusCode.INTERNAL), + (None, grpc.StatusCode.INTERNAL), + ], + ) + def test_mapping(self, status_code, expected): + assert _abort_status_code({"status_code": status_code}) == expected + + +class TestMakeJsonSerializable: + def test_primitives(self): + assert _make_json_serializable(1) == 1 + assert _make_json_serializable("x") == "x" + assert _make_json_serializable(True) is True + assert _make_json_serializable(None) is None + + def test_list_tuple_set(self): + assert _make_json_serializable([1, "a"]) == [1, "a"] + assert _make_json_serializable((1, 2)) == [1, 2] + assert _make_json_serializable({1, 2, 3}) in ( + [1, 2, 3], + [1, 3, 2], + [2, 1, 3], + [2, 3, 1], + [3, 1, 2], + [3, 2, 1], + ) + + def test_nested_dict(self): + assert _make_json_serializable({"a": [1, {"b": 2}]}) == {"a": [1, {"b": 2}]} + + def test_exotic_types_coerced_to_str(self): + class Foo: + def __str__(self): + return "foo-str" + + assert _make_json_serializable(Foo()) == "foo-str" + + +# --------------------------------------------------------------------------- +# Sampling params conversion +# --------------------------------------------------------------------------- + + +class TestSamplingParamsConversion: + def test_defaults_not_forwarded(self): + params = tokenspeed_scheduler_pb2.SamplingParams() + out = TokenSpeedSchedulerServicer._sampling_params_from_proto(params) + # proto3 defaults (0 / False / "") should not end up as TokenSpeed + # overrides — only the always-forwarded bool fields appear. + assert "temperature" not in out + assert "top_p" not in out + assert "top_k" not in out + assert "max_new_tokens" not in out + # always-forwarded bools + assert out["skip_special_tokens"] is False + assert out["spaces_between_special_tokens"] is False + assert out["ignore_eos"] is False + + def test_numeric_fields_forwarded(self): + params = tokenspeed_scheduler_pb2.SamplingParams( + temperature=0.7, + top_p=0.9, + top_k=50, + min_p=0.05, + frequency_penalty=0.1, + presence_penalty=0.2, + repetition_penalty=1.1, + max_new_tokens=128, + min_new_tokens=4, + ) + out = TokenSpeedSchedulerServicer._sampling_params_from_proto(params) + assert out["temperature"] == pytest.approx(0.7) + assert out["top_p"] == pytest.approx(0.9) + assert out["top_k"] == 50 + assert out["min_p"] == pytest.approx(0.05) + assert out["frequency_penalty"] == pytest.approx(0.1) + assert out["presence_penalty"] == pytest.approx(0.2) + assert out["repetition_penalty"] == pytest.approx(1.1) + assert out["max_new_tokens"] == 128 + assert out["min_new_tokens"] == 4 + + def test_stop_lists_and_logit_bias(self): + params = tokenspeed_scheduler_pb2.SamplingParams( + stop=["\n\n", ""], + stop_token_ids=[2, 0], + logit_bias={"100": -10.0, "200": 10.0}, + ) + out = TokenSpeedSchedulerServicer._sampling_params_from_proto(params) + assert out["stop"] == ["\n\n", ""] + assert out["stop_token_ids"] == [2, 0] + assert out["logit_bias"] == {"100": -10.0, "200": 10.0} + + @pytest.mark.parametrize( + "setter, key, value", + [ + (lambda p: setattr(p, "regex", "a.*"), "regex", "a.*"), + (lambda p: setattr(p, "json_schema", "{}"), "json_schema", "{}"), + (lambda p: setattr(p, "ebnf_grammar", "g"), "ebnf", "g"), + (lambda p: setattr(p, "structural_tag", "tag"), "structural_tag", "tag"), + ], + ) + def test_constraints(self, setter, key, value): + params = tokenspeed_scheduler_pb2.SamplingParams() + setter(params) + out = TokenSpeedSchedulerServicer._sampling_params_from_proto(params) + assert out[key] == value + + def test_json_schema_no_reasoning_parser_passes_through(self): + params = tokenspeed_scheduler_pb2.SamplingParams(json_schema='{"type": "object"}') + out = TokenSpeedSchedulerServicer._sampling_params_from_proto(params, reasoning_parser=None) + assert out["json_schema"] == '{"type": "object"}' + assert "structural_tag" not in out + + def test_json_schema_with_reasoning_parser_wraps_as_structural_tag(self, monkeypatch): + import sys + import types + + fake_module = types.ModuleType("tokenspeed.runtime.grammar.reasoning_structural_tag") + captured: dict[str, Any] = {} + + def _fake_wrap(rp: str, schema: Any) -> str: + captured["rp"] = rp + captured["schema"] = schema + return '{"wrapped": "tag"}' + + fake_module.structural_tag_for_reasoning_json_schema = _fake_wrap + monkeypatch.setitem( + sys.modules, + "tokenspeed.runtime.grammar.reasoning_structural_tag", + fake_module, + ) + + params = tokenspeed_scheduler_pb2.SamplingParams(json_schema='{"type": "object"}') + out = TokenSpeedSchedulerServicer._sampling_params_from_proto( + params, reasoning_parser="gpt-oss" + ) + + assert "json_schema" not in out + assert out["structural_tag"] == '{"wrapped": "tag"}' + assert captured["rp"] == "gpt-oss" + assert captured["schema"] == {"type": "object"} + + def test_json_schema_unknown_parser_falls_back_to_raw(self, monkeypatch): + import sys + import types + + fake_module = types.ModuleType("tokenspeed.runtime.grammar.reasoning_structural_tag") + fake_module.structural_tag_for_reasoning_json_schema = lambda rp, s: None + monkeypatch.setitem( + sys.modules, + "tokenspeed.runtime.grammar.reasoning_structural_tag", + fake_module, + ) + + params = tokenspeed_scheduler_pb2.SamplingParams(json_schema='{"type": "object"}') + out = TokenSpeedSchedulerServicer._sampling_params_from_proto( + params, reasoning_parser="unknown-parser" + ) + + assert out["json_schema"] == '{"type": "object"}' + assert "structural_tag" not in out + + +# --------------------------------------------------------------------------- +# Generate RPC +# --------------------------------------------------------------------------- + + +def _make_generate_request( + *, + request_id: str = "rid-1", + input_ids: list[int] | None = None, + stream: bool = False, + max_new_tokens: int = 16, +) -> tokenspeed_scheduler_pb2.GenerateRequest: + return tokenspeed_scheduler_pb2.GenerateRequest( + request_id=request_id, + tokenized=tokenspeed_scheduler_pb2.TokenizedInput( + # Preserve explicit empty-list inputs (for "rejects empty ids" test); + # only fall back to the default if the caller didn't supply any. + input_ids=(input_ids if input_ids is not None else [1, 2, 3, 4]), + original_text="hello", + ), + sampling_params=tokenspeed_scheduler_pb2.SamplingParams( + temperature=0.0, + max_new_tokens=max_new_tokens, + ), + stream=stream, + ) + + +class TestGenerate: + @pytest.mark.asyncio + async def test_non_streaming_emits_complete( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + # TokenSpeed's AsyncLLM includes the trailing matched-stop token in + # ``output_ids`` (and prepends chat-template header tokens — modeled in + # ``test_strips_chat_template_prefix`` below). The servicer normalizes + # these out before the proto goes to the smg gateway so the tool + # parsers see the same tokens they would from the SGLang path. Here we + # check the matched-stop trim: ``raw=[10,11,12]`` with ``matched=12`` + # should arrive as ``[10,11]`` on the wire, and the matched id still + # rides in the ``matched_token_id`` field. + fake_engine.outputs = [ + { + "text": "hi", + "output_ids": [10, 11, 12], + "meta_info": { + "prompt_tokens": 4, + "completion_tokens": 3, + "cached_tokens": 0, + "finish_reason": FINISH_MATCHED_TOKEN(matched=12), + }, + } + ] + ctx = _make_context() + req = _make_generate_request(stream=False) + + frames = [frame async for frame in servicer.Generate(req, ctx)] + assert len(frames) == 1 + frame = frames[0] + assert frame.request_id == "rid-1" + assert frame.HasField("complete") + complete = frame.complete + assert list(complete.output_ids) == [10, 11] + assert complete.finish_reason == "stop" + assert complete.matched_token_id == 12 + assert complete.prompt_tokens == 4 + # Meta's completion_tokens passes through unchanged — matches SGLang's + # ``meta_info.get("completion_tokens")`` convention — even though the + # on-the-wire ``output_ids`` drops the stop token. + assert complete.completion_tokens == 3 + ctx.abort.assert_not_called() + + @pytest.mark.asyncio + async def test_strips_chat_template_prefix( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + """Reproducer for the bug where ``assistant\\n\\n`` leaked into the + decoded text and broke the ``llama`` tool-call parser. + + Real-world capture on Llama-3.2-1B-Instruct with a function-calling + prompt — ``output_ids`` was 27 tokens: 5 chat-template header tokens + (``<|eot_id|>, <|start_header_id|>, "assistant", <|end_header_id|>, + "\\n\\n"``) + 21 generated JSON tokens + 1 ``<|eom_id|>`` stop. With + ``skip_special_tokens=True`` only the 128xxx control tokens get + stripped at detokenization time, so the word token ``"assistant"`` + (78191) and ``"\\n\\n"`` (271) leaked into the text and flipped + ``serde_json::from_str`` from succeeding on clean JSON to failing on + ``assistant\\n\\n{...}``. + + The servicer now slices to the last ``completion_tokens`` tokens so + downstream detokenization only sees the actual generated content. + """ + fake_engine.outputs = [ + { + "text": '{"name": "add", "parameters": {"a": 3, "b": 5}}', + # Shape observed in the wild: [<|eot|>, <|start|>, "assistant", + # <|end|>, "\n\n", ...21 json tokens, <|eom|>] = 27 tokens. + # ``completion_tokens`` in TokenSpeed's meta covers the content + # *plus* the stop token, so 21 + 1 = 22. + "output_ids": [ + 128009, + 128006, + 78191, + 128007, + 271, + *range(9000, 9021), + 128008, + ], + "meta_info": { + "prompt_tokens": 200, + "completion_tokens": 22, + "cached_tokens": 0, + "finish_reason": FINISH_MATCHED_TOKEN(matched=128008), + }, + } + ] + ctx = _make_context() + req = _make_generate_request(stream=False) + + frames = [frame async for frame in servicer.Generate(req, ctx)] + complete = frames[0].complete + # Header tokens dropped via the ``raw[-completion_tokens:]`` slice; + # trailing stop token dropped because ``matched == token_ids[-1]``. + assert list(complete.output_ids) == list(range(9000, 9021)) + assert complete.matched_token_id == 128008 + # meta_info.completion_tokens passes through; only ``output_ids`` is + # normalized. Keeps the tokenspeed servicer's wire contract aligned + # with the SGLang reference. + assert complete.completion_tokens == 22 + + @pytest.mark.asyncio + async def test_streaming_emits_chunks_then_complete( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + fake_engine.outputs = [ + { + "text": "hi", + "output_ids": [10], # delta chunk 1 + "meta_info": { + "prompt_tokens": 4, + "completion_tokens": 1, + "cached_tokens": 0, + "finish_reason": None, + }, + }, + { + "text": "hi there", + "output_ids": [11, 12], # delta chunk 2 + finish + "meta_info": { + "prompt_tokens": 4, + "completion_tokens": 3, + "cached_tokens": 0, + "finish_reason": FINISH_LENGTH(length=16), + }, + }, + ] + ctx = _make_context() + req = _make_generate_request(stream=True) + + frames = [frame async for frame in servicer.Generate(req, ctx)] + # Expect: 2 chunks + 1 complete (emitted alongside the final chunk). + # ``completion_tokens`` here (3) exceeds this chunk's delta length (2), + # so the slice falls back to the raw delta. Length-finish has no + # matched stop to strip either, so token_ids pass through. + assert len(frames) == 3 + assert frames[0].HasField("chunk") + assert list(frames[0].chunk.token_ids) == [10] + assert frames[1].HasField("chunk") + assert list(frames[1].chunk.token_ids) == [11, 12] + assert frames[2].HasField("complete") + assert frames[2].complete.finish_reason == "length" + assert list(frames[2].complete.output_ids) == [11, 12] + + @pytest.mark.asyncio + async def test_empty_input_ids_rejected( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + ctx = _make_context() + req = _make_generate_request(input_ids=[]) + + with pytest.raises(_FakeAbortError) as exc: + async for _ in servicer.Generate(req, ctx): + pass + assert exc.value.code == grpc.StatusCode.INVALID_ARGUMENT + ctx.abort.assert_awaited_once() + + @pytest.mark.asyncio + async def test_abort_finish_reason_surfaces_as_grpc_error( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + fake_engine.outputs = [ + { + "text": "", + "output_ids": [], + "meta_info": { + "prompt_tokens": 0, + "completion_tokens": 0, + "cached_tokens": 0, + "finish_reason": { + "type": "abort", + "message": "client disconnected", + "status_code": 400, + }, + }, + } + ] + ctx = _make_context() + req = _make_generate_request() + + with pytest.raises(_FakeAbortError) as exc: + async for _ in servicer.Generate(req, ctx): + pass + assert exc.value.code == grpc.StatusCode.INVALID_ARGUMENT + + @pytest.mark.asyncio + async def test_cancel_calls_abort_request( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + """Cancelling the Generate task should tell the scheduler to drop the rid.""" + + started = asyncio.Event() + + async def never_finish(_obj): + started.set() + # Block forever so we can cancel from outside. ``yield`` is + # unreachable but keeps this an async generator. + await asyncio.sleep(30) + yield {} # pragma: no cover + + fake_engine.generate_fn = never_finish + ctx = _make_context() + req = _make_generate_request() + + gen = servicer.Generate(req, ctx) + task = asyncio.create_task(_drain(gen)) + await started.wait() + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + assert "rid-1" in fake_engine.aborted_rids + + @pytest.mark.asyncio + async def test_cancel_aborts_all_n_children( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + """n>1 expands rid to a list of per-choice ids; cancel must sweep them all. + + _build_generate_req rewrites ``rid`` to ``[rid-n0, rid-n1, ...]`` so + TokenSpeed's batch path sees unique rids per choice. If Generate's + cancel handler aborts only the original rid, the child scheduler + requests keep consuming GPU work. This test guards that edge. + """ + started = asyncio.Event() + + async def never_finish(_obj): + started.set() + await asyncio.sleep(30) + yield {} # pragma: no cover + + fake_engine.generate_fn = never_finish + ctx = _make_context() + req = _make_generate_request() + req.sampling_params.n = 3 + + gen = servicer.Generate(req, ctx) + task = asyncio.create_task(_drain(gen)) + await started.wait() + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + # Every per-choice rid must have had abort_request called. + assert set(fake_engine.aborted_rids) >= {"rid-1-n0", "rid-1-n1", "rid-1-n2"} + + +async def _drain(async_gen): + async for _ in async_gen: + pass + + +# --------------------------------------------------------------------------- +# Embed RPC +# --------------------------------------------------------------------------- + + +# --------------------------------------------------------------------------- +# Abort / HealthCheck / GetModelInfo / GetServerInfo / GetLoads +# +# Note: TokenSpeed's slim proto removes Embed / GetTokenizer / SubscribeKvEvents +# entirely, so there are no tests for them — the methods aren't on the +# servicer surface. +# --------------------------------------------------------------------------- + + +class TestAbortRpc: + @pytest.mark.asyncio + async def test_abort_known( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + fake_engine.rid_to_state["rid-1"] = _FakeState() + resp = await servicer.Abort( + tokenspeed_scheduler_pb2.AbortRequest(request_id="rid-1"), + _make_context(), + ) + assert resp.success is True + assert "rid-1" in fake_engine.aborted_rids + + @pytest.mark.asyncio + async def test_abort_unknown( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + resp = await servicer.Abort( + tokenspeed_scheduler_pb2.AbortRequest(request_id="missing"), + _make_context(), + ) + assert resp.success is False + # Nothing to abort — no state for "missing" or any "missing-n*" child. + assert fake_engine.aborted_rids == [] + + @pytest.mark.asyncio + async def test_abort_sweeps_n_children( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + """Abort("rid-1") must sweep the per-choice rids Generate mints + when ``sampling_params.n > 1`` (``rid-1-n0``, ``rid-1-n1``, ...). + """ + for child in ("rid-1-n0", "rid-1-n1", "rid-1-n2"): + fake_engine.rid_to_state[child] = _FakeState() + # An unrelated rid the sweep must NOT touch. + fake_engine.rid_to_state["unrelated-rid"] = _FakeState() + + resp = await servicer.Abort( + tokenspeed_scheduler_pb2.AbortRequest(request_id="rid-1"), + _make_context(), + ) + assert resp.success is True + assert sorted(fake_engine.aborted_rids) == [ + "rid-1-n0", + "rid-1-n1", + "rid-1-n2", + ] + + +class TestHealthCheck: + @pytest.mark.asyncio + async def test_reports_shutdown( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + fake_engine.gracefully_exit = True + resp = await servicer.HealthCheck( + tokenspeed_scheduler_pb2.HealthCheckRequest(), _make_context() + ) + assert resp.healthy is False + assert "shutting down" in resp.message.lower() + + @pytest.mark.asyncio + async def test_reports_healthy_when_scheduler_pushes_output( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + # generate_request yields once and updates last_receive_tstamp, which + # is what the health RPC watches for. + fake_engine.outputs = [ + { + "text": "", + "output_ids": [99], + "meta_info": {"finish_reason": FINISH_LENGTH(length=1)}, + } + ] + resp = await servicer.HealthCheck( + tokenspeed_scheduler_pb2.HealthCheckRequest(), _make_context() + ) + assert resp.healthy is True + + +class TestGetModelInfo: + @pytest.mark.asyncio + async def test_basic_fields( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + resp = await servicer.GetModelInfo( + tokenspeed_scheduler_pb2.GetModelInfoRequest(), _make_context() + ) + assert resp.model_path == "fake-model" + assert resp.vocab_size == 32000 + assert resp.max_context_length == 8192 + assert list(resp.eos_token_ids) == [2] + assert resp.model_type == "llama" + assert list(resp.architectures) == ["LlamaForCausalLM"] + + +class TestGetServerInfo: + @pytest.mark.asyncio + async def test_returns_scheduler_info( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + fake_engine.rid_to_state["a"] = _FakeState() + fake_engine.rid_to_state["b"] = _FakeState() + resp = await servicer.GetServerInfo( + tokenspeed_scheduler_pb2.GetServerInfoRequest(), _make_context() + ) + assert resp.active_requests == 2 + assert resp.max_total_num_tokens == 100000 + assert resp.tokenspeed_version + + @pytest.mark.asyncio + async def test_uses_tokenspeed_service_bases(self, servicer: TokenSpeedSchedulerServicer): + """TokenSpeed's servicer inherits the dedicated + ``TokenSpeedSchedulerServicer`` stub — identity is carried by the + proto package/service name, not by a field inside ``server_args``. + Guard the inheritance so nobody reverts to ``SglangSchedulerServicer`` + under the impression that 'wire shape is the same'; the wire shape + is the same, the *service path* is not, and the Rust router routes + on the service path. + """ + from smg_grpc_proto.generated import tokenspeed_scheduler_pb2_grpc + + assert isinstance(servicer, tokenspeed_scheduler_pb2_grpc.TokenSpeedSchedulerServicer) + + +class TestGetLoads: + @pytest.mark.asyncio + async def test_no_dp_ranks_returns_empty( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + # Bridge returns an empty list (e.g. before scheduler boots) — proto + # comes back with 0 ranks but still validly populated for the router. + fake_engine.load_outputs = [] + resp = await servicer.GetLoads(tokenspeed_scheduler_pb2.GetLoadsRequest(), _make_context()) + assert resp.dp_rank_count == 0 + assert resp.version == "tokenspeed" + assert list(resp.loads) == [] + assert resp.aggregate.total_running_reqs == 0 + assert resp.aggregate.total_waiting_reqs == 0 + + @pytest.mark.asyncio + async def test_maps_load_output_fields( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + # 2 DP ranks. rank 0 has 3 reqs (2 running, 1 waiting) and 100 pages + # used; rank 1 has 1 reqs (1 running, 0 waiting) and 200 pages used. + # page_size=16 (from fake_engine.server_args), max_total_num_tokens=100000 + # (from the servicer fixture's scheduler_info). + fake_engine.load_outputs = [ + SimpleNamespace(dp_rank=0, num_reqs=3, num_waiting_reqs=1, num_pages=100), + SimpleNamespace(dp_rank=1, num_reqs=1, num_waiting_reqs=0, num_pages=200), + ] + resp = await servicer.GetLoads(tokenspeed_scheduler_pb2.GetLoadsRequest(), _make_context()) + assert resp.dp_rank_count == 2 + assert len(resp.loads) == 2 + # rank 0 + l0 = resp.loads[0] + assert l0.dp_rank == 0 + assert l0.num_running_reqs == 2 # num_reqs - num_waiting_reqs + assert l0.num_waiting_reqs == 1 + assert l0.num_total_reqs == 3 + assert l0.num_used_tokens == 100 * 16 # pages * page_size + assert l0.max_total_num_tokens == 100000 + assert l0.token_usage == pytest.approx(100 * 16 / 100000) + # rank 1 + l1 = resp.loads[1] + assert l1.dp_rank == 1 + assert l1.num_running_reqs == 1 + assert l1.num_used_tokens == 200 * 16 + # aggregate + assert resp.aggregate.total_running_reqs == 3 + assert resp.aggregate.total_waiting_reqs == 1 + assert resp.aggregate.total_reqs == 4 + assert resp.aggregate.avg_token_usage == pytest.approx( + (100 * 16 / 100000 + 200 * 16 / 100000) / 2 + ) + + @pytest.mark.asyncio + async def test_scheduler_timeout_aborts_with_deadline_exceeded( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer, monkeypatch + ): + # If the scheduler subprocess never replies, the bridge call hangs. + # The servicer wraps it in ``asyncio.wait_for`` and aborts with + # DEADLINE_EXCEEDED rather than blocking the gRPC call indefinitely. + async def _hang(): + await asyncio.sleep(60) + return [] + + fake_engine.get_load = _hang # type: ignore[method-assign] + monkeypatch.setattr(_servicer_module, "HEALTH_CHECK_TIMEOUT", 0.05) + ctx = _make_context() + with pytest.raises(_FakeAbortError) as exc: + await servicer.GetLoads(tokenspeed_scheduler_pb2.GetLoadsRequest(), ctx) + assert exc.value.code == grpc.StatusCode.DEADLINE_EXCEEDED + + +# --------------------------------------------------------------------------- +# _build_generate_req semantics (pre-tokenized input) +# --------------------------------------------------------------------------- + + +class TestBuildGenerateReq: + def test_preserves_input_ids(self, servicer: TokenSpeedSchedulerServicer): + req = _make_generate_request(input_ids=[11, 22, 33], stream=True) + obj = servicer._build_generate_req(req) + assert obj.input_ids == [11, 22, 33] + assert obj.rid == "rid-1" + assert obj.stream is True + assert obj.sampling_params["max_new_tokens"] == 16 + + def test_rejects_missing_tokenized(self, servicer: TokenSpeedSchedulerServicer): + req = tokenspeed_scheduler_pb2.GenerateRequest(request_id="x") + with pytest.raises(ValueError, match="tokenized"): + servicer._build_generate_req(req) + + +# --------------------------------------------------------------------------- +# Output logprobs proto conversion +# --------------------------------------------------------------------------- + + +class TestConvertOutputLogprobsToProto: + """``_convert_output_logprobs_to_proto`` reads the cumulative + ``meta_info["output_token_logprobs"]`` / ``output_top_logprobs`` lists + that TokenSpeed accumulates per request, slices the last + ``len(output_ids)`` entries (the tokens this frame emitted), and keeps + the first ``n_keep`` so the result aligns with whatever + ``_generated_output_ids`` returned (which may have stripped a trailing + stop token).""" + + def test_returns_none_when_logprobs_empty(self): + # ``--enable-output-logprobs`` not set on the server → the keys exist + # in meta_info but the lists are empty. Must not return a half-built + # proto in this case (gateway would treat empty as "logprobs missing"). + out = { + "output_ids": [10, 20, 30], + "meta_info": {"output_token_logprobs": [], "output_top_logprobs": []}, + } + assert TokenSpeedSchedulerServicer._convert_output_logprobs_to_proto(out, n_keep=3) is None + + def test_returns_none_when_keys_missing(self): + # Logprobs not requested at all → meta_info lacks the keys entirely. + out = {"output_ids": [10, 20, 30], "meta_info": {}} + assert TokenSpeedSchedulerServicer._convert_output_logprobs_to_proto(out, n_keep=3) is None + + def test_returns_none_when_n_keep_zero(self): + # Stop-token strip can leave n_keep == 0 for a 1-token frame whose + # only token was the stop. Don't emit a proto with a length mismatch. + out = { + "output_ids": [99], + "meta_info": { + "output_token_logprobs": [(-0.1, 99, None)], + "output_top_logprobs": [None], + }, + } + assert TokenSpeedSchedulerServicer._convert_output_logprobs_to_proto(out, n_keep=0) is None + + def test_non_streaming_full_output(self): + # Non-streaming: output_ids covers the entire generation; cumulative + # meta_info matches it exactly. n_keep == len(output_ids) → emit all. + out = { + "output_ids": [10, 20, 30], + "meta_info": { + "output_token_logprobs": [ + (-0.5, 10, None), + (-0.3, 20, None), + (-0.1, 30, None), + ], + "output_top_logprobs": [None, None, None], + }, + } + proto = TokenSpeedSchedulerServicer._convert_output_logprobs_to_proto(out, n_keep=3) + assert proto is not None + assert list(proto.token_logprobs) == pytest.approx([-0.5, -0.3, -0.1]) + assert list(proto.token_ids) == [10, 20, 30] + assert len(proto.top_logprobs) == 3 + # ``None`` entries in raw_top translate to empty TopLogProbs placeholders. + for tl in proto.top_logprobs: + assert list(tl.values) == [] + assert list(tl.token_ids) == [] + + def test_streaming_chunk_emits_only_delta(self): + # Streaming chunk: output_ids has just the new tokens for this chunk, + # but meta_info is cumulative across the entire request. The slice + # ``[-len(output_ids):]`` on the cumulative list must yield exactly + # the delta this chunk represents. + out = { + "output_ids": [40, 50], # 2 new tokens this chunk + "meta_info": { + # cumulative: 4 prior tokens + 2 new + "output_token_logprobs": [ + (-1.1, 10, None), + (-1.2, 20, None), + (-1.3, 30, None), + (-1.4, 99, None), + (-0.7, 40, None), + (-0.6, 50, None), + ], + "output_top_logprobs": [None] * 6, + }, + } + proto = TokenSpeedSchedulerServicer._convert_output_logprobs_to_proto(out, n_keep=2) + assert proto is not None + assert list(proto.token_logprobs) == pytest.approx([-0.7, -0.6]) + assert list(proto.token_ids) == [40, 50] + + def test_top_k_alternatives(self): + # When the user requests top_logprobs=3, each position in + # output_top_logprobs is a list of K (logprob, token_id, text) tuples. + # Translate each into a TopLogProbs proto with parallel value/id arrays. + out = { + "output_ids": [40], + "meta_info": { + "output_token_logprobs": [(-0.7, 40, None)], + "output_top_logprobs": [ + [(-0.7, 40, None), (-1.2, 41, None), (-2.5, 42, None)], + ], + }, + } + proto = TokenSpeedSchedulerServicer._convert_output_logprobs_to_proto(out, n_keep=1) + assert proto is not None + assert len(proto.top_logprobs) == 1 + tl = proto.top_logprobs[0] + assert list(tl.values) == pytest.approx([-0.7, -1.2, -2.5]) + assert list(tl.token_ids) == [40, 41, 42] + + def test_strips_stop_token_alignment(self): + # When ``_generated_output_ids`` strips a trailing stop token, + # n_keep == len(output_ids) - 1. The converter must take the first + # n_keep entries of this frame's cumulative slice — emitting the + # logprob for the stripped stop token would misalign with the + # ``token_ids`` field on the proto. + out = { + "output_ids": [10, 20, 99], # 99 = stop, will be stripped → n_keep=2 + "meta_info": { + "output_token_logprobs": [ + (-0.5, 10, None), + (-0.3, 20, None), + (-0.1, 99, None), # logprob for the stop we just stripped + ], + "output_top_logprobs": [None, None, None], + }, + } + proto = TokenSpeedSchedulerServicer._convert_output_logprobs_to_proto(out, n_keep=2) + assert proto is not None + # Note: 99's logprob is dropped; emitted logprobs match the kept tokens. + assert list(proto.token_logprobs) == pytest.approx([-0.5, -0.3]) + assert list(proto.token_ids) == [10, 20]