-
Notifications
You must be signed in to change notification settings - Fork 75
feat(grpc_servicer): add TokenSpeed servicer (Part 2/3) #1464
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: feat/grpc-servicer-tokenspeed
Are you sure you want to change the base?
Changes from all commits
6f84101
6bb18d2
93038d1
a812f5c
8a3e651
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| """TokenSpeed gRPC servicer implementation. | ||
|
|
||
| Mirrors smg_grpc_servicer.vllm / smg_grpc_servicer.sglang. Wraps TokenSpeed's | ||
| AsyncLLM (main-process async frontend) behind the SGLang gRPC service so the | ||
| existing Rust router (which auto-detects the SGLang proto) can route traffic | ||
| to TokenSpeed without needing a new client. | ||
| """ | ||
|
|
||
| from smg_grpc_servicer.tokenspeed.servicer import TokenSpeedSchedulerServicer | ||
|
|
||
| __all__ = ["TokenSpeedSchedulerServicer"] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| """CLI entrypoint for the TokenSpeed gRPC server. | ||
|
|
||
| Usage:: | ||
|
|
||
| python -m smg_grpc_servicer.tokenspeed --model <model> --host 127.0.0.1 --port 50051 | ||
|
|
||
| All :class:`tokenspeed.runtime.utils.server_args.ServerArgs` flags are accepted | ||
| verbatim (we reuse TokenSpeed's own ``prepare_server_args`` so there is no | ||
| flag drift between the HTTP and gRPC frontends). | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import asyncio | ||
| import logging | ||
| import sys | ||
|
|
||
| import uvloop | ||
| from tokenspeed.runtime.utils.server_args import prepare_server_args | ||
|
|
||
| from smg_grpc_servicer.tokenspeed.server import serve_grpc | ||
|
|
||
|
|
||
| def main(argv: list[str] | None = None) -> None: | ||
| if argv is None: | ||
| argv = sys.argv[1:] | ||
|
|
||
| logging.basicConfig( | ||
| level=logging.INFO, | ||
| format="%(asctime)s [%(name)s] %(levelname)s %(message)s", | ||
| ) | ||
|
|
||
| # TokenSpeed's ``ServerArgs.resolve_kernel_backends`` defaults | ||
| # ``sampling_backend`` to ``"greedy"`` when the user doesn't pass | ||
| # ``--sampling-backend``. The greedy backend is argmax-only and | ||
| # ignores per-request ``temperature``/``top_p``/``top_k`` — fine for | ||
| # the legacy CLI where users opt in to sampling explicitly, but | ||
| # disastrous for a gateway-fronted gRPC servicer where per-request | ||
| # sampling params arrive on every call. With Llama-3.2-1B the | ||
| # always-argmax behavior collapses into single-token loops | ||
| # (\\n×N, ' ('×N, "no"×N) within a few hundred steps and | ||
| # generation runs to ``max_new_tokens`` — the smg e2e function-calling | ||
| # suite makes this directly observable. Force a sampling-respecting | ||
| # default unless the operator explicitly chose one. | ||
| if not any(a == "--sampling-backend" or a.startswith("--sampling-backend=") for a in argv): | ||
| argv = [*argv, "--sampling-backend", "flashinfer"] | ||
|
|
||
| # TokenSpeed's logprob computation is gated by ``--enable-output-logprobs`` | ||
| # (default OFF, see ``ServerArgs.enable_output_logprobs``); without the | ||
| # flag, requests asking for logprobs receive empty arrays rather than an | ||
| # error. The smg gateway's OpenAI-compat path expects per-token logprobs | ||
| # whenever ``logprobs=True`` is set, so flip the flag on by default for a | ||
| # gateway-fronted gRPC servicer. Operators who want the smaller CUDA-graph | ||
| # footprint can pass ``--enable-output-logprobs=False`` explicitly. | ||
| # ``--enable-top-logprobs`` is intentionally NOT injected: TokenSpeed | ||
| # raises at startup when it's set (the path is not yet implemented). | ||
| if not any( | ||
| a == "--enable-output-logprobs" or a.startswith("--enable-output-logprobs=") for a in argv | ||
| ): | ||
| argv = [*argv, "--enable-output-logprobs"] | ||
|
|
||
| server_args = prepare_server_args(argv) | ||
| # The scheduler processes will read these env vars; make sure we ran | ||
| # through TokenSpeed's shared env/resource setup path instead of | ||
| # duplicating it here. | ||
| asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) | ||
| asyncio.run(serve_grpc(server_args)) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,130 @@ | ||
| """Standard ``grpc.health.v1.Health`` servicer for the TokenSpeed backend. | ||
|
|
||
| Mirrors ``smg_grpc_servicer.sglang.health_servicer.SGLangHealthServicer`` — | ||
| same service-name semantics, same lifecycle (NOT_SERVING → SERVING → NOT_SERVING), | ||
| same ``check/watch`` contract — but sources liveness signals from a TokenSpeed | ||
| :class:`AsyncLLM` instead of an SGLang ``GrpcRequestManager``. | ||
|
|
||
| The Rust router uses this health check to auto-detect the backend runtime. | ||
| TokenSpeed ships its own ``tokenspeed.grpc.scheduler.TokenSpeedScheduler`` | ||
| service identity (see ``proto/tokenspeed_scheduler.proto``) so the probe | ||
| distinguishes TokenSpeed workers from real SGLang workers regardless of any | ||
| wire-level message-type sharing between the two backends. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import logging | ||
| import time | ||
| from collections.abc import AsyncIterator | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| import grpc | ||
| from grpc_health.v1 import health_pb2, health_pb2_grpc | ||
| from smg_grpc_proto.generated import tokenspeed_scheduler_pb2 | ||
|
|
||
| if TYPE_CHECKING: | ||
| from tokenspeed.runtime.engine.async_llm import AsyncLLM | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| # Seconds of scheduler silence — with pending requests — before we report | ||
| # NOT_SERVING. Matches the SGLang equivalent so oncall dashboards are aligned. | ||
| STUCK_SCHEDULER_THRESHOLD_SEC = 30.0 | ||
|
|
||
| # Source the advertised service name from the proto descriptor so a future | ||
| # ``package`` or ``service`` rename in tokenspeed_scheduler.proto stays in | ||
| # sync without a hand-edited string here. | ||
| TOKENSPEED_SCHEDULER_SERVICE_NAME = tokenspeed_scheduler_pb2.DESCRIPTOR.services_by_name[ | ||
| "TokenSpeedScheduler" | ||
| ].full_name | ||
|
|
||
|
|
||
| class TokenSpeedHealthServicer(health_pb2_grpc.HealthServicer): | ||
| """Health servicer that tracks TokenSpeed's AsyncLLM liveness. | ||
|
|
||
| Advertises two service levels: | ||
|
|
||
| * ``""`` (empty) — overall server health, flipped to SERVING once the | ||
| warmup request succeeds and back to NOT_SERVING on shutdown. | ||
| * ``tokenspeed.grpc.scheduler.TokenSpeedScheduler`` — readiness: the | ||
| base status, plus a scheduler-responsiveness check (if there are | ||
| pending requests but the scheduler hasn't pushed output for >30s, | ||
| report NOT_SERVING). | ||
| """ | ||
|
|
||
| OVERALL_SERVER = "" | ||
| TOKENSPEED_SERVICE = TOKENSPEED_SCHEDULER_SERVICE_NAME | ||
|
|
||
| def __init__(self, async_llm: AsyncLLM, scheduler_info: dict): | ||
| self.async_llm = async_llm | ||
| self.scheduler_info = scheduler_info | ||
| self._serving_status: dict[str, int] = { | ||
| self.OVERALL_SERVER: health_pb2.HealthCheckResponse.NOT_SERVING, | ||
| self.TOKENSPEED_SERVICE: health_pb2.HealthCheckResponse.NOT_SERVING, | ||
| } | ||
| logger.info("TokenSpeed gRPC health service initialized") | ||
|
|
||
| def set_serving(self) -> None: | ||
| """Flip both services to SERVING (call after successful warmup).""" | ||
| self._serving_status[self.OVERALL_SERVER] = health_pb2.HealthCheckResponse.SERVING | ||
| self._serving_status[self.TOKENSPEED_SERVICE] = health_pb2.HealthCheckResponse.SERVING | ||
| logger.info("TokenSpeed gRPC health status -> SERVING") | ||
|
|
||
| def set_not_serving(self) -> None: | ||
| """Flip both services to NOT_SERVING (call on shutdown).""" | ||
| self._serving_status[self.OVERALL_SERVER] = health_pb2.HealthCheckResponse.NOT_SERVING | ||
| self._serving_status[self.TOKENSPEED_SERVICE] = health_pb2.HealthCheckResponse.NOT_SERVING | ||
| logger.info("TokenSpeed gRPC health status -> NOT_SERVING") | ||
|
|
||
| async def Check( | ||
| self, | ||
| request: health_pb2.HealthCheckRequest, | ||
| context: grpc.aio.ServicerContext, | ||
| ) -> health_pb2.HealthCheckResponse: | ||
| service_name = request.service | ||
| logger.debug("Health check request for service=%r", service_name) | ||
|
|
||
| if self.async_llm.gracefully_exit: | ||
| return health_pb2.HealthCheckResponse(status=health_pb2.HealthCheckResponse.NOT_SERVING) | ||
|
|
||
| if service_name == self.OVERALL_SERVER: | ||
| return health_pb2.HealthCheckResponse( | ||
| status=self._serving_status.get( | ||
| self.OVERALL_SERVER, health_pb2.HealthCheckResponse.NOT_SERVING | ||
| ) | ||
| ) | ||
|
|
||
| if service_name == self.TOKENSPEED_SERVICE: | ||
| base = self._serving_status.get( | ||
| self.TOKENSPEED_SERVICE, health_pb2.HealthCheckResponse.NOT_SERVING | ||
| ) | ||
| if base != health_pb2.HealthCheckResponse.SERVING: | ||
| return health_pb2.HealthCheckResponse(status=base) | ||
|
|
||
| # Scheduler-stuck check: pending work but no recent output. | ||
| time_since_last_receive = time.time() - self.async_llm.last_receive_tstamp | ||
| pending = len(self.async_llm.rid_to_state) | ||
| if time_since_last_receive > STUCK_SCHEDULER_THRESHOLD_SEC and pending > 0: | ||
| logger.warning( | ||
| "Scheduler appears stuck: %.1fs since last receive, %d pending requests", | ||
| time_since_last_receive, | ||
| pending, | ||
| ) | ||
| return health_pb2.HealthCheckResponse( | ||
| status=health_pb2.HealthCheckResponse.NOT_SERVING | ||
| ) | ||
|
|
||
| return health_pb2.HealthCheckResponse(status=health_pb2.HealthCheckResponse.SERVING) | ||
|
|
||
| context.set_code(grpc.StatusCode.NOT_FOUND) | ||
| context.set_details(f"Unknown service: {service_name}") | ||
| return health_pb2.HealthCheckResponse(status=health_pb2.HealthCheckResponse.SERVICE_UNKNOWN) | ||
|
|
||
| async def Watch( | ||
| self, | ||
| request: health_pb2.HealthCheckRequest, | ||
| context: grpc.aio.ServicerContext, | ||
| ) -> AsyncIterator[health_pb2.HealthCheckResponse]: | ||
| # K8s probes use Check, not Watch — we emit the current status once. | ||
| yield await self.Check(request, context) | ||
|
Comment on lines
+124
to
+130
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| """Scheduler subprocess launcher for the TokenSpeed gRPC server. | ||
|
|
||
| Mirrors ``smg_grpc_servicer.sglang.scheduler_launcher`` but delegates to | ||
| TokenSpeed's ``_launch_subprocesses``: we get back a fully-initialised | ||
| ``AsyncLLM`` along with the scheduler info dict. All scheduler/DP-controller | ||
| spawning, multiprocessing start-method, and env priming already live inside | ||
| ``_launch_subprocesses`` — we only wrap it to return what the gRPC server | ||
| cares about and to keep the call site symmetric with the sibling backends. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import logging | ||
| from typing import Any | ||
|
|
||
| from tokenspeed.runtime.engine.async_llm import AsyncLLM | ||
| from tokenspeed.runtime.entrypoints.engine import _launch_subprocesses | ||
| from tokenspeed.runtime.utils.server_args import PortArgs, ServerArgs | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def launch_engine( | ||
| server_args: ServerArgs, | ||
| port_args: PortArgs | None = None, | ||
| ) -> tuple[AsyncLLM, dict[str, Any]]: | ||
| """Launch TokenSpeed scheduler subprocess(es) and the main-process AsyncLLM. | ||
|
|
||
| Returns: | ||
| A tuple ``(async_llm, scheduler_info)``. ``async_llm`` is the live | ||
| :class:`AsyncLLM` that the gRPC servicer will drive. ``scheduler_info`` | ||
| is the dict rank-0 sent back once its scheduler was ready (contains | ||
| e.g. ``max_total_num_tokens``, ``max_req_input_len``, ...). | ||
|
|
||
| Raises: | ||
| RuntimeError: If rank-0 scheduler fails to initialize. The original | ||
| ``_launch_subprocesses`` surfaces this by re-raising the EOF/assertion | ||
| error — we propagate it unchanged. | ||
| """ | ||
| async_llm, _template_manager, scheduler_info = _launch_subprocesses( | ||
| server_args=server_args, | ||
| port_args=port_args, | ||
| ) | ||
|
|
||
| # Non-zero rank nodes return (None, None, None) from _launch_subprocesses | ||
| # and block forever on the dummy health server — they never reach the gRPC | ||
| # server. Guard against callers relying on this return on secondary nodes. | ||
| if async_llm is None: | ||
| raise RuntimeError( | ||
| "launch_engine() returned no AsyncLLM. This means the current node " | ||
| "is not rank 0 in a multi-node deployment, or the scheduler died " | ||
| "during initialization. Only rank 0 may serve gRPC traffic." | ||
| ) | ||
|
|
||
| logger.info( | ||
| "TokenSpeed engine ready: max_total_num_tokens=%s max_req_input_len=%s", | ||
| scheduler_info.get("max_total_num_tokens"), | ||
| scheduler_info.get("max_req_input_len"), | ||
| ) | ||
| return async_llm, scheduler_info |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟡 Nit: This docstring is stale — it describes the opposite of what the implementation does. The servicer does NOT wrap behind "the SGLang gRPC service"; it uses its own
tokenspeed.grpc.scheduler.TokenSpeedSchedulerproto. The Rust router does NOT "auto-detect the SGLang proto";DetectBackendStepidentifies TokenSpeed natively from the service name. And there IS a new Rust client (TokenSpeedSchedulerClient).