From 6b696e537b2bbfa4a6b1995b1146a59c8256b93d Mon Sep 17 00:00:00 2001 From: key4ng Date: Fri, 8 May 2026 12:37:18 -0700 Subject: [PATCH 01/24] feat(grpc): add TokenSpeed gRPC client and router wiring MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds TokenSpeed as a first-class GPU backend on the Rust side: a self- contained `tokenspeed.grpc.scheduler.TokenSpeedScheduler` proto, the `TokenSpeedSchedulerClient` wrapper that translates SGLang-shaped request/response types at the boundary, and the model_gateway router plumbing (client dispatch, runtime detection, harmony/regular request builders, multimodal and embedding stages). This is part 1 of 3 splitting #1351: - PR1 (this): Rust gRPC + protocol - PR2: Python servicer (grpc_servicer) - PR3: CI workflows + e2e tests PR1 is functionally inert without PR2 — the router can dial a TokenSpeed worker, but the worker process lives in the Python servicer landed by PR2. Addresses CatherineSue's review on #1351: - shorten the TokenSpeed RuntimeType doc in protocols/worker.rs - drop the verbose TokenSpeed note in grpc_client/build.rs - restore the concise module doc in detect_backend.rs, just adding tokenspeed to the existing health-check ordering Signed-off-by: key4ng --- crates/grpc_client/build.rs | 5 +- .../proto/tokenspeed_scheduler.proto | 300 ++++++++ .../python/smg_grpc_proto/__init__.py | 6 +- crates/grpc_client/src/lib.rs | 7 + crates/grpc_client/src/sglang_scheduler.rs | 20 +- .../grpc_client/src/tokenspeed_scheduler.rs | 711 ++++++++++++++++++ crates/protocols/src/worker.rs | 5 + crates/tokenizer/src/chat_template.rs | 50 +- model_gateway/src/routers/grpc/client.rs | 125 ++- .../grpc/common/stages/request_execution.rs | 4 + .../grpc/harmony/stages/request_building.rs | 49 ++ model_gateway/src/routers/grpc/multimodal.rs | 7 + .../src/routers/grpc/proto_wrapper.rs | 16 +- .../stages/embedding/request_building.rs | 10 + .../workflow/steps/local/detect_backend.rs | 13 +- model_gateway/src/workflow/steps/util.rs | 15 +- 16 files changed, 1320 insertions(+), 23 deletions(-) create mode 100644 crates/grpc_client/proto/tokenspeed_scheduler.proto create mode 100644 crates/grpc_client/src/tokenspeed_scheduler.rs diff --git a/crates/grpc_client/build.rs b/crates/grpc_client/build.rs index 9809b80b1..f03ed71c1 100644 --- a/crates/grpc_client/build.rs +++ b/crates/grpc_client/build.rs @@ -2,6 +2,7 @@ fn main() -> Result<(), Box> { // Rebuild triggers println!("cargo:rerun-if-changed=proto/common.proto"); println!("cargo:rerun-if-changed=proto/sglang_scheduler.proto"); + println!("cargo:rerun-if-changed=proto/tokenspeed_scheduler.proto"); println!("cargo:rerun-if-changed=proto/vllm_engine.proto"); println!("cargo:rerun-if-changed=proto/trtllm_service.proto"); println!("cargo:rerun-if-changed=proto/mlx_engine.proto"); @@ -20,7 +21,8 @@ fn main() -> Result<(), Box> { .extern_path(".smg.grpc.common", "crate::common_proto") .type_attribute("GetModelInfoResponse", "#[derive(serde::Serialize)]") // vllm + trtllm ServerInfo have only primitive fields. - // sglang's contains prost_types::{Struct,Timestamp} so it's handled separately. + // sglang's and tokenspeed's contain prost_types::{Struct,Timestamp}; + // those are handled separately at the wrapper layer. .type_attribute( "vllm.grpc.engine.GetServerInfoResponse", "#[derive(serde::Serialize)]", @@ -40,6 +42,7 @@ fn main() -> Result<(), Box> { "proto/vllm_engine.proto", "proto/trtllm_service.proto", "proto/mlx_engine.proto", + "proto/tokenspeed_scheduler.proto", ], &["proto"], )?; diff --git a/crates/grpc_client/proto/tokenspeed_scheduler.proto b/crates/grpc_client/proto/tokenspeed_scheduler.proto new file mode 100644 index 000000000..02d649ae7 --- /dev/null +++ b/crates/grpc_client/proto/tokenspeed_scheduler.proto @@ -0,0 +1,300 @@ +syntax = "proto3"; + +package tokenspeed.grpc.scheduler; + +import "google/protobuf/timestamp.proto"; +import "google/protobuf/struct.proto"; + +// Service definition for TokenSpeed scheduler communication. +// +// TokenSpeed has its own service identity AND its own message shapes — wire +// definition is fully self-contained, with zero dependencies on +// ``sglang_scheduler.proto``. The message catalog is intentionally minimal: +// it covers what TokenSpeed's top-tier LLMs (Kimi K2, MiniMax M2, Qwen 3, +// gpt-oss, DeepSeek V4) actually need today, and nothing more. Anything +// SGLang-specific (PD-disaggregated serving, LoRA hot-swap, multimodal, +// classifier outputs, hidden-state forwarding, embeddings) is deliberately +// out of scope and lands here only when an explicit TokenSpeed use case +// shows up. +service TokenSpeedScheduler { + // Submit a generation request (server-streaming for token-by-token). + rpc Generate(GenerateRequest) returns (stream GenerateResponse); + + // Liveness + readiness probe. + rpc HealthCheck(HealthCheckRequest) returns (HealthCheckResponse); + + // Cancel a running request. + rpc Abort(AbortRequest) returns (AbortResponse); + + // Static info about the loaded model. + rpc GetModelInfo(GetModelInfoRequest) returns (GetModelInfoResponse); + + // Runtime info about the server. + rpc GetServerInfo(GetServerInfoRequest) returns (GetServerInfoResponse); + + // Per-DP-rank load metrics (used by router for least-load). + rpc GetLoads(GetLoadsRequest) returns (GetLoadsResponse); +} + +// ===================== +// Sampling +// ===================== + +// IMPORTANT: proto3 numeric defaults (0) do NOT match semantic defaults +// (temperature=1.0, top_p=1.0, top_k=-1). All sampling scalars are +// declared ``optional`` so presence is preserved on the wire — the +// servicer uses ``HasField()`` to distinguish "client explicitly set 0" +// from "client didn't send anything." Without this, ``temperature=0`` +// (a valid request for greedy decoding) is indistinguishable from the +// proto3 default and would be silently dropped by truthy-check guards. +// +// ``min_new_tokens`` is left non-optional because 0 is its semantic +// "no minimum" sentinel. +message SamplingParams { + optional float temperature = 1; + optional float top_p = 2; + optional int32 top_k = 3; + optional float min_p = 4; + optional float frequency_penalty = 5; + optional float presence_penalty = 6; + optional float repetition_penalty = 7; + + optional uint32 max_new_tokens = 8; + uint32 min_new_tokens = 9; + + repeated string stop = 10; + repeated uint32 stop_token_ids = 11; + bool ignore_eos = 12; + + bool skip_special_tokens = 13; + bool spaces_between_special_tokens = 14; + + // Number of samples (n in OpenAI API). + uint32 n = 15; + + // Per-token logit bias. + map logit_bias = 16; + + // Structured generation. Currently xfailed in e2e (tokenspeed#361), + // but the wire shape stays so wiring it later doesn't bump the proto. + oneof constraint { + string regex = 17; + string json_schema = 18; + string ebnf_grammar = 19; + string structural_tag = 20; + } + + // When true, generation does not strip the trailing matched stop token + // from ``output_ids`` (matches SGLang's ``no_stop_trim``). Combined with + // ``skip_special_tokens=False`` it lets the gateway-side detokenizer + // render the EOS marker in the visible response — required for the + // ``test_no_stop_trim_with_skip_special_false`` e2e check and for any + // downstream logic that needs the raw stop token in the output stream. + bool no_stop_trim = 22; + + // Escape hatch for backend-specific knobs without bumping the proto. + google.protobuf.Struct custom_params = 21; +} + +// ===================== +// Generate +// ===================== + +message GenerateRequest { + string request_id = 1; + + // Tokenized input (router does its own tokenization). + TokenizedInput tokenized = 2; + + SamplingParams sampling_params = 3; + + // Logprob options. + bool return_logprob = 4; + // Optional so the servicer can distinguish "client omitted" (use SGLang's + // ``-1`` default = no input logprobs) from an explicit value like 0. + optional int32 logprob_start_len = 5; + int32 top_logprobs_num = 6; + repeated uint32 token_ids_logprob = 7; + + // Whether the client wants stream chunks (otherwise: complete-only). + bool stream = 8; +} + +message TokenizedInput { + repeated uint32 input_ids = 1; + // Original text — purely cosmetic; the tokenizer pass is skipped because + // input_ids is set. Used in worker logs for traceability. + string original_text = 2; +} + +message GenerateResponse { + string request_id = 1; + + oneof response { + GenerateStreamChunk chunk = 2; + GenerateComplete complete = 3; + } +} + +message GenerateStreamChunk { + // Generated tokens since the previous chunk. + repeated uint32 token_ids = 1; + + uint32 prompt_tokens = 2; + uint32 completion_tokens = 3; + uint32 cached_tokens = 4; + + OutputLogProbs output_logprobs = 5; + + // For ordering when n>1. + uint32 index = 6; +} + +message GenerateComplete { + repeated uint32 output_ids = 1; + + // OpenAI-compatible: "stop", "length", "abort", "tool_calls". + string finish_reason = 2; + + uint32 prompt_tokens = 3; + uint32 completion_tokens = 4; + uint32 cached_tokens = 5; + + OutputLogProbs output_logprobs = 6; + + // Which stop matched (for clients that care which `stop` triggered). + oneof matched_stop { + uint32 matched_token_id = 7; + string matched_stop_str = 8; + } + + uint32 index = 9; +} + +message OutputLogProbs { + repeated float token_logprobs = 1; + repeated uint32 token_ids = 2; + repeated TopLogProbs top_logprobs = 3; +} + +message TopLogProbs { + repeated float values = 1; + repeated uint32 token_ids = 2; +} + +// ===================== +// Management +// ===================== + +message HealthCheckRequest {} +message HealthCheckResponse { + bool healthy = 1; + string message = 2; +} + +message AbortRequest { + string request_id = 1; + string reason = 2; +} +message AbortResponse { + bool success = 1; + string message = 2; +} + +// ===================== +// Model & Server Info +// ===================== + +message GetModelInfoRequest {} +message GetModelInfoResponse { + string model_path = 1; + string tokenizer_path = 2; + string served_model_name = 3; + string model_type = 4; + repeated string architectures = 5; + + int32 max_context_length = 6; + int32 max_req_input_len = 7; + int32 vocab_size = 8; + + repeated int32 eos_token_ids = 9; + int32 pad_token_id = 10; + int32 bos_token_id = 11; + + string weight_version = 12; + string preferred_sampling_params = 13; // JSON string or empty +} + +message GetServerInfoRequest {} +message GetServerInfoResponse { + google.protobuf.Struct server_args = 1; + google.protobuf.Struct scheduler_info = 2; + + int32 active_requests = 3; + bool is_paused = 4; + double uptime_seconds = 5; + int32 max_total_num_tokens = 6; + + string tokenspeed_version = 7; + google.protobuf.Timestamp start_time = 8; +} + +// ===================== +// Loads +// ===================== + +message GetLoadsRequest { + optional int32 dp_rank = 1; + // Sections: "core" (default), "memory", "queues". Pass "all" for everything. + repeated string include = 2; +} + +message GetLoadsResponse { + string timestamp = 1; + string version = 2; + int32 dp_rank_count = 3; + repeated SchedulerLoad loads = 4; + AggregateMetrics aggregate = 5; +} + +message SchedulerLoad { + int32 dp_rank = 1; + + int32 num_running_reqs = 2; + int32 num_waiting_reqs = 3; + int32 num_total_reqs = 4; + int32 num_used_tokens = 5; + int32 max_total_num_tokens = 6; + int32 max_running_requests = 7; + + double token_usage = 8; + double gen_throughput = 9; + double cache_hit_rate = 10; + double utilization = 11; + + optional MemoryMetrics memory = 12; + optional QueueMetrics queues = 13; +} + +message MemoryMetrics { + double weight_gb = 1; + double kv_cache_gb = 2; + double graph_gb = 3; + int32 token_capacity = 4; +} + +message QueueMetrics { + int32 waiting = 1; + int32 grammar = 2; + int32 paused = 3; + int32 retracted = 4; +} + +message AggregateMetrics { + int32 total_running_reqs = 1; + int32 total_waiting_reqs = 2; + int32 total_reqs = 3; + double avg_token_usage = 4; + double avg_throughput = 5; + double avg_utilization = 6; +} diff --git a/crates/grpc_client/python/smg_grpc_proto/__init__.py b/crates/grpc_client/python/smg_grpc_proto/__init__.py index 6a19b4aea..f7eac4a3e 100644 --- a/crates/grpc_client/python/smg_grpc_proto/__init__.py +++ b/crates/grpc_client/python/smg_grpc_proto/__init__.py @@ -1,4 +1,4 @@ -"""SMG gRPC Proto - Protocol definitions for SGLang, vLLM, TRT-LLM, and MLX.""" +"""SMG gRPC Proto - Protocol definitions for SGLang, TokenSpeed, vLLM, TRT-LLM, and MLX.""" from importlib.metadata import version @@ -14,6 +14,8 @@ sglang_encoder_pb2_grpc, sglang_scheduler_pb2, sglang_scheduler_pb2_grpc, + tokenspeed_scheduler_pb2, + tokenspeed_scheduler_pb2_grpc, trtllm_service_pb2, trtllm_service_pb2_grpc, vllm_engine_pb2, @@ -25,6 +27,8 @@ "sglang_scheduler_pb2_grpc", "sglang_encoder_pb2", "sglang_encoder_pb2_grpc", + "tokenspeed_scheduler_pb2", + "tokenspeed_scheduler_pb2_grpc", "vllm_engine_pb2", "vllm_engine_pb2_grpc", "trtllm_service_pb2", diff --git a/crates/grpc_client/src/lib.rs b/crates/grpc_client/src/lib.rs index 77c4fa5a2..a26528953 100644 --- a/crates/grpc_client/src/lib.rs +++ b/crates/grpc_client/src/lib.rs @@ -10,6 +10,7 @@ pub mod common_proto { pub mod mlx_engine; pub mod sglang_scheduler; pub mod tokenizer_bundle; +pub mod tokenspeed_scheduler; pub mod trtllm_service; pub mod vllm_engine; @@ -18,6 +19,12 @@ use std::sync::Arc; pub use mlx_engine::{proto as mlx_proto, MlxEngineClient}; pub use sglang_scheduler::{proto as sglang_proto, SglangSchedulerClient}; +// TokenSpeed has a fully independent wire definition (see +// ``proto/tokenspeed_scheduler.proto``) — distinct service, distinct +// messages with intentionally trimmed field sets aimed at top-tier LLM +// workloads. The client wraps that wire and translates to/from SGLang-shaped +// types at the boundary so the router's dispatch enums don't proliferate. +pub use tokenspeed_scheduler::{tokenspeed_proto, TokenSpeedSchedulerClient}; use tonic::metadata::MetadataMap; pub use trtllm_service::{proto as trtllm_proto, TrtllmServiceClient}; pub use vllm_engine::{proto as vllm_proto, VllmEngineClient}; diff --git a/crates/grpc_client/src/sglang_scheduler.rs b/crates/grpc_client/src/sglang_scheduler.rs index edfc43e0f..8f59ee554 100644 --- a/crates/grpc_client/src/sglang_scheduler.rs +++ b/crates/grpc_client/src/sglang_scheduler.rs @@ -434,7 +434,7 @@ impl SglangSchedulerClient { } /// Build gRPC SamplingParams from ChatCompletionRequest - fn build_grpc_sampling_params_from_chat( + pub(crate) fn build_grpc_sampling_params_from_chat( request: &ChatCompletionRequest, tool_call_constraint: Option<(String, String)>, ) -> Result { @@ -444,6 +444,16 @@ impl SglangSchedulerClient { // Hardcode to true: gRPC backends return raw token IDs, not decoded text. // Detokenization happens on the SMG Rust side (StopDecoder/Sequence). + // + // Note: TokenSpeed's HTTP serving_chat sets this to false when tools are + // present (serving_chat.py:178-179) — but mirroring that on the gRPC + // path measurably HURTS BFCL accuracy. We tested it: simple_python + // dropped from ~88.75 % to 79 %, parallel_multiple from ~84.5 % to + // 60.5 %. With skip_special_tokens=false the engine emits the + // ``<|tool_call_*|>`` special tokens in the raw output stream, and the + // SMG-side detokenizer + kimik2 tool-call parser then double-counts or + // misframes them. Keep it at true so SMG sees normal tokens and + // applies its own parsing. let skip_special_tokens = true; Ok(proto::SamplingParams { @@ -542,7 +552,7 @@ impl SglangSchedulerClient { } /// Build gRPC SamplingParams from ResponsesRequest - fn build_grpc_sampling_params_from_responses( + pub(crate) fn build_grpc_sampling_params_from_responses( request: &ResponsesRequest, constraint: Option<(String, String)>, ) -> Result { @@ -635,7 +645,7 @@ impl SglangSchedulerClient { } /// Build gRPC SamplingParams from CreateMessageRequest - fn build_grpc_sampling_params_from_messages( + pub(crate) fn build_grpc_sampling_params_from_messages( request: &CreateMessageRequest, tool_call_constraint: Option<(String, String)>, ) -> Result { @@ -698,7 +708,7 @@ impl SglangSchedulerClient { Ok(grpc_request) } - fn build_grpc_sampling_params_from_completion( + pub(crate) fn build_grpc_sampling_params_from_completion( request: &CompletionRequest, ) -> Result { let stop_sequences = match &request.stop { @@ -781,7 +791,7 @@ impl SglangSchedulerClient { } } - fn build_sampling_params_from_plain( + pub(crate) fn build_sampling_params_from_plain( params: Option<&GenerateSamplingParams>, ) -> Result { let mut sampling = proto::SamplingParams { diff --git a/crates/grpc_client/src/tokenspeed_scheduler.rs b/crates/grpc_client/src/tokenspeed_scheduler.rs new file mode 100644 index 000000000..d4d02cc3c --- /dev/null +++ b/crates/grpc_client/src/tokenspeed_scheduler.rs @@ -0,0 +1,711 @@ +//! gRPC client for the TokenSpeed scheduler service. +//! +//! TokenSpeed has a fully independent wire definition (see +//! ``proto/tokenspeed_scheduler.proto``) — distinct package +//! (``tokenspeed.grpc.scheduler``), distinct service, distinct messages with +//! intentionally trimmed field sets aimed at the top-tier LLM workloads +//! (Kimi K2, MiniMax M2, Qwen 3, gpt-oss, DeepSeek V4). Anything SGLang has +//! that doesn't apply here (PD-disaggregated serving, multimodal inputs, +//! LoRA hot-swap, hidden-state forwarding, embeddings, classifier outputs, +//! tokenizer streaming, KV-event subscription) is simply not on TokenSpeed's +//! wire. +//! +//! Internally this client still leverages SGLang's +//! ``build_grpc_sampling_params_from_*`` helpers because the source-of-truth +//! is an OpenAI request and most fields map identically. We translate from +//! the SGLang-shaped ``GenerateRequest`` into a TokenSpeed-shaped one at the +//! wire boundary, and translate the streamed response back so the router's +//! ``ProtoGenerateStreamChunk`` / ``ProtoGenerateComplete`` accessors can +//! operate on a familiar shape. When TokenSpeed needs a field SGLang lacks, +//! add it to the proto and extend the translator — not the router. + +use std::{ + pin::Pin, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + task::{Context, Poll}, + time::Duration, +}; + +use openai_protocol::{ + chat::ChatCompletionRequest, completion::CompletionRequest, generate::GenerateRequest, + messages::CreateMessageRequest, responses::ResponsesRequest, +}; +use tonic::{transport::Channel, Request, Streaming}; +use tracing::{debug, warn}; + +use crate::{ + sglang_scheduler::{proto as sglang, SglangSchedulerClient}, + BoxedTraceInjector, NoopTraceInjector, +}; + +#[expect(clippy::allow_attributes)] +pub mod tokenspeed_proto { + #![allow(clippy::all, clippy::absolute_paths, unused_qualifications)] + tonic::include_proto!("tokenspeed.grpc.scheduler"); +} + +/// Fire-and-forget abort sender used by [`AbortOnDropStream`]. The closure +/// captures the TokenSpeed client that owns the stream so ``Drop`` can +/// dispatch the abort RPC over the same connection without ``Drop`` itself +/// being async. Local to this module — SGLang's equivalent stream type +/// holds its own client field directly and doesn't need this indirection. +type AbortDispatcher = Arc; + +/// Auto-aborting wrapper around the TokenSpeed generate stream. +/// +/// Yields ``sglang::GenerateResponse`` (translated from the on-wire +/// ``tokenspeed_proto::GenerateResponse``) so the router-side +/// ``ProtoGenerateStreamChunk`` / ``ProtoGenerateComplete`` accessors can +/// keep operating on a single shape. Sends an Abort RPC on Drop unless +/// ``mark_completed`` was called first — same lifecycle contract as +/// ``sglang_scheduler::AbortOnDropStream``. +pub struct AbortOnDropStream { + inner: Streaming, + request_id: String, + abort_dispatcher: AbortDispatcher, + aborted: Arc, +} + +impl AbortOnDropStream { + pub fn new( + stream: Streaming, + request_id: String, + abort_dispatcher: AbortDispatcher, + ) -> Self { + debug!( + "Created TokenSpeed AbortOnDropStream for request {}", + request_id + ); + Self { + inner: stream, + request_id, + abort_dispatcher, + aborted: Arc::new(AtomicBool::new(false)), + } + } + + pub fn mark_completed(&self) { + self.aborted.store(true, Ordering::Release); + debug!("TokenSpeed request {} marked as completed", self.request_id); + } +} + +impl Drop for AbortOnDropStream { + fn drop(&mut self) { + if self + .aborted + .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire) + .is_err() + { + return; + } + debug!( + "TokenSpeed stream dropped without completion for request {}, sending abort", + self.request_id + ); + (self.abort_dispatcher)(self.request_id.clone()); + } +} + +impl futures::Stream for AbortOnDropStream { + // Yield SGLang-shaped responses so the router's wrapper enums don't need + // a TokenSpeed variant for every chunk-accessor. + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match Pin::new(&mut self.inner).poll_next(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(None) => Poll::Ready(None), + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), + Poll::Ready(Some(Ok(ts_resp))) => { + Poll::Ready(Some(Ok(translate::generate_response(ts_resp)))) + } + } + } +} + +/// gRPC client for the TokenSpeed scheduler. +#[derive(Clone)] +pub struct TokenSpeedSchedulerClient { + client: tokenspeed_proto::token_speed_scheduler_client::TokenSpeedSchedulerClient, + trace_injector: BoxedTraceInjector, +} + +impl TokenSpeedSchedulerClient { + pub async fn connect(endpoint: &str) -> Result> { + Self::connect_with_trace_injector(endpoint, Arc::new(NoopTraceInjector)).await + } + + pub async fn connect_with_trace_injector( + endpoint: &str, + trace_injector: BoxedTraceInjector, + ) -> Result> { + debug!("Connecting to TokenSpeed scheduler at {}", endpoint); + + let http_endpoint = if let Some(addr) = endpoint.strip_prefix("grpc://") { + format!("http://{addr}") + } else { + endpoint.to_string() + }; + + // Same channel knobs as SglangSchedulerClient — independent of the + // service being called and proven load-appropriate in prod. + let channel = Channel::from_shared(http_endpoint)? + .http2_keep_alive_interval(Duration::from_secs(30)) + .keep_alive_timeout(Duration::from_secs(10)) + .keep_alive_while_idle(true) + .tcp_keepalive(Some(Duration::from_secs(60))) + .tcp_nodelay(true) + .http2_adaptive_window(true) + .initial_stream_window_size(Some(16 * 1024 * 1024)) + .initial_connection_window_size(Some(32 * 1024 * 1024)) + .connect() + .await?; + + let client = + tokenspeed_proto::token_speed_scheduler_client::TokenSpeedSchedulerClient::new(channel); + + Ok(Self { + client, + trace_injector, + }) + } + + #[must_use] + pub fn with_trace_injector(mut self, trace_injector: BoxedTraceInjector) -> Self { + self.trace_injector = trace_injector; + self + } + + /// Submit a generation request. + /// + /// Accepts an SGLang-shaped request for symmetry with the router's + /// existing dispatch path; the translation to TokenSpeed's slimmer wire + /// shape (drops mm_inputs, disagg, LoRA, hidden-states, etc.) happens + /// here at the wire boundary. + pub async fn generate( + &self, + req: sglang::GenerateRequest, + ) -> Result { + let request_id = req.request_id.clone(); + let ts_req = translate::generate_request(req); + + let mut client = self.client.clone(); + let mut request = Request::new(ts_req); + + if let Err(e) = self.trace_injector.inject(request.metadata_mut()) { + warn!("Failed to inject trace context: {}", e); + } + + let response = client.generate(request).await?; + + Ok(AbortOnDropStream::new( + response.into_inner(), + request_id, + tokenspeed_abort_dispatcher(self.clone()), + )) + } + + pub async fn health_check(&self) -> Result { + debug!("Sending TokenSpeed health check request"); + let request = Request::new(tokenspeed_proto::HealthCheckRequest {}); + let mut client = self.client.clone(); + let response = client.health_check(request).await?; + let r = response.into_inner(); + Ok(sglang::HealthCheckResponse { + healthy: r.healthy, + message: r.message, + }) + } + + pub async fn abort_request( + &self, + request_id: String, + reason: String, + ) -> Result<(), tonic::Status> { + debug!( + "Sending TokenSpeed abort for {} (reason: {})", + request_id, reason + ); + let request = Request::new(tokenspeed_proto::AbortRequest { + request_id: request_id.clone(), + reason, + }); + let mut client = self.client.clone(); + let response = client.abort(request).await?; + debug!( + "TokenSpeed abort response for {}: success={}, message={}", + request_id, + response.get_ref().success, + response.get_ref().message + ); + Ok(()) + } + + pub async fn get_model_info(&self) -> Result { + let request = Request::new(tokenspeed_proto::GetModelInfoRequest {}); + let mut client = self.client.clone(); + let response = client.get_model_info(request).await?; + Ok(translate::model_info(response.into_inner())) + } + + pub async fn get_server_info(&self) -> Result { + let request = Request::new(tokenspeed_proto::GetServerInfoRequest {}); + let mut client = self.client.clone(); + let response = client.get_server_info(request).await?; + Ok(translate::server_info(response.into_inner())) + } + + pub async fn get_loads( + &self, + include: Vec, + ) -> Result { + let request = Request::new(tokenspeed_proto::GetLoadsRequest { + dp_rank: None, + include, + }); + let mut client = self.client.clone(); + let response = client.get_loads(request).await?; + Ok(translate::loads(response.into_inner())) + } + + // ── Request builders ────────────────────────────────────────────── + // + // These produce SGLang-shaped requests so the router's existing + // ``ProtoGenerateRequest::Sglang`` plumbing is reused. The wire-side + // translation to TokenSpeed shape happens inside ``generate()`` above. + // + // Sampling-param construction delegates to SglangSchedulerClient's + // ``pub(crate)`` helpers — same OpenAI source, same semantics. + + #[expect( + clippy::unused_self, + reason = "receiver kept for API parity with SglangSchedulerClient" + )] + pub fn build_generate_request_from_chat( + &self, + request_id: String, + body: &ChatCompletionRequest, + processed_text: String, + token_ids: Vec, + tool_call_constraint: Option<(String, String)>, + ) -> Result { + let sampling_params = SglangSchedulerClient::build_grpc_sampling_params_from_chat( + body, + tool_call_constraint, + )?; + Ok(sglang::GenerateRequest { + request_id, + tokenized: Some(sglang::TokenizedInput { + original_text: processed_text, + input_ids: token_ids, + }), + sampling_params: Some(sampling_params), + return_logprob: body.logprobs, + logprob_start_len: -1, + top_logprobs_num: body.top_logprobs.unwrap_or(0) as i32, + stream: body.stream, + ..Default::default() + }) + } + + #[expect( + clippy::unused_self, + reason = "receiver kept for API parity with SglangSchedulerClient" + )] + pub fn build_plain_generate_request( + &self, + request_id: String, + body: &GenerateRequest, + original_text: Option, + token_ids: Vec, + ) -> Result { + let sampling_params = + SglangSchedulerClient::build_sampling_params_from_plain(body.sampling_params.as_ref())?; + Ok(sglang::GenerateRequest { + request_id, + tokenized: Some(sglang::TokenizedInput { + original_text: original_text.unwrap_or_default(), + input_ids: token_ids, + }), + sampling_params: Some(sampling_params), + return_logprob: body.return_logprob.unwrap_or(false), + logprob_start_len: body.logprob_start_len.unwrap_or(-1), + top_logprobs_num: body.top_logprobs_num.unwrap_or(0), + token_ids_logprob: body.token_ids_logprob.clone().unwrap_or_default(), + stream: body.stream, + ..Default::default() + }) + } + + #[expect( + clippy::unused_self, + reason = "receiver kept for API parity with SglangSchedulerClient" + )] + pub fn build_generate_request_from_responses( + &self, + request_id: String, + body: &ResponsesRequest, + processed_text: String, + token_ids: Vec, + constraint: Option<(String, String)>, + ) -> Result { + let sampling_params = + SglangSchedulerClient::build_grpc_sampling_params_from_responses(body, constraint)?; + Ok(sglang::GenerateRequest { + request_id, + tokenized: Some(sglang::TokenizedInput { + original_text: processed_text, + input_ids: token_ids, + }), + sampling_params: Some(sampling_params), + stream: body.stream.unwrap_or(false), + ..Default::default() + }) + } + + #[expect( + clippy::unused_self, + reason = "receiver kept for API parity with SglangSchedulerClient" + )] + pub fn build_generate_request_from_messages( + &self, + request_id: String, + body: &CreateMessageRequest, + processed_text: String, + token_ids: Vec, + tool_call_constraint: Option<(String, String)>, + ) -> Result { + let sampling_params = SglangSchedulerClient::build_grpc_sampling_params_from_messages( + body, + tool_call_constraint, + )?; + Ok(sglang::GenerateRequest { + request_id, + tokenized: Some(sglang::TokenizedInput { + original_text: processed_text, + input_ids: token_ids, + }), + sampling_params: Some(sampling_params), + stream: body.stream.unwrap_or(false), + ..Default::default() + }) + } + + #[expect( + clippy::unused_self, + reason = "receiver kept for API parity with SglangSchedulerClient" + )] + pub fn build_generate_request_from_completion( + &self, + request_id: String, + body: &CompletionRequest, + original_text: String, + token_ids: Vec, + ) -> Result { + let sampling_params = + SglangSchedulerClient::build_grpc_sampling_params_from_completion(body)?; + Ok(sglang::GenerateRequest { + request_id, + tokenized: Some(sglang::TokenizedInput { + original_text, + input_ids: token_ids, + }), + sampling_params: Some(sampling_params), + return_logprob: body.logprobs.is_some(), + logprob_start_len: -1, + top_logprobs_num: body.logprobs.unwrap_or(0) as i32, + stream: body.stream, + ..Default::default() + }) + } +} + +/// Spawn a fire-and-forget abort RPC against TokenSpeed when an +/// ``AbortOnDropStream`` is dropped without completion. +fn tokenspeed_abort_dispatcher(client: TokenSpeedSchedulerClient) -> AbortDispatcher { + Arc::new(move |request_id: String| { + let client = client.clone(); + let request_id_for_log = request_id.clone(); + #[expect( + clippy::disallowed_methods, + reason = "fire-and-forget abort on Drop is intentional" + )] + tokio::spawn(async move { + if let Err(e) = client + .abort_request(request_id, "Stream dropped".to_string()) + .await + { + warn!( + "Failed to send TokenSpeed abort on drop for request {}: {}", + request_id_for_log, e + ); + } + }); + }) +} + +// ── Wire-boundary translation ───────────────────────────────────────── +// +// Maps SGLang-shaped types (used internally by the router) to TokenSpeed's +// slimmer wire types and back. Fields TokenSpeed doesn't carry on the wire +// (mm_inputs, disagg, LoRA, hidden states, embeddings, etc.) are dropped on +// the way out; fields TokenSpeed doesn't return are filled with defaults on +// the way in. When the protocols genuinely diverge — i.e. TokenSpeed needs +// a field SGLang doesn't have — extend this module rather than threading +// new variants through proto_wrapper. +mod translate { + use super::{sglang, tokenspeed_proto as ts}; + + pub(super) fn sampling_params(s: sglang::SamplingParams) -> ts::SamplingParams { + // sglang's proto declares numeric scalars as non-optional, so the Rust + // router has already substituted semantic defaults (e.g. + // ``temperature=1.0``, ``top_p=1.0``, ``repetition_penalty=1.0``) + // before getting here. tokenspeed's proto declares the same fields + // as ``optional`` so the servicer can use ``HasField()`` to + // distinguish presence — wrap the (already-defaulted) sglang values + // in ``Some(...)`` to mark them as explicitly set on the wire. This + // preserves the pre-fix behavior while letting future direct-to- + // tokenspeed clients use ``None`` to mean "let the engine default + // apply" (e.g. for health-probe / warmup paths that would otherwise + // hit ``top_p must be in (0, 1], got 0.0``). + ts::SamplingParams { + temperature: Some(s.temperature), + top_p: Some(s.top_p), + top_k: Some(s.top_k), + min_p: Some(s.min_p), + frequency_penalty: Some(s.frequency_penalty), + presence_penalty: Some(s.presence_penalty), + repetition_penalty: Some(s.repetition_penalty), + max_new_tokens: s.max_new_tokens, + min_new_tokens: s.min_new_tokens, + stop: s.stop, + stop_token_ids: s.stop_token_ids, + ignore_eos: s.ignore_eos, + skip_special_tokens: s.skip_special_tokens, + spaces_between_special_tokens: s.spaces_between_special_tokens, + no_stop_trim: s.no_stop_trim, + n: s.n, + logit_bias: s.logit_bias, + constraint: s.constraint.map(constraint), + custom_params: s.custom_params, + } + } + + fn constraint(c: sglang::sampling_params::Constraint) -> ts::sampling_params::Constraint { + match c { + sglang::sampling_params::Constraint::Regex(r) => { + ts::sampling_params::Constraint::Regex(r) + } + sglang::sampling_params::Constraint::JsonSchema(s) => { + ts::sampling_params::Constraint::JsonSchema(s) + } + sglang::sampling_params::Constraint::EbnfGrammar(g) => { + ts::sampling_params::Constraint::EbnfGrammar(g) + } + sglang::sampling_params::Constraint::StructuralTag(t) => { + ts::sampling_params::Constraint::StructuralTag(t) + } + } + } + + pub(super) fn generate_request(r: sglang::GenerateRequest) -> ts::GenerateRequest { + ts::GenerateRequest { + request_id: r.request_id, + tokenized: r.tokenized.map(|t| ts::TokenizedInput { + input_ids: t.input_ids, + original_text: t.original_text, + }), + sampling_params: r.sampling_params.map(sampling_params), + return_logprob: r.return_logprob, + // SGLang's wire-side `logprob_start_len` is non-optional `i32` + // with `-1` as the "no input logprobs" sentinel; TokenSpeed's + // proto makes the field `optional` so the servicer can tell + // "unset" from "explicit 0". Always wrap in `Some(...)` so + // existing SGLang-shaped callers preserve their sentinel + // through to the Python side. + logprob_start_len: Some(r.logprob_start_len), + top_logprobs_num: r.top_logprobs_num, + token_ids_logprob: r.token_ids_logprob, + stream: r.stream, + // Fields TokenSpeed has no concept of: + // r.mm_inputs, r.disaggregated_params, r.custom_logit_processor, + // r.timestamp, r.input_embeds, r.lora_id, r.data_parallel_rank, + // r.log_metrics, r.return_hidden_states + // — silently dropped here. Routing multimodal / disagg / LoRA + // requests to a TokenSpeed worker is a router-level config bug, + // not something this layer should try to paper over. + } + } + + pub(super) fn generate_response(r: ts::GenerateResponse) -> sglang::GenerateResponse { + let response = r.response.map(|resp| match resp { + ts::generate_response::Response::Chunk(c) => { + sglang::generate_response::Response::Chunk(stream_chunk(c)) + } + ts::generate_response::Response::Complete(c) => { + sglang::generate_response::Response::Complete(complete(c)) + } + }); + sglang::GenerateResponse { + request_id: r.request_id, + response, + } + } + + fn stream_chunk(c: ts::GenerateStreamChunk) -> sglang::GenerateStreamChunk { + sglang::GenerateStreamChunk { + token_ids: c.token_ids, + prompt_tokens: c.prompt_tokens, + completion_tokens: c.completion_tokens, + cached_tokens: c.cached_tokens, + output_logprobs: c.output_logprobs.map(output_logprobs), + // Fields not on TokenSpeed's wire — defaulted. + hidden_states: vec![], + input_logprobs: None, + index: c.index, + } + } + + fn complete(c: ts::GenerateComplete) -> sglang::GenerateComplete { + let matched_stop = c.matched_stop.map(|m| match m { + ts::generate_complete::MatchedStop::MatchedTokenId(id) => { + sglang::generate_complete::MatchedStop::MatchedTokenId(id) + } + ts::generate_complete::MatchedStop::MatchedStopStr(s) => { + sglang::generate_complete::MatchedStop::MatchedStopStr(s) + } + }); + sglang::GenerateComplete { + output_ids: c.output_ids, + finish_reason: c.finish_reason, + prompt_tokens: c.prompt_tokens, + completion_tokens: c.completion_tokens, + cached_tokens: c.cached_tokens, + output_logprobs: c.output_logprobs.map(output_logprobs), + // Not on TokenSpeed's wire. + all_hidden_states: vec![], + input_logprobs: None, + matched_stop, + index: c.index, + } + } + + fn output_logprobs(o: ts::OutputLogProbs) -> sglang::OutputLogProbs { + sglang::OutputLogProbs { + token_logprobs: o.token_logprobs, + token_ids: o.token_ids, + top_logprobs: o + .top_logprobs + .into_iter() + .map(|t| sglang::TopLogProbs { + values: t.values, + token_ids: t.token_ids, + }) + .collect(), + } + } + + pub(super) fn model_info(r: ts::GetModelInfoResponse) -> sglang::GetModelInfoResponse { + sglang::GetModelInfoResponse { + model_path: r.model_path, + tokenizer_path: r.tokenizer_path, + // TokenSpeed only serves generative LLMs at this layer; classifier + // / embedding models are out of scope. Hard-code accordingly. + is_generation: true, + preferred_sampling_params: r.preferred_sampling_params, + weight_version: r.weight_version, + served_model_name: r.served_model_name, + max_context_length: r.max_context_length, + vocab_size: r.vocab_size, + supports_vision: false, + model_type: r.model_type, + eos_token_ids: r.eos_token_ids, + pad_token_id: r.pad_token_id, + bos_token_id: r.bos_token_id, + max_req_input_len: r.max_req_input_len, + architectures: r.architectures, + id2label_json: String::new(), + num_labels: 0, + default_sampling_params_json: String::new(), + } + } + + pub(super) fn server_info(r: ts::GetServerInfoResponse) -> sglang::GetServerInfoResponse { + sglang::GetServerInfoResponse { + server_args: r.server_args, + scheduler_info: r.scheduler_info, + active_requests: r.active_requests, + is_paused: r.is_paused, + // TokenSpeed scheduler doesn't track this — router doesn't read + // it for TokenSpeed either, so a fixed 0 is fine. + last_receive_timestamp: 0.0, + uptime_seconds: r.uptime_seconds, + // sglang_version field on the SGLang struct is the runtime version; + // for TokenSpeed we surface the TokenSpeed version through the same + // slot since downstream metric labels keep the field name. + sglang_version: r.tokenspeed_version, + server_type: "grpc".to_string(), + start_time: r.start_time, + max_total_num_tokens: r.max_total_num_tokens, + } + } + + pub(super) fn loads(r: ts::GetLoadsResponse) -> sglang::GetLoadsResponse { + sglang::GetLoadsResponse { + timestamp: r.timestamp, + version: r.version, + dp_rank_count: r.dp_rank_count, + loads: r.loads.into_iter().map(scheduler_load).collect(), + aggregate: r.aggregate.map(aggregate_metrics), + } + } + + fn scheduler_load(s: ts::SchedulerLoad) -> sglang::SchedulerLoad { + sglang::SchedulerLoad { + dp_rank: s.dp_rank, + num_running_reqs: s.num_running_reqs, + num_waiting_reqs: s.num_waiting_reqs, + num_total_reqs: s.num_total_reqs, + num_used_tokens: s.num_used_tokens, + max_total_num_tokens: s.max_total_num_tokens, + token_usage: s.token_usage, + gen_throughput: s.gen_throughput, + cache_hit_rate: s.cache_hit_rate, + utilization: s.utilization, + max_running_requests: s.max_running_requests, + memory: s.memory.map(|m| sglang::MemoryMetrics { + weight_gb: m.weight_gb, + kv_cache_gb: m.kv_cache_gb, + graph_gb: m.graph_gb, + token_capacity: m.token_capacity, + }), + // TokenSpeed's wire intentionally omits speculative / LoRA / + // disaggregation metrics — fill the SGLang-shaped slots with + // None so callers ignore them. + speculative: None, + lora: None, + disaggregation: None, + queues: s.queues.map(|q| sglang::QueueMetrics { + waiting: q.waiting, + grammar: q.grammar, + paused: q.paused, + retracted: q.retracted, + }), + } + } + + fn aggregate_metrics(a: ts::AggregateMetrics) -> sglang::AggregateMetrics { + sglang::AggregateMetrics { + total_running_reqs: a.total_running_reqs, + total_waiting_reqs: a.total_waiting_reqs, + total_reqs: a.total_reqs, + avg_token_usage: a.avg_token_usage, + avg_throughput: a.avg_throughput, + avg_utilization: a.avg_utilization, + } + } +} diff --git a/crates/protocols/src/worker.rs b/crates/protocols/src/worker.rs index 03d6d7547..cd7897eb4 100644 --- a/crates/protocols/src/worker.rs +++ b/crates/protocols/src/worker.rs @@ -197,6 +197,8 @@ pub enum RuntimeType { Trtllm, /// MLX runtime (Apple Silicon). Mlx, + /// TokenSpeed runtime. + TokenSpeed, /// External OpenAI-compatible API (not local inference). External, } @@ -216,6 +218,7 @@ impl std::fmt::Display for RuntimeType { RuntimeType::Vllm => write!(f, "vllm"), RuntimeType::Trtllm => write!(f, "trtllm"), RuntimeType::Mlx => write!(f, "mlx"), + RuntimeType::TokenSpeed => write!(f, "tokenspeed"), RuntimeType::External => write!(f, "external"), } } @@ -235,6 +238,8 @@ impl std::str::FromStr for RuntimeType { Ok(RuntimeType::Trtllm) } else if s.eq_ignore_ascii_case("mlx") { Ok(RuntimeType::Mlx) + } else if s.eq_ignore_ascii_case("tokenspeed") { + Ok(RuntimeType::TokenSpeed) } else if s.eq_ignore_ascii_case("external") { Ok(RuntimeType::External) } else { diff --git a/crates/tokenizer/src/chat_template.rs b/crates/tokenizer/src/chat_template.rs index 2ac3b50c7..97ca4e550 100644 --- a/crates/tokenizer/src/chat_template.rs +++ b/crates/tokenizer/src/chat_template.rs @@ -620,11 +620,59 @@ fn render_chat_template( // Convert messages to minijinja::Value (messages already processed by router) let minijinja_messages: Vec = messages.iter().map(Value::from_serialize).collect(); + // Strip the OpenAI tool wrapper for downstream rendering: convert + // ``[{"type": "function", "function": {...}}, ...]`` into the bare list + // ``[{...}, ...]``. This matches what TokenSpeed's HTTP path does in + // ``serving_chat.py:188`` (``[item.function.model_dump() for item in + // tools]``). Empirically, Kimi-K2.5 + TokenSpeed-NVFP4 BFCL accuracy is + // significantly higher when the model sees the bare-inner JSON form + // (HTTP path: simple_python 92.25 %) than either the wrapped form + // (SMG pre-fix: 86 %) or the TS-namespace form produced by the model's + // own ``encode_tools_to_typescript_style`` (88.25 %). The chat template + // falls through to ``{{ tools | tojson }}`` whenever ``tools_ts_str`` + // is empty/missing — so feeding the bare-inner shape and letting the + // template's JSON branch fire reproduces the HTTP path exactly. + let stripped_tools_json: Option> = params.tools.as_ref().map(|arr| { + arr.iter() + .map(|t| { + let v = serde_json::to_value(t).unwrap_or(serde_json::Value::Null); + // If this is an OpenAI-style wrapped tool ({"type":"function","function":{...}}), + // unwrap to the inner function dict. Then mirror what TokenSpeed HTTP's + // ``serving_chat.py:188`` does — call Pydantic's ``model_dump()`` on the + // inner function model, which emits *all* fields of the function schema + // including the ``strict`` field with its default ``false``. Without this + // SMG produces 108 tokens vs HTTP's 111 (3 missing tokens for ``,"strict":false``) + // and the model's BFCL accuracy stays ~5 pp below HTTP. Adding ``strict:false`` + // to each function dict closes that gap by matching HTTP byte-for-byte. + match v { + serde_json::Value::Object(ref m) => { + if m.get("type").and_then(|x| x.as_str()) == Some("function") { + if let Some(serde_json::Value::Object(inner)) = m.get("function") { + let mut inner_with_default = inner.clone(); + // Match Pydantic's ChatCompletionToolFunction.model_dump(): + // emit ``strict`` even when not set in the source request, + // defaulting to ``false`` (the OpenAI-API default). + inner_with_default + .entry("strict") + .or_insert(serde_json::Value::Bool(false)); + return serde_json::Value::Object(inner_with_default); + } + } + v + } + other => other, + } + }) + .collect() + }); + // Use Value::UNDEFINED for missing optional params so they are truly "undefined" // in the template context, matching HuggingFace Python behavior. Many chat templates // use `{% if tools is defined %}` guards — passing null (none) instead of undefined // would bypass those guards since `none` IS defined, causing `tools | length` to fail. - let tools_value = params.tools.map_or(Value::UNDEFINED, Value::from_serialize); + let tools_value = stripped_tools_json + .as_ref() + .map_or(Value::UNDEFINED, Value::from_serialize); let documents_value = params .documents .map_or(Value::UNDEFINED, Value::from_serialize); diff --git a/model_gateway/src/routers/grpc/client.rs b/model_gateway/src/routers/grpc/client.rs index 81cfc8f11..bab1d1aa0 100644 --- a/model_gateway/src/routers/grpc/client.rs +++ b/model_gateway/src/routers/grpc/client.rs @@ -8,7 +8,7 @@ use openai_protocol::{ }; use smg_grpc_client::{ tokenizer_bundle, tokenizer_bundle::StreamBundle, MlxEngineClient, SglangSchedulerClient, - TrtllmServiceClient, VllmEngineClient, + TokenSpeedSchedulerClient, TrtllmServiceClient, VllmEngineClient, }; use crate::routers::grpc::{ @@ -23,13 +23,17 @@ pub struct HealthCheckResponse { pub message: String, } -/// Polymorphic gRPC client that wraps SGLang, vLLM, TensorRT-LLM, or MLX +/// Wraps the per-backend gRPC clients. TokenSpeed has its own service but +/// reuses SGLang-shaped wrapper variants where the wire shapes line up +/// after translation; RPCs absent on a backend's wire return +/// ``Status::unimplemented``. #[derive(Clone)] pub enum GrpcClient { Sglang(SglangSchedulerClient), Vllm(VllmEngineClient), Trtllm(TrtllmServiceClient), Mlx(MlxEngineClient), + TokenSpeed(TokenSpeedSchedulerClient), } impl GrpcClient { @@ -137,6 +141,32 @@ impl GrpcClient { matches!(self, Self::Mlx(_)) } + #[expect( + clippy::panic, + reason = "typed accessor: caller guarantees variant via is_tokenspeed() check" + )] + pub fn as_tokenspeed(&self) -> &TokenSpeedSchedulerClient { + match self { + Self::TokenSpeed(client) => client, + _ => panic!("Expected TokenSpeed client"), + } + } + + #[expect( + clippy::panic, + reason = "typed accessor: caller guarantees variant via is_tokenspeed() check" + )] + pub fn as_tokenspeed_mut(&mut self) -> &mut TokenSpeedSchedulerClient { + match self { + Self::TokenSpeed(client) => client, + _ => panic!("Expected TokenSpeed client"), + } + } + + pub fn is_tokenspeed(&self) -> bool { + matches!(self, Self::TokenSpeed(_)) + } + pub async fn connect( url: &str, runtime_type: &str, @@ -146,6 +176,9 @@ impl GrpcClient { "vllm" => Ok(Self::Vllm(VllmEngineClient::connect(url).await?)), "trtllm" | "tensorrt-llm" => Ok(Self::Trtllm(TrtllmServiceClient::connect(url).await?)), "mlx" => Ok(Self::Mlx(MlxEngineClient::connect(url).await?)), + "tokenspeed" => Ok(Self::TokenSpeed( + TokenSpeedSchedulerClient::connect(url).await?, + )), _ => Err(format!("Unknown runtime type: {runtime_type}").into()), } } @@ -182,6 +215,13 @@ impl GrpcClient { message: resp.message, }) } + Self::TokenSpeed(client) => { + let resp = client.health_check().await?; + Ok(HealthCheckResponse { + healthy: resp.healthy, + message: resp.message, + }) + } } } @@ -191,24 +231,31 @@ impl GrpcClient { Self::Vllm(client) => Ok(ModelInfo::Vllm(client.get_model_info().await?)), Self::Trtllm(client) => Ok(ModelInfo::Trtllm(client.get_model_info().await?)), Self::Mlx(client) => Ok(ModelInfo::Mlx(client.get_model_info().await?)), + Self::TokenSpeed(client) => { + Ok(ModelInfo::Sglang(Box::new(client.get_model_info().await?))) + } } } /// Get the full load response from the backend. - /// Only supported for SGLang backends. Returns per-DP-rank load metrics. + /// Supported for SGLang and TokenSpeed backends. pub async fn get_loads(&self) -> Result { match self { Self::Sglang(client) => { let resp = client.get_loads(vec!["core".to_string()]).await?; Ok(WorkerLoadResponse::from(resp)) } + Self::TokenSpeed(client) => { + let resp = client.get_loads(vec!["core".to_string()]).await?; + Ok(WorkerLoadResponse::from(resp)) + } _ => Err(tonic::Status::unimplemented( "GetLoads RPC not supported for this backend", )), } } - /// Subscribe to KV cache events (all backends). + /// Subscribe to KV cache events (SGLang / vLLM / TRT-LLM only). pub async fn subscribe_kv_events( &self, start_seq: u64, @@ -220,6 +267,9 @@ impl GrpcClient { Self::Mlx(_) => Err(tonic::Status::unimplemented( "SubscribeKvEvents RPC not supported for MLX backend", )), + Self::TokenSpeed(_) => Err(tonic::Status::unimplemented( + "SubscribeKvEvents RPC not supported for TokenSpeed backend", + )), } } @@ -231,6 +281,9 @@ impl GrpcClient { Self::Vllm(client) => Ok(ServerInfo::Vllm(client.get_server_info().await?)), Self::Trtllm(client) => Ok(ServerInfo::Trtllm(client.get_server_info().await?)), Self::Mlx(client) => Ok(ServerInfo::Mlx(client.get_server_info().await?)), + Self::TokenSpeed(client) => Ok(ServerInfo::Sglang(Box::new( + client.get_server_info().await?, + ))), } } @@ -243,6 +296,14 @@ impl GrpcClient { Self::Vllm(client) => client.get_tokenizer().await, Self::Trtllm(client) => client.get_tokenizer().await, Self::Mlx(client) => client.get_tokenizer().await, + // Status::unimplemented (not a String error) so the fallback in + // tokenizer_registration's downcast_ref::() check + // skips TokenSpeed workers silently. + Self::TokenSpeed(_) => { + return Err(Box::new(tonic::Status::unimplemented( + "TokenSpeed backend does not support GetTokenizer RPC", + ))); + } }?; tokenizer_bundle::validate_bundle_sha256(&bundle).map_err(|e| { @@ -280,6 +341,10 @@ impl GrpcClient { let stream = client.generate(*boxed_req).await?; Ok(ProtoStream::Mlx(stream)) } + (Self::TokenSpeed(client), ProtoGenerateRequest::Sglang(boxed_req)) => { + let stream = client.generate(*boxed_req).await?; + Ok(ProtoStream::TokenSpeed(stream)) + } #[expect( clippy::panic, reason = "client and request types are always matched by construction in the pipeline" @@ -301,6 +366,11 @@ impl GrpcClient { let resp = client.embed(*boxed_req).await?; Ok(ProtoEmbedComplete::Vllm(resp)) } + // TokenSpeed dropped the Embed RPC from its wire — top-tier + // LLMs aren't embedding models, so the proto doesn't carry one. + (Self::TokenSpeed(_), _) => Err(tonic::Status::unimplemented( + "TokenSpeed backend does not support embedding", + )), (Self::Mlx(_), _) => Err(tonic::Status::unimplemented( "MLX backend does not support embedding", )), @@ -382,6 +452,22 @@ impl GrpcClient { )?; Ok(ProtoGenerateRequest::Mlx(Box::new(req))) } + // TokenSpeed's wire intentionally has no multimodal fields + // (top-tier LLMs are text-only today). Reject if the assembly + // stage produced any — that's a router-config bug. + Self::TokenSpeed(client) => { + if multimodal_inputs.is_some() { + return Err("TokenSpeed backend does not support multimodal inputs".to_string()); + } + let req = client.build_generate_request_from_chat( + request_id, + body, + processed_text, + token_ids, + tool_constraints, + )?; + Ok(ProtoGenerateRequest::Sglang(Box::new(req))) + } } } @@ -455,6 +541,19 @@ impl GrpcClient { )?; Ok(ProtoGenerateRequest::Mlx(Box::new(req))) } + Self::TokenSpeed(client) => { + if multimodal_inputs.is_some() { + return Err("TokenSpeed backend does not support multimodal inputs".to_string()); + } + let req = client.build_generate_request_from_messages( + request_id, + body, + processed_text, + token_ids, + tool_constraints, + )?; + Ok(ProtoGenerateRequest::Sglang(Box::new(req))) + } } } @@ -502,6 +601,15 @@ impl GrpcClient { )?; Ok(ProtoGenerateRequest::Mlx(Box::new(req))) } + Self::TokenSpeed(client) => { + let req = client.build_generate_request_from_completion( + request_id, + body, + original_text, + token_ids, + )?; + Ok(ProtoGenerateRequest::Sglang(Box::new(req))) + } } } @@ -549,6 +657,15 @@ impl GrpcClient { )?; Ok(ProtoGenerateRequest::Mlx(Box::new(req))) } + Self::TokenSpeed(client) => { + let req = client.build_plain_generate_request( + request_id, + body, + original_text, + token_ids, + )?; + Ok(ProtoGenerateRequest::Sglang(Box::new(req))) + } } } } diff --git a/model_gateway/src/routers/grpc/common/stages/request_execution.rs b/model_gateway/src/routers/grpc/common/stages/request_execution.rs index 07cf75496..1447e025b 100644 --- a/model_gateway/src/routers/grpc/common/stages/request_execution.rs +++ b/model_gateway/src/routers/grpc/common/stages/request_execution.rs @@ -114,6 +114,10 @@ impl PipelineStage for RequestExecutionStage { } Some(RuntimeType::Trtllm) | Some(RuntimeType::Mlx) + // TokenSpeed shares the SGLang proto but doesn't + // ship PD-disaggregation support yet — treat it + // like the other non-PD backends here. + | Some(RuntimeType::TokenSpeed) | Some(RuntimeType::External) | Some(RuntimeType::Unspecified) => { error!( diff --git a/model_gateway/src/routers/grpc/harmony/stages/request_building.rs b/model_gateway/src/routers/grpc/harmony/stages/request_building.rs index d084d66f3..0cc042dc7 100644 --- a/model_gateway/src/routers/grpc/harmony/stages/request_building.rs +++ b/model_gateway/src/routers/grpc/harmony/stages/request_building.rs @@ -277,6 +277,55 @@ impl PipelineStage for HarmonyRequestBuildingStage { }; ProtoGenerateRequest::Mlx(Box::new(req)) } + // TokenSpeed: builder produces an SGLang-shaped request so the + // ``ProtoGenerateRequest::Sglang`` plumbing carries it; the + // wire-side translation to TokenSpeed shape happens inside the + // client's ``generate()``. Multimodal is intentionally not + // supported here — the harmony path is text-only today. + GrpcClient::TokenSpeed(tokenspeed_client) => { + let req = match &ctx.input.request_type { + RequestType::Chat(request) => { + let body = modified_request.as_deref().unwrap_or_else(|| request.as_ref()); + tokenspeed_client + .build_generate_request_from_chat( + request_id, + body, + placeholder_processed_text, + token_ids, + tool_constraints, + ) + .map_err(|e| { + error!(function = "HarmonyRequestBuildingStage::execute", error = %e, "Failed to build TokenSpeed generate request"); + error::bad_request("invalid_request_parameters", format!("Invalid request parameters: {e}")) + })? + } + RequestType::Responses(request) => tokenspeed_client + .build_generate_request_from_responses( + request_id, + request.as_ref(), + placeholder_processed_text, + token_ids, + tool_constraints, + ) + .map_err(|e| { + error!(function = "HarmonyRequestBuildingStage::execute", error = %e, "Failed to build TokenSpeed generate request from responses"); + error::bad_request("invalid_request_parameters", format!("Invalid request parameters: {e}")) + })?, + RequestType::Embedding(_) => { + return Err(error::bad_request( + "harmony_embedding_not_supported", + "Embedding requests are not supported with Harmony models".to_string(), + )); + } + _ => { + return Err(error::bad_request( + "unsupported_request_type", + "Unsupported request type for Harmony models".to_string(), + )); + } + }; + ProtoGenerateRequest::Sglang(Box::new(req)) + } }; // Inject Harmony stop token IDs into sampling params for ALL Harmony requests diff --git a/model_gateway/src/routers/grpc/multimodal.rs b/model_gateway/src/routers/grpc/multimodal.rs index d28c7e3fc..f279b6156 100644 --- a/model_gateway/src/routers/grpc/multimodal.rs +++ b/model_gateway/src/routers/grpc/multimodal.rs @@ -708,6 +708,13 @@ pub(crate) fn assemble_multimodal_data( GrpcClient::Mlx(_) => unreachable!( "caller rejects multimodal for MLX in build_chat_request/build_messages_request" ), + // TokenSpeed's wire intentionally has no multimodal fields. The + // detect-backend / preparation stages never enable multimodal for a + // text-only top-tier LLM, so reaching this arm is a router-config + // bug rather than a user error. + GrpcClient::TokenSpeed(_) => unreachable!( + "TokenSpeed backend does not support multimodal; preparation stage should reject earlier" + ), } } diff --git a/model_gateway/src/routers/grpc/proto_wrapper.rs b/model_gateway/src/routers/grpc/proto_wrapper.rs index 971ff388b..6b053e3ef 100644 --- a/model_gateway/src/routers/grpc/proto_wrapper.rs +++ b/model_gateway/src/routers/grpc/proto_wrapper.rs @@ -11,6 +11,7 @@ use smg_grpc_client::{ mlx_proto::{self as mlx}, sglang_proto::{self as sglang, generate_complete::MatchedStop as SglangMatchedStop}, sglang_scheduler::AbortOnDropStream as SglangStream, + tokenspeed_scheduler::AbortOnDropStream as TokenSpeedStream, trtllm_proto::{self as trtllm, generate_complete::MatchedStop as TrtllmMatchedStop}, trtllm_service::AbortOnDropStream as TrtllmStream, vllm_engine::AbortOnDropStream as VllmStream, @@ -918,12 +919,20 @@ impl ProtoGenerateComplete { } } -/// Unified stream wrapper +/// Unified stream wrapper. +/// +/// TokenSpeed has its own variant because its underlying stream type differs +/// from SGLang's ([`tokenspeed_scheduler::AbortOnDropStream`] vs +/// [`sglang_scheduler::AbortOnDropStream`]) — the wire is independent. Both +/// yield ``sglang::GenerateResponse``-shaped items (TokenSpeed translates at +/// the boundary), so the chunk / complete accessors below don't need a +/// dedicated TokenSpeed variant. pub enum ProtoStream { Sglang(SglangStream), Vllm(VllmStream), Trtllm(TrtllmStream), Mlx(MlxStream), + TokenSpeed(TokenSpeedStream), } impl ProtoStream { @@ -946,6 +955,10 @@ impl ProtoStream { .next() .await .map(|result| result.map(|r| ProtoGenerateResponse::Mlx(Box::new(r)))), + Self::TokenSpeed(stream) => stream + .next() + .await + .map(|result| result.map(|r| ProtoGenerateResponse::Sglang(Box::new(r)))), } } @@ -956,6 +969,7 @@ impl ProtoStream { Self::Vllm(stream) => stream.mark_completed(), Self::Trtllm(stream) => stream.mark_completed(), Self::Mlx(stream) => stream.mark_completed(), + Self::TokenSpeed(stream) => stream.mark_completed(), } } } diff --git a/model_gateway/src/routers/grpc/regular/stages/embedding/request_building.rs b/model_gateway/src/routers/grpc/regular/stages/embedding/request_building.rs index b9ee0a845..c69eef5b9 100644 --- a/model_gateway/src/routers/grpc/regular/stages/embedding/request_building.rs +++ b/model_gateway/src/routers/grpc/regular/stages/embedding/request_building.rs @@ -96,6 +96,16 @@ impl PipelineStage for EmbeddingRequestBuildingStage { "MLX embedding is not supported via gRPC", )); } + GrpcClient::TokenSpeed(_) => { + error!( + function = "EmbeddingRequestBuildingStage::execute", + "TokenSpeed backend does not support embeddings" + ); + return Err(error::not_implemented( + "unsupported_backend", + "TokenSpeed backend does not support embeddings", + )); + } }; ctx.state.proto_request = Some(ProtoRequest::Embed(proto_req)); diff --git a/model_gateway/src/workflow/steps/local/detect_backend.rs b/model_gateway/src/workflow/steps/local/detect_backend.rs index 3295d6f82..fbcb9c1dd 100644 --- a/model_gateway/src/workflow/steps/local/detect_backend.rs +++ b/model_gateway/src/workflow/steps/local/detect_backend.rs @@ -1,8 +1,8 @@ //! Backend runtime detection step. //! -//! Detects the runtime type (sglang, vllm, trtllm, mlx) for both HTTP and gRPC workers. +//! Detects the runtime type (sglang, vllm, trtllm, tokenspeed, mlx) for both HTTP and gRPC workers. //! - HTTP: probes `/v1/models` (owned_by field), falls back to unique endpoints. -//! - gRPC: tries sglang → vllm → trtllm → mlx health checks sequentially. +//! - gRPC: tries sglang → vllm → trtllm → tokenspeed → mlx health checks sequentially. use std::time::Duration; @@ -43,8 +43,11 @@ async fn detect_grpc_backend( } } - // Try each runtime sequentially (most common first), skipping the hint we already tried - for runtime in &["sglang", "vllm", "trtllm", "mlx"] { + // Try each runtime sequentially, ordered by expected frequency so the + // common case finishes after one probe. Each backend speaks its own + // gRPC service, so order is purely a latency optimisation, not a + // correctness condition. + for runtime in &["sglang", "vllm", "trtllm", "tokenspeed", "mlx"] { if Some(*runtime) == runtime_hint { continue; } @@ -57,7 +60,7 @@ async fn detect_grpc_backend( } Err(format!( - "gRPC backend detection failed for {url} (tried sglang, vllm, trtllm, mlx)" + "gRPC backend detection failed for {url} (tried sglang, vllm, trtllm, tokenspeed, mlx)" )) } diff --git a/model_gateway/src/workflow/steps/util.rs b/model_gateway/src/workflow/steps/util.rs index efd6753e3..ff7d17208 100644 --- a/model_gateway/src/workflow/steps/util.rs +++ b/model_gateway/src/workflow/steps/util.rs @@ -88,17 +88,22 @@ pub(crate) async fn try_grpc_reachable(url: &str, timeout_secs: u64) -> Result<( format!("grpc://{}", strip_protocol(url)) }; - let (sglang, vllm, trtllm, mlx) = tokio::join!( + let (sglang, vllm, trtllm, mlx, tokenspeed) = tokio::join!( do_grpc_health_check(&grpc_url, timeout_secs, "sglang"), do_grpc_health_check(&grpc_url, timeout_secs, "vllm"), do_grpc_health_check(&grpc_url, timeout_secs, "trtllm"), do_grpc_health_check(&grpc_url, timeout_secs, "mlx"), + do_grpc_health_check(&grpc_url, timeout_secs, "tokenspeed"), ); - match (sglang, vllm, trtllm, mlx) { - (Ok(()), _, _, _) | (_, Ok(()), _, _) | (_, _, Ok(()), _) | (_, _, _, Ok(())) => Ok(()), - (Err(e1), Err(e2), Err(e3), Err(e4)) => Err(format!( - "gRPC not reachable (tried sglang, vllm, trtllm, mlx): sglang={e1}, vllm={e2}, trtllm={e3}, mlx={e4}", + match (sglang, vllm, trtllm, mlx, tokenspeed) { + (Ok(()), _, _, _, _) + | (_, Ok(()), _, _, _) + | (_, _, Ok(()), _, _) + | (_, _, _, Ok(()), _) + | (_, _, _, _, Ok(())) => Ok(()), + (Err(e1), Err(e2), Err(e3), Err(e4), Err(e5)) => Err(format!( + "gRPC not reachable (tried sglang, vllm, trtllm, mlx, tokenspeed): sglang={e1}, vllm={e2}, trtllm={e3}, mlx={e4}, tokenspeed={e5}", )), } } From 25375e5df40ae658a815421fb85ab71207659cca Mon Sep 17 00:00:00 2001 From: key4ng Date: Sat, 9 May 2026 12:19:16 -0700 Subject: [PATCH 02/24] =?UTF-8?q?refactor(grpc):=20extract=20OpenAI?= =?UTF-8?q?=E2=86=92sampling-params=20helpers=20to=20a=20common=20module?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The 5 ``build_*_sampling_params_from_*`` helpers (chat / responses / messages / completion / plain) plus their constraint helpers were sitting on ``SglangSchedulerClient`` even though the OpenAI mapping is backend-neutral. TokenSpeed's client was reaching across to call them through ``SglangSchedulerClient::*`` which suggested SGLang owned the OpenAI→sampling translation, when it doesn't. Move them to ``crate::sampling_params`` as free functions. Both the SGLang and TokenSpeed clients (and any future client that wants the same mapping) now call ``crate::sampling_params::build_*`` directly. The return type is still ``sglang::SamplingParams`` because that proto happens to be the most permissive shape across our supported backends; TokenSpeed translates to its own slimmer shape at the wire boundary. When a backend grows a sampling field SGLang lacks, this is the place to add it. No behavior change. Tests stay green; the call sites in ``build_generate_request_from_*`` are mechanically updated. Signed-off-by: key4ng --- crates/grpc_client/src/lib.rs | 1 + crates/grpc_client/src/sampling_params.rs | 370 +++++++++++++++++ crates/grpc_client/src/sglang_scheduler.rs | 382 +----------------- .../grpc_client/src/tokenspeed_scheduler.rs | 12 +- 4 files changed, 387 insertions(+), 378 deletions(-) create mode 100644 crates/grpc_client/src/sampling_params.rs diff --git a/crates/grpc_client/src/lib.rs b/crates/grpc_client/src/lib.rs index a26528953..79dddc333 100644 --- a/crates/grpc_client/src/lib.rs +++ b/crates/grpc_client/src/lib.rs @@ -8,6 +8,7 @@ pub mod common_proto { tonic::include_proto!("smg.grpc.common"); } pub mod mlx_engine; +pub mod sampling_params; pub mod sglang_scheduler; pub mod tokenizer_bundle; pub mod tokenspeed_scheduler; diff --git a/crates/grpc_client/src/sampling_params.rs b/crates/grpc_client/src/sampling_params.rs new file mode 100644 index 000000000..b0c2d3d5f --- /dev/null +++ b/crates/grpc_client/src/sampling_params.rs @@ -0,0 +1,370 @@ +//! Backend-neutral OpenAI → sampling-params builders. +//! +//! These helpers translate OpenAI request shapes (Chat, Responses, Messages, +//! Completion, plain `GenerateSamplingParams`) into a sampling-params struct +//! shared by the SGLang and TokenSpeed gRPC clients. They live here rather +//! than on `SglangSchedulerClient` because the OpenAI mapping is independent +//! of the wire backend. +//! +//! The return type is currently [`sglang::SamplingParams`] because that proto +//! happens to be the most permissive shape across our supported backends. +//! Other backends (TokenSpeed) translate from this shape to their own slimmer +//! wire format at the boundary. When a backend grows a sampling field SGLang +//! lacks, this is the place to add it (and consider whether a neutral +//! intermediate struct is worth introducing). + +use openai_protocol::{ + chat::ChatCompletionRequest, + common::{ResponseFormat, StringOrArray}, + completion::CompletionRequest, + messages::CreateMessageRequest, + responses::ResponsesRequest, + sampling_params::SamplingParams as GenerateSamplingParams, +}; +use tracing::warn; + +use crate::sglang_scheduler::proto; + +/// Build gRPC `SamplingParams` from a `ChatCompletionRequest`. +pub fn build_grpc_sampling_params_from_chat( + request: &ChatCompletionRequest, + tool_call_constraint: Option<(String, String)>, +) -> Result { + let stop_sequences = extract_stop_strings(request); + + let max_new_tokens = request.max_completion_tokens; + + // Hardcode to true: gRPC backends return raw token IDs, not decoded text. + // Detokenization happens on the SMG Rust side (StopDecoder/Sequence). + // + // Note: TokenSpeed's HTTP serving_chat sets this to false when tools are + // present (serving_chat.py:178-179) — but mirroring that on the gRPC + // path measurably HURTS BFCL accuracy. We tested it: simple_python + // dropped from ~88.75 % to 79 %, parallel_multiple from ~84.5 % to + // 60.5 %. With skip_special_tokens=false the engine emits the + // ``<|tool_call_*|>`` special tokens in the raw output stream, and the + // SMG-side detokenizer + kimik2 tool-call parser then double-counts or + // misframes them. Keep it at true so SMG sees normal tokens and + // applies its own parsing. + let skip_special_tokens = true; + + Ok(proto::SamplingParams { + temperature: request.temperature.unwrap_or(1.0), + top_p: request.top_p.unwrap_or(1.0), + top_k: request.top_k.unwrap_or(-1), + min_p: request.min_p.unwrap_or(0.0), + frequency_penalty: request.frequency_penalty.unwrap_or(0.0), + presence_penalty: request.presence_penalty.unwrap_or(0.0), + repetition_penalty: request.repetition_penalty.unwrap_or(1.0), + max_new_tokens, + stop: stop_sequences, + stop_token_ids: request.stop_token_ids.clone().unwrap_or_default(), + skip_special_tokens, + spaces_between_special_tokens: true, // Default from Python SamplingParams + ignore_eos: request.ignore_eos, + no_stop_trim: request.no_stop_trim, + n: request.n.unwrap_or(1), + constraint: build_constraint_for_chat(request, tool_call_constraint)?, + ..Default::default() + }) +} + +/// Build gRPC `SamplingParams` from a `ResponsesRequest`. +/// +/// Used by Harmony models only. Regular models use the Chat API path. +/// Constraints come from the Harmony preparation stage (`structural_tag`) +/// or tool handling. +pub fn build_grpc_sampling_params_from_responses( + request: &ResponsesRequest, + constraint: Option<(String, String)>, +) -> Result { + let max_new_tokens = request.max_output_tokens; + + Ok(proto::SamplingParams { + temperature: request.temperature.unwrap_or(1.0), + top_p: request.top_p.unwrap_or(1.0), + top_k: request.top_k, + min_p: request.min_p, + frequency_penalty: request.frequency_penalty.unwrap_or(0.0), + presence_penalty: request.presence_penalty.unwrap_or(0.0), + repetition_penalty: request.repetition_penalty, + max_new_tokens, + stop: vec![], // Does not pass through request.stop yet (follow-up fix) + stop_token_ids: vec![], // Handled by Harmony stop tokens + skip_special_tokens: false, // Keep special tokens for Harmony + spaces_between_special_tokens: true, + ignore_eos: false, + no_stop_trim: false, + n: 1, // Responses API doesn't support n>1 + constraint: build_constraint_for_responses(constraint)?, + ..Default::default() + }) +} + +/// Build gRPC `SamplingParams` from a `CreateMessageRequest` (Anthropic +/// Messages API). +pub fn build_grpc_sampling_params_from_messages( + request: &CreateMessageRequest, + tool_call_constraint: Option<(String, String)>, +) -> Result { + let stop_sequences = request.stop_sequences.clone().unwrap_or_default(); + + // Hardcode to true: gRPC backends return raw token IDs, not decoded text. + let skip_special_tokens = true; + + Ok(proto::SamplingParams { + temperature: request.temperature.unwrap_or(1.0) as f32, + top_p: request.top_p.unwrap_or(1.0) as f32, + top_k: request.top_k.map(|v| v as i32).unwrap_or(-1), + min_p: 0.0, + frequency_penalty: 0.0, + presence_penalty: 0.0, + repetition_penalty: 1.0, + max_new_tokens: Some(request.max_tokens), + stop: stop_sequences, + stop_token_ids: vec![], + skip_special_tokens, + spaces_between_special_tokens: true, + ignore_eos: false, + no_stop_trim: false, + n: 1, + constraint: build_constraint_for_responses(tool_call_constraint)?, + ..Default::default() + }) +} + +/// Build gRPC `SamplingParams` from a `CompletionRequest` +/// (`/v1/completions`). +pub fn build_grpc_sampling_params_from_completion( + request: &CompletionRequest, +) -> Result { + let stop_sequences = match &request.stop { + Some(StringOrArray::String(s)) => vec![s.clone()], + Some(StringOrArray::Array(arr)) => arr.clone(), + None => vec![], + }; + + let constraint = build_single_constraint_from_completion(request)?; + + Ok(proto::SamplingParams { + temperature: request.temperature.unwrap_or(1.0), + top_p: request.top_p.unwrap_or(1.0), + top_k: request.top_k.unwrap_or(-1), + min_p: request.min_p.unwrap_or(0.0), + frequency_penalty: request.frequency_penalty.unwrap_or(0.0), + presence_penalty: request.presence_penalty.unwrap_or(0.0), + repetition_penalty: request.repetition_penalty.unwrap_or(1.0), + max_new_tokens: request.max_tokens, + min_new_tokens: request.min_tokens.unwrap_or(0), + stop: stop_sequences, + stop_token_ids: request.stop_token_ids.clone().unwrap_or_default(), + skip_special_tokens: request.skip_special_tokens, + spaces_between_special_tokens: true, + ignore_eos: request.ignore_eos, + no_stop_trim: request.no_stop_trim, + n: request.n.unwrap_or(1), + constraint, + ..Default::default() + }) +} + +/// Build gRPC `SamplingParams` from the plain `GenerateSamplingParams` +/// shape used by `/generate`. +pub fn build_sampling_params_from_plain( + params: Option<&GenerateSamplingParams>, +) -> Result { + let mut sampling = proto::SamplingParams { + temperature: 1.0, + top_p: 1.0, + top_k: -1, + repetition_penalty: 1.0, + n: 1, + skip_special_tokens: true, + spaces_between_special_tokens: true, + ..Default::default() + }; + + let Some(p) = params else { + return Ok(sampling); + }; + + macro_rules! map_field { + ($field:ident) => { + if let Some(val) = p.$field { + sampling.$field = val; + } + }; + } + + map_field!(temperature); + map_field!(top_p); + map_field!(top_k); + map_field!(frequency_penalty); + map_field!(presence_penalty); + map_field!(repetition_penalty); + map_field!(min_p); + map_field!(ignore_eos); + map_field!(skip_special_tokens); + map_field!(no_stop_trim); + + if let Some(stop) = &p.stop { + match stop { + StringOrArray::String(s) => sampling.stop.push(s.clone()), + StringOrArray::Array(arr) => sampling.stop.extend(arr.clone()), + } + } + + if let Some(stop_token_ids) = &p.stop_token_ids { + sampling.stop_token_ids.clone_from(stop_token_ids); + } + + sampling.max_new_tokens = p.max_new_tokens; + + if let Some(min_new_tokens) = p.min_new_tokens { + sampling.min_new_tokens = min_new_tokens; + } + + if let Some(n) = p.n { + sampling.n = n; + } + + sampling.constraint = build_single_constraint_from_plain(p)?; + + Ok(sampling) +} + +// --------------------------------------------------------------------------- +// Constraint helpers +// --------------------------------------------------------------------------- + +fn extract_stop_strings(request: &ChatCompletionRequest) -> Vec { + match &request.stop { + Some(StringOrArray::String(s)) => vec![s.clone()], + Some(StringOrArray::Array(arr)) => arr.clone(), + None => vec![], + } +} + +fn build_constraint_for_chat( + request: &ChatCompletionRequest, + tool_call_constraint: Option<(String, String)>, +) -> Result, String> { + let mut constraints = Vec::new(); + + match &request.response_format { + Some(ResponseFormat::JsonObject) => { + let schema = serde_json::json!({"type": "object"}); + let schema_str = serde_json::to_string(&schema) + .map_err(|e| format!("Failed to serialize JSON schema: {e}"))?; + constraints.push(proto::sampling_params::Constraint::JsonSchema(schema_str)); + } + Some(ResponseFormat::JsonSchema { json_schema }) => { + let schema_str = serde_json::to_string(&json_schema.schema) + .map_err(|e| format!("Failed to serialize JSON schema: {e}"))?; + constraints.push(proto::sampling_params::Constraint::JsonSchema(schema_str)); + } + Some(ResponseFormat::Text) | None => {} + } + + if let Some(ebnf) = &request.ebnf { + constraints.push(proto::sampling_params::Constraint::EbnfGrammar( + ebnf.clone(), + )); + } + + if let Some(regex) = &request.regex { + constraints.push(proto::sampling_params::Constraint::Regex(regex.clone())); + } + + // If response_format already set a constraint, drop the tool constraint + // (matches SGLang HTTP behavior where response_format takes priority). + if let Some((constraint_type, constraint_value)) = tool_call_constraint { + if constraints.is_empty() { + let tool_constraint = match constraint_type.as_str() { + "structural_tag" => { + proto::sampling_params::Constraint::StructuralTag(constraint_value) + } + "json_schema" => proto::sampling_params::Constraint::JsonSchema(constraint_value), + "ebnf" => proto::sampling_params::Constraint::EbnfGrammar(constraint_value), + "regex" => proto::sampling_params::Constraint::Regex(constraint_value), + _ => return Err(format!("Unknown constraint type: {constraint_type}")), + }; + constraints.push(tool_constraint); + } else { + warn!("Constrained decoding is not compatible with tool calls, dropping tool constraint"); + } + } + + match constraints.len() { + 0 => Ok(None), + 1 => Ok(constraints.pop()), + _ => Err("Multiple constraints are not allowed.".to_string()), + } +} + +fn build_constraint_for_responses( + constraint: Option<(String, String)>, +) -> Result, String> { + if let Some((constraint_type, constraint_value)) = constraint { + let parsed_constraint = match constraint_type.as_str() { + "structural_tag" => proto::sampling_params::Constraint::StructuralTag(constraint_value), + "json_schema" => proto::sampling_params::Constraint::JsonSchema(constraint_value), + "ebnf" => proto::sampling_params::Constraint::EbnfGrammar(constraint_value), + "regex" => proto::sampling_params::Constraint::Regex(constraint_value), + _ => return Err(format!("Unknown constraint type: {constraint_type}")), + }; + Ok(Some(parsed_constraint)) + } else { + Ok(None) + } +} + +fn build_single_constraint_from_completion( + request: &CompletionRequest, +) -> Result, String> { + let mut constraints = Vec::new(); + if let Some(json_schema) = &request.json_schema { + constraints.push(proto::sampling_params::Constraint::JsonSchema( + json_schema.clone(), + )); + } + if let Some(regex) = &request.regex { + constraints.push(proto::sampling_params::Constraint::Regex(regex.clone())); + } + if let Some(ebnf) = &request.ebnf { + constraints.push(proto::sampling_params::Constraint::EbnfGrammar( + ebnf.clone(), + )); + } + + match constraints.len() { + 0 => Ok(None), + 1 => Ok(constraints.pop()), + _ => Err("Multiple structured constraints are not allowed".to_string()), + } +} + +fn build_single_constraint_from_plain( + params: &GenerateSamplingParams, +) -> Result, String> { + let mut constraints = Vec::new(); + if let Some(json_schema) = ¶ms.json_schema { + constraints.push(proto::sampling_params::Constraint::JsonSchema( + json_schema.clone(), + )); + } + if let Some(regex) = ¶ms.regex { + constraints.push(proto::sampling_params::Constraint::Regex(regex.clone())); + } + if let Some(ebnf) = ¶ms.ebnf { + constraints.push(proto::sampling_params::Constraint::EbnfGrammar( + ebnf.clone(), + )); + } + + match constraints.len() { + 0 => Ok(None), + 1 => Ok(constraints.pop()), + _ => Err("Multiple structured constraints are not allowed".to_string()), + } +} diff --git a/crates/grpc_client/src/sglang_scheduler.rs b/crates/grpc_client/src/sglang_scheduler.rs index 8f59ee554..7c1862153 100644 --- a/crates/grpc_client/src/sglang_scheduler.rs +++ b/crates/grpc_client/src/sglang_scheduler.rs @@ -9,13 +9,8 @@ use std::{ }; use openai_protocol::{ - chat::ChatCompletionRequest, - common::{ResponseFormat, StringOrArray}, - completion::CompletionRequest, - generate::GenerateRequest, - messages::CreateMessageRequest, - responses::ResponsesRequest, - sampling_params::SamplingParams as GenerateSamplingParams, + chat::ChatCompletionRequest, completion::CompletionRequest, generate::GenerateRequest, + messages::CreateMessageRequest, responses::ResponsesRequest, }; use tonic::{transport::Channel, Request, Streaming}; use tracing::{debug, warn}; @@ -338,7 +333,7 @@ impl SglangSchedulerClient { ) -> Result { // Build sampling params let sampling_params = - Self::build_grpc_sampling_params_from_chat(body, tool_call_constraint)?; + crate::sampling_params::build_grpc_sampling_params_from_chat(body, tool_call_constraint)?; let grpc_request = proto::GenerateRequest { request_id, @@ -372,7 +367,7 @@ impl SglangSchedulerClient { token_ids: Vec, ) -> Result { let sampling_params = - Self::build_sampling_params_from_plain(body.sampling_params.as_ref())?; + crate::sampling_params::build_sampling_params_from_plain(body.sampling_params.as_ref())?; let grpc_request = proto::GenerateRequest { request_id, @@ -411,7 +406,8 @@ impl SglangSchedulerClient { constraint: Option<(String, String)>, ) -> Result { // Build sampling params from ResponsesRequest - let sampling_params = Self::build_grpc_sampling_params_from_responses(body, constraint)?; + let sampling_params = + crate::sampling_params::build_grpc_sampling_params_from_responses(body, constraint)?; let grpc_request = proto::GenerateRequest { request_id, @@ -433,181 +429,6 @@ impl SglangSchedulerClient { Ok(grpc_request) } - /// Build gRPC SamplingParams from ChatCompletionRequest - pub(crate) fn build_grpc_sampling_params_from_chat( - request: &ChatCompletionRequest, - tool_call_constraint: Option<(String, String)>, - ) -> Result { - let stop_sequences = Self::extract_stop_strings(request); - - let max_new_tokens = request.max_completion_tokens; - - // Hardcode to true: gRPC backends return raw token IDs, not decoded text. - // Detokenization happens on the SMG Rust side (StopDecoder/Sequence). - // - // Note: TokenSpeed's HTTP serving_chat sets this to false when tools are - // present (serving_chat.py:178-179) — but mirroring that on the gRPC - // path measurably HURTS BFCL accuracy. We tested it: simple_python - // dropped from ~88.75 % to 79 %, parallel_multiple from ~84.5 % to - // 60.5 %. With skip_special_tokens=false the engine emits the - // ``<|tool_call_*|>`` special tokens in the raw output stream, and the - // SMG-side detokenizer + kimik2 tool-call parser then double-counts or - // misframes them. Keep it at true so SMG sees normal tokens and - // applies its own parsing. - let skip_special_tokens = true; - - Ok(proto::SamplingParams { - temperature: request.temperature.unwrap_or(1.0), - top_p: request.top_p.unwrap_or(1.0), - top_k: request.top_k.unwrap_or(-1), - min_p: request.min_p.unwrap_or(0.0), - frequency_penalty: request.frequency_penalty.unwrap_or(0.0), - presence_penalty: request.presence_penalty.unwrap_or(0.0), - repetition_penalty: request.repetition_penalty.unwrap_or(1.0), - max_new_tokens, - stop: stop_sequences, - stop_token_ids: request.stop_token_ids.clone().unwrap_or_default(), - skip_special_tokens, - spaces_between_special_tokens: true, // Default from Python SamplingParams - ignore_eos: request.ignore_eos, - no_stop_trim: request.no_stop_trim, - n: request.n.unwrap_or(1), - constraint: Self::build_constraint_for_chat(request, tool_call_constraint)?, - ..Default::default() - }) - } - - /// Extract stop strings from request - fn extract_stop_strings(request: &ChatCompletionRequest) -> Vec { - match &request.stop { - Some(StringOrArray::String(s)) => vec![s.clone()], - Some(StringOrArray::Array(arr)) => arr.clone(), - None => vec![], - } - } - - /// Build constraint for structured generation - fn build_constraint_for_chat( - request: &ChatCompletionRequest, - tool_call_constraint: Option<(String, String)>, - ) -> Result, String> { - let mut constraints = Vec::new(); - - // Handle response_format constraints - match &request.response_format { - Some(ResponseFormat::JsonObject) => { - // json_object mode - constrain to valid JSON object - let schema = serde_json::json!({"type": "object"}); - let schema_str = serde_json::to_string(&schema) - .map_err(|e| format!("Failed to serialize JSON schema: {e}"))?; - constraints.push(proto::sampling_params::Constraint::JsonSchema(schema_str)); - } - Some(ResponseFormat::JsonSchema { json_schema }) => { - let schema_str = serde_json::to_string(&json_schema.schema) - .map_err(|e| format!("Failed to serialize JSON schema: {e}"))?; - constraints.push(proto::sampling_params::Constraint::JsonSchema(schema_str)); - } - Some(ResponseFormat::Text) | None => { - // No constraint for text format - } - } - - if let Some(ebnf) = &request.ebnf { - constraints.push(proto::sampling_params::Constraint::EbnfGrammar( - ebnf.clone(), - )); - } - - if let Some(regex) = &request.regex { - constraints.push(proto::sampling_params::Constraint::Regex(regex.clone())); - } - - // Handle tool call constraint from preparation stage. - // If response_format already set a constraint, drop the tool constraint - // (matches SGLang HTTP behavior where response_format takes priority). - if let Some((constraint_type, constraint_value)) = tool_call_constraint { - if constraints.is_empty() { - let tool_constraint = match constraint_type.as_str() { - "structural_tag" => { - proto::sampling_params::Constraint::StructuralTag(constraint_value) - } - "json_schema" => { - proto::sampling_params::Constraint::JsonSchema(constraint_value) - } - "ebnf" => proto::sampling_params::Constraint::EbnfGrammar(constraint_value), - "regex" => proto::sampling_params::Constraint::Regex(constraint_value), - _ => return Err(format!("Unknown constraint type: {constraint_type}")), - }; - constraints.push(tool_constraint); - } else { - warn!("Constrained decoding is not compatible with tool calls, dropping tool constraint"); - } - } - - match constraints.len() { - 0 => Ok(None), - 1 => Ok(constraints.pop()), - _ => Err("Multiple constraints are not allowed.".to_string()), - } - } - - /// Build gRPC SamplingParams from ResponsesRequest - pub(crate) fn build_grpc_sampling_params_from_responses( - request: &ResponsesRequest, - constraint: Option<(String, String)>, - ) -> Result { - // Used by Harmony models only. Regular models use Chat API path. - // Constraints come from Harmony preparation stage (structural_tag) or tool handling. - - let max_new_tokens = request.max_output_tokens; - - Ok(proto::SamplingParams { - temperature: request.temperature.unwrap_or(1.0), - top_p: request.top_p.unwrap_or(1.0), - top_k: request.top_k, - min_p: request.min_p, - frequency_penalty: request.frequency_penalty.unwrap_or(0.0), - presence_penalty: request.presence_penalty.unwrap_or(0.0), - repetition_penalty: request.repetition_penalty, - max_new_tokens, - stop: vec![], // Does not pass through request.stop yet (follow-up fix) - stop_token_ids: vec![], // Handled by Harmony stop tokens - skip_special_tokens: false, // Keep special tokens for Harmony - spaces_between_special_tokens: true, - ignore_eos: false, - no_stop_trim: false, - n: 1, // Responses API doesn't support n>1 - constraint: Self::build_constraint_for_responses(constraint)?, - ..Default::default() - }) - } - - /// Build constraint for Responses API - /// - /// Handles constraints from Harmony preparation stage (structural_tag for Harmony models, - /// structured output via text field, or tool call constraints). - /// - /// Note: Regular gRPC models use Chat API path with response_format, not this function. - fn build_constraint_for_responses( - constraint: Option<(String, String)>, - ) -> Result, String> { - if let Some((constraint_type, constraint_value)) = constraint { - let parsed_constraint = match constraint_type.as_str() { - "structural_tag" => { - // Harmony models: structural tag from preparation stage - proto::sampling_params::Constraint::StructuralTag(constraint_value) - } - "json_schema" => proto::sampling_params::Constraint::JsonSchema(constraint_value), - "ebnf" => proto::sampling_params::Constraint::EbnfGrammar(constraint_value), - "regex" => proto::sampling_params::Constraint::Regex(constraint_value), - _ => return Err(format!("Unknown constraint type: {constraint_type}")), - }; - Ok(Some(parsed_constraint)) - } else { - Ok(None) - } - } - /// Build a GenerateRequest from CreateMessageRequest (Anthropic Messages API) #[expect( clippy::unused_self, @@ -623,7 +444,7 @@ impl SglangSchedulerClient { tool_call_constraint: Option<(String, String)>, ) -> Result { let sampling_params = - Self::build_grpc_sampling_params_from_messages(body, tool_call_constraint)?; + crate::sampling_params::build_grpc_sampling_params_from_messages(body, tool_call_constraint)?; let grpc_request = proto::GenerateRequest { request_id, @@ -644,37 +465,6 @@ impl SglangSchedulerClient { Ok(grpc_request) } - /// Build gRPC SamplingParams from CreateMessageRequest - pub(crate) fn build_grpc_sampling_params_from_messages( - request: &CreateMessageRequest, - tool_call_constraint: Option<(String, String)>, - ) -> Result { - let stop_sequences = request.stop_sequences.clone().unwrap_or_default(); - - // Hardcode to true: gRPC backends return raw token IDs, not decoded text. - let skip_special_tokens = true; - - Ok(proto::SamplingParams { - temperature: request.temperature.unwrap_or(1.0) as f32, - top_p: request.top_p.unwrap_or(1.0) as f32, - top_k: request.top_k.map(|v| v as i32).unwrap_or(-1), - min_p: 0.0, - frequency_penalty: 0.0, - presence_penalty: 0.0, - repetition_penalty: 1.0, - max_new_tokens: Some(request.max_tokens), - stop: stop_sequences, - stop_token_ids: vec![], - skip_special_tokens, - spaces_between_special_tokens: true, - ignore_eos: false, - no_stop_trim: false, - n: 1, - constraint: Self::build_constraint_for_responses(tool_call_constraint)?, - ..Default::default() - }) - } - /// Build a GenerateRequest from CompletionRequest (`/v1/completions`) #[expect( clippy::unused_self, @@ -687,7 +477,7 @@ impl SglangSchedulerClient { original_text: String, token_ids: Vec, ) -> Result { - let sampling_params = Self::build_grpc_sampling_params_from_completion(body)?; + let sampling_params = crate::sampling_params::build_grpc_sampling_params_from_completion(body)?; let grpc_request = proto::GenerateRequest { request_id, @@ -708,158 +498,6 @@ impl SglangSchedulerClient { Ok(grpc_request) } - pub(crate) fn build_grpc_sampling_params_from_completion( - request: &CompletionRequest, - ) -> Result { - let stop_sequences = match &request.stop { - Some(StringOrArray::String(s)) => vec![s.clone()], - Some(StringOrArray::Array(arr)) => arr.clone(), - None => vec![], - }; - - let constraint = Self::build_single_constraint_from_completion(request)?; - - Ok(proto::SamplingParams { - temperature: request.temperature.unwrap_or(1.0), - top_p: request.top_p.unwrap_or(1.0), - top_k: request.top_k.unwrap_or(-1), - min_p: request.min_p.unwrap_or(0.0), - frequency_penalty: request.frequency_penalty.unwrap_or(0.0), - presence_penalty: request.presence_penalty.unwrap_or(0.0), - repetition_penalty: request.repetition_penalty.unwrap_or(1.0), - max_new_tokens: request.max_tokens, - min_new_tokens: request.min_tokens.unwrap_or(0), - stop: stop_sequences, - stop_token_ids: request.stop_token_ids.clone().unwrap_or_default(), - skip_special_tokens: request.skip_special_tokens, - spaces_between_special_tokens: true, - ignore_eos: request.ignore_eos, - no_stop_trim: request.no_stop_trim, - n: request.n.unwrap_or(1), - constraint, - ..Default::default() - }) - } - - fn build_single_constraint_from_completion( - request: &CompletionRequest, - ) -> Result, String> { - let mut constraints = Vec::new(); - if let Some(json_schema) = &request.json_schema { - constraints.push(proto::sampling_params::Constraint::JsonSchema( - json_schema.clone(), - )); - } - if let Some(regex) = &request.regex { - constraints.push(proto::sampling_params::Constraint::Regex(regex.clone())); - } - if let Some(ebnf) = &request.ebnf { - constraints.push(proto::sampling_params::Constraint::EbnfGrammar( - ebnf.clone(), - )); - } - - match constraints.len() { - 0 => Ok(None), - 1 => Ok(constraints.pop()), - _ => Err("Multiple structured constraints are not allowed".to_string()), - } - } - - fn build_single_constraint_from_plain( - params: &GenerateSamplingParams, - ) -> Result, String> { - let mut constraints = Vec::new(); - if let Some(json_schema) = ¶ms.json_schema { - constraints.push(proto::sampling_params::Constraint::JsonSchema( - json_schema.clone(), - )); - } - if let Some(regex) = ¶ms.regex { - constraints.push(proto::sampling_params::Constraint::Regex(regex.clone())); - } - if let Some(ebnf) = ¶ms.ebnf { - constraints.push(proto::sampling_params::Constraint::EbnfGrammar( - ebnf.clone(), - )); - } - - match constraints.len() { - 0 => Ok(None), - 1 => Ok(constraints.pop()), - _ => Err("Multiple structured constraints are not allowed".to_string()), - } - } - - pub(crate) fn build_sampling_params_from_plain( - params: Option<&GenerateSamplingParams>, - ) -> Result { - let mut sampling = proto::SamplingParams { - temperature: 1.0, - top_p: 1.0, - top_k: -1, - repetition_penalty: 1.0, - n: 1, - skip_special_tokens: true, - spaces_between_special_tokens: true, - ..Default::default() - }; - - let Some(p) = params else { - return Ok(sampling); - }; - - // Simple field mappings using a macro - macro_rules! map_field { - ($field:ident) => { - if let Some(val) = p.$field { - sampling.$field = val; - } - }; - } - - map_field!(temperature); - map_field!(top_p); - map_field!(top_k); - map_field!(frequency_penalty); - map_field!(presence_penalty); - map_field!(repetition_penalty); - map_field!(min_p); - map_field!(ignore_eos); - map_field!(skip_special_tokens); - map_field!(no_stop_trim); - - // Handle stop sequences - if let Some(stop) = &p.stop { - match stop { - StringOrArray::String(s) => sampling.stop.push(s.clone()), - StringOrArray::Array(arr) => sampling.stop.extend(arr.clone()), - } - } - - // Handle stop token IDs - if let Some(stop_token_ids) = &p.stop_token_ids { - sampling.stop_token_ids.clone_from(stop_token_ids); - } - - // Handle max_new_tokens - sampling.max_new_tokens = p.max_new_tokens; - - // Handle min_new_tokens - if let Some(min_new_tokens) = p.min_new_tokens { - sampling.min_new_tokens = min_new_tokens; - } - - // Handle n - if let Some(n) = p.n { - sampling.n = n; - } - - // Handle constraints (exactly one allowed) - sampling.constraint = Self::build_single_constraint_from_plain(p)?; - - Ok(sampling) - } } // --------------------------------------------------------------------------- @@ -1018,7 +656,7 @@ mod tests { }; let params = - SglangSchedulerClient::build_grpc_sampling_params_from_responses(&request, None) + crate::sampling_params::build_grpc_sampling_params_from_responses(&request, None) .expect("build sampling params"); assert_eq!(params.top_k, 40); @@ -1033,7 +671,7 @@ mod tests { ..Default::default() }; let disabled_params = - SglangSchedulerClient::build_grpc_sampling_params_from_responses(&disabled, None) + crate::sampling_params::build_grpc_sampling_params_from_responses(&disabled, None) .expect("build sampling params"); assert_eq!(disabled_params.top_k, -1); } diff --git a/crates/grpc_client/src/tokenspeed_scheduler.rs b/crates/grpc_client/src/tokenspeed_scheduler.rs index d4d02cc3c..1d6f49320 100644 --- a/crates/grpc_client/src/tokenspeed_scheduler.rs +++ b/crates/grpc_client/src/tokenspeed_scheduler.rs @@ -37,7 +37,7 @@ use tonic::{transport::Channel, Request, Streaming}; use tracing::{debug, warn}; use crate::{ - sglang_scheduler::{proto as sglang, SglangSchedulerClient}, + sglang_scheduler::proto as sglang, BoxedTraceInjector, NoopTraceInjector, }; @@ -293,7 +293,7 @@ impl TokenSpeedSchedulerClient { token_ids: Vec, tool_call_constraint: Option<(String, String)>, ) -> Result { - let sampling_params = SglangSchedulerClient::build_grpc_sampling_params_from_chat( + let sampling_params = crate::sampling_params::build_grpc_sampling_params_from_chat( body, tool_call_constraint, )?; @@ -324,7 +324,7 @@ impl TokenSpeedSchedulerClient { token_ids: Vec, ) -> Result { let sampling_params = - SglangSchedulerClient::build_sampling_params_from_plain(body.sampling_params.as_ref())?; + crate::sampling_params::build_sampling_params_from_plain(body.sampling_params.as_ref())?; Ok(sglang::GenerateRequest { request_id, tokenized: Some(sglang::TokenizedInput { @@ -354,7 +354,7 @@ impl TokenSpeedSchedulerClient { constraint: Option<(String, String)>, ) -> Result { let sampling_params = - SglangSchedulerClient::build_grpc_sampling_params_from_responses(body, constraint)?; + crate::sampling_params::build_grpc_sampling_params_from_responses(body, constraint)?; Ok(sglang::GenerateRequest { request_id, tokenized: Some(sglang::TokenizedInput { @@ -379,7 +379,7 @@ impl TokenSpeedSchedulerClient { token_ids: Vec, tool_call_constraint: Option<(String, String)>, ) -> Result { - let sampling_params = SglangSchedulerClient::build_grpc_sampling_params_from_messages( + let sampling_params = crate::sampling_params::build_grpc_sampling_params_from_messages( body, tool_call_constraint, )?; @@ -407,7 +407,7 @@ impl TokenSpeedSchedulerClient { token_ids: Vec, ) -> Result { let sampling_params = - SglangSchedulerClient::build_grpc_sampling_params_from_completion(body)?; + crate::sampling_params::build_grpc_sampling_params_from_completion(body)?; Ok(sglang::GenerateRequest { request_id, tokenized: Some(sglang::TokenizedInput { From 76bf5720e6d6c33307383ba4d6632b8af2f60a24 Mon Sep 17 00:00:00 2001 From: key4ng Date: Sat, 9 May 2026 12:35:21 -0700 Subject: [PATCH 03/24] refactor(grpc): give TokenSpeed its own IR arms (drop SGLang impersonation) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit TokenSpeed previously rode the SGLang IR arms in proto_wrapper.rs: ``TokenSpeedSchedulerClient::generate()`` accepted ``sglang::GenerateRequest``, the streaming response was translated back into ``sglang::GenerateResponse`` at the wire boundary, and the router dispatched through ``ProtoGenerateRequest::Sglang``. This let TokenSpeed reuse the existing match arms but coupled it to SGLang's evolving schema — every SGLang field addition forced a TokenSpeed translator stub (most recently ``default_sampling_params_json``). Add native TokenSpeed arms to the router IR: - ``ProtoGenerateRequest::TokenSpeed(Box)`` - ``ProtoGenerateResponse::TokenSpeed(Box)`` - ``ProtoGenerateStreamChunk::TokenSpeed(tokenspeed::GenerateStreamChunk)`` - ``ProtoGenerateComplete::TokenSpeed(tokenspeed::GenerateComplete)`` …with ``as_tokenspeed`` / ``as_tokenspeed_mut`` / ``is_tokenspeed`` accessors mirroring the existing per-backend pattern (Mlx / Vllm / Trtllm) and ``TokenSpeed`` arms in every aggregator method (``token_ids``, ``index``, ``output_logprobs``, ``input_logprobs``, ``prompt_tokens``, ``completion_tokens``, ``cached_tokens``, ``finish_reason``, ``matched_stop_json``, ``output_ids``, ``set_stream``, ``request_id``, ``set_max_tokens_for_prefill``, ``clear_mm_inputs``, ``set_kv_transfer_params``, ``kv_transfer_params``). Client-side: - ``TokenSpeedSchedulerClient::generate()`` now takes ``tokenspeed_proto::GenerateRequest`` - ``AbortOnDropStream::Item`` is ``tokenspeed_proto::GenerateResponse`` - The 5 ``build_*_request`` builders return native ``tokenspeed_proto`` types - ``translate::generate_request`` / ``generate_response`` / ``stream_chunk`` / ``complete`` / ``output_logprobs`` are gone — the only translation kept is ``translate::sampling_params`` (a thin field map) plus the unary RPC adapters (``model_info`` / ``server_info`` / ``loads``), which still produce SGLang shapes because the router's ``ModelInfo`` / ``ServerInfo`` enums consume those — that's a separate cleanup. Router-side: - ``client.rs::generate()`` dispatch arm now matches ``(Self::TokenSpeed(_), ProtoGenerateRequest::TokenSpeed(_))``. - The 5 ``build_*_request`` paths in ``GrpcClient`` wrap into ``ProtoGenerateRequest::TokenSpeed`` instead of ``::Sglang``. - ``harmony/stages/request_building.rs`` builds ``ProtoGenerateRequest::TokenSpeed`` and grew a ``TokenSpeed`` arm in the Harmony stop-token injection match. - ``common/stages/helpers.rs::apply_sampling_defaults_to_generate_request`` early-returns for TokenSpeed (alongside Trtllm) since neither backend plumbs sampling defaults through today; the explicit arm keeps the match exhaustive. PD-disagg paths (``response_collection.rs``) remain SGLang-keyed — the ``if let ProtoGenerateComplete::Sglang(...)`` checks simply won't match TokenSpeed responses, which is the correct behavior since TokenSpeed doesn't ship PD-disaggregation. Verification: - ``cargo +nightly fmt --all -- --check`` passes - ``cargo clippy -p smg-grpc-client -p smg --all-targets --all-features -- -D warnings`` passes - ``cargo check -p smg --bin smg`` passes This addresses the architectural concern raised on #1351: SGLang's proto shouldn't be the de-facto router IR. Each backend now has its own arm, matching how vLLM / MLX / TRT-LLM are integrated. Signed-off-by: key4ng --- crates/grpc_client/src/sampling_params.rs | 6 +- crates/grpc_client/src/sglang_scheduler.rs | 21 +- .../grpc_client/src/tokenspeed_scheduler.rs | 253 +++++------------- model_gateway/src/routers/grpc/client.rs | 10 +- .../src/routers/grpc/common/stages/helpers.rs | 9 +- .../grpc/harmony/stages/request_building.rs | 19 +- .../src/routers/grpc/proto_wrapper.rs | 121 +++++++-- .../workflow/steps/local/detect_backend.rs | 5 +- 8 files changed, 215 insertions(+), 229 deletions(-) diff --git a/crates/grpc_client/src/sampling_params.rs b/crates/grpc_client/src/sampling_params.rs index b0c2d3d5f..a47478793 100644 --- a/crates/grpc_client/src/sampling_params.rs +++ b/crates/grpc_client/src/sampling_params.rs @@ -89,7 +89,7 @@ pub fn build_grpc_sampling_params_from_responses( presence_penalty: request.presence_penalty.unwrap_or(0.0), repetition_penalty: request.repetition_penalty, max_new_tokens, - stop: vec![], // Does not pass through request.stop yet (follow-up fix) + stop: vec![], // Does not pass through request.stop yet (follow-up fix) stop_token_ids: vec![], // Handled by Harmony stop tokens skip_special_tokens: false, // Keep special tokens for Harmony spaces_between_special_tokens: true, @@ -291,7 +291,9 @@ fn build_constraint_for_chat( }; constraints.push(tool_constraint); } else { - warn!("Constrained decoding is not compatible with tool calls, dropping tool constraint"); + warn!( + "Constrained decoding is not compatible with tool calls, dropping tool constraint" + ); } } diff --git a/crates/grpc_client/src/sglang_scheduler.rs b/crates/grpc_client/src/sglang_scheduler.rs index 7c1862153..183d04bac 100644 --- a/crates/grpc_client/src/sglang_scheduler.rs +++ b/crates/grpc_client/src/sglang_scheduler.rs @@ -332,8 +332,10 @@ impl SglangSchedulerClient { tool_call_constraint: Option<(String, String)>, // (constraint_type, constraint_value) ) -> Result { // Build sampling params - let sampling_params = - crate::sampling_params::build_grpc_sampling_params_from_chat(body, tool_call_constraint)?; + let sampling_params = crate::sampling_params::build_grpc_sampling_params_from_chat( + body, + tool_call_constraint, + )?; let grpc_request = proto::GenerateRequest { request_id, @@ -366,8 +368,9 @@ impl SglangSchedulerClient { original_text: Option, token_ids: Vec, ) -> Result { - let sampling_params = - crate::sampling_params::build_sampling_params_from_plain(body.sampling_params.as_ref())?; + let sampling_params = crate::sampling_params::build_sampling_params_from_plain( + body.sampling_params.as_ref(), + )?; let grpc_request = proto::GenerateRequest { request_id, @@ -443,8 +446,10 @@ impl SglangSchedulerClient { multimodal_inputs: Option, tool_call_constraint: Option<(String, String)>, ) -> Result { - let sampling_params = - crate::sampling_params::build_grpc_sampling_params_from_messages(body, tool_call_constraint)?; + let sampling_params = crate::sampling_params::build_grpc_sampling_params_from_messages( + body, + tool_call_constraint, + )?; let grpc_request = proto::GenerateRequest { request_id, @@ -477,7 +482,8 @@ impl SglangSchedulerClient { original_text: String, token_ids: Vec, ) -> Result { - let sampling_params = crate::sampling_params::build_grpc_sampling_params_from_completion(body)?; + let sampling_params = + crate::sampling_params::build_grpc_sampling_params_from_completion(body)?; let grpc_request = proto::GenerateRequest { request_id, @@ -497,7 +503,6 @@ impl SglangSchedulerClient { Ok(grpc_request) } - } // --------------------------------------------------------------------------- diff --git a/crates/grpc_client/src/tokenspeed_scheduler.rs b/crates/grpc_client/src/tokenspeed_scheduler.rs index 1d6f49320..4c67fc567 100644 --- a/crates/grpc_client/src/tokenspeed_scheduler.rs +++ b/crates/grpc_client/src/tokenspeed_scheduler.rs @@ -10,14 +10,22 @@ //! tokenizer streaming, KV-event subscription) is simply not on TokenSpeed's //! wire. //! -//! Internally this client still leverages SGLang's -//! ``build_grpc_sampling_params_from_*`` helpers because the source-of-truth -//! is an OpenAI request and most fields map identically. We translate from -//! the SGLang-shaped ``GenerateRequest`` into a TokenSpeed-shaped one at the -//! wire boundary, and translate the streamed response back so the router's -//! ``ProtoGenerateStreamChunk`` / ``ProtoGenerateComplete`` accessors can -//! operate on a familiar shape. When TokenSpeed needs a field SGLang lacks, -//! add it to the proto and extend the translator — not the router. +//! Request/response types are TokenSpeed-native end-to-end: the stream +//! yields ``tokenspeed_proto::GenerateResponse`` and ``build_*_request`` +//! produces ``tokenspeed_proto::GenerateRequest``. The router dispatches +//! through dedicated ``ProtoGenerateRequest::TokenSpeed`` / +//! ``ProtoGenerateStreamChunk::TokenSpeed`` / +//! ``ProtoGenerateComplete::TokenSpeed`` arms — same shape as the other +//! per-backend variants. +//! +//! Sampling-params construction reuses the backend-neutral helpers in +//! ``crate::sampling_params`` (which currently return +//! ``sglang::SamplingParams``); the [`translate::sampling_params`] +//! field-mapper converts to TokenSpeed's shape at the seam. The unary RPC +//! responses (``GetModelInfo``, ``GetServerInfo``, ``GetLoads``) still +//! return SGLang-shaped types because their consumers ride the +//! ``ModelInfo`` / ``ServerInfo`` SGLang variants in the router; that's a +//! separate cleanup. use std::{ pin::Pin, @@ -36,10 +44,7 @@ use openai_protocol::{ use tonic::{transport::Channel, Request, Streaming}; use tracing::{debug, warn}; -use crate::{ - sglang_scheduler::proto as sglang, - BoxedTraceInjector, NoopTraceInjector, -}; +use crate::{sglang_scheduler::proto as sglang, BoxedTraceInjector, NoopTraceInjector}; #[expect(clippy::allow_attributes)] pub mod tokenspeed_proto { @@ -56,11 +61,9 @@ type AbortDispatcher = Arc; /// Auto-aborting wrapper around the TokenSpeed generate stream. /// -/// Yields ``sglang::GenerateResponse`` (translated from the on-wire -/// ``tokenspeed_proto::GenerateResponse``) so the router-side -/// ``ProtoGenerateStreamChunk`` / ``ProtoGenerateComplete`` accessors can -/// keep operating on a single shape. Sends an Abort RPC on Drop unless -/// ``mark_completed`` was called first — same lifecycle contract as +/// Yields ``tokenspeed_proto::GenerateResponse`` directly (no translation +/// at the seam). Sends an Abort RPC on Drop unless ``mark_completed`` was +/// called first — same lifecycle contract as /// ``sglang_scheduler::AbortOnDropStream``. pub struct AbortOnDropStream { inner: Streaming, @@ -111,19 +114,10 @@ impl Drop for AbortOnDropStream { } impl futures::Stream for AbortOnDropStream { - // Yield SGLang-shaped responses so the router's wrapper enums don't need - // a TokenSpeed variant for every chunk-accessor. - type Item = Result; + type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match Pin::new(&mut self.inner).poll_next(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(None) => Poll::Ready(None), - Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), - Poll::Ready(Some(Ok(ts_resp))) => { - Poll::Ready(Some(Ok(translate::generate_response(ts_resp)))) - } - } + Pin::new(&mut self.inner).poll_next(cx) } } @@ -181,20 +175,14 @@ impl TokenSpeedSchedulerClient { } /// Submit a generation request. - /// - /// Accepts an SGLang-shaped request for symmetry with the router's - /// existing dispatch path; the translation to TokenSpeed's slimmer wire - /// shape (drops mm_inputs, disagg, LoRA, hidden-states, etc.) happens - /// here at the wire boundary. pub async fn generate( &self, - req: sglang::GenerateRequest, + req: tokenspeed_proto::GenerateRequest, ) -> Result { let request_id = req.request_id.clone(); - let ts_req = translate::generate_request(req); let mut client = self.client.clone(); - let mut request = Request::new(ts_req); + let mut request = Request::new(req); if let Err(e) = self.trace_injector.inject(request.metadata_mut()) { warn!("Failed to inject trace context: {}", e); @@ -274,16 +262,15 @@ impl TokenSpeedSchedulerClient { // ── Request builders ────────────────────────────────────────────── // - // These produce SGLang-shaped requests so the router's existing - // ``ProtoGenerateRequest::Sglang`` plumbing is reused. The wire-side - // translation to TokenSpeed shape happens inside ``generate()`` above. - // - // Sampling-param construction delegates to SglangSchedulerClient's - // ``pub(crate)`` helpers — same OpenAI source, same semantics. + // Produce ``tokenspeed_proto::GenerateRequest`` directly. Sampling-param + // construction delegates to ``crate::sampling_params`` (which returns + // ``sglang::SamplingParams`` because that proto is the most permissive + // shape across our backends); ``translate::sampling_params`` field-maps + // it to TokenSpeed's slimmer shape at the seam. #[expect( clippy::unused_self, - reason = "receiver kept for API parity with SglangSchedulerClient" + reason = "receiver kept for API parity with the other engine clients" )] pub fn build_generate_request_from_chat( &self, @@ -292,20 +279,20 @@ impl TokenSpeedSchedulerClient { processed_text: String, token_ids: Vec, tool_call_constraint: Option<(String, String)>, - ) -> Result { - let sampling_params = crate::sampling_params::build_grpc_sampling_params_from_chat( + ) -> Result { + let sglang_sampling = crate::sampling_params::build_grpc_sampling_params_from_chat( body, tool_call_constraint, )?; - Ok(sglang::GenerateRequest { + Ok(tokenspeed_proto::GenerateRequest { request_id, - tokenized: Some(sglang::TokenizedInput { + tokenized: Some(tokenspeed_proto::TokenizedInput { original_text: processed_text, input_ids: token_ids, }), - sampling_params: Some(sampling_params), + sampling_params: Some(translate::sampling_params(sglang_sampling)), return_logprob: body.logprobs, - logprob_start_len: -1, + logprob_start_len: Some(-1), top_logprobs_num: body.top_logprobs.unwrap_or(0) as i32, stream: body.stream, ..Default::default() @@ -314,7 +301,7 @@ impl TokenSpeedSchedulerClient { #[expect( clippy::unused_self, - reason = "receiver kept for API parity with SglangSchedulerClient" + reason = "receiver kept for API parity with the other engine clients" )] pub fn build_plain_generate_request( &self, @@ -322,28 +309,28 @@ impl TokenSpeedSchedulerClient { body: &GenerateRequest, original_text: Option, token_ids: Vec, - ) -> Result { - let sampling_params = - crate::sampling_params::build_sampling_params_from_plain(body.sampling_params.as_ref())?; - Ok(sglang::GenerateRequest { + ) -> Result { + let sglang_sampling = crate::sampling_params::build_sampling_params_from_plain( + body.sampling_params.as_ref(), + )?; + Ok(tokenspeed_proto::GenerateRequest { request_id, - tokenized: Some(sglang::TokenizedInput { + tokenized: Some(tokenspeed_proto::TokenizedInput { original_text: original_text.unwrap_or_default(), input_ids: token_ids, }), - sampling_params: Some(sampling_params), + sampling_params: Some(translate::sampling_params(sglang_sampling)), return_logprob: body.return_logprob.unwrap_or(false), - logprob_start_len: body.logprob_start_len.unwrap_or(-1), + logprob_start_len: Some(body.logprob_start_len.unwrap_or(-1)), top_logprobs_num: body.top_logprobs_num.unwrap_or(0), token_ids_logprob: body.token_ids_logprob.clone().unwrap_or_default(), stream: body.stream, - ..Default::default() }) } #[expect( clippy::unused_self, - reason = "receiver kept for API parity with SglangSchedulerClient" + reason = "receiver kept for API parity with the other engine clients" )] pub fn build_generate_request_from_responses( &self, @@ -352,16 +339,16 @@ impl TokenSpeedSchedulerClient { processed_text: String, token_ids: Vec, constraint: Option<(String, String)>, - ) -> Result { - let sampling_params = + ) -> Result { + let sglang_sampling = crate::sampling_params::build_grpc_sampling_params_from_responses(body, constraint)?; - Ok(sglang::GenerateRequest { + Ok(tokenspeed_proto::GenerateRequest { request_id, - tokenized: Some(sglang::TokenizedInput { + tokenized: Some(tokenspeed_proto::TokenizedInput { original_text: processed_text, input_ids: token_ids, }), - sampling_params: Some(sampling_params), + sampling_params: Some(translate::sampling_params(sglang_sampling)), stream: body.stream.unwrap_or(false), ..Default::default() }) @@ -369,7 +356,7 @@ impl TokenSpeedSchedulerClient { #[expect( clippy::unused_self, - reason = "receiver kept for API parity with SglangSchedulerClient" + reason = "receiver kept for API parity with the other engine clients" )] pub fn build_generate_request_from_messages( &self, @@ -378,18 +365,18 @@ impl TokenSpeedSchedulerClient { processed_text: String, token_ids: Vec, tool_call_constraint: Option<(String, String)>, - ) -> Result { - let sampling_params = crate::sampling_params::build_grpc_sampling_params_from_messages( + ) -> Result { + let sglang_sampling = crate::sampling_params::build_grpc_sampling_params_from_messages( body, tool_call_constraint, )?; - Ok(sglang::GenerateRequest { + Ok(tokenspeed_proto::GenerateRequest { request_id, - tokenized: Some(sglang::TokenizedInput { + tokenized: Some(tokenspeed_proto::TokenizedInput { original_text: processed_text, input_ids: token_ids, }), - sampling_params: Some(sampling_params), + sampling_params: Some(translate::sampling_params(sglang_sampling)), stream: body.stream.unwrap_or(false), ..Default::default() }) @@ -397,7 +384,7 @@ impl TokenSpeedSchedulerClient { #[expect( clippy::unused_self, - reason = "receiver kept for API parity with SglangSchedulerClient" + reason = "receiver kept for API parity with the other engine clients" )] pub fn build_generate_request_from_completion( &self, @@ -405,18 +392,18 @@ impl TokenSpeedSchedulerClient { body: &CompletionRequest, original_text: String, token_ids: Vec, - ) -> Result { - let sampling_params = + ) -> Result { + let sglang_sampling = crate::sampling_params::build_grpc_sampling_params_from_completion(body)?; - Ok(sglang::GenerateRequest { + Ok(tokenspeed_proto::GenerateRequest { request_id, - tokenized: Some(sglang::TokenizedInput { + tokenized: Some(tokenspeed_proto::TokenizedInput { original_text, input_ids: token_ids, }), - sampling_params: Some(sampling_params), + sampling_params: Some(translate::sampling_params(sglang_sampling)), return_logprob: body.logprobs.is_some(), - logprob_start_len: -1, + logprob_start_len: Some(-1), top_logprobs_num: body.logprobs.unwrap_or(0) as i32, stream: body.stream, ..Default::default() @@ -448,15 +435,14 @@ fn tokenspeed_abort_dispatcher(client: TokenSpeedSchedulerClient) -> AbortDispat }) } -// ── Wire-boundary translation ───────────────────────────────────────── +// ── Sampling-params translation + unary RPC adapters ───────────────── // -// Maps SGLang-shaped types (used internally by the router) to TokenSpeed's -// slimmer wire types and back. Fields TokenSpeed doesn't carry on the wire -// (mm_inputs, disagg, LoRA, hidden states, embeddings, etc.) are dropped on -// the way out; fields TokenSpeed doesn't return are filled with defaults on -// the way in. When the protocols genuinely diverge — i.e. TokenSpeed needs -// a field SGLang doesn't have — extend this module rather than threading -// new variants through proto_wrapper. +// `sampling_params` field-maps the sglang-shaped sampling params produced +// by `crate::sampling_params` into TokenSpeed's slimmer wire shape. The +// other adapters (model_info / server_info / loads) translate unary RPC +// responses into the SGLang-shaped types the router's metadata wrappers +// currently consume — that's a follow-up cleanup that can ride on top of +// the per-backend `ModelInfo` / `ServerInfo` enums. mod translate { use super::{sglang, tokenspeed_proto as ts}; @@ -512,103 +498,6 @@ mod translate { } } - pub(super) fn generate_request(r: sglang::GenerateRequest) -> ts::GenerateRequest { - ts::GenerateRequest { - request_id: r.request_id, - tokenized: r.tokenized.map(|t| ts::TokenizedInput { - input_ids: t.input_ids, - original_text: t.original_text, - }), - sampling_params: r.sampling_params.map(sampling_params), - return_logprob: r.return_logprob, - // SGLang's wire-side `logprob_start_len` is non-optional `i32` - // with `-1` as the "no input logprobs" sentinel; TokenSpeed's - // proto makes the field `optional` so the servicer can tell - // "unset" from "explicit 0". Always wrap in `Some(...)` so - // existing SGLang-shaped callers preserve their sentinel - // through to the Python side. - logprob_start_len: Some(r.logprob_start_len), - top_logprobs_num: r.top_logprobs_num, - token_ids_logprob: r.token_ids_logprob, - stream: r.stream, - // Fields TokenSpeed has no concept of: - // r.mm_inputs, r.disaggregated_params, r.custom_logit_processor, - // r.timestamp, r.input_embeds, r.lora_id, r.data_parallel_rank, - // r.log_metrics, r.return_hidden_states - // — silently dropped here. Routing multimodal / disagg / LoRA - // requests to a TokenSpeed worker is a router-level config bug, - // not something this layer should try to paper over. - } - } - - pub(super) fn generate_response(r: ts::GenerateResponse) -> sglang::GenerateResponse { - let response = r.response.map(|resp| match resp { - ts::generate_response::Response::Chunk(c) => { - sglang::generate_response::Response::Chunk(stream_chunk(c)) - } - ts::generate_response::Response::Complete(c) => { - sglang::generate_response::Response::Complete(complete(c)) - } - }); - sglang::GenerateResponse { - request_id: r.request_id, - response, - } - } - - fn stream_chunk(c: ts::GenerateStreamChunk) -> sglang::GenerateStreamChunk { - sglang::GenerateStreamChunk { - token_ids: c.token_ids, - prompt_tokens: c.prompt_tokens, - completion_tokens: c.completion_tokens, - cached_tokens: c.cached_tokens, - output_logprobs: c.output_logprobs.map(output_logprobs), - // Fields not on TokenSpeed's wire — defaulted. - hidden_states: vec![], - input_logprobs: None, - index: c.index, - } - } - - fn complete(c: ts::GenerateComplete) -> sglang::GenerateComplete { - let matched_stop = c.matched_stop.map(|m| match m { - ts::generate_complete::MatchedStop::MatchedTokenId(id) => { - sglang::generate_complete::MatchedStop::MatchedTokenId(id) - } - ts::generate_complete::MatchedStop::MatchedStopStr(s) => { - sglang::generate_complete::MatchedStop::MatchedStopStr(s) - } - }); - sglang::GenerateComplete { - output_ids: c.output_ids, - finish_reason: c.finish_reason, - prompt_tokens: c.prompt_tokens, - completion_tokens: c.completion_tokens, - cached_tokens: c.cached_tokens, - output_logprobs: c.output_logprobs.map(output_logprobs), - // Not on TokenSpeed's wire. - all_hidden_states: vec![], - input_logprobs: None, - matched_stop, - index: c.index, - } - } - - fn output_logprobs(o: ts::OutputLogProbs) -> sglang::OutputLogProbs { - sglang::OutputLogProbs { - token_logprobs: o.token_logprobs, - token_ids: o.token_ids, - top_logprobs: o - .top_logprobs - .into_iter() - .map(|t| sglang::TopLogProbs { - values: t.values, - token_ids: t.token_ids, - }) - .collect(), - } - } - pub(super) fn model_info(r: ts::GetModelInfoResponse) -> sglang::GetModelInfoResponse { sglang::GetModelInfoResponse { model_path: r.model_path, diff --git a/model_gateway/src/routers/grpc/client.rs b/model_gateway/src/routers/grpc/client.rs index bab1d1aa0..16503e7ae 100644 --- a/model_gateway/src/routers/grpc/client.rs +++ b/model_gateway/src/routers/grpc/client.rs @@ -341,7 +341,7 @@ impl GrpcClient { let stream = client.generate(*boxed_req).await?; Ok(ProtoStream::Mlx(stream)) } - (Self::TokenSpeed(client), ProtoGenerateRequest::Sglang(boxed_req)) => { + (Self::TokenSpeed(client), ProtoGenerateRequest::TokenSpeed(boxed_req)) => { let stream = client.generate(*boxed_req).await?; Ok(ProtoStream::TokenSpeed(stream)) } @@ -466,7 +466,7 @@ impl GrpcClient { token_ids, tool_constraints, )?; - Ok(ProtoGenerateRequest::Sglang(Box::new(req))) + Ok(ProtoGenerateRequest::TokenSpeed(Box::new(req))) } } } @@ -552,7 +552,7 @@ impl GrpcClient { token_ids, tool_constraints, )?; - Ok(ProtoGenerateRequest::Sglang(Box::new(req))) + Ok(ProtoGenerateRequest::TokenSpeed(Box::new(req))) } } } @@ -608,7 +608,7 @@ impl GrpcClient { original_text, token_ids, )?; - Ok(ProtoGenerateRequest::Sglang(Box::new(req))) + Ok(ProtoGenerateRequest::TokenSpeed(Box::new(req))) } } } @@ -664,7 +664,7 @@ impl GrpcClient { original_text, token_ids, )?; - Ok(ProtoGenerateRequest::Sglang(Box::new(req))) + Ok(ProtoGenerateRequest::TokenSpeed(Box::new(req))) } } } diff --git a/model_gateway/src/routers/grpc/common/stages/helpers.rs b/model_gateway/src/routers/grpc/common/stages/helpers.rs index 65ef6dded..da6a2d514 100644 --- a/model_gateway/src/routers/grpc/common/stages/helpers.rs +++ b/model_gateway/src/routers/grpc/common/stages/helpers.rs @@ -119,7 +119,10 @@ pub(crate) fn apply_sampling_defaults_to_generate_request( request_type: &RequestType, workers: Option<&WorkerSelection>, ) { - if matches!(request, ProtoGenerateRequest::Trtllm(_)) { + if matches!( + request, + ProtoGenerateRequest::Trtllm(_) | ProtoGenerateRequest::TokenSpeed(_) + ) { return; } @@ -156,7 +159,9 @@ pub(crate) fn apply_sampling_defaults_to_generate_request( }; apply_mlx_sampling_defaults(params, defaults, mask); } - ProtoGenerateRequest::Trtllm(_) => {} + // TokenSpeed and TRT-LLM are early-returned above; the arms exist + // only to keep the match exhaustive. + ProtoGenerateRequest::Trtllm(_) | ProtoGenerateRequest::TokenSpeed(_) => {} } } diff --git a/model_gateway/src/routers/grpc/harmony/stages/request_building.rs b/model_gateway/src/routers/grpc/harmony/stages/request_building.rs index 0cc042dc7..3568a6e3e 100644 --- a/model_gateway/src/routers/grpc/harmony/stages/request_building.rs +++ b/model_gateway/src/routers/grpc/harmony/stages/request_building.rs @@ -277,11 +277,9 @@ impl PipelineStage for HarmonyRequestBuildingStage { }; ProtoGenerateRequest::Mlx(Box::new(req)) } - // TokenSpeed: builder produces an SGLang-shaped request so the - // ``ProtoGenerateRequest::Sglang`` plumbing carries it; the - // wire-side translation to TokenSpeed shape happens inside the - // client's ``generate()``. Multimodal is intentionally not - // supported here — the harmony path is text-only today. + // TokenSpeed: builder produces a native + // ``tokenspeed::GenerateRequest``; multimodal is intentionally + // not supported here — the harmony path is text-only today. GrpcClient::TokenSpeed(tokenspeed_client) => { let req = match &ctx.input.request_type { RequestType::Chat(request) => { @@ -324,7 +322,7 @@ impl PipelineStage for HarmonyRequestBuildingStage { )); } }; - ProtoGenerateRequest::Sglang(Box::new(req)) + ProtoGenerateRequest::TokenSpeed(Box::new(req)) } }; @@ -371,6 +369,15 @@ impl PipelineStage for HarmonyRequestBuildingStage { ); } } + ProtoGenerateRequest::TokenSpeed(req) => { + if let Some(params) = req.sampling_params.as_mut() { + params.stop_token_ids.extend_from_slice(&harmony_stop_ids); + debug!( + stop_token_count = harmony_stop_ids.len(), + "Injected Harmony stop tokens into TokenSpeed sampling params" + ); + } + } } } diff --git a/model_gateway/src/routers/grpc/proto_wrapper.rs b/model_gateway/src/routers/grpc/proto_wrapper.rs index 6b053e3ef..27128b150 100644 --- a/model_gateway/src/routers/grpc/proto_wrapper.rs +++ b/model_gateway/src/routers/grpc/proto_wrapper.rs @@ -1,7 +1,8 @@ -//! Protocol buffer type wrappers for SGLang, vLLM, and TensorRT-LLM backends +//! Protocol buffer type wrappers for SGLang, vLLM, TensorRT-LLM, MLX, and TokenSpeed backends. //! -//! This module provides unified enums that wrap proto types from SGLang, vLLM, and TensorRT-LLM, -//! allowing the router to work with any backend transparently. +//! This module provides unified enums that wrap proto types from each +//! supported backend, allowing the router to work with any backend +//! transparently. use std::collections::HashMap; @@ -11,6 +12,9 @@ use smg_grpc_client::{ mlx_proto::{self as mlx}, sglang_proto::{self as sglang, generate_complete::MatchedStop as SglangMatchedStop}, sglang_scheduler::AbortOnDropStream as SglangStream, + tokenspeed_proto::{ + self as tokenspeed, generate_complete::MatchedStop as TokenSpeedMatchedStop, + }, tokenspeed_scheduler::AbortOnDropStream as TokenSpeedStream, trtllm_proto::{self as trtllm, generate_complete::MatchedStop as TrtllmMatchedStop}, trtllm_service::AbortOnDropStream as TrtllmStream, @@ -281,6 +285,7 @@ pub enum ProtoGenerateRequest { Vllm(Box), Trtllm(Box), Mlx(Box), + TokenSpeed(Box), } impl ProtoGenerateRequest { @@ -356,6 +361,30 @@ impl ProtoGenerateRequest { } } + /// Get TokenSpeed variant (panics if not TokenSpeed) + #[expect( + clippy::panic, + reason = "typed accessor: caller guarantees variant via is_tokenspeed() check" + )] + pub fn as_tokenspeed(&self) -> &tokenspeed::GenerateRequest { + match self { + Self::TokenSpeed(req) => req, + _ => panic!("Expected TokenSpeed GenerateRequest"), + } + } + + /// Get mutable TokenSpeed variant (panics if not TokenSpeed) + #[expect( + clippy::panic, + reason = "typed accessor: caller guarantees variant via is_tokenspeed() check" + )] + pub fn as_tokenspeed_mut(&mut self) -> &mut tokenspeed::GenerateRequest { + match self { + Self::TokenSpeed(req) => req, + _ => panic!("Expected TokenSpeed GenerateRequest"), + } + } + /// Check if this is SGLang pub fn is_sglang(&self) -> bool { matches!(self, Self::Sglang(_)) @@ -371,6 +400,11 @@ impl ProtoGenerateRequest { matches!(self, Self::Trtllm(_)) } + /// Check if this is TokenSpeed + pub fn is_tokenspeed(&self) -> bool { + matches!(self, Self::TokenSpeed(_)) + } + /// Set max_tokens for prefill-only execution (vLLM PD mode). /// The prefill request uses max_tokens=1 to trigger KV cache computation /// without generating unnecessary tokens. @@ -386,7 +420,7 @@ impl ProtoGenerateRequest { }); } } - Self::Sglang(_) | Self::Trtllm(_) | Self::Mlx(_) => { + Self::Sglang(_) | Self::Trtllm(_) | Self::Mlx(_) | Self::TokenSpeed(_) => { tracing::warn!("set_max_tokens_for_prefill called on non-vLLM request, ignoring"); } } @@ -399,6 +433,7 @@ impl ProtoGenerateRequest { Self::Sglang(req) => req.stream = stream, Self::Trtllm(req) => req.streaming = stream, Self::Mlx(req) => req.stream = stream, + Self::TokenSpeed(req) => req.stream = stream, } } @@ -416,7 +451,8 @@ impl ProtoGenerateRequest { match self { Self::Sglang(req) => req.mm_inputs = None, Self::Vllm(req) => req.mm_inputs = None, - Self::Trtllm(_) | Self::Mlx(_) => {} // TRT-LLM and MLX protos have no mm_inputs field + // TRT-LLM, MLX, and TokenSpeed protos have no mm_inputs field + Self::Trtllm(_) | Self::Mlx(_) | Self::TokenSpeed(_) => {} } } @@ -427,6 +463,7 @@ impl ProtoGenerateRequest { Self::Vllm(req) => &req.request_id, Self::Trtllm(req) => &req.request_id, Self::Mlx(req) => &req.request_id, + Self::TokenSpeed(req) => &req.request_id, } } @@ -440,7 +477,7 @@ impl ProtoGenerateRequest { remote_port, }); } - Self::Sglang(_) | Self::Trtllm(_) | Self::Mlx(_) => { + Self::Sglang(_) | Self::Trtllm(_) | Self::Mlx(_) | Self::TokenSpeed(_) => { tracing::warn!("set_kv_transfer_params called on non-vLLM request, ignoring"); } } @@ -453,6 +490,7 @@ pub enum ProtoGenerateResponse { Vllm(Box), Trtllm(Box), Mlx(Box), + TokenSpeed(Box), } impl ProtoGenerateResponse { @@ -497,6 +535,15 @@ impl ProtoGenerateResponse { } None => ProtoResponseVariant::None, }, + Self::TokenSpeed(resp) => match resp.response { + Some(tokenspeed::generate_response::Response::Chunk(chunk)) => { + ProtoResponseVariant::Chunk(ProtoGenerateStreamChunk::TokenSpeed(chunk)) + } + Some(tokenspeed::generate_response::Response::Complete(complete)) => { + ProtoResponseVariant::Complete(ProtoGenerateComplete::TokenSpeed(complete)) + } + None => ProtoResponseVariant::None, + }, } } } @@ -515,6 +562,7 @@ pub enum ProtoGenerateStreamChunk { Vllm(vllm::GenerateStreamChunk), Trtllm(trtllm::GenerateStreamChunk), Mlx(mlx::GenerateStreamChunk), + TokenSpeed(tokenspeed::GenerateStreamChunk), } impl ProtoGenerateStreamChunk { @@ -574,6 +622,11 @@ impl ProtoGenerateStreamChunk { matches!(self, Self::Mlx(_)) } + /// Check if this is TokenSpeed + pub fn is_tokenspeed(&self) -> bool { + matches!(self, Self::TokenSpeed(_)) + } + /// Get token IDs from chunk (common field) pub fn token_ids(&self) -> &[u32] { match self { @@ -581,6 +634,7 @@ impl ProtoGenerateStreamChunk { Self::Vllm(c) => &c.token_ids, Self::Trtllm(c) => &c.token_ids, Self::Mlx(c) => &c.token_ids, + Self::TokenSpeed(c) => &c.token_ids, } } @@ -592,10 +646,11 @@ impl ProtoGenerateStreamChunk { Self::Vllm(c) => c.index, Self::Trtllm(c) => c.sequence_index, Self::Mlx(c) => c.index, + Self::TokenSpeed(c) => c.index, } } - /// Get output logprobs (SGLang, vLLM, TensorRT-LLM, and MLX) + /// Get output logprobs (SGLang, vLLM, TensorRT-LLM, MLX, and TokenSpeed) pub fn output_logprobs(&self) -> Option { match self { Self::Sglang(c) => c @@ -611,6 +666,10 @@ impl ProtoGenerateStreamChunk { .output_logprobs .as_ref() .map(|lp| convert_output_logprobs!(lp)), + Self::TokenSpeed(c) => c + .output_logprobs + .as_ref() + .map(|lp| convert_output_logprobs!(lp)), } } @@ -625,8 +684,8 @@ impl ProtoGenerateStreamChunk { .input_logprobs .as_ref() .map(|lp| convert_input_logprobs!(lp)), - // TRT-LLM and MLX streaming chunks don't have input_logprobs - Self::Trtllm(_) | Self::Mlx(_) => None, + // TRT-LLM, MLX, and TokenSpeed streaming chunks don't have input_logprobs + Self::Trtllm(_) | Self::Mlx(_) | Self::TokenSpeed(_) => None, } } @@ -637,6 +696,7 @@ impl ProtoGenerateStreamChunk { Self::Vllm(c) => c.prompt_tokens, Self::Trtllm(c) => c.prompt_tokens, Self::Mlx(c) => c.prompt_tokens, + Self::TokenSpeed(c) => c.prompt_tokens, } } @@ -647,6 +707,7 @@ impl ProtoGenerateStreamChunk { Self::Vllm(c) => c.completion_tokens, Self::Trtllm(c) => c.completion_tokens, Self::Mlx(c) => c.completion_tokens, + Self::TokenSpeed(c) => c.completion_tokens, } } @@ -657,6 +718,7 @@ impl ProtoGenerateStreamChunk { Self::Vllm(c) => c.cached_tokens, Self::Trtllm(c) => c.cached_tokens, Self::Mlx(c) => c.cached_tokens, + Self::TokenSpeed(c) => c.cached_tokens, } } } @@ -668,6 +730,7 @@ pub enum ProtoGenerateComplete { Vllm(vllm::GenerateComplete), Trtllm(trtllm::GenerateComplete), Mlx(mlx::GenerateComplete), + TokenSpeed(tokenspeed::GenerateComplete), } impl ProtoGenerateComplete { @@ -739,6 +802,11 @@ impl ProtoGenerateComplete { matches!(self, Self::Mlx(_)) } + /// Check if this is TokenSpeed + pub fn is_tokenspeed(&self) -> bool { + matches!(self, Self::TokenSpeed(_)) + } + /// Get token IDs from either backend (output_ids in proto) pub fn token_ids(&self) -> &[u32] { match self { @@ -746,6 +814,7 @@ impl ProtoGenerateComplete { Self::Vllm(c) => &c.output_ids, Self::Trtllm(c) => &c.output_token_ids, Self::Mlx(c) => &c.output_ids, + Self::TokenSpeed(c) => &c.output_ids, } } @@ -756,6 +825,7 @@ impl ProtoGenerateComplete { Self::Vllm(c) => c.prompt_tokens, Self::Trtllm(c) => c.prompt_tokens, Self::Mlx(c) => c.prompt_tokens, + Self::TokenSpeed(c) => c.prompt_tokens, } } @@ -766,6 +836,7 @@ impl ProtoGenerateComplete { Self::Vllm(c) => c.completion_tokens, Self::Trtllm(c) => c.completion_tokens, Self::Mlx(c) => c.completion_tokens, + Self::TokenSpeed(c) => c.completion_tokens, } } @@ -776,6 +847,7 @@ impl ProtoGenerateComplete { Self::Vllm(c) => &c.finish_reason, Self::Trtllm(c) => &c.finish_reason, Self::Mlx(c) => &c.finish_reason, + Self::TokenSpeed(c) => &c.finish_reason, } } @@ -787,6 +859,7 @@ impl ProtoGenerateComplete { Self::Vllm(c) => c.index, Self::Trtllm(c) => c.sequence_index, Self::Mlx(c) => c.index, + Self::TokenSpeed(c) => c.index, } } @@ -824,6 +897,11 @@ impl ProtoGenerateComplete { Self::Mlx(c) => c .matched_stop_token_id .map(|id| serde_json::Value::Number(id.into())), + Self::TokenSpeed(c) => convert!( + &c.matched_stop, + TokenSpeedMatchedStop::MatchedTokenId, + TokenSpeedMatchedStop::MatchedStopStr + ), } } @@ -834,6 +912,7 @@ impl ProtoGenerateComplete { Self::Vllm(c) => &c.output_ids, Self::Trtllm(c) => &c.output_token_ids, Self::Mlx(c) => &c.output_ids, + Self::TokenSpeed(c) => &c.output_ids, } } @@ -844,6 +923,7 @@ impl ProtoGenerateComplete { Self::Vllm(c) => c.cached_tokens, Self::Trtllm(c) => c.cached_tokens, Self::Mlx(c) => c.cached_tokens, + Self::TokenSpeed(c) => c.cached_tokens, } } @@ -882,12 +962,12 @@ impl ProtoGenerateComplete { }) } } - // MLX does not have input_logprobs - Self::Mlx(_) => None, + // MLX and TokenSpeed do not have input_logprobs + Self::Mlx(_) | Self::TokenSpeed(_) => None, } } - /// Get output logprobs (SGLang, vLLM, TensorRT-LLM, and MLX) + /// Get output logprobs (SGLang, vLLM, TensorRT-LLM, MLX, and TokenSpeed) pub fn output_logprobs(&self) -> Option { match self { Self::Sglang(c) => c @@ -903,6 +983,10 @@ impl ProtoGenerateComplete { .output_logprobs .as_ref() .map(|lp| convert_output_logprobs!(lp)), + Self::TokenSpeed(c) => c + .output_logprobs + .as_ref() + .map(|lp| convert_output_logprobs!(lp)), } } @@ -914,19 +998,16 @@ impl ProtoGenerateComplete { .kv_transfer_params .as_ref() .map(|params| (params.remote_host.clone(), params.remote_port)), - Self::Sglang(_) | Self::Trtllm(_) | Self::Mlx(_) => None, + Self::Sglang(_) | Self::Trtllm(_) | Self::Mlx(_) | Self::TokenSpeed(_) => None, } } } /// Unified stream wrapper. /// -/// TokenSpeed has its own variant because its underlying stream type differs -/// from SGLang's ([`tokenspeed_scheduler::AbortOnDropStream`] vs -/// [`sglang_scheduler::AbortOnDropStream`]) — the wire is independent. Both -/// yield ``sglang::GenerateResponse``-shaped items (TokenSpeed translates at -/// the boundary), so the chunk / complete accessors below don't need a -/// dedicated TokenSpeed variant. +/// One variant per backend. Each yields its own native proto response shape; +/// the chunk / complete accessors above match on the corresponding +/// `ProtoGenerateStreamChunk` / `ProtoGenerateComplete` arm. pub enum ProtoStream { Sglang(SglangStream), Vllm(VllmStream), @@ -958,7 +1039,7 @@ impl ProtoStream { Self::TokenSpeed(stream) => stream .next() .await - .map(|result| result.map(|r| ProtoGenerateResponse::Sglang(Box::new(r)))), + .map(|result| result.map(|r| ProtoGenerateResponse::TokenSpeed(Box::new(r)))), } } diff --git a/model_gateway/src/workflow/steps/local/detect_backend.rs b/model_gateway/src/workflow/steps/local/detect_backend.rs index fbcb9c1dd..235b672af 100644 --- a/model_gateway/src/workflow/steps/local/detect_backend.rs +++ b/model_gateway/src/workflow/steps/local/detect_backend.rs @@ -43,10 +43,7 @@ async fn detect_grpc_backend( } } - // Try each runtime sequentially, ordered by expected frequency so the - // common case finishes after one probe. Each backend speaks its own - // gRPC service, so order is purely a latency optimisation, not a - // correctness condition. + // Try each runtime sequentially (most common first), skipping the hint we already tried for runtime in &["sglang", "vllm", "trtllm", "tokenspeed", "mlx"] { if Some(*runtime) == runtime_hint { continue; From e400fab3f87aaba137a852170ede48be6206bbf5 Mon Sep 17 00:00:00 2001 From: key4ng Date: Sat, 9 May 2026 13:52:42 -0700 Subject: [PATCH 04/24] revert(tokenizer): defer OpenAI tool-wrapper strip + strict:false injection The chat-template tool-shape pre-processor was correct for Kimi-K2.5 (BFCL accuracy +6 pp on simple_python, +24 pp on parallel_multiple) but breaks Mistral chat templates: their template at line 32 iterates ``tool.function.parameters.properties.items()``, which raises ``unknown method: undefined has no method named items`` once we unwrap ``{"type": "function", "function": {...}}`` into the bare inner dict. The shape a chat template expects is template-dependent, not engine-dependent. Reverting the unconditional unwrap; full rationale, accuracy data, and proposed per-model fix in docs/proposals/2026-05-09-deferred-chat-template-tools-strip.md. Affected CI lane: e2e_test/chat_completions/test_function_calling.py::TestToolChoiceMistral (20 tests failing with chat:32 render error). Signed-off-by: key4ng --- crates/tokenizer/src/chat_template.rs | 50 +-------------------------- 1 file changed, 1 insertion(+), 49 deletions(-) diff --git a/crates/tokenizer/src/chat_template.rs b/crates/tokenizer/src/chat_template.rs index 97ca4e550..2ac3b50c7 100644 --- a/crates/tokenizer/src/chat_template.rs +++ b/crates/tokenizer/src/chat_template.rs @@ -620,59 +620,11 @@ fn render_chat_template( // Convert messages to minijinja::Value (messages already processed by router) let minijinja_messages: Vec = messages.iter().map(Value::from_serialize).collect(); - // Strip the OpenAI tool wrapper for downstream rendering: convert - // ``[{"type": "function", "function": {...}}, ...]`` into the bare list - // ``[{...}, ...]``. This matches what TokenSpeed's HTTP path does in - // ``serving_chat.py:188`` (``[item.function.model_dump() for item in - // tools]``). Empirically, Kimi-K2.5 + TokenSpeed-NVFP4 BFCL accuracy is - // significantly higher when the model sees the bare-inner JSON form - // (HTTP path: simple_python 92.25 %) than either the wrapped form - // (SMG pre-fix: 86 %) or the TS-namespace form produced by the model's - // own ``encode_tools_to_typescript_style`` (88.25 %). The chat template - // falls through to ``{{ tools | tojson }}`` whenever ``tools_ts_str`` - // is empty/missing — so feeding the bare-inner shape and letting the - // template's JSON branch fire reproduces the HTTP path exactly. - let stripped_tools_json: Option> = params.tools.as_ref().map(|arr| { - arr.iter() - .map(|t| { - let v = serde_json::to_value(t).unwrap_or(serde_json::Value::Null); - // If this is an OpenAI-style wrapped tool ({"type":"function","function":{...}}), - // unwrap to the inner function dict. Then mirror what TokenSpeed HTTP's - // ``serving_chat.py:188`` does — call Pydantic's ``model_dump()`` on the - // inner function model, which emits *all* fields of the function schema - // including the ``strict`` field with its default ``false``. Without this - // SMG produces 108 tokens vs HTTP's 111 (3 missing tokens for ``,"strict":false``) - // and the model's BFCL accuracy stays ~5 pp below HTTP. Adding ``strict:false`` - // to each function dict closes that gap by matching HTTP byte-for-byte. - match v { - serde_json::Value::Object(ref m) => { - if m.get("type").and_then(|x| x.as_str()) == Some("function") { - if let Some(serde_json::Value::Object(inner)) = m.get("function") { - let mut inner_with_default = inner.clone(); - // Match Pydantic's ChatCompletionToolFunction.model_dump(): - // emit ``strict`` even when not set in the source request, - // defaulting to ``false`` (the OpenAI-API default). - inner_with_default - .entry("strict") - .or_insert(serde_json::Value::Bool(false)); - return serde_json::Value::Object(inner_with_default); - } - } - v - } - other => other, - } - }) - .collect() - }); - // Use Value::UNDEFINED for missing optional params so they are truly "undefined" // in the template context, matching HuggingFace Python behavior. Many chat templates // use `{% if tools is defined %}` guards — passing null (none) instead of undefined // would bypass those guards since `none` IS defined, causing `tools | length` to fail. - let tools_value = stripped_tools_json - .as_ref() - .map_or(Value::UNDEFINED, Value::from_serialize); + let tools_value = params.tools.map_or(Value::UNDEFINED, Value::from_serialize); let documents_value = params .documents .map_or(Value::UNDEFINED, Value::from_serialize); From 8478a636bac272bad93ecdaa97efe0b3728ade84 Mon Sep 17 00:00:00 2001 From: key4ng Date: Sat, 9 May 2026 13:52:42 -0700 Subject: [PATCH 05/24] style(grpc): trim verbose comments PR1 introduced MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cuts ~125 lines of doc-comment / inline-rationale prose without losing information. Hot spots: - tokenspeed_scheduler.rs module doc (28 → 7 lines) - tokenspeed_scheduler.proto service header + SamplingParams comment + per-field rationale (~40 lines) - sampling_params.rs module doc (14 → 4 lines) - translate::sampling_params explainer (12 → 3 lines) - inline arm comments in client.rs / multimodal.rs / harmony/stages / common/stages that just re-stated what the code already shows Behavior unchanged. ``cargo +nightly fmt --check`` passes; ``cargo clippy -p smg-grpc-client -p smg --all-targets --all-features -- -D warnings`` passes. Signed-off-by: key4ng --- .../proto/tokenspeed_scheduler.proto | 55 +++--------- crates/grpc_client/src/sampling_params.rs | 18 +--- .../grpc_client/src/tokenspeed_scheduler.rs | 88 ++++--------------- model_gateway/src/routers/grpc/client.rs | 14 +-- .../src/routers/grpc/common/stages/helpers.rs | 2 - .../grpc/common/stages/request_execution.rs | 3 - .../grpc/harmony/stages/request_building.rs | 3 - model_gateway/src/routers/grpc/multimodal.rs | 4 - 8 files changed, 31 insertions(+), 156 deletions(-) diff --git a/crates/grpc_client/proto/tokenspeed_scheduler.proto b/crates/grpc_client/proto/tokenspeed_scheduler.proto index 02d649ae7..fc5621014 100644 --- a/crates/grpc_client/proto/tokenspeed_scheduler.proto +++ b/crates/grpc_client/proto/tokenspeed_scheduler.proto @@ -5,34 +5,15 @@ package tokenspeed.grpc.scheduler; import "google/protobuf/timestamp.proto"; import "google/protobuf/struct.proto"; -// Service definition for TokenSpeed scheduler communication. -// -// TokenSpeed has its own service identity AND its own message shapes — wire -// definition is fully self-contained, with zero dependencies on -// ``sglang_scheduler.proto``. The message catalog is intentionally minimal: -// it covers what TokenSpeed's top-tier LLMs (Kimi K2, MiniMax M2, Qwen 3, -// gpt-oss, DeepSeek V4) actually need today, and nothing more. Anything -// SGLang-specific (PD-disaggregated serving, LoRA hot-swap, multimodal, -// classifier outputs, hidden-state forwarding, embeddings) is deliberately -// out of scope and lands here only when an explicit TokenSpeed use case -// shows up. +// TokenSpeed scheduler gRPC service. Self-contained wire — no dependency +// on sglang_scheduler.proto. Trimmed to text-generation only (no embed, +// no multimodal, no PD-disaggregated, no LoRA, no hidden-state forwarding). service TokenSpeedScheduler { - // Submit a generation request (server-streaming for token-by-token). rpc Generate(GenerateRequest) returns (stream GenerateResponse); - - // Liveness + readiness probe. rpc HealthCheck(HealthCheckRequest) returns (HealthCheckResponse); - - // Cancel a running request. rpc Abort(AbortRequest) returns (AbortResponse); - - // Static info about the loaded model. rpc GetModelInfo(GetModelInfoRequest) returns (GetModelInfoResponse); - - // Runtime info about the server. rpc GetServerInfo(GetServerInfoRequest) returns (GetServerInfoResponse); - - // Per-DP-rank load metrics (used by router for least-load). rpc GetLoads(GetLoadsRequest) returns (GetLoadsResponse); } @@ -40,16 +21,9 @@ service TokenSpeedScheduler { // Sampling // ===================== -// IMPORTANT: proto3 numeric defaults (0) do NOT match semantic defaults -// (temperature=1.0, top_p=1.0, top_k=-1). All sampling scalars are -// declared ``optional`` so presence is preserved on the wire — the -// servicer uses ``HasField()`` to distinguish "client explicitly set 0" -// from "client didn't send anything." Without this, ``temperature=0`` -// (a valid request for greedy decoding) is indistinguishable from the -// proto3 default and would be silently dropped by truthy-check guards. -// -// ``min_new_tokens`` is left non-optional because 0 is its semantic -// "no minimum" sentinel. +// Sampling scalars are `optional` so the servicer can distinguish +// "client set 0" from "client unset" via `HasField()`. `min_new_tokens` +// stays non-optional because 0 is its semantic "no minimum" sentinel. message SamplingParams { optional float temperature = 1; optional float top_p = 2; @@ -75,8 +49,7 @@ message SamplingParams { // Per-token logit bias. map logit_bias = 16; - // Structured generation. Currently xfailed in e2e (tokenspeed#361), - // but the wire shape stays so wiring it later doesn't bump the proto. + // Structured generation. Currently xfailed in e2e (tokenspeed#361). oneof constraint { string regex = 17; string json_schema = 18; @@ -84,12 +57,7 @@ message SamplingParams { string structural_tag = 20; } - // When true, generation does not strip the trailing matched stop token - // from ``output_ids`` (matches SGLang's ``no_stop_trim``). Combined with - // ``skip_special_tokens=False`` it lets the gateway-side detokenizer - // render the EOS marker in the visible response — required for the - // ``test_no_stop_trim_with_skip_special_false`` e2e check and for any - // downstream logic that needs the raw stop token in the output stream. + // Keep the trailing matched stop token in `output_ids` (SGLang parity). bool no_stop_trim = 22; // Escape hatch for backend-specific knobs without bumping the proto. @@ -110,8 +78,7 @@ message GenerateRequest { // Logprob options. bool return_logprob = 4; - // Optional so the servicer can distinguish "client omitted" (use SGLang's - // ``-1`` default = no input logprobs) from an explicit value like 0. + // Optional: unset = SGLang's `-1` (no input logprobs); explicit 0 = "from start". optional int32 logprob_start_len = 5; int32 top_logprobs_num = 6; repeated uint32 token_ids_logprob = 7; @@ -122,9 +89,7 @@ message GenerateRequest { message TokenizedInput { repeated uint32 input_ids = 1; - // Original text — purely cosmetic; the tokenizer pass is skipped because - // input_ids is set. Used in worker logs for traceability. - string original_text = 2; + string original_text = 2; // cosmetic, for worker logs } message GenerateResponse { diff --git a/crates/grpc_client/src/sampling_params.rs b/crates/grpc_client/src/sampling_params.rs index a47478793..bf4a6640e 100644 --- a/crates/grpc_client/src/sampling_params.rs +++ b/crates/grpc_client/src/sampling_params.rs @@ -1,17 +1,7 @@ -//! Backend-neutral OpenAI → sampling-params builders. -//! -//! These helpers translate OpenAI request shapes (Chat, Responses, Messages, -//! Completion, plain `GenerateSamplingParams`) into a sampling-params struct -//! shared by the SGLang and TokenSpeed gRPC clients. They live here rather -//! than on `SglangSchedulerClient` because the OpenAI mapping is independent -//! of the wire backend. -//! -//! The return type is currently [`sglang::SamplingParams`] because that proto -//! happens to be the most permissive shape across our supported backends. -//! Other backends (TokenSpeed) translate from this shape to their own slimmer -//! wire format at the boundary. When a backend grows a sampling field SGLang -//! lacks, this is the place to add it (and consider whether a neutral -//! intermediate struct is worth introducing). +//! Backend-neutral OpenAI → sampling-params builders shared by the SGLang +//! and TokenSpeed gRPC clients. Returns [`sglang::SamplingParams`] (the +//! most permissive shape today); other backends translate from this at +//! their wire seam. use openai_protocol::{ chat::ChatCompletionRequest, diff --git a/crates/grpc_client/src/tokenspeed_scheduler.rs b/crates/grpc_client/src/tokenspeed_scheduler.rs index 4c67fc567..f06893edd 100644 --- a/crates/grpc_client/src/tokenspeed_scheduler.rs +++ b/crates/grpc_client/src/tokenspeed_scheduler.rs @@ -1,31 +1,11 @@ //! gRPC client for the TokenSpeed scheduler service. //! -//! TokenSpeed has a fully independent wire definition (see -//! ``proto/tokenspeed_scheduler.proto``) — distinct package -//! (``tokenspeed.grpc.scheduler``), distinct service, distinct messages with -//! intentionally trimmed field sets aimed at the top-tier LLM workloads -//! (Kimi K2, MiniMax M2, Qwen 3, gpt-oss, DeepSeek V4). Anything SGLang has -//! that doesn't apply here (PD-disaggregated serving, multimodal inputs, -//! LoRA hot-swap, hidden-state forwarding, embeddings, classifier outputs, -//! tokenizer streaming, KV-event subscription) is simply not on TokenSpeed's -//! wire. -//! -//! Request/response types are TokenSpeed-native end-to-end: the stream -//! yields ``tokenspeed_proto::GenerateResponse`` and ``build_*_request`` -//! produces ``tokenspeed_proto::GenerateRequest``. The router dispatches -//! through dedicated ``ProtoGenerateRequest::TokenSpeed`` / -//! ``ProtoGenerateStreamChunk::TokenSpeed`` / -//! ``ProtoGenerateComplete::TokenSpeed`` arms — same shape as the other -//! per-backend variants. -//! -//! Sampling-params construction reuses the backend-neutral helpers in -//! ``crate::sampling_params`` (which currently return -//! ``sglang::SamplingParams``); the [`translate::sampling_params`] -//! field-mapper converts to TokenSpeed's shape at the seam. The unary RPC -//! responses (``GetModelInfo``, ``GetServerInfo``, ``GetLoads``) still -//! return SGLang-shaped types because their consumers ride the -//! ``ModelInfo`` / ``ServerInfo`` SGLang variants in the router; that's a -//! separate cleanup. +//! Wire types are TokenSpeed-native end-to-end (`tokenspeed_proto::*`). +//! Sampling params come from the shared `crate::sampling_params` helpers +//! and are field-mapped to TokenSpeed's shape via [`translate::sampling_params`]. +//! The unary RPC adapters (`translate::model_info` / `server_info` / `loads`) +//! still emit SGLang-shaped types pending dedicated `ModelInfo::TokenSpeed` / +//! `ServerInfo::TokenSpeed` arms in the router. use std::{ pin::Pin, @@ -52,19 +32,11 @@ pub mod tokenspeed_proto { tonic::include_proto!("tokenspeed.grpc.scheduler"); } -/// Fire-and-forget abort sender used by [`AbortOnDropStream`]. The closure -/// captures the TokenSpeed client that owns the stream so ``Drop`` can -/// dispatch the abort RPC over the same connection without ``Drop`` itself -/// being async. Local to this module — SGLang's equivalent stream type -/// holds its own client field directly and doesn't need this indirection. +/// Fire-and-forget abort sender invoked from `Drop`. type AbortDispatcher = Arc; /// Auto-aborting wrapper around the TokenSpeed generate stream. -/// -/// Yields ``tokenspeed_proto::GenerateResponse`` directly (no translation -/// at the seam). Sends an Abort RPC on Drop unless ``mark_completed`` was -/// called first — same lifecycle contract as -/// ``sglang_scheduler::AbortOnDropStream``. +/// Sends Abort on Drop unless `mark_completed` ran first. pub struct AbortOnDropStream { inner: Streaming, request_id: String, @@ -145,8 +117,7 @@ impl TokenSpeedSchedulerClient { endpoint.to_string() }; - // Same channel knobs as SglangSchedulerClient — independent of the - // service being called and proven load-appropriate in prod. + // Channel knobs match the other engine clients. let channel = Channel::from_shared(http_endpoint)? .http2_keep_alive_interval(Duration::from_secs(30)) .keep_alive_timeout(Duration::from_secs(10)) @@ -261,12 +232,6 @@ impl TokenSpeedSchedulerClient { } // ── Request builders ────────────────────────────────────────────── - // - // Produce ``tokenspeed_proto::GenerateRequest`` directly. Sampling-param - // construction delegates to ``crate::sampling_params`` (which returns - // ``sglang::SamplingParams`` because that proto is the most permissive - // shape across our backends); ``translate::sampling_params`` field-maps - // it to TokenSpeed's slimmer shape at the seam. #[expect( clippy::unused_self, @@ -435,29 +400,15 @@ fn tokenspeed_abort_dispatcher(client: TokenSpeedSchedulerClient) -> AbortDispat }) } -// ── Sampling-params translation + unary RPC adapters ───────────────── -// -// `sampling_params` field-maps the sglang-shaped sampling params produced -// by `crate::sampling_params` into TokenSpeed's slimmer wire shape. The -// other adapters (model_info / server_info / loads) translate unary RPC -// responses into the SGLang-shaped types the router's metadata wrappers -// currently consume — that's a follow-up cleanup that can ride on top of -// the per-backend `ModelInfo` / `ServerInfo` enums. +// Sampling-params + unary RPC adapters: map between TokenSpeed's wire +// shape and the SGLang shape the router metadata wrappers consume. mod translate { use super::{sglang, tokenspeed_proto as ts}; pub(super) fn sampling_params(s: sglang::SamplingParams) -> ts::SamplingParams { - // sglang's proto declares numeric scalars as non-optional, so the Rust - // router has already substituted semantic defaults (e.g. - // ``temperature=1.0``, ``top_p=1.0``, ``repetition_penalty=1.0``) - // before getting here. tokenspeed's proto declares the same fields - // as ``optional`` so the servicer can use ``HasField()`` to - // distinguish presence — wrap the (already-defaulted) sglang values - // in ``Some(...)`` to mark them as explicitly set on the wire. This - // preserves the pre-fix behavior while letting future direct-to- - // tokenspeed clients use ``None`` to mean "let the engine default - // apply" (e.g. for health-probe / warmup paths that would otherwise - // hit ``top_p must be in (0, 1], got 0.0``). + // SGLang scalars are non-optional with semantic defaults already + // applied; TokenSpeed wraps in `Some(_)` so the servicer's + // `HasField()` checks distinguish "set" from "unset". ts::SamplingParams { temperature: Some(s.temperature), top_p: Some(s.top_p), @@ -502,8 +453,6 @@ mod translate { sglang::GetModelInfoResponse { model_path: r.model_path, tokenizer_path: r.tokenizer_path, - // TokenSpeed only serves generative LLMs at this layer; classifier - // / embedding models are out of scope. Hard-code accordingly. is_generation: true, preferred_sampling_params: r.preferred_sampling_params, weight_version: r.weight_version, @@ -529,13 +478,9 @@ mod translate { scheduler_info: r.scheduler_info, active_requests: r.active_requests, is_paused: r.is_paused, - // TokenSpeed scheduler doesn't track this — router doesn't read - // it for TokenSpeed either, so a fixed 0 is fine. last_receive_timestamp: 0.0, uptime_seconds: r.uptime_seconds, - // sglang_version field on the SGLang struct is the runtime version; - // for TokenSpeed we surface the TokenSpeed version through the same - // slot since downstream metric labels keep the field name. + // SGLang's `sglang_version` slot carries the runtime version label. sglang_version: r.tokenspeed_version, server_type: "grpc".to_string(), start_time: r.start_time, @@ -572,9 +517,6 @@ mod translate { graph_gb: m.graph_gb, token_capacity: m.token_capacity, }), - // TokenSpeed's wire intentionally omits speculative / LoRA / - // disaggregation metrics — fill the SGLang-shaped slots with - // None so callers ignore them. speculative: None, lora: None, disaggregation: None, diff --git a/model_gateway/src/routers/grpc/client.rs b/model_gateway/src/routers/grpc/client.rs index 16503e7ae..004858607 100644 --- a/model_gateway/src/routers/grpc/client.rs +++ b/model_gateway/src/routers/grpc/client.rs @@ -23,10 +23,8 @@ pub struct HealthCheckResponse { pub message: String, } -/// Wraps the per-backend gRPC clients. TokenSpeed has its own service but -/// reuses SGLang-shaped wrapper variants where the wire shapes line up -/// after translation; RPCs absent on a backend's wire return -/// ``Status::unimplemented``. +/// Wraps the per-backend gRPC clients. RPCs absent on a backend's wire +/// return `Status::unimplemented`. #[derive(Clone)] pub enum GrpcClient { Sglang(SglangSchedulerClient), @@ -296,9 +294,6 @@ impl GrpcClient { Self::Vllm(client) => client.get_tokenizer().await, Self::Trtllm(client) => client.get_tokenizer().await, Self::Mlx(client) => client.get_tokenizer().await, - // Status::unimplemented (not a String error) so the fallback in - // tokenizer_registration's downcast_ref::() check - // skips TokenSpeed workers silently. Self::TokenSpeed(_) => { return Err(Box::new(tonic::Status::unimplemented( "TokenSpeed backend does not support GetTokenizer RPC", @@ -366,8 +361,6 @@ impl GrpcClient { let resp = client.embed(*boxed_req).await?; Ok(ProtoEmbedComplete::Vllm(resp)) } - // TokenSpeed dropped the Embed RPC from its wire — top-tier - // LLMs aren't embedding models, so the proto doesn't carry one. (Self::TokenSpeed(_), _) => Err(tonic::Status::unimplemented( "TokenSpeed backend does not support embedding", )), @@ -452,9 +445,6 @@ impl GrpcClient { )?; Ok(ProtoGenerateRequest::Mlx(Box::new(req))) } - // TokenSpeed's wire intentionally has no multimodal fields - // (top-tier LLMs are text-only today). Reject if the assembly - // stage produced any — that's a router-config bug. Self::TokenSpeed(client) => { if multimodal_inputs.is_some() { return Err("TokenSpeed backend does not support multimodal inputs".to_string()); diff --git a/model_gateway/src/routers/grpc/common/stages/helpers.rs b/model_gateway/src/routers/grpc/common/stages/helpers.rs index da6a2d514..a0fcff107 100644 --- a/model_gateway/src/routers/grpc/common/stages/helpers.rs +++ b/model_gateway/src/routers/grpc/common/stages/helpers.rs @@ -159,8 +159,6 @@ pub(crate) fn apply_sampling_defaults_to_generate_request( }; apply_mlx_sampling_defaults(params, defaults, mask); } - // TokenSpeed and TRT-LLM are early-returned above; the arms exist - // only to keep the match exhaustive. ProtoGenerateRequest::Trtllm(_) | ProtoGenerateRequest::TokenSpeed(_) => {} } } diff --git a/model_gateway/src/routers/grpc/common/stages/request_execution.rs b/model_gateway/src/routers/grpc/common/stages/request_execution.rs index 1447e025b..950b6e4dc 100644 --- a/model_gateway/src/routers/grpc/common/stages/request_execution.rs +++ b/model_gateway/src/routers/grpc/common/stages/request_execution.rs @@ -114,9 +114,6 @@ impl PipelineStage for RequestExecutionStage { } Some(RuntimeType::Trtllm) | Some(RuntimeType::Mlx) - // TokenSpeed shares the SGLang proto but doesn't - // ship PD-disaggregation support yet — treat it - // like the other non-PD backends here. | Some(RuntimeType::TokenSpeed) | Some(RuntimeType::External) | Some(RuntimeType::Unspecified) => { diff --git a/model_gateway/src/routers/grpc/harmony/stages/request_building.rs b/model_gateway/src/routers/grpc/harmony/stages/request_building.rs index 3568a6e3e..edce1e2ee 100644 --- a/model_gateway/src/routers/grpc/harmony/stages/request_building.rs +++ b/model_gateway/src/routers/grpc/harmony/stages/request_building.rs @@ -277,9 +277,6 @@ impl PipelineStage for HarmonyRequestBuildingStage { }; ProtoGenerateRequest::Mlx(Box::new(req)) } - // TokenSpeed: builder produces a native - // ``tokenspeed::GenerateRequest``; multimodal is intentionally - // not supported here — the harmony path is text-only today. GrpcClient::TokenSpeed(tokenspeed_client) => { let req = match &ctx.input.request_type { RequestType::Chat(request) => { diff --git a/model_gateway/src/routers/grpc/multimodal.rs b/model_gateway/src/routers/grpc/multimodal.rs index f279b6156..a30d42135 100644 --- a/model_gateway/src/routers/grpc/multimodal.rs +++ b/model_gateway/src/routers/grpc/multimodal.rs @@ -708,10 +708,6 @@ pub(crate) fn assemble_multimodal_data( GrpcClient::Mlx(_) => unreachable!( "caller rejects multimodal for MLX in build_chat_request/build_messages_request" ), - // TokenSpeed's wire intentionally has no multimodal fields. The - // detect-backend / preparation stages never enable multimodal for a - // text-only top-tier LLM, so reaching this arm is a router-config - // bug rather than a user error. GrpcClient::TokenSpeed(_) => unreachable!( "TokenSpeed backend does not support multimodal; preparation stage should reject earlier" ), From 656f1c2e98ed37212be55990ea946bae6efe1f27 Mon Sep 17 00:00:00 2001 From: key4ng Date: Sun, 10 May 2026 20:41:44 -0700 Subject: [PATCH 06/24] style(grpc): clean up code comments Tightens the comments introduced (or modified) by this PR to describe behavior directly. No behavior change; literal type / path references left intact. Files touched: - crates/grpc_client/build.rs - crates/grpc_client/proto/tokenspeed_scheduler.proto - crates/grpc_client/src/sampling_params.rs - crates/grpc_client/src/tokenspeed_scheduler.rs - model_gateway/src/routers/grpc/client.rs - model_gateway/src/routers/grpc/proto_wrapper.rs Signed-off-by: key4ng --- crates/grpc_client/build.rs | 3 +-- crates/grpc_client/proto/tokenspeed_scheduler.proto | 10 +++++----- crates/grpc_client/src/sampling_params.rs | 9 ++++----- crates/grpc_client/src/tokenspeed_scheduler.rs | 13 ++++++------- model_gateway/src/routers/grpc/client.rs | 5 +++-- model_gateway/src/routers/grpc/proto_wrapper.rs | 6 +++--- 6 files changed, 22 insertions(+), 24 deletions(-) diff --git a/crates/grpc_client/build.rs b/crates/grpc_client/build.rs index f03ed71c1..8a8730c1a 100644 --- a/crates/grpc_client/build.rs +++ b/crates/grpc_client/build.rs @@ -20,8 +20,7 @@ fn main() -> Result<(), Box> { .build_client(true) .extern_path(".smg.grpc.common", "crate::common_proto") .type_attribute("GetModelInfoResponse", "#[derive(serde::Serialize)]") - // vllm + trtllm ServerInfo have only primitive fields. - // sglang's and tokenspeed's contain prost_types::{Struct,Timestamp}; + // Some ServerInfo protos contain prost_types::{Struct, Timestamp}; // those are handled separately at the wrapper layer. .type_attribute( "vllm.grpc.engine.GetServerInfoResponse", diff --git a/crates/grpc_client/proto/tokenspeed_scheduler.proto b/crates/grpc_client/proto/tokenspeed_scheduler.proto index fc5621014..8062050b9 100644 --- a/crates/grpc_client/proto/tokenspeed_scheduler.proto +++ b/crates/grpc_client/proto/tokenspeed_scheduler.proto @@ -5,9 +5,9 @@ package tokenspeed.grpc.scheduler; import "google/protobuf/timestamp.proto"; import "google/protobuf/struct.proto"; -// TokenSpeed scheduler gRPC service. Self-contained wire — no dependency -// on sglang_scheduler.proto. Trimmed to text-generation only (no embed, -// no multimodal, no PD-disaggregated, no LoRA, no hidden-state forwarding). +// TokenSpeed scheduler gRPC service. Fully self-contained wire definition. +// Trimmed to text-generation only (no embed, no multimodal, no +// PD-disaggregated, no LoRA, no hidden-state forwarding). service TokenSpeedScheduler { rpc Generate(GenerateRequest) returns (stream GenerateResponse); rpc HealthCheck(HealthCheckRequest) returns (HealthCheckResponse); @@ -57,7 +57,7 @@ message SamplingParams { string structural_tag = 20; } - // Keep the trailing matched stop token in `output_ids` (SGLang parity). + // Keep the trailing matched stop token in `output_ids`. bool no_stop_trim = 22; // Escape hatch for backend-specific knobs without bumping the proto. @@ -78,7 +78,7 @@ message GenerateRequest { // Logprob options. bool return_logprob = 4; - // Optional: unset = SGLang's `-1` (no input logprobs); explicit 0 = "from start". + // Optional: unset = no input logprobs; explicit 0 = "from start". optional int32 logprob_start_len = 5; int32 top_logprobs_num = 6; repeated uint32 token_ids_logprob = 7; diff --git a/crates/grpc_client/src/sampling_params.rs b/crates/grpc_client/src/sampling_params.rs index bf4a6640e..4a28e7076 100644 --- a/crates/grpc_client/src/sampling_params.rs +++ b/crates/grpc_client/src/sampling_params.rs @@ -1,7 +1,6 @@ -//! Backend-neutral OpenAI → sampling-params builders shared by the SGLang -//! and TokenSpeed gRPC clients. Returns [`sglang::SamplingParams`] (the -//! most permissive shape today); other backends translate from this at -//! their wire seam. +//! Backend-neutral OpenAI → sampling-params builders. The return type +//! happens to be [`proto::SamplingParams`] (the most permissive shape +//! today); other backends translate from this at their wire seam. use openai_protocol::{ chat::ChatCompletionRequest, @@ -267,7 +266,7 @@ fn build_constraint_for_chat( } // If response_format already set a constraint, drop the tool constraint - // (matches SGLang HTTP behavior where response_format takes priority). + // (response_format takes priority over tool-call constraints). if let Some((constraint_type, constraint_value)) = tool_call_constraint { if constraints.is_empty() { let tool_constraint = match constraint_type.as_str() { diff --git a/crates/grpc_client/src/tokenspeed_scheduler.rs b/crates/grpc_client/src/tokenspeed_scheduler.rs index f06893edd..c8c55b6c1 100644 --- a/crates/grpc_client/src/tokenspeed_scheduler.rs +++ b/crates/grpc_client/src/tokenspeed_scheduler.rs @@ -4,8 +4,8 @@ //! Sampling params come from the shared `crate::sampling_params` helpers //! and are field-mapped to TokenSpeed's shape via [`translate::sampling_params`]. //! The unary RPC adapters (`translate::model_info` / `server_info` / `loads`) -//! still emit SGLang-shaped types pending dedicated `ModelInfo::TokenSpeed` / -//! `ServerInfo::TokenSpeed` arms in the router. +//! still emit the legacy router-side shape pending dedicated +//! `ModelInfo::TokenSpeed` / `ServerInfo::TokenSpeed` arms in the router. use std::{ pin::Pin, @@ -401,14 +401,14 @@ fn tokenspeed_abort_dispatcher(client: TokenSpeedSchedulerClient) -> AbortDispat } // Sampling-params + unary RPC adapters: map between TokenSpeed's wire -// shape and the SGLang shape the router metadata wrappers consume. +// shape and the router-side shape the metadata wrappers consume. mod translate { use super::{sglang, tokenspeed_proto as ts}; pub(super) fn sampling_params(s: sglang::SamplingParams) -> ts::SamplingParams { - // SGLang scalars are non-optional with semantic defaults already - // applied; TokenSpeed wraps in `Some(_)` so the servicer's - // `HasField()` checks distinguish "set" from "unset". + // Source scalars are non-optional with semantic defaults already + // applied; wrap in `Some(_)` so the servicer's `HasField()` checks + // distinguish "set" from "unset" on the wire. ts::SamplingParams { temperature: Some(s.temperature), top_p: Some(s.top_p), @@ -480,7 +480,6 @@ mod translate { is_paused: r.is_paused, last_receive_timestamp: 0.0, uptime_seconds: r.uptime_seconds, - // SGLang's `sglang_version` slot carries the runtime version label. sglang_version: r.tokenspeed_version, server_type: "grpc".to_string(), start_time: r.start_time, diff --git a/model_gateway/src/routers/grpc/client.rs b/model_gateway/src/routers/grpc/client.rs index 004858607..4604d75f2 100644 --- a/model_gateway/src/routers/grpc/client.rs +++ b/model_gateway/src/routers/grpc/client.rs @@ -236,7 +236,7 @@ impl GrpcClient { } /// Get the full load response from the backend. - /// Supported for SGLang and TokenSpeed backends. + /// Returns `Unimplemented` for backends without scheduler load metrics. pub async fn get_loads(&self) -> Result { match self { Self::Sglang(client) => { @@ -253,7 +253,8 @@ impl GrpcClient { } } - /// Subscribe to KV cache events (SGLang / vLLM / TRT-LLM only). + /// Subscribe to KV cache events. Returns `Unimplemented` on backends + /// without KV-event streaming. pub async fn subscribe_kv_events( &self, start_seq: u64, diff --git a/model_gateway/src/routers/grpc/proto_wrapper.rs b/model_gateway/src/routers/grpc/proto_wrapper.rs index 27128b150..48c7aef22 100644 --- a/model_gateway/src/routers/grpc/proto_wrapper.rs +++ b/model_gateway/src/routers/grpc/proto_wrapper.rs @@ -1,4 +1,4 @@ -//! Protocol buffer type wrappers for SGLang, vLLM, TensorRT-LLM, MLX, and TokenSpeed backends. +//! Protocol buffer type wrappers for the supported gRPC backends. //! //! This module provides unified enums that wrap proto types from each //! supported backend, allowing the router to work with any backend @@ -650,7 +650,7 @@ impl ProtoGenerateStreamChunk { } } - /// Get output logprobs (SGLang, vLLM, TensorRT-LLM, MLX, and TokenSpeed) + /// Get output logprobs. pub fn output_logprobs(&self) -> Option { match self { Self::Sglang(c) => c @@ -967,7 +967,7 @@ impl ProtoGenerateComplete { } } - /// Get output logprobs (SGLang, vLLM, TensorRT-LLM, MLX, and TokenSpeed) + /// Get output logprobs. pub fn output_logprobs(&self) -> Option { match self { Self::Sglang(c) => c From c478f1acd742434fb800cc9b968c31f1002a3118 Mon Sep 17 00:00:00 2001 From: key4ng Date: Sun, 10 May 2026 20:52:29 -0700 Subject: [PATCH 07/24] fix(grpc): apply model sampling defaults to TokenSpeed requests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The IR refactor split TokenSpeed off the SGLang request arm, which left two gaps in the sampling-defaults path: 1. `translate::model_info` populated the `preferred_sampling_params` label slot from TokenSpeed's response but left `default_sampling_params_json` empty. Worker discovery reads only the latter, so model-published defaults never reached the label map and were invisible to the request-stage injector. 2. `apply_sampling_defaults_to_generate_request` early-returned for `ProtoGenerateRequest::TokenSpeed(_)`. Even if a worker's labels carried defaults, the injector skipped the TokenSpeed arm — so a model publishing `temperature=0.7` in its generation config would apply to the other engines but TokenSpeed would fall through to the hardcoded 1.0 in the request builder. Both fixed: - `tokenspeed_scheduler.rs`: surface `preferred_sampling_params` in both `preferred_sampling_params` and `default_sampling_params_json` slots so the discovery path picks it up. - `helpers.rs`: drop TokenSpeed from the early-return, add a `TokenSpeed(req)` match arm calling a new `apply_tokenspeed_sampling_defaults`. TokenSpeed's wire declares every sampling scalar as `optional`, so the helper writes `Some(value)` rather than the bare value — preserving the set-vs-unset distinction the servicer's `HasField()` checks rely on. Verification: `cargo +nightly fmt --all -- --check` and `cargo clippy -p smg-grpc-client -p smg --all-targets --all-features -- -D warnings` both pass. Signed-off-by: key4ng --- .../grpc_client/src/tokenspeed_scheduler.rs | 9 +++- .../src/routers/grpc/common/stages/helpers.rs | 42 ++++++++++++++++--- 2 files changed, 43 insertions(+), 8 deletions(-) diff --git a/crates/grpc_client/src/tokenspeed_scheduler.rs b/crates/grpc_client/src/tokenspeed_scheduler.rs index c8c55b6c1..f82ec5974 100644 --- a/crates/grpc_client/src/tokenspeed_scheduler.rs +++ b/crates/grpc_client/src/tokenspeed_scheduler.rs @@ -450,11 +450,16 @@ mod translate { } pub(super) fn model_info(r: ts::GetModelInfoResponse) -> sglang::GetModelInfoResponse { + // Surface TokenSpeed's `preferred_sampling_params` JSON in both label + // fields the discovery path may consult, so worker-discovery can + // expose model-published defaults (`temperature`, `top_p`, etc.) to + // the router's default-injection stage. + let preferred = r.preferred_sampling_params; sglang::GetModelInfoResponse { model_path: r.model_path, tokenizer_path: r.tokenizer_path, is_generation: true, - preferred_sampling_params: r.preferred_sampling_params, + preferred_sampling_params: preferred.clone(), weight_version: r.weight_version, served_model_name: r.served_model_name, max_context_length: r.max_context_length, @@ -468,7 +473,7 @@ mod translate { architectures: r.architectures, id2label_json: String::new(), num_labels: 0, - default_sampling_params_json: String::new(), + default_sampling_params_json: preferred, } } diff --git a/model_gateway/src/routers/grpc/common/stages/helpers.rs b/model_gateway/src/routers/grpc/common/stages/helpers.rs index a0fcff107..f780c4879 100644 --- a/model_gateway/src/routers/grpc/common/stages/helpers.rs +++ b/model_gateway/src/routers/grpc/common/stages/helpers.rs @@ -6,7 +6,7 @@ use rand::Rng; use smg_grpc_client::{ mlx_proto, sglang_proto::{self, DisaggregatedParams}, - vllm_proto, + tokenspeed_proto, vllm_proto, }; use tracing::{debug, warn}; @@ -119,10 +119,7 @@ pub(crate) fn apply_sampling_defaults_to_generate_request( request_type: &RequestType, workers: Option<&WorkerSelection>, ) { - if matches!( - request, - ProtoGenerateRequest::Trtllm(_) | ProtoGenerateRequest::TokenSpeed(_) - ) { + if matches!(request, ProtoGenerateRequest::Trtllm(_)) { return; } @@ -159,7 +156,16 @@ pub(crate) fn apply_sampling_defaults_to_generate_request( }; apply_mlx_sampling_defaults(params, defaults, mask); } - ProtoGenerateRequest::Trtllm(_) | ProtoGenerateRequest::TokenSpeed(_) => {} + ProtoGenerateRequest::TokenSpeed(req) => { + let Some(params) = req.sampling_params.as_mut() else { + warn!( + "Cannot apply sampling defaults to TokenSpeed request without sampling_params" + ); + return; + }; + apply_tokenspeed_sampling_defaults(params, defaults, mask); + } + ProtoGenerateRequest::Trtllm(_) => {} } } @@ -221,6 +227,30 @@ optional_temperature_sampling_defaults_fn!( ); optional_temperature_sampling_defaults_fn!(apply_mlx_sampling_defaults, mlx_proto::SamplingParams); +/// TokenSpeed declares every sampling scalar as `optional` so the servicer +/// can distinguish "client set 0" from "client unset". Apply defaults by +/// writing `Some(value)` rather than the bare value. +fn apply_tokenspeed_sampling_defaults( + params: &mut tokenspeed_proto::SamplingParams, + defaults: SamplingDefaults, + mask: SamplingDefaultsMask, +) { + macro_rules! apply_opt { + ($field:ident) => { + if mask.$field { + if let Some(value) = defaults.$field { + params.$field = Some(value); + } + } + }; + } + apply_opt!(temperature); + apply_opt!(top_p); + apply_opt!(top_k); + apply_opt!(min_p); + apply_opt!(repetition_penalty); +} + /// Inject PD bootstrap metadata for SGLang if needed. /// /// SGLang uses DisaggregatedParams with bootstrap host/port/room. From e93dca991c9f3f2b9d73854cd84310d86de3a698 Mon Sep 17 00:00:00 2001 From: key4ng Date: Sun, 10 May 2026 21:01:00 -0700 Subject: [PATCH 08/24] style(grpc): tidy lib.rs comments - Collapse the duplicate 2-line crate-level docstring into a single neutral line covering all supported backends. - Drop the 5-line TokenSpeed re-export rationale block; the design rationale already lives in `tokenspeed_scheduler.rs`'s module doc. Signed-off-by: key4ng --- crates/grpc_client/src/lib.rs | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/crates/grpc_client/src/lib.rs b/crates/grpc_client/src/lib.rs index 79dddc333..0e33a0169 100644 --- a/crates/grpc_client/src/lib.rs +++ b/crates/grpc_client/src/lib.rs @@ -1,7 +1,4 @@ -//! gRPC clients for SGLang, vLLM, TensorRT-LLM, and MLX backends -//! -//! This crate provides gRPC client implementations for communicating with -//! SGLang scheduler, vLLM engine, TensorRT-LLM engine, and MLX engine backends. +//! gRPC clients for the supported inference backends. pub mod common_proto { #![allow(clippy::all, clippy::absolute_paths, unused_qualifications)] @@ -20,11 +17,6 @@ use std::sync::Arc; pub use mlx_engine::{proto as mlx_proto, MlxEngineClient}; pub use sglang_scheduler::{proto as sglang_proto, SglangSchedulerClient}; -// TokenSpeed has a fully independent wire definition (see -// ``proto/tokenspeed_scheduler.proto``) — distinct service, distinct -// messages with intentionally trimmed field sets aimed at top-tier LLM -// workloads. The client wraps that wire and translates to/from SGLang-shaped -// types at the boundary so the router's dispatch enums don't proliferate. pub use tokenspeed_scheduler::{tokenspeed_proto, TokenSpeedSchedulerClient}; use tonic::metadata::MetadataMap; pub use trtllm_service::{proto as trtllm_proto, TrtllmServiceClient}; From 6f841017c43bdb69857c514db58ea6e192eea524 Mon Sep 17 00:00:00 2001 From: key4ng Date: Fri, 8 May 2026 12:42:08 -0700 Subject: [PATCH 09/24] feat(grpc_servicer): add TokenSpeed servicer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds the Python servicer that runs alongside a TokenSpeed scheduler process and serves the gRPC protocol PR1 introduced. Includes: - the async scheduler servicer (Generate/HealthCheck/Abort/ GetModelInfo/GetServerInfo/GetLoads), with cancellation handling for streaming, non-streaming, channel-close, and n>1 paths - a health-service bridge that flips SERVING/NOT_SERVING based on scheduler liveness (deep probe with bounded staleness) - a scheduler launcher that boots TokenSpeed's AsyncLLM in-process - the ``python -m smg_grpc_servicer.tokenspeed`` entrypoint - real ``GetLoads`` plumbing backed by ``AsyncLLM.get_load()`` so router-side load balancing reflects scheduler-side metrics - 57 unit tests covering the servicer, health service, proto conversion, finish reasons, sampling params, streaming/non- streaming generation, abort/cancel (incl. n>1), model/server info, and load metrics This is part 2 of 3 splitting #1351: - PR1: Rust gRPC + protocol (merged first) - PR2 (this): Python servicer + unit tests - PR3: CI workflows + e2e tests Stacked on PR1 — the servicer imports the proto stubs PR1 generates from ``crates/grpc_client/proto/tokenspeed_scheduler.proto``. Fixes a 🔴 critical from review on #1351: - FakeAsyncLLM.generate_request crashed with ``TypeError: unhashable type: 'list'`` when n>1, because ``_build_generate_req`` rewrites ``rid`` to a list of per-choice ids. The fake engine now registers state for each child rid, so ``test_cancel_aborts_all_n_children`` exercises the cancel sweep instead of dying at setup. Signed-off-by: key4ng --- grpc_servicer/pyproject.toml | 19 +- .../smg_grpc_servicer/tokenspeed/__init__.py | 11 + .../smg_grpc_servicer/tokenspeed/__main__.py | 71 ++ .../tokenspeed/health_servicer.py | 130 ++ .../tokenspeed/scheduler_launcher.py | 60 + .../smg_grpc_servicer/tokenspeed/server.py | 195 +++ .../smg_grpc_servicer/tokenspeed/servicer.py | 909 ++++++++++++++ grpc_servicer/tests/__init__.py | 0 grpc_servicer/tests/conftest.py | 22 + .../tests/test_tokenspeed_health_servicer.py | 98 ++ .../tests/test_tokenspeed_servicer.py | 1048 +++++++++++++++++ 11 files changed, 2562 insertions(+), 1 deletion(-) create mode 100644 grpc_servicer/smg_grpc_servicer/tokenspeed/__init__.py create mode 100644 grpc_servicer/smg_grpc_servicer/tokenspeed/__main__.py create mode 100644 grpc_servicer/smg_grpc_servicer/tokenspeed/health_servicer.py create mode 100644 grpc_servicer/smg_grpc_servicer/tokenspeed/scheduler_launcher.py create mode 100644 grpc_servicer/smg_grpc_servicer/tokenspeed/server.py create mode 100644 grpc_servicer/smg_grpc_servicer/tokenspeed/servicer.py create mode 100644 grpc_servicer/tests/__init__.py create mode 100644 grpc_servicer/tests/conftest.py create mode 100644 grpc_servicer/tests/test_tokenspeed_health_servicer.py create mode 100644 grpc_servicer/tests/test_tokenspeed_servicer.py diff --git a/grpc_servicer/pyproject.toml b/grpc_servicer/pyproject.toml index bea1e10f9..893255b98 100644 --- a/grpc_servicer/pyproject.toml +++ b/grpc_servicer/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta" [project] name = "smg-grpc-servicer" version = "0.5.2" -description = "SMG gRPC servicer implementations for LLM inference engines (vLLM, SGLang, MLX)" +description = "SMG gRPC servicer implementations for LLM inference engines (vLLM, SGLang, MLX, TokenSpeed)" requires-python = ">=3.10" dependencies = [ "smg-grpc-proto>=0.4.6", @@ -36,6 +36,23 @@ sglang = ["sglang>=0.5.10"] # without this floor, installing [mlx] against an older proto build would # crash at import time when smg_grpc_servicer.mlx.server runs. mlx = ["smg-grpc-proto>=0.4.7", "mlx>=0.22.0", "mlx-lm>=0.22.0"] +# Note: there is intentionally no ``tokenspeed`` extra. TokenSpeed is not +# published to PyPI; it is installed out-of-tree from the lightseekorg +# checkout via ``scripts/ci_install_tokenspeed.sh`` (CI) or a manual +# ``pip install -e ./tokenspeed/python`` (local dev). An extra named +# ``tokenspeed`` would imply ``pip install smg-grpc-servicer[tokenspeed]`` +# yields a working tokenspeed setup; it does not. +test = [ + "pytest>=8.0.0", + "pytest-asyncio>=0.23.0", +] + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] +markers = [ + "tokenspeed: tests that require TokenSpeed", +] [project.urls] Homepage = "https://github.com/lightseekorg/smg" diff --git a/grpc_servicer/smg_grpc_servicer/tokenspeed/__init__.py b/grpc_servicer/smg_grpc_servicer/tokenspeed/__init__.py new file mode 100644 index 000000000..d5ced6c52 --- /dev/null +++ b/grpc_servicer/smg_grpc_servicer/tokenspeed/__init__.py @@ -0,0 +1,11 @@ +"""TokenSpeed gRPC servicer implementation. + +Mirrors smg_grpc_servicer.vllm / smg_grpc_servicer.sglang. Wraps TokenSpeed's +AsyncLLM (main-process async frontend) behind the SGLang gRPC service so the +existing Rust router (which auto-detects the SGLang proto) can route traffic +to TokenSpeed without needing a new client. +""" + +from smg_grpc_servicer.tokenspeed.servicer import TokenSpeedSchedulerServicer + +__all__ = ["TokenSpeedSchedulerServicer"] diff --git a/grpc_servicer/smg_grpc_servicer/tokenspeed/__main__.py b/grpc_servicer/smg_grpc_servicer/tokenspeed/__main__.py new file mode 100644 index 000000000..fb80dcace --- /dev/null +++ b/grpc_servicer/smg_grpc_servicer/tokenspeed/__main__.py @@ -0,0 +1,71 @@ +"""CLI entrypoint for the TokenSpeed gRPC server. + +Usage:: + + python -m smg_grpc_servicer.tokenspeed --model-path --host 127.0.0.1 --port 50051 + +All :class:`tokenspeed.runtime.utils.server_args.ServerArgs` flags are accepted +verbatim (we reuse TokenSpeed's own ``prepare_server_args`` so there is no +flag drift between the HTTP and gRPC frontends). +""" + +from __future__ import annotations + +import asyncio +import logging +import sys + +import uvloop +from tokenspeed.runtime.utils.server_args import prepare_server_args + +from smg_grpc_servicer.tokenspeed.server import serve_grpc + + +def main(argv: list[str] | None = None) -> None: + if argv is None: + argv = sys.argv[1:] + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(name)s] %(levelname)s %(message)s", + ) + + # TokenSpeed's ``ServerArgs.resolve_kernel_backends`` defaults + # ``sampling_backend`` to ``"greedy"`` when the user doesn't pass + # ``--sampling-backend``. The greedy backend is argmax-only and + # ignores per-request ``temperature``/``top_p``/``top_k`` — fine for + # the legacy CLI where users opt in to sampling explicitly, but + # disastrous for a gateway-fronted gRPC servicer where per-request + # sampling params arrive on every call. With Llama-3.2-1B the + # always-argmax behavior collapses into single-token loops + # (\\n×N, ' ('×N, "no"×N) within a few hundred steps and + # generation runs to ``max_new_tokens`` — the smg e2e function-calling + # suite makes this directly observable. Force a sampling-respecting + # default unless the operator explicitly chose one. + if not any(a == "--sampling-backend" or a.startswith("--sampling-backend=") for a in argv): + argv = [*argv, "--sampling-backend", "flashinfer"] + + # TokenSpeed's logprob computation is gated by ``--enable-output-logprobs`` + # (default OFF, see ``ServerArgs.enable_output_logprobs``); without the + # flag, requests asking for logprobs receive empty arrays rather than an + # error. The smg gateway's OpenAI-compat path expects per-token logprobs + # whenever ``logprobs=True`` is set, so flip the flag on by default for a + # gateway-fronted gRPC servicer. Operators who want the smaller CUDA-graph + # footprint can pass ``--enable-output-logprobs=False`` explicitly. + # ``--enable-top-logprobs`` is intentionally NOT injected: TokenSpeed + # raises at startup when it's set (the path is not yet implemented). + if not any( + a == "--enable-output-logprobs" or a.startswith("--enable-output-logprobs=") for a in argv + ): + argv = [*argv, "--enable-output-logprobs"] + + server_args = prepare_server_args(argv) + # The scheduler processes will read these env vars; make sure we ran + # through TokenSpeed's shared env/resource setup path instead of + # duplicating it here. + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + asyncio.run(serve_grpc(server_args)) + + +if __name__ == "__main__": + main() diff --git a/grpc_servicer/smg_grpc_servicer/tokenspeed/health_servicer.py b/grpc_servicer/smg_grpc_servicer/tokenspeed/health_servicer.py new file mode 100644 index 000000000..d6b04a62a --- /dev/null +++ b/grpc_servicer/smg_grpc_servicer/tokenspeed/health_servicer.py @@ -0,0 +1,130 @@ +"""Standard ``grpc.health.v1.Health`` servicer for the TokenSpeed backend. + +Mirrors ``smg_grpc_servicer.sglang.health_servicer.SGLangHealthServicer`` — +same service-name semantics, same lifecycle (NOT_SERVING → SERVING → NOT_SERVING), +same ``check/watch`` contract — but sources liveness signals from a TokenSpeed +:class:`AsyncLLM` instead of an SGLang ``GrpcRequestManager``. + +The Rust router uses this health check to auto-detect the backend runtime. +TokenSpeed ships its own ``tokenspeed.grpc.scheduler.TokenSpeedScheduler`` +service identity (see ``proto/tokenspeed_scheduler.proto``) so the probe +distinguishes TokenSpeed workers from real SGLang workers regardless of any +wire-level message-type sharing between the two backends. +""" + +from __future__ import annotations + +import logging +import time +from collections.abc import AsyncIterator +from typing import TYPE_CHECKING + +import grpc +from grpc_health.v1 import health_pb2, health_pb2_grpc +from smg_grpc_proto.generated import tokenspeed_scheduler_pb2 + +if TYPE_CHECKING: + from tokenspeed.runtime.engine.async_llm import AsyncLLM + +logger = logging.getLogger(__name__) + +# Seconds of scheduler silence — with pending requests — before we report +# NOT_SERVING. Matches the SGLang equivalent so oncall dashboards are aligned. +STUCK_SCHEDULER_THRESHOLD_SEC = 30.0 + +# Source the advertised service name from the proto descriptor so a future +# ``package`` or ``service`` rename in tokenspeed_scheduler.proto stays in +# sync without a hand-edited string here. +TOKENSPEED_SCHEDULER_SERVICE_NAME = tokenspeed_scheduler_pb2.DESCRIPTOR.services_by_name[ + "TokenSpeedScheduler" +].full_name + + +class TokenSpeedHealthServicer(health_pb2_grpc.HealthServicer): + """Health servicer that tracks TokenSpeed's AsyncLLM liveness. + + Advertises two service levels: + + * ``""`` (empty) — overall server health, flipped to SERVING once the + warmup request succeeds and back to NOT_SERVING on shutdown. + * ``tokenspeed.grpc.scheduler.TokenSpeedScheduler`` — readiness: the + base status, plus a scheduler-responsiveness check (if there are + pending requests but the scheduler hasn't pushed output for >30s, + report NOT_SERVING). + """ + + OVERALL_SERVER = "" + TOKENSPEED_SERVICE = TOKENSPEED_SCHEDULER_SERVICE_NAME + + def __init__(self, async_llm: AsyncLLM, scheduler_info: dict): + self.async_llm = async_llm + self.scheduler_info = scheduler_info + self._serving_status: dict[str, int] = { + self.OVERALL_SERVER: health_pb2.HealthCheckResponse.NOT_SERVING, + self.TOKENSPEED_SERVICE: health_pb2.HealthCheckResponse.NOT_SERVING, + } + logger.info("TokenSpeed gRPC health service initialized") + + def set_serving(self) -> None: + """Flip both services to SERVING (call after successful warmup).""" + self._serving_status[self.OVERALL_SERVER] = health_pb2.HealthCheckResponse.SERVING + self._serving_status[self.TOKENSPEED_SERVICE] = health_pb2.HealthCheckResponse.SERVING + logger.info("TokenSpeed gRPC health status -> SERVING") + + def set_not_serving(self) -> None: + """Flip both services to NOT_SERVING (call on shutdown).""" + self._serving_status[self.OVERALL_SERVER] = health_pb2.HealthCheckResponse.NOT_SERVING + self._serving_status[self.TOKENSPEED_SERVICE] = health_pb2.HealthCheckResponse.NOT_SERVING + logger.info("TokenSpeed gRPC health status -> NOT_SERVING") + + async def Check( + self, + request: health_pb2.HealthCheckRequest, + context: grpc.aio.ServicerContext, + ) -> health_pb2.HealthCheckResponse: + service_name = request.service + logger.debug("Health check request for service=%r", service_name) + + if self.async_llm.gracefully_exit: + return health_pb2.HealthCheckResponse(status=health_pb2.HealthCheckResponse.NOT_SERVING) + + if service_name == self.OVERALL_SERVER: + return health_pb2.HealthCheckResponse( + status=self._serving_status.get( + self.OVERALL_SERVER, health_pb2.HealthCheckResponse.NOT_SERVING + ) + ) + + if service_name == self.TOKENSPEED_SERVICE: + base = self._serving_status.get( + self.TOKENSPEED_SERVICE, health_pb2.HealthCheckResponse.NOT_SERVING + ) + if base != health_pb2.HealthCheckResponse.SERVING: + return health_pb2.HealthCheckResponse(status=base) + + # Scheduler-stuck check: pending work but no recent output. + time_since_last_receive = time.time() - self.async_llm.last_receive_tstamp + pending = len(self.async_llm.rid_to_state) + if time_since_last_receive > STUCK_SCHEDULER_THRESHOLD_SEC and pending > 0: + logger.warning( + "Scheduler appears stuck: %.1fs since last receive, %d pending requests", + time_since_last_receive, + pending, + ) + return health_pb2.HealthCheckResponse( + status=health_pb2.HealthCheckResponse.NOT_SERVING + ) + + return health_pb2.HealthCheckResponse(status=health_pb2.HealthCheckResponse.SERVING) + + context.set_code(grpc.StatusCode.NOT_FOUND) + context.set_details(f"Unknown service: {service_name}") + return health_pb2.HealthCheckResponse(status=health_pb2.HealthCheckResponse.SERVICE_UNKNOWN) + + async def Watch( + self, + request: health_pb2.HealthCheckRequest, + context: grpc.aio.ServicerContext, + ) -> AsyncIterator[health_pb2.HealthCheckResponse]: + # K8s probes use Check, not Watch — we emit the current status once. + yield await self.Check(request, context) diff --git a/grpc_servicer/smg_grpc_servicer/tokenspeed/scheduler_launcher.py b/grpc_servicer/smg_grpc_servicer/tokenspeed/scheduler_launcher.py new file mode 100644 index 000000000..64acb18fa --- /dev/null +++ b/grpc_servicer/smg_grpc_servicer/tokenspeed/scheduler_launcher.py @@ -0,0 +1,60 @@ +"""Scheduler subprocess launcher for the TokenSpeed gRPC server. + +Mirrors ``smg_grpc_servicer.sglang.scheduler_launcher`` but delegates to +TokenSpeed's ``_launch_subprocesses``: we get back a fully-initialised +``AsyncLLM`` along with the scheduler info dict. All scheduler/DP-controller +spawning, multiprocessing start-method, and env priming already live inside +``_launch_subprocesses`` — we only wrap it to return what the gRPC server +cares about and to keep the call site symmetric with the sibling backends. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from tokenspeed.runtime.engine.async_llm import AsyncLLM +from tokenspeed.runtime.entrypoints.engine import _launch_subprocesses +from tokenspeed.runtime.utils.server_args import PortArgs, ServerArgs + +logger = logging.getLogger(__name__) + + +def launch_engine( + server_args: ServerArgs, + port_args: PortArgs | None = None, +) -> tuple[AsyncLLM, dict[str, Any]]: + """Launch TokenSpeed scheduler subprocess(es) and the main-process AsyncLLM. + + Returns: + A tuple ``(async_llm, scheduler_info)``. ``async_llm`` is the live + :class:`AsyncLLM` that the gRPC servicer will drive. ``scheduler_info`` + is the dict rank-0 sent back once its scheduler was ready (contains + e.g. ``max_total_num_tokens``, ``max_req_input_len``, ...). + + Raises: + RuntimeError: If rank-0 scheduler fails to initialize. The original + ``_launch_subprocesses`` surfaces this by re-raising the EOF/assertion + error — we propagate it unchanged. + """ + async_llm, _template_manager, scheduler_info = _launch_subprocesses( + server_args=server_args, + port_args=port_args, + ) + + # Non-zero rank nodes return (None, None, None) from _launch_subprocesses + # and block forever on the dummy health server — they never reach the gRPC + # server. Guard against callers relying on this return on secondary nodes. + if async_llm is None: + raise RuntimeError( + "launch_engine() returned no AsyncLLM. This means the current node " + "is not rank 0 in a multi-node deployment, or the scheduler died " + "during initialization. Only rank 0 may serve gRPC traffic." + ) + + logger.info( + "TokenSpeed engine ready: max_total_num_tokens=%s max_req_input_len=%s", + scheduler_info.get("max_total_num_tokens"), + scheduler_info.get("max_req_input_len"), + ) + return async_llm, scheduler_info diff --git a/grpc_servicer/smg_grpc_servicer/tokenspeed/server.py b/grpc_servicer/smg_grpc_servicer/tokenspeed/server.py new file mode 100644 index 000000000..bbe67e69a --- /dev/null +++ b/grpc_servicer/smg_grpc_servicer/tokenspeed/server.py @@ -0,0 +1,195 @@ +"""Standalone TokenSpeed gRPC server — mirrors ``smg_grpc_servicer.sglang.server``.""" + +from __future__ import annotations + +import asyncio +import logging +import os +import signal +import threading +import time +from concurrent import futures + +import grpc +from grpc_health.v1 import health_pb2_grpc +from grpc_reflection.v1alpha import reflection +from smg_grpc_proto import tokenspeed_scheduler_pb2_grpc +from smg_grpc_proto.generated import tokenspeed_scheduler_pb2 +from tokenspeed.runtime.utils.server_args import ServerArgs + +from smg_grpc_servicer.tokenspeed.health_servicer import TokenSpeedHealthServicer +from smg_grpc_servicer.tokenspeed.scheduler_launcher import launch_engine +from smg_grpc_servicer.tokenspeed.servicer import TokenSpeedSchedulerServicer + +logger = logging.getLogger(__name__) + + +async def serve_grpc(server_args: ServerArgs) -> None: + """Run the TokenSpeed gRPC server until a shutdown signal is received.""" + + logger.info("Launching TokenSpeed scheduler + AsyncLLM...") + async_llm, scheduler_info = launch_engine(server_args) + + server = grpc.aio.server( + futures.ThreadPoolExecutor(max_workers=10), + options=[ + ("grpc.max_send_message_length", 1024 * 1024 * 256), + ("grpc.max_receive_message_length", 1024 * 1024 * 256), + # Match SGLang's more-permissive keepalive defaults so long + # prefill stalls don't trip GOAWAY in the Rust client. + ("grpc.http2.min_recv_ping_interval_without_data_ms", 10000), + ("grpc.keepalive_permit_without_calls", True), + ], + ) + + health_servicer = TokenSpeedHealthServicer( + async_llm=async_llm, + scheduler_info=scheduler_info, + ) + health_pb2_grpc.add_HealthServicer_to_server(health_servicer, server) + + servicer = TokenSpeedSchedulerServicer( + async_llm=async_llm, + server_args=server_args, + scheduler_info=scheduler_info, + health_servicer=health_servicer, + ) + tokenspeed_scheduler_pb2_grpc.add_TokenSpeedSchedulerServicer_to_server(servicer, server) + + service_names = ( + tokenspeed_scheduler_pb2.DESCRIPTOR.services_by_name["TokenSpeedScheduler"].full_name, + "grpc.health.v1.Health", + reflection.SERVICE_NAME, + ) + reflection.enable_server_reflection(service_names, server) + + listen_addr = f"{server_args.host}:{server_args.port}" + server.add_insecure_port(listen_addr) + logger.info("TokenSpeed gRPC server listening on %s", listen_addr) + + await server.start() + + # Warmup on a background thread so the async server can handle the probe. + warmup_thread = threading.Thread( + target=_wait_and_warmup, + args=(server_args, health_servicer), + daemon=True, + ) + warmup_thread.start() + + loop = asyncio.get_running_loop() + stop_event = asyncio.Event() + + def _signal_handler() -> None: + logger.info("Received shutdown signal") + stop_event.set() + + for sig in (signal.SIGTERM, signal.SIGINT): + try: + loop.add_signal_handler(sig, _signal_handler) + except NotImplementedError: + # Windows and some exotic envs don't support loop.add_signal_handler. + pass + + try: + await stop_event.wait() + finally: + logger.info("Shutting down TokenSpeed gRPC server") + try: + await servicer.shutdown() + except Exception: # noqa: BLE001 + logger.exception("servicer.shutdown() raised") + await server.stop(5.0) + if warmup_thread.is_alive(): + warmup_thread.join(timeout=5.0) + + +def _wait_and_warmup( + server_args: ServerArgs, + health_servicer: TokenSpeedHealthServicer, +) -> None: + """Probe the gRPC server until it can generate one token, then set SERVING. + + We hit the external port (not the in-process servicer) so the warmup + exercises the same code path a production caller would — including the + gRPC transport, proto codec, and scheduler IPC. + """ + if os.getenv("TOKENSPEED_SKIP_GRPC_WARMUP", "0").lower() in ("1", "true", "yes"): + logger.info("TOKENSPEED_SKIP_GRPC_WARMUP=1 — skipping warmup") + health_servicer.set_serving() + return + + grpc_url = f"{server_args.host}:{server_args.port}" + channel = grpc.insecure_channel( + grpc_url, + options=[ + ("grpc.max_send_message_length", 1024 * 1024 * 256), + ("grpc.max_receive_message_length", 1024 * 1024 * 256), + ], + ) + stub = tokenspeed_scheduler_pb2_grpc.TokenSpeedSchedulerStub(channel) + + # Wait until GetModelInfo round-trips — that's the quickest confirmation + # that the gRPC server is both bound and has a live AsyncLLM behind it. + deadline = time.time() + 180 + connected = False + while time.time() < deadline: + try: + stub.GetModelInfo( + tokenspeed_scheduler_pb2.GetModelInfoRequest(), + timeout=5, + ) + connected = True + break + except Exception as e: # noqa: BLE001 + logger.debug("Warmup: GetModelInfo not ready yet: %s", e) + time.sleep(1) + + if not connected: + logger.error("TokenSpeed gRPC warmup failed: GetModelInfo never succeeded") + channel.close() + return + + # TokenSpeed serves generative LLMs only (the proto has no Embed RPC), so + # the warmup is always a 1-token generate. + warmup_ok = False + try: + warmup = tokenspeed_scheduler_pb2.GenerateRequest( + request_id=f"WARMUP_{time.time()}", + tokenized=tokenspeed_scheduler_pb2.TokenizedInput( + input_ids=[0], + original_text="warmup", + ), + sampling_params=tokenspeed_scheduler_pb2.SamplingParams( + temperature=0.0, + max_new_tokens=1, + ), + stream=False, + ) + final = None + for resp in stub.Generate(warmup, timeout=600): + final = resp + if final is None or not final.HasField("complete"): + logger.warning( + "Warmup Generate returned no Complete frame (last=%r)", + final, + ) + else: + logger.info("Warmup generation succeeded") + warmup_ok = True + except Exception as e: # noqa: BLE001 + logger.warning("TokenSpeed warmup failed: %s", e) + finally: + channel.close() + + # NOT_SERVING keeps the pod out of K8s readiness rotation when warmup + # never produced a Complete frame. + if warmup_ok: + health_servicer.set_serving() + logger.info("TokenSpeed gRPC server is ready to serve") + else: + logger.error( + "TokenSpeed gRPC warmup did not produce a complete frame; " + "health stays NOT_SERVING. K8s readiness will keep this " + "worker out of rotation until manually restarted." + ) diff --git a/grpc_servicer/smg_grpc_servicer/tokenspeed/servicer.py b/grpc_servicer/smg_grpc_servicer/tokenspeed/servicer.py new file mode 100644 index 000000000..ca641b736 --- /dev/null +++ b/grpc_servicer/smg_grpc_servicer/tokenspeed/servicer.py @@ -0,0 +1,909 @@ +"""TokenSpeed gRPC servicer. + +Implements the ``tokenspeed.grpc.scheduler.TokenSpeedScheduler`` gRPC service +on top of TokenSpeed's :class:`tokenspeed.runtime.engine.async_llm.AsyncLLM` — +the main-process async frontend that replaced ``TokenizerManager`` in the +AsyncLLM refactor. + +Wire identity & message catalog +------------------------------- +TokenSpeed ships a fully independent proto (``proto/tokenspeed_scheduler.proto``) +with a distinct package, service, and message catalog. The Rust gateway's +``DetectBackendStep`` identifies the worker natively from the service name — +no SGLang-look-alike hack, no runtime marker probe. The proto's field set is +intentionally minimal (top-tier LLM serving only): no Embed, no +GetTokenizer, no SubscribeKvEvents, no multimodal, no PD-disaggregated +serving, no LoRA, no hidden-state forwarding, no classifier outputs. +Anything in that list has to be added to the proto first; it doesn't ride +on a shared SGLang message anymore. +""" + +from __future__ import annotations + +import asyncio +import dataclasses +import logging +import os +import re +import time +from collections.abc import AsyncIterator +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any + +import grpc +from google.protobuf.struct_pb2 import Struct +from google.protobuf.timestamp_pb2 import Timestamp +from smg_grpc_proto import tokenspeed_scheduler_pb2_grpc +from smg_grpc_proto.generated import tokenspeed_scheduler_pb2 + +from smg_grpc_servicer.tokenspeed.health_servicer import TokenSpeedHealthServicer + +if TYPE_CHECKING: + # Type-only imports — not resolved at module load so the servicer is + # importable in test environments that stub AsyncLLM / ServerArgs. + from tokenspeed.runtime.engine.async_llm import AsyncLLM + from tokenspeed.runtime.utils.server_args import ServerArgs + +logger = logging.getLogger(__name__) + +HEALTH_CHECK_TIMEOUT = int(os.getenv("TOKENSPEED_HEALTH_CHECK_TIMEOUT", "20")) + + +def _lazy_generate_req_input(): + """Late import for ``tokenspeed.runtime.engine.io_struct.GenerateReqInput``. + + Kept lazy so the top of this module loads in test environments that stub + the TokenSpeed engine surface (unit tests don't need a fully-working + TokenSpeed install to exercise proto ↔ request-input conversion). + """ + from tokenspeed.runtime.engine.io_struct import GenerateReqInput + + return GenerateReqInput + + +def _finish_reason_to_dict(reason: Any) -> dict | None: + """Normalise a TokenSpeed finish reason into the SGLang on-wire shape. + + TokenSpeed emits ``BaseFinishReason``-style objects (or an already-normalised + dict) in ``meta_info["finish_reason"]``; downstream code expects a dict + with at minimum ``{"type": ...}`` and optionally ``{"matched": int|str}``. + ``None`` means "still running". + + We duck-type on ``to_json()`` rather than importing the concrete + ``BaseFinishReason`` class so the servicer module loads without pulling + in TokenSpeed's full request-processing graph. + + Raises ``TypeError`` for unknown shapes rather than coercing to a fake + ``stop``: silently flipping ``length``/``abort`` to ``stop`` and leaking + a debug ``repr()`` into the user-facing ``matched_stop_str`` field would + hide real bugs and corrupt the OpenAI ``finish_reason`` semantics. The + caller wraps this in ``try/except`` and turns it into ``StatusCode.INTERNAL``. + """ + if reason is None: + return None + if isinstance(reason, dict): + return reason + to_json = getattr(reason, "to_json", None) + if callable(to_json): + try: + result = to_json() + except Exception as e: # noqa: BLE001 + raise TypeError( + f"finish_reason of type {type(reason).__name__!r} raised in " + f"to_json(); refusing to silently emit a fake stop. {e}" + ) from e + if isinstance(result, dict): + return result + raise TypeError( + f"finish_reason {type(reason).__name__!r}.to_json() returned " + f"{type(result).__name__!r}; expected dict with at least 'type'." + ) + raise TypeError( + f"Unknown finish_reason shape {type(reason).__name__!r}; expected " + f"a dict or an object with a to_json() method." + ) + + +class TokenSpeedSchedulerServicer(tokenspeed_scheduler_pb2_grpc.TokenSpeedSchedulerServicer): + """gRPC servicer exposing TokenSpeed's AsyncLLM over the dedicated TokenSpeed proto.""" + + def __init__( + self, + async_llm: AsyncLLM, + server_args: ServerArgs, + scheduler_info: dict, + health_servicer: TokenSpeedHealthServicer | None = None, + ): + self.async_llm = async_llm + self.server_args = server_args + self.scheduler_info = scheduler_info + self.health_servicer = health_servicer + self.start_time = time.time() + + # Drive AsyncLLM's output-dispatch loop. This is idempotent — the + # first caller creates the handle loop; subsequent callers (including + # the HealthCheck RPC) are no-ops thanks to ``no_create_loop``. + self.async_llm.auto_create_handle_loop() + + logger.info("TokenSpeedSchedulerServicer initialized") + + # ------------------------------------------------------------------ + # Generate (server-streaming) + # ------------------------------------------------------------------ + + async def Generate( + self, + request: tokenspeed_scheduler_pb2.GenerateRequest, + context: grpc.aio.ServicerContext, + ) -> AsyncIterator[tokenspeed_scheduler_pb2.GenerateResponse]: + rid = request.request_id + logger.info("Generate request %s (stream=%s)", rid, request.stream) + + try: + req_obj = self._build_generate_req(request) + except ValueError as e: + await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e)) + return + + # For n>1, tokenspeed's batch handler generates fresh UUIDs per + # sub-request and tags each streamed dict with a sequential + # ``index`` (see tokenizer_manager.py::_handle_batch_request). + # Non-streaming n>1 yields a *list* of final dicts instead. We + # handle both shapes below. + expanded_rid = getattr(req_obj, "rid", None) + + # When the client sets ``no_stop_trim``, the matched stop token must + # remain in the proto's ``output_ids`` so the gateway-side detokenizer + # can render it (relevant when ``skip_special_tokens=False`` is also + # set). Capture once and thread through the response builders. + no_stop_trim = bool(request.sampling_params.no_stop_trim) + + aborted = False + try: + async for output in self.async_llm.generate_request(req_obj): + # Non-streaming n>1 emits a list of final dicts in one yield. + if isinstance(output, list): + for idx, item in enumerate(output): + item_reason = _finish_reason_to_dict( + item.get("meta_info", {}).get("finish_reason") + ) + if item_reason and item_reason.get("type") == "abort": + code = _abort_status_code(item_reason) + await context.abort(code, item_reason.get("message") or "aborted") + return + ci = int(item.get("index", idx)) + yield self._complete_response( + rid, item, item_reason, ci, no_stop_trim=no_stop_trim + ) + continue + + meta = output.get("meta_info", {}) + reason_dict = _finish_reason_to_dict(meta.get("finish_reason")) + is_finished = reason_dict is not None + + if reason_dict is not None and reason_dict.get("type") == "abort": + code = _abort_status_code(reason_dict) + await context.abort(code, reason_dict.get("message") or "aborted") + return + + choice_index = int(output.get("index", 0)) + + if request.stream: + yield self._chunk_response( + rid, output, reason_dict, choice_index, no_stop_trim=no_stop_trim + ) + if is_finished: + yield self._complete_response( + rid, output, reason_dict, choice_index, no_stop_trim=no_stop_trim + ) + elif is_finished: + yield self._complete_response( + rid, output, reason_dict, choice_index, no_stop_trim=no_stop_trim + ) + + except ValueError as e: + logger.warning("Generate invalid request %s: %s", rid, e) + await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e)) + except asyncio.CancelledError: + # Client disconnected — sweep every scheduler-side rid we minted + # (including the per-choice ``{rid}-n{i}`` children n>1 creates) + # so abandoned requests don't keep consuming GPU work. + aborted = True + if isinstance(expanded_rid, list): + for r in expanded_rid: + self.async_llm.abort_request(r) + else: + self.async_llm.abort_request(rid) + raise + except grpc.aio.AbortError: + raise + except Exception as e: + logger.exception("Generate failed for request %s", rid) + await context.abort(grpc.StatusCode.INTERNAL, str(e)) + finally: + # Defensive cleanup — the scheduler owns rid_to_state, but if the + # stream was torn down before finish we need to notify it. When + # n>1 we expanded rid to a list of per-choice ids, so walk them. + if not aborted: + rids_to_check = ( + list(expanded_rid) + if isinstance(expanded_rid, list) + else ([expanded_rid] if isinstance(expanded_rid, str) else []) + ) + for r in rids_to_check: + state = self.async_llm.rid_to_state.get(r) + if state is not None and not getattr(state, "finished", False): + self.async_llm.abort_request(r) + + # ------------------------------------------------------------------ + # HealthCheck (unary) + # ------------------------------------------------------------------ + + async def HealthCheck( + self, + request: tokenspeed_scheduler_pb2.HealthCheckRequest, + context: grpc.aio.ServicerContext, + ) -> tokenspeed_scheduler_pb2.HealthCheckResponse: + """Deep health probe — sends a 1-token generation to the scheduler. + + Mirrors SGLang's contract exactly: if the scheduler pushes *any* + output within ``HEALTH_CHECK_TIMEOUT`` seconds, we consider it alive. + We bypass the normal AsyncLLM lock/metrics by crafting a dedicated + request with ``log_metrics=False`` so health checks don't skew + Prometheus counters. + """ + rid = f"HEALTH_CHECK_{time.time()}" + + if self.async_llm.gracefully_exit: + return tokenspeed_scheduler_pb2.HealthCheckResponse( + healthy=False, message="Server is shutting down" + ) + + # TokenSpeed only serves generative LLMs at this layer (the proto + # has no Embed RPC), so the probe is always a 1-token generate. + GenerateReqInput = _lazy_generate_req_input() + probe = GenerateReqInput( + input_ids=[0], + sampling_params={"max_new_tokens": 1, "temperature": 0.0}, + log_metrics=False, + ) + probe.rid = rid + + tic = time.time() + + async def _drive_probe() -> bool: + try: + async for _ in self.async_llm.generate_request(probe): + return True + except Exception as e: # noqa: BLE001 — the probe is best-effort. + logger.warning("Health probe failed: %s", e) + return False + return False + + task = asyncio.create_task(_drive_probe()) + try: + while time.time() - tic < HEALTH_CHECK_TIMEOUT: + await asyncio.sleep(0.5) + # Any scheduler push after we started counts as healthy. + if self.async_llm.last_receive_tstamp > tic: + return tokenspeed_scheduler_pb2.HealthCheckResponse( + healthy=True, + message="Health check passed", + ) + if task.done(): + return tokenspeed_scheduler_pb2.HealthCheckResponse( + healthy=bool(task.result()), + message=( + "Health check passed" + if task.result() + else "Scheduler returned no output" + ), + ) + finally: + if not task.done(): + task.cancel() + # Best-effort cleanup: the probe rid shouldn't linger. + self.async_llm.abort_request(rid) + + return tokenspeed_scheduler_pb2.HealthCheckResponse( + healthy=False, + message=f"Health check timeout after {HEALTH_CHECK_TIMEOUT}s", + ) + + # ------------------------------------------------------------------ + # Abort (unary) + # ------------------------------------------------------------------ + + async def Abort( + self, + request: tokenspeed_scheduler_pb2.AbortRequest, + _context: grpc.aio.ServicerContext, + ) -> tokenspeed_scheduler_pb2.AbortResponse: + """Abort the request + any per-choice expansions from n>1. + + Generate rewrites ``n>1`` requests into a list of rids + ``[{request_id}-n0, {request_id}-n1, ...]`` so TokenSpeed's batch + path sees unique rids. Aborting only the original ``request_id`` + would leave those children running — we sweep them all. + """ + rid = request.request_id + logger.info("Abort request %s", rid) + state_map = self.async_llm.rid_to_state + + # Anchored regex avoids matching unrelated rids like "{rid}-name". + child_pattern = re.compile(rf"^{re.escape(rid)}-n\d+$") + targets = [r for r in state_map if r == rid or child_pattern.match(r)] + + try: + for r in targets: + self.async_llm.abort_request(r) + known = bool(targets) + return tokenspeed_scheduler_pb2.AbortResponse( + success=known, + message=( + f"Aborted {len(targets)} request(s) for {rid}" + if known + else f"Request {rid} not found" + ), + ) + except Exception as e: + logger.exception("Abort failed for %s", rid) + return tokenspeed_scheduler_pb2.AbortResponse(success=False, message=str(e)) + + # ------------------------------------------------------------------ + # GetModelInfo (unary) + # ------------------------------------------------------------------ + + async def GetModelInfo( + self, + _request: tokenspeed_scheduler_pb2.GetModelInfoRequest, + _context: grpc.aio.ServicerContext, + ) -> tokenspeed_scheduler_pb2.GetModelInfoResponse: + model_config = self.async_llm.model_config + hf_config = getattr(model_config, "hf_config", None) + + eos = getattr(hf_config, "eos_token_id", None) if hf_config else None + if isinstance(eos, int): + eos_token_ids = [eos] + elif isinstance(eos, list): + eos_token_ids = list(eos) + else: + eos_token_ids = [] + + max_req_input_len = self.scheduler_info.get("max_req_input_len") or ( + self.async_llm.max_req_input_len or 0 + ) + + # TokenSpeed's GetModelInfoResponse intentionally drops + # ``is_generation`` (always true), ``supports_vision`` (always false), + # and ``id2label_json`` / ``num_labels`` (not a classifier serving + # path). The Rust client fills those slots back in when translating + # to its SGLang-shaped wrapper. + return tokenspeed_scheduler_pb2.GetModelInfoResponse( + model_path=self.server_args.model_path, + tokenizer_path=self.server_args.tokenizer_path or "", + preferred_sampling_params=self.server_args.preferred_sampling_params or "", + weight_version="", + served_model_name=(self.server_args.served_model_name or self.server_args.model_path), + max_context_length=int(self.async_llm.context_len), + vocab_size=int(model_config.vocab_size), + model_type=(getattr(hf_config, "model_type", "") or "") if hf_config else "", + architectures=(getattr(hf_config, "architectures", []) or []) if hf_config else [], + eos_token_ids=eos_token_ids, + pad_token_id=(getattr(hf_config, "pad_token_id", 0) or 0) if hf_config else 0, + bos_token_id=(getattr(hf_config, "bos_token_id", 0) or 0) if hf_config else 0, + max_req_input_len=int(max_req_input_len), + ) + + # ------------------------------------------------------------------ + # GetServerInfo (unary) + # ------------------------------------------------------------------ + + async def GetServerInfo( + self, + _request: tokenspeed_scheduler_pb2.GetServerInfoRequest, + _context: grpc.aio.ServicerContext, + ) -> tokenspeed_scheduler_pb2.GetServerInfoResponse: + # TokenSpeed's ``ServerArgs`` is a dataclass, but tests sometimes pass + # a plain namespace. Fall back to ``__dict__`` so both shapes work. + if dataclasses.is_dataclass(self.server_args) and not isinstance(self.server_args, type): + server_args_dict = dataclasses.asdict(self.server_args) + else: + server_args_dict = dict(getattr(self.server_args, "__dict__", {})) + server_args_struct = Struct() + server_args_struct.update(_make_json_serializable(server_args_dict)) + + scheduler_info_struct = Struct() + scheduler_info_struct.update(_make_json_serializable(dict(self.scheduler_info))) + + uptime = time.time() - self.start_time + start_timestamp = Timestamp() + start_timestamp.FromSeconds(int(self.start_time)) + + try: + import tokenspeed # local import: avoid module-load-time dependency + + version = getattr(tokenspeed, "__version__", "unknown") + except Exception: # noqa: BLE001 — fall back gracefully. + version = "unknown" + + return tokenspeed_scheduler_pb2.GetServerInfoResponse( + server_args=server_args_struct, + scheduler_info=scheduler_info_struct, + active_requests=len(self.async_llm.rid_to_state), + is_paused=False, + uptime_seconds=float(uptime), + tokenspeed_version=version, + start_time=start_timestamp, + max_total_num_tokens=int(self.scheduler_info.get("max_total_num_tokens", 0)), + ) + + # ------------------------------------------------------------------ + # GetLoads (unary) — bridges to TokenSpeed's scheduler-side load metrics + # ------------------------------------------------------------------ + + async def GetLoads( + self, + _request: tokenspeed_scheduler_pb2.GetLoadsRequest, + context: grpc.aio.ServicerContext, + ) -> tokenspeed_scheduler_pb2.GetLoadsResponse: + """Return per-DP-rank scheduler load by RPC-ing the scheduler subprocess. + + ``AsyncLLM`` inherits ``SchedulerControlClient.get_load`` which sends + ``GetLoadReqInput`` over the engine_core_client zmq channel and awaits + a ``List[GetLoadReqOutput]`` reply (one per DP rank). Each reply carries + the live counts the scheduler computes in ``event_loop._get_load``: + ``num_reqs`` (running + waiting), ``num_waiting_reqs``, and + ``num_pages`` (KV pages currently in use). We map those to the + ``SchedulerLoad`` proto plus a coarse aggregate so the router-side + consumer matches what it gets from SGLang. + """ + try: + load_outputs = await asyncio.wait_for( + self.async_llm.get_load(), timeout=HEALTH_CHECK_TIMEOUT + ) + except TimeoutError: + await context.abort( + grpc.StatusCode.DEADLINE_EXCEEDED, + f"tokenspeed scheduler did not respond to GetLoad within {HEALTH_CHECK_TIMEOUT}s", + ) + return + except Exception as e: # noqa: BLE001 + logger.exception("GetLoads failed") + await context.abort(grpc.StatusCode.INTERNAL, str(e)) + return + + page_size = int(getattr(self.async_llm.server_args, "page_size", 1) or 1) + # ``max_total_num_tokens`` lives on the scheduler-side ``scheduler_info`` + # dict that ``launch_engine`` plumbed through at boot — not directly on + # AsyncLLM. Fall back to ``server_args.max_total_num_tokens`` (used in + # tests' SimpleNamespace stubs). + max_total_num_tokens = int( + (self.scheduler_info.get("max_total_num_tokens") if self.scheduler_info else None) + or getattr(self.async_llm.server_args, "max_total_num_tokens", 0) + or 0 + ) + + scheduler_loads: list[tokenspeed_scheduler_pb2.SchedulerLoad] = [] + total_running = 0 + total_waiting = 0 + token_usages: list[float] = [] + for lo in load_outputs: + num_running = max(0, int(lo.num_reqs) - int(lo.num_waiting_reqs)) + num_used_tokens = int(lo.num_pages) * page_size + token_usage = ( + num_used_tokens / max_total_num_tokens if max_total_num_tokens > 0 else 0.0 + ) + scheduler_loads.append( + tokenspeed_scheduler_pb2.SchedulerLoad( + dp_rank=int(lo.dp_rank), + num_running_reqs=num_running, + num_waiting_reqs=int(lo.num_waiting_reqs), + num_total_reqs=int(lo.num_reqs), + num_used_tokens=num_used_tokens, + max_total_num_tokens=max_total_num_tokens, + token_usage=token_usage, + ) + ) + total_running += num_running + total_waiting += int(lo.num_waiting_reqs) + token_usages.append(token_usage) + + aggregate = tokenspeed_scheduler_pb2.AggregateMetrics( + total_running_reqs=total_running, + total_waiting_reqs=total_waiting, + total_reqs=total_running + total_waiting, + avg_token_usage=(sum(token_usages) / len(token_usages)) if token_usages else 0.0, + ) + + return tokenspeed_scheduler_pb2.GetLoadsResponse( + timestamp=datetime.now(timezone.utc).isoformat(), + version="tokenspeed", + dp_rank_count=len(scheduler_loads), + loads=scheduler_loads, + aggregate=aggregate, + ) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + async def shutdown(self, drain_timeout_secs: float = 30.0) -> None: + """Graceful shutdown — drain in-flight requests, then kill scheduler children. + + AsyncLLM's ``sigterm_watchdog`` polls ``gracefully_exit`` every 5s, + drains ``rid_to_state`` and finally calls + ``kill_process_tree(getpid, include_parent=True)``. That works in + steady-state but the gRPC server's main coroutine may unwind before + the watchdog ticks again, in which case the scheduler subprocesses + outlive the parent and end up orphaned. To avoid that, we: + + 1. Flag ``gracefully_exit`` so AsyncLLM stops accepting work and + the watchdog will eventually run its own cleanup. + 2. Wait up to ``drain_timeout_secs`` for ``rid_to_state`` to empty. + 3. Forcibly kill the subprocess tree (``include_parent=False``) so + the scheduler children are reaped regardless of whether the + watchdog tick fires before this coroutine returns. Idempotent + with the watchdog's own ``kill_process_tree`` call. + """ + self.async_llm.gracefully_exit = True + if self.health_servicer: + self.health_servicer.set_not_serving() + + deadline = time.monotonic() + drain_timeout_secs + while time.monotonic() < deadline: + if not getattr(self.async_llm, "rid_to_state", None): + break + await asyncio.sleep(0.5) + else: + logger.warning( + "shutdown drain timed out after %.1fs with %d in-flight requests; " + "killing scheduler children anyway", + drain_timeout_secs, + len(getattr(self.async_llm, "rid_to_state", {}) or {}), + ) + + # Reap the scheduler subprocesses without taking down our own PID; + # server.py's stop sequence still needs us alive to finish gRPC drain. + try: + from tokenspeed.runtime.utils.process import kill_process_tree + except ImportError: + logger.exception( + "Could not import tokenspeed.runtime.utils.process.kill_process_tree; " + "scheduler subprocesses may be orphaned" + ) + return + kill_process_tree(os.getpid(), include_parent=False) + + def _build_generate_req(self, request: tokenspeed_scheduler_pb2.GenerateRequest): + """Translate proto GenerateRequest → TokenSpeed GenerateReqInput. + + Keeps the router's pre-tokenized inputs intact (``input_ids`` set, + ``text`` left blank) so the TokenSpeed InputProcessor skips its own + tokenizer pass. + """ + if not request.HasField("tokenized"): + raise ValueError("GenerateRequest.tokenized is required") + + input_ids = list(request.tokenized.input_ids) + if not input_ids: + raise ValueError("GenerateRequest.tokenized.input_ids is empty") + + sampling = self._sampling_params_from_proto(request.sampling_params) + + GenerateReqInput = _lazy_generate_req_input() + obj = GenerateReqInput( + input_ids=input_ids, + sampling_params=sampling, + stream=bool(request.stream), + return_logprob=bool(request.return_logprob), + # ``logprob_start_len`` is ``optional int32`` on the wire — use + # presence-tracking, not the proto3 zero-default, to distinguish + # "client omitted" (→ SGLang's ``-1`` = no input logprobs) from + # an explicit ``0`` (→ start input logprobs at position 0). + logprob_start_len=( + request.logprob_start_len if request.HasField("logprob_start_len") else -1 + ), + top_logprobs_num=int(request.top_logprobs_num or 0), + token_ids_logprob=( + list(request.token_ids_logprob) if request.token_ids_logprob else None + ), + # Hidden-state forwarding, multimodal inputs, PD-disaggregated + # serving, LoRA hot-swap and ``log_metrics`` are intentionally + # absent from TokenSpeed's wire — leaving the engine defaults in + # place keeps the call shape simple. + ) + # Older tokenspeed's ``normalize_batch_and_arguments`` treats n>1 as + # batched and asserts ``rid`` is a list in that case. One gRPC + # request carries one rid; expand it to a list of deterministic + # per-choice rids when the caller asked for multiple samples so the + # assert doesn't fire (and the scheduler can still deduplicate). + n = sampling.get("n", 1) or 1 + if n > 1: + obj.rid = [f"{request.request_id}-n{i}" for i in range(n)] + else: + obj.rid = request.request_id + + # NOTE: We deliberately do NOT set ``obj.text`` even when the proto + # carries ``original_text``. TokenSpeed's HTTP serving_chat passes + # ``input_ids=[...], text=None`` to the engine; setting both fields + # has been observed to perturb the engine's input-processor path + # (some validators and normalizers branch on whether text is + # populated). Matching the HTTP shape — ids only, text=None — + # eliminates one source of HTTP-vs-gRPC divergence. + + return obj + + @staticmethod + def _sampling_params_from_proto( + params: tokenspeed_scheduler_pb2.SamplingParams, + ) -> dict[str, Any]: + """Build the dict that ``GenerateReqInput.sampling_params`` expects. + + TokenSpeed's :class:`SamplingParams` consumes this dict via + ``SamplingParams(**obj.sampling_params)``, so field names must match + the Python class (``max_new_tokens``, ``stop``, ``stop_token_ids``, ...). + """ + out: dict[str, Any] = {} + + # All sampling scalars in tokenspeed_scheduler.proto are declared + # ``optional`` (matching ``vllm_engine.proto``). We use + # ``HasField()`` to forward only the values the client explicitly + # set; absent fields fall through to the engine's own + # ``SamplingParams.__init__`` defaults. This eliminates the old + # truthy-check pitfall that silently dropped ``temperature=0`` + # (BFCL's intent for greedy decoding) AND the warmup-default-zero + # crash where invalid ``top_p=0.0`` / ``repetition_penalty=0.0`` + # would reach the engine from internal probe paths. + # + # When ``temperature=0`` does reach the engine (HasField=True for + # an explicitly-sent ``0.0``), the engine + # (``sampling_params.py:104-107``) sets ``top_k=1`` to engage + # greedy decoding. That's the path BFCL relies on. + for _field in ( + "max_new_tokens", + "temperature", + "top_p", + "top_k", + "min_p", + "frequency_penalty", + "presence_penalty", + "repetition_penalty", + ): + if params.HasField(_field): + out[_field] = getattr(params, _field) + + if params.min_new_tokens: + # ``min_new_tokens`` is non-optional; 0 is the "no minimum" sentinel. + out["min_new_tokens"] = params.min_new_tokens + + # Lists + if params.stop: + out["stop"] = list(params.stop) + if params.stop_token_ids: + out["stop_token_ids"] = list(params.stop_token_ids) + + # Bools (always forwarded) + out["skip_special_tokens"] = bool(params.skip_special_tokens) + out["spaces_between_special_tokens"] = bool(params.spaces_between_special_tokens) + out["ignore_eos"] = bool(params.ignore_eos) + # When set, tokenspeed's detokenizer keeps the matched stop token in + # the rendered text (see ``runtime/engine/detokenizer.py``); we also + # suppress the servicer-side ``output_ids`` strip in + # ``_generated_output_ids`` so the EOS reaches the gateway's + # detokenizer when ``skip_special_tokens=False``. + out["no_stop_trim"] = bool(params.no_stop_trim) + + # n (OpenAI-compat, passthrough) + if params.n: + out["n"] = params.n + if params.logit_bias: + out["logit_bias"] = dict(params.logit_bias) + + # Constraint types — exactly one may be set. + if params.HasField("regex"): + out["regex"] = params.regex + elif params.HasField("json_schema"): + out["json_schema"] = params.json_schema + elif params.HasField("ebnf_grammar"): + out["ebnf"] = params.ebnf_grammar + elif params.HasField("structural_tag"): + out["structural_tag"] = params.structural_tag + + return out + + def _generated_output_ids( + self, + output: dict, + reason_dict: dict | None, + *, + no_stop_trim: bool = False, + ) -> list[int]: + """Return just the newly-generated tokens from a TokenSpeed output dict. + + TokenSpeed's AsyncLLM has two quirks that the SGLang gRPC proto contract + doesn't expect, both of which break the smg gateway's detokenization + layer and downstream tool-call parsing: + + 1. ``output_ids`` is prefixed with the Llama-3 chat-template assistant + header: ``[<|eot_id|>, <|start_header_id|>, "assistant", + <|end_header_id|>, "\\n\\n", ...generated..., ]``. The + ``skip_special_tokens=True`` detokenization strips the 128xxx + control tokens but keeps the word tokens ``"assistant"`` (78191) + and ``"\\n\\n"`` (271), so the final text looks like + ``assistant\\n\\n{"name": ...}``. The ``llama`` tool parser's + ``serde_json::from_str`` can't handle leading non-JSON prefix and + silently returns zero tool calls. + 2. The trailing stop token (e.g. ``<|eom_id|>`` = 128008) is included + in ``output_ids``; SGLang excludes it. If the gateway ever runs + with ``skip_special_tokens=False`` the stop leaks into the decoded + text and breaks JSON parsing for the same reason. + + Slicing the last ``meta_info.completion_tokens`` tokens gives us the + bare generated sequence that SGLang's ``token_ids`` would carry, and + we then defensively drop any trailing matched stop token. The + per-choice ``matched_stop`` fires in a separate proto field, so no + information is lost. + """ + raw = list(output.get("output_ids") or []) + if not raw: + return raw + completion = output.get("meta_info", {}).get("completion_tokens") + if isinstance(completion, int) and 0 < completion <= len(raw): + token_ids = raw[-completion:] + else: + token_ids = raw + if not no_stop_trim and reason_dict and reason_dict.get("type") == "stop": + matched = reason_dict.get("matched") + if isinstance(matched, int) and token_ids and token_ids[-1] == matched: + token_ids = token_ids[:-1] + return token_ids + + def _chunk_response( + self, + rid: str, + output: dict, + reason_dict: dict | None, + choice_index: int = 0, + *, + no_stop_trim: bool = False, + ) -> tokenspeed_scheduler_pb2.GenerateResponse: + meta = output.get("meta_info", {}) + token_ids = self._generated_output_ids(output, reason_dict, no_stop_trim=no_stop_trim) + return tokenspeed_scheduler_pb2.GenerateResponse( + request_id=rid, + chunk=tokenspeed_scheduler_pb2.GenerateStreamChunk( + token_ids=token_ids, + prompt_tokens=int(meta.get("prompt_tokens", 0)), + completion_tokens=int(meta.get("completion_tokens", len(token_ids))), + cached_tokens=int(meta.get("cached_tokens", 0)), + output_logprobs=self._convert_output_logprobs_to_proto(output, len(token_ids)), + index=choice_index, + ), + ) + + def _complete_response( + self, + rid: str, + output: dict, + reason_dict: dict | None, + choice_index: int = 0, + *, + no_stop_trim: bool = False, + ) -> tokenspeed_scheduler_pb2.GenerateResponse: + meta = output.get("meta_info", {}) + token_ids = self._generated_output_ids(output, reason_dict, no_stop_trim=no_stop_trim) + + finish_reason = "stop" + matched_kwargs: dict[str, Any] = {} + if reason_dict: + kind = reason_dict.get("type") + if kind == "length": + finish_reason = "length" + elif kind == "abort": + finish_reason = "abort" + matched = reason_dict.get("matched") + if isinstance(matched, int): + matched_kwargs["matched_token_id"] = matched + elif isinstance(matched, str): + matched_kwargs["matched_stop_str"] = matched + + return tokenspeed_scheduler_pb2.GenerateResponse( + request_id=rid, + complete=tokenspeed_scheduler_pb2.GenerateComplete( + output_ids=token_ids, + finish_reason=finish_reason, + prompt_tokens=int(meta.get("prompt_tokens", 0)), + completion_tokens=int(meta.get("completion_tokens", len(token_ids))), + cached_tokens=int(meta.get("cached_tokens", 0)), + output_logprobs=self._convert_output_logprobs_to_proto(output, len(token_ids)), + index=choice_index, + **matched_kwargs, + ), + ) + + @staticmethod + def _convert_output_logprobs_to_proto( + output: dict, n_keep: int + ) -> tokenspeed_scheduler_pb2.OutputLogProbs | None: + """Build an ``OutputLogProbs`` proto from a tokenspeed output dict. + + TokenSpeed accumulates the request's logprobs in per-request state + across chunks; ``meta_info["output_token_logprobs"]`` is therefore the + running cumulative list of detokenized + ``(logprob: float, token_id: int, text: Optional[str])`` tuples, and + ``meta_info["output_top_logprobs"]`` is the parallel list of top-K + alternatives per position (each entry is ``None`` or a list of the + same tuple shape). + + We slice the cumulative list down to just **this frame's tokens** by + taking the last ``len(output["output_ids"])`` entries — that's how + many new tokens this frame emitted — and then keep only the first + ``n_keep`` of those, so the alignment matches whatever + ``_generated_output_ids`` returned (it strips a trailing stop token + when the finish reason is ``stop``, leaving the last logprob entry + with no corresponding output id). + + Returns ``None`` when there are no logprobs to emit — either the + client did not request them, or the server was started without + ``--enable-output-logprobs`` (in which case TokenSpeed silently + leaves these meta_info lists empty rather than raising). + """ + if n_keep <= 0: + return None + meta = output.get("meta_info", {}) or {} + raw_token = meta.get("output_token_logprobs") or [] + if not raw_token: + return None + n_chunk = len(output.get("output_ids", []) or []) + if n_chunk <= 0: + return None + + raw_top = meta.get("output_top_logprobs") or [] + chunk_token = raw_token[-n_chunk:] if len(raw_token) >= n_chunk else raw_token + chunk_top = raw_top[-n_chunk:] if len(raw_top) >= n_chunk else raw_top + delta_token = chunk_token[:n_keep] + delta_top = chunk_top[:n_keep] + + top_proto = [] + for entry in delta_top: + if entry: + top_proto.append( + tokenspeed_scheduler_pb2.TopLogProbs( + values=[t[0] for t in entry], + token_ids=[t[1] for t in entry], + ) + ) + else: + # Position with no top-K data (e.g. ``--enable-top-logprobs`` + # is not yet implemented in TokenSpeed; we still emit a + # placeholder per position so the gateway can align indices). + top_proto.append(tokenspeed_scheduler_pb2.TopLogProbs()) + + return tokenspeed_scheduler_pb2.OutputLogProbs( + token_logprobs=[t[0] for t in delta_token], + token_ids=[t[1] for t in delta_token], + top_logprobs=top_proto, + ) + + +def _abort_status_code(reason: dict) -> grpc.StatusCode: + status_code = reason.get("status_code") + if status_code == 400: + return grpc.StatusCode.INVALID_ARGUMENT + if status_code in (408, 504): + return grpc.StatusCode.DEADLINE_EXCEEDED + if status_code == 429: + return grpc.StatusCode.RESOURCE_EXHAUSTED + return grpc.StatusCode.INTERNAL + + +def _make_json_serializable(obj: Any) -> Any: + """Flatten an arbitrary dataclass/config graph into JSON-safe primitives.""" + if obj is None or isinstance(obj, str | int | float | bool): + return obj + if isinstance(obj, list | tuple | set): + return [_make_json_serializable(x) for x in obj] + if isinstance(obj, dict): + return {str(k): _make_json_serializable(v) for k, v in obj.items()} + return str(obj) diff --git a/grpc_servicer/tests/__init__.py b/grpc_servicer/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/grpc_servicer/tests/conftest.py b/grpc_servicer/tests/conftest.py new file mode 100644 index 000000000..3ceadba4f --- /dev/null +++ b/grpc_servicer/tests/conftest.py @@ -0,0 +1,22 @@ +"""Pytest configuration for smg-grpc-servicer unit tests. + +Adds the parent directory to ``sys.path`` so editable installs work +without needing ``pip install -e``, and declares an asyncio-mode default. +""" + +from __future__ import annotations + +import pathlib +import sys + +import pytest + +_HERE = pathlib.Path(__file__).resolve().parent +_PKG_ROOT = _HERE.parent + +if str(_PKG_ROOT) not in sys.path: + sys.path.insert(0, str(_PKG_ROOT)) + + +def pytest_configure(config: pytest.Config) -> None: + config.addinivalue_line("markers", "tokenspeed: tests that require TokenSpeed") diff --git a/grpc_servicer/tests/test_tokenspeed_health_servicer.py b/grpc_servicer/tests/test_tokenspeed_health_servicer.py new file mode 100644 index 000000000..df4856af1 --- /dev/null +++ b/grpc_servicer/tests/test_tokenspeed_health_servicer.py @@ -0,0 +1,98 @@ +"""Unit tests for ``smg_grpc_servicer.tokenspeed.health_servicer``.""" + +from __future__ import annotations + +import time +from dataclasses import dataclass, field +from typing import Any +from unittest.mock import MagicMock + +import grpc +import pytest +from grpc_health.v1 import health_pb2 # noqa: E402 +from smg_grpc_servicer.tokenspeed.health_servicer import ( # noqa: E402 + TokenSpeedHealthServicer, +) + + +@dataclass +class FakeEngine: + gracefully_exit: bool = False + last_receive_tstamp: float = 0.0 + rid_to_state: dict[str, Any] = field(default_factory=dict) + + +@pytest.fixture +def servicer() -> TokenSpeedHealthServicer: + return TokenSpeedHealthServicer( + async_llm=FakeEngine(), + scheduler_info={}, + ) + + +@pytest.mark.asyncio +async def test_initial_state_is_not_serving(servicer: TokenSpeedHealthServicer): + ctx = MagicMock(spec=grpc.aio.ServicerContext) + resp = await servicer.Check(health_pb2.HealthCheckRequest(service=""), ctx) + assert resp.status == health_pb2.HealthCheckResponse.NOT_SERVING + + +@pytest.mark.asyncio +async def test_set_serving_flips_both_levels(servicer: TokenSpeedHealthServicer): + ctx = MagicMock(spec=grpc.aio.ServicerContext) + servicer.set_serving() + + # overall + resp = await servicer.Check(health_pb2.HealthCheckRequest(service=""), ctx) + assert resp.status == health_pb2.HealthCheckResponse.SERVING + + # specific + resp = await servicer.Check( + health_pb2.HealthCheckRequest(service=servicer.TOKENSPEED_SERVICE), ctx + ) + assert resp.status == health_pb2.HealthCheckResponse.SERVING + + +@pytest.mark.asyncio +async def test_shutdown_flips_back(servicer: TokenSpeedHealthServicer): + ctx = MagicMock(spec=grpc.aio.ServicerContext) + servicer.set_serving() + servicer.async_llm.gracefully_exit = True + resp = await servicer.Check(health_pb2.HealthCheckRequest(service=""), ctx) + assert resp.status == health_pb2.HealthCheckResponse.NOT_SERVING + + +@pytest.mark.asyncio +async def test_unknown_service_returns_unknown(servicer: TokenSpeedHealthServicer): + ctx = MagicMock(spec=grpc.aio.ServicerContext) + resp = await servicer.Check(health_pb2.HealthCheckRequest(service="bogus.Service"), ctx) + assert resp.status == health_pb2.HealthCheckResponse.SERVICE_UNKNOWN + ctx.set_code.assert_called_once_with(grpc.StatusCode.NOT_FOUND) + + +@pytest.mark.asyncio +async def test_stuck_scheduler_flips_to_not_serving( + servicer: TokenSpeedHealthServicer, +): + ctx = MagicMock(spec=grpc.aio.ServicerContext) + servicer.set_serving() + # Simulate "pending requests, but scheduler hasn't pushed output for 45s" + servicer.async_llm.last_receive_tstamp = time.time() - 45 + servicer.async_llm.rid_to_state["rid-1"] = object() + + resp = await servicer.Check( + health_pb2.HealthCheckRequest(service=servicer.TOKENSPEED_SERVICE), ctx + ) + assert resp.status == health_pb2.HealthCheckResponse.NOT_SERVING + + +@pytest.mark.asyncio +async def test_recent_activity_keeps_serving(servicer: TokenSpeedHealthServicer): + ctx = MagicMock(spec=grpc.aio.ServicerContext) + servicer.set_serving() + servicer.async_llm.last_receive_tstamp = time.time() - 1 + servicer.async_llm.rid_to_state["rid-1"] = object() + resp = await servicer.Check( + health_pb2.HealthCheckRequest(service=servicer.TOKENSPEED_SERVICE), ctx + ) + assert resp.status == health_pb2.HealthCheckResponse.SERVING diff --git a/grpc_servicer/tests/test_tokenspeed_servicer.py b/grpc_servicer/tests/test_tokenspeed_servicer.py new file mode 100644 index 000000000..112b9dfb2 --- /dev/null +++ b/grpc_servicer/tests/test_tokenspeed_servicer.py @@ -0,0 +1,1048 @@ +"""Unit tests for ``smg_grpc_servicer.tokenspeed.servicer``. + +Runs against a minimal ``FakeAsyncLLM`` that implements only the AsyncLLM +surface the servicer actually touches. We *do* require TokenSpeed to be +importable (the servicer takes real request classes from ``tokenspeed.*``), +so the whole module is skipped when TokenSpeed is not installed. +""" + +from __future__ import annotations + +import asyncio +from collections.abc import Callable +from dataclasses import dataclass, field +from types import SimpleNamespace +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import grpc +import pytest + +pytest.importorskip( + "smg_grpc_proto", + reason="smg-grpc-proto must be installed to test the servicer", +) + +from smg_grpc_proto.generated import tokenspeed_scheduler_pb2 # noqa: E402 +from smg_grpc_servicer.tokenspeed import servicer as _servicer_module # noqa: E402 +from smg_grpc_servicer.tokenspeed.servicer import ( # noqa: E402 + TokenSpeedSchedulerServicer, + _abort_status_code, + _finish_reason_to_dict, + _make_json_serializable, +) + +# --------------------------------------------------------------------------- +# Stub request class. The servicer lazily imports ``GenerateReqInput`` so +# tests can substitute a minimal local stand-in without pulling in +# TokenSpeed's full scheduler graph. (No ``EmbeddingReqInput`` — the slim +# TokenSpeed proto removed the Embed RPC.) +# --------------------------------------------------------------------------- + + +class _StubReq: + """Minimal stand-in with the attributes the servicer sets on req objects.""" + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + # Allow later attribute assignment for rid / text. + self.rid = None + self.text = None + + +class StubGenerateReqInput(_StubReq): + pass + + +@pytest.fixture(autouse=True) +def _stub_request_inputs(monkeypatch): + """Redirect the servicer's lazy GenerateReqInput import to a local stub.""" + monkeypatch.setattr(_servicer_module, "_lazy_generate_req_input", lambda: StubGenerateReqInput) + yield + + +# --------------------------------------------------------------------------- +# Local fake finish-reason classes. The servicer duck-types on ``.to_json()`` +# so tests don't need to import TokenSpeed's request_types module (which +# pulls in the full scheduler graph and breaks in minimal test envs). +# --------------------------------------------------------------------------- + + +class FINISH_MATCHED_TOKEN: + def __init__(self, matched): + self.matched = matched + + def to_json(self): + return {"type": "stop", "matched": self.matched} + + +class FINISH_MATCHED_STR: + def __init__(self, matched): + self.matched = matched + + def to_json(self): + return {"type": "stop", "matched": self.matched} + + +class FINISH_LENGTH: + def __init__(self, length): + self.length = length + + def to_json(self): + return {"type": "length", "length": self.length} + + +class FINISH_ABORT: + def __init__(self, message="Unknown error"): + self.message = message + + def to_json(self): + return {"type": "abort", "message": self.message} + + +# --------------------------------------------------------------------------- +# FakeAsyncLLM — minimal stand-in for TokenSpeed's AsyncLLM in unit tests. +# --------------------------------------------------------------------------- + + +@dataclass +class _FakeState: + finished: bool = False + + +@dataclass +class FakeAsyncLLM: + """Implements just enough AsyncLLM surface to drive the servicer.""" + + outputs: list[dict] = field(default_factory=list) + is_generation: bool = True + context_len: int = 8192 + max_req_input_len: int | None = 4096 + # Captured state — the servicer mutates/inspects these. + rid_to_state: dict[str, _FakeState] = field(default_factory=dict) + gracefully_exit: bool = False + last_receive_tstamp: float = 0.0 + handle_loop_started: bool = False + aborted_rids: list[str] = field(default_factory=list) + # Override hook: a callable producing outputs per request, used for + # tests that need dynamic yields (e.g. cancel mid-stream). + generate_fn: Callable[[Any], Any] | None = None + + # Default load-fixture: single DP rank, 1 running request, no waiting, + # 100 used pages out of (max_total_num_tokens / page_size). Tests can + # override ``load_outputs`` directly to assert proto-mapping semantics. + load_outputs: list[Any] = field(default_factory=list) + max_total_num_tokens: int = 8192 + + server_args: Any = field( + default_factory=lambda: SimpleNamespace( + model_path="fake-model", + tokenizer_path="fake-model", + served_model_name="fake-model", + preferred_sampling_params=None, + page_size=16, + ) + ) + model_config: Any = field( + default_factory=lambda: SimpleNamespace( + vocab_size=32000, + is_multimodal=False, + hf_config=SimpleNamespace( + eos_token_id=2, + pad_token_id=0, + bos_token_id=1, + model_type="llama", + architectures=["LlamaForCausalLM"], + ), + ) + ) + + def auto_create_handle_loop(self) -> None: + self.handle_loop_started = True + + def abort_request(self, rid: str) -> None: + self.aborted_rids.append(rid) + self.rid_to_state.pop(rid, None) + + async def get_load(self): + # Mirror SchedulerControlClient.get_load — returns the configured + # ``load_outputs`` so tests can drive proto-mapping assertions. + return list(self.load_outputs) + + async def generate_request(self, obj): + # Record the request so tests can assert on what was forwarded. + # ``_build_generate_req`` rewrites ``rid`` to a list of per-choice ids + # when n>1; register state for each so the cancel sweep can abort them + # individually (and so dict assignment doesn't crash on a list key). + rid_attr = getattr(obj, "rid", None) or "no-rid" + rids = list(rid_attr) if isinstance(rid_attr, list) else [rid_attr] + for r in rids: + self.rid_to_state[r] = _FakeState() + if self.generate_fn is not None: + async for out in self.generate_fn(obj): + self.last_receive_tstamp = 9999.0 # anything > tic + yield out + return + for out in self.outputs: + self.last_receive_tstamp = 9999.0 + yield out + for r in rids: + self.rid_to_state[r].finished = True + + +@pytest.fixture +def fake_engine() -> FakeAsyncLLM: + return FakeAsyncLLM() + + +@pytest.fixture +def servicer(fake_engine: FakeAsyncLLM) -> TokenSpeedSchedulerServicer: + return TokenSpeedSchedulerServicer( + async_llm=fake_engine, + server_args=fake_engine.server_args, + scheduler_info={ + "max_total_num_tokens": 100000, + "max_req_input_len": 4096, + }, + ) + + +class _FakeAbortError(grpc.aio.AbortError): + """Stand-in for grpc.aio.AbortError raised by our mock context.abort().""" + + def __init__(self, code: grpc.StatusCode, details: str): + super().__init__() + self.code = code + self.details = details + + def __str__(self) -> str: # makes pytest.raises(match=...) useful + return f"ABORT({self.code.name}, {self.details})" + + +def _make_context() -> MagicMock: + """Build a grpc.aio.ServicerContext whose ``abort()`` raises AbortError. + + Real gRPC servicer contexts raise ``grpc.aio.AbortError`` from + ``context.abort()``. The servicer has a dedicated ``except + grpc.aio.AbortError: raise`` branch to let that propagate cleanly, so + the mock reproduces that behaviour. + """ + ctx = MagicMock(spec=grpc.aio.ServicerContext) + + async def _abort(code, details): + raise _FakeAbortError(code, details) + + ctx.abort = AsyncMock(side_effect=_abort) + ctx.set_code = MagicMock() + ctx.set_details = MagicMock() + return ctx + + +# --------------------------------------------------------------------------- +# Pure-helper tests +# --------------------------------------------------------------------------- + + +class TestFinishReasonToDict: + def test_none(self): + assert _finish_reason_to_dict(None) is None + + def test_length(self): + assert _finish_reason_to_dict(FINISH_LENGTH(length=42)) == { + "type": "length", + "length": 42, + } + + def test_matched_token(self): + assert _finish_reason_to_dict(FINISH_MATCHED_TOKEN(matched=7)) == { + "type": "stop", + "matched": 7, + } + + def test_matched_str(self): + assert _finish_reason_to_dict(FINISH_MATCHED_STR(matched="")) == { + "type": "stop", + "matched": "", + } + + def test_abort(self): + out = _finish_reason_to_dict(FINISH_ABORT(message="boom")) + assert out["type"] == "abort" + assert out["message"] == "boom" + + def test_passthrough_dict(self): + d = {"type": "stop", "matched": "foo"} + assert _finish_reason_to_dict(d) is d + + def test_unknown_raises_typeerror(self): + # Unknown shapes raise TypeError rather than coercing to a fake + # ``stop`` dict: silently flipping length/abort to stop and leaking + # repr() into the user-facing matched_stop_str field would corrupt + # the OpenAI ``finish_reason`` semantics. The Generate handler's + # ``except Exception`` turns the TypeError into INTERNAL. + with pytest.raises(TypeError, match="Unknown finish_reason shape"): + _finish_reason_to_dict("weird") + with pytest.raises(TypeError, match="Unknown finish_reason shape"): + _finish_reason_to_dict(42) + + +class TestAbortStatusCode: + @pytest.mark.parametrize( + "status_code, expected", + [ + (400, grpc.StatusCode.INVALID_ARGUMENT), + (408, grpc.StatusCode.DEADLINE_EXCEEDED), + (504, grpc.StatusCode.DEADLINE_EXCEEDED), + (429, grpc.StatusCode.RESOURCE_EXHAUSTED), + (500, grpc.StatusCode.INTERNAL), + (None, grpc.StatusCode.INTERNAL), + ], + ) + def test_mapping(self, status_code, expected): + assert _abort_status_code({"status_code": status_code}) == expected + + +class TestMakeJsonSerializable: + def test_primitives(self): + assert _make_json_serializable(1) == 1 + assert _make_json_serializable("x") == "x" + assert _make_json_serializable(True) is True + assert _make_json_serializable(None) is None + + def test_list_tuple_set(self): + assert _make_json_serializable([1, "a"]) == [1, "a"] + assert _make_json_serializable((1, 2)) == [1, 2] + assert _make_json_serializable({1, 2, 3}) in ( + [1, 2, 3], + [1, 3, 2], + [2, 1, 3], + [2, 3, 1], + [3, 1, 2], + [3, 2, 1], + ) + + def test_nested_dict(self): + assert _make_json_serializable({"a": [1, {"b": 2}]}) == {"a": [1, {"b": 2}]} + + def test_exotic_types_coerced_to_str(self): + class Foo: + def __str__(self): + return "foo-str" + + assert _make_json_serializable(Foo()) == "foo-str" + + +# --------------------------------------------------------------------------- +# Sampling params conversion +# --------------------------------------------------------------------------- + + +class TestSamplingParamsConversion: + def test_defaults_not_forwarded(self): + params = tokenspeed_scheduler_pb2.SamplingParams() + out = TokenSpeedSchedulerServicer._sampling_params_from_proto(params) + # proto3 defaults (0 / False / "") should not end up as TokenSpeed + # overrides — only the always-forwarded bool fields appear. + assert "temperature" not in out + assert "top_p" not in out + assert "top_k" not in out + assert "max_new_tokens" not in out + # always-forwarded bools + assert out["skip_special_tokens"] is False + assert out["spaces_between_special_tokens"] is False + assert out["ignore_eos"] is False + + def test_numeric_fields_forwarded(self): + params = tokenspeed_scheduler_pb2.SamplingParams( + temperature=0.7, + top_p=0.9, + top_k=50, + min_p=0.05, + frequency_penalty=0.1, + presence_penalty=0.2, + repetition_penalty=1.1, + max_new_tokens=128, + min_new_tokens=4, + ) + out = TokenSpeedSchedulerServicer._sampling_params_from_proto(params) + assert out["temperature"] == pytest.approx(0.7) + assert out["top_p"] == pytest.approx(0.9) + assert out["top_k"] == 50 + assert out["min_p"] == pytest.approx(0.05) + assert out["frequency_penalty"] == pytest.approx(0.1) + assert out["presence_penalty"] == pytest.approx(0.2) + assert out["repetition_penalty"] == pytest.approx(1.1) + assert out["max_new_tokens"] == 128 + assert out["min_new_tokens"] == 4 + + def test_stop_lists_and_logit_bias(self): + params = tokenspeed_scheduler_pb2.SamplingParams( + stop=["\n\n", ""], + stop_token_ids=[2, 0], + logit_bias={"100": -10.0, "200": 10.0}, + ) + out = TokenSpeedSchedulerServicer._sampling_params_from_proto(params) + assert out["stop"] == ["\n\n", ""] + assert out["stop_token_ids"] == [2, 0] + assert out["logit_bias"] == {"100": -10.0, "200": 10.0} + + @pytest.mark.parametrize( + "setter, key, value", + [ + (lambda p: setattr(p, "regex", "a.*"), "regex", "a.*"), + (lambda p: setattr(p, "json_schema", "{}"), "json_schema", "{}"), + (lambda p: setattr(p, "ebnf_grammar", "g"), "ebnf", "g"), + (lambda p: setattr(p, "structural_tag", "tag"), "structural_tag", "tag"), + ], + ) + def test_constraints(self, setter, key, value): + params = tokenspeed_scheduler_pb2.SamplingParams() + setter(params) + out = TokenSpeedSchedulerServicer._sampling_params_from_proto(params) + assert out[key] == value + + +# --------------------------------------------------------------------------- +# Generate RPC +# --------------------------------------------------------------------------- + + +def _make_generate_request( + *, + request_id: str = "rid-1", + input_ids: list[int] | None = None, + stream: bool = False, + max_new_tokens: int = 16, +) -> tokenspeed_scheduler_pb2.GenerateRequest: + return tokenspeed_scheduler_pb2.GenerateRequest( + request_id=request_id, + tokenized=tokenspeed_scheduler_pb2.TokenizedInput( + # Preserve explicit empty-list inputs (for "rejects empty ids" test); + # only fall back to the default if the caller didn't supply any. + input_ids=(input_ids if input_ids is not None else [1, 2, 3, 4]), + original_text="hello", + ), + sampling_params=tokenspeed_scheduler_pb2.SamplingParams( + temperature=0.0, + max_new_tokens=max_new_tokens, + ), + stream=stream, + ) + + +class TestGenerate: + @pytest.mark.asyncio + async def test_non_streaming_emits_complete( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + # TokenSpeed's AsyncLLM includes the trailing matched-stop token in + # ``output_ids`` (and prepends chat-template header tokens — modeled in + # ``test_strips_chat_template_prefix`` below). The servicer normalizes + # these out before the proto goes to the smg gateway so the tool + # parsers see the same tokens they would from the SGLang path. Here we + # check the matched-stop trim: ``raw=[10,11,12]`` with ``matched=12`` + # should arrive as ``[10,11]`` on the wire, and the matched id still + # rides in the ``matched_token_id`` field. + fake_engine.outputs = [ + { + "text": "hi", + "output_ids": [10, 11, 12], + "meta_info": { + "prompt_tokens": 4, + "completion_tokens": 3, + "cached_tokens": 0, + "finish_reason": FINISH_MATCHED_TOKEN(matched=12), + }, + } + ] + ctx = _make_context() + req = _make_generate_request(stream=False) + + frames = [frame async for frame in servicer.Generate(req, ctx)] + assert len(frames) == 1 + frame = frames[0] + assert frame.request_id == "rid-1" + assert frame.HasField("complete") + complete = frame.complete + assert list(complete.output_ids) == [10, 11] + assert complete.finish_reason == "stop" + assert complete.matched_token_id == 12 + assert complete.prompt_tokens == 4 + # Meta's completion_tokens passes through unchanged — matches SGLang's + # ``meta_info.get("completion_tokens")`` convention — even though the + # on-the-wire ``output_ids`` drops the stop token. + assert complete.completion_tokens == 3 + ctx.abort.assert_not_called() + + @pytest.mark.asyncio + async def test_strips_chat_template_prefix( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + """Reproducer for the bug where ``assistant\\n\\n`` leaked into the + decoded text and broke the ``llama`` tool-call parser. + + Real-world capture on Llama-3.2-1B-Instruct with a function-calling + prompt — ``output_ids`` was 27 tokens: 5 chat-template header tokens + (``<|eot_id|>, <|start_header_id|>, "assistant", <|end_header_id|>, + "\\n\\n"``) + 21 generated JSON tokens + 1 ``<|eom_id|>`` stop. With + ``skip_special_tokens=True`` only the 128xxx control tokens get + stripped at detokenization time, so the word token ``"assistant"`` + (78191) and ``"\\n\\n"`` (271) leaked into the text and flipped + ``serde_json::from_str`` from succeeding on clean JSON to failing on + ``assistant\\n\\n{...}``. + + The servicer now slices to the last ``completion_tokens`` tokens so + downstream detokenization only sees the actual generated content. + """ + fake_engine.outputs = [ + { + "text": '{"name": "add", "parameters": {"a": 3, "b": 5}}', + # Shape observed in the wild: [<|eot|>, <|start|>, "assistant", + # <|end|>, "\n\n", ...21 json tokens, <|eom|>] = 27 tokens. + # ``completion_tokens`` in TokenSpeed's meta covers the content + # *plus* the stop token, so 21 + 1 = 22. + "output_ids": [ + 128009, + 128006, + 78191, + 128007, + 271, + *range(9000, 9021), + 128008, + ], + "meta_info": { + "prompt_tokens": 200, + "completion_tokens": 22, + "cached_tokens": 0, + "finish_reason": FINISH_MATCHED_TOKEN(matched=128008), + }, + } + ] + ctx = _make_context() + req = _make_generate_request(stream=False) + + frames = [frame async for frame in servicer.Generate(req, ctx)] + complete = frames[0].complete + # Header tokens dropped via the ``raw[-completion_tokens:]`` slice; + # trailing stop token dropped because ``matched == token_ids[-1]``. + assert list(complete.output_ids) == list(range(9000, 9021)) + assert complete.matched_token_id == 128008 + # meta_info.completion_tokens passes through; only ``output_ids`` is + # normalized. Keeps the tokenspeed servicer's wire contract aligned + # with the SGLang reference. + assert complete.completion_tokens == 22 + + @pytest.mark.asyncio + async def test_streaming_emits_chunks_then_complete( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + fake_engine.outputs = [ + { + "text": "hi", + "output_ids": [10], # delta chunk 1 + "meta_info": { + "prompt_tokens": 4, + "completion_tokens": 1, + "cached_tokens": 0, + "finish_reason": None, + }, + }, + { + "text": "hi there", + "output_ids": [11, 12], # delta chunk 2 + finish + "meta_info": { + "prompt_tokens": 4, + "completion_tokens": 3, + "cached_tokens": 0, + "finish_reason": FINISH_LENGTH(length=16), + }, + }, + ] + ctx = _make_context() + req = _make_generate_request(stream=True) + + frames = [frame async for frame in servicer.Generate(req, ctx)] + # Expect: 2 chunks + 1 complete (emitted alongside the final chunk). + # ``completion_tokens`` here (3) exceeds this chunk's delta length (2), + # so the slice falls back to the raw delta. Length-finish has no + # matched stop to strip either, so token_ids pass through. + assert len(frames) == 3 + assert frames[0].HasField("chunk") + assert list(frames[0].chunk.token_ids) == [10] + assert frames[1].HasField("chunk") + assert list(frames[1].chunk.token_ids) == [11, 12] + assert frames[2].HasField("complete") + assert frames[2].complete.finish_reason == "length" + assert list(frames[2].complete.output_ids) == [11, 12] + + @pytest.mark.asyncio + async def test_empty_input_ids_rejected( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + ctx = _make_context() + req = _make_generate_request(input_ids=[]) + + with pytest.raises(_FakeAbortError) as exc: + async for _ in servicer.Generate(req, ctx): + pass + assert exc.value.code == grpc.StatusCode.INVALID_ARGUMENT + ctx.abort.assert_awaited_once() + + @pytest.mark.asyncio + async def test_abort_finish_reason_surfaces_as_grpc_error( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + fake_engine.outputs = [ + { + "text": "", + "output_ids": [], + "meta_info": { + "prompt_tokens": 0, + "completion_tokens": 0, + "cached_tokens": 0, + "finish_reason": { + "type": "abort", + "message": "client disconnected", + "status_code": 400, + }, + }, + } + ] + ctx = _make_context() + req = _make_generate_request() + + with pytest.raises(_FakeAbortError) as exc: + async for _ in servicer.Generate(req, ctx): + pass + assert exc.value.code == grpc.StatusCode.INVALID_ARGUMENT + + @pytest.mark.asyncio + async def test_cancel_calls_abort_request( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + """Cancelling the Generate task should tell the scheduler to drop the rid.""" + + started = asyncio.Event() + + async def never_finish(_obj): + started.set() + # Block forever so we can cancel from outside. ``yield`` is + # unreachable but keeps this an async generator. + await asyncio.sleep(30) + yield {} # pragma: no cover + + fake_engine.generate_fn = never_finish + ctx = _make_context() + req = _make_generate_request() + + gen = servicer.Generate(req, ctx) + task = asyncio.create_task(_drain(gen)) + await started.wait() + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + assert "rid-1" in fake_engine.aborted_rids + + @pytest.mark.asyncio + async def test_cancel_aborts_all_n_children( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + """n>1 expands rid to a list of per-choice ids; cancel must sweep them all. + + _build_generate_req rewrites ``rid`` to ``[rid-n0, rid-n1, ...]`` so + TokenSpeed's batch path sees unique rids per choice. If Generate's + cancel handler aborts only the original rid, the child scheduler + requests keep consuming GPU work. This test guards that edge. + """ + started = asyncio.Event() + + async def never_finish(_obj): + started.set() + await asyncio.sleep(30) + yield {} # pragma: no cover + + fake_engine.generate_fn = never_finish + ctx = _make_context() + req = _make_generate_request() + req.sampling_params.n = 3 + + gen = servicer.Generate(req, ctx) + task = asyncio.create_task(_drain(gen)) + await started.wait() + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + # Every per-choice rid must have had abort_request called. + assert set(fake_engine.aborted_rids) >= {"rid-1-n0", "rid-1-n1", "rid-1-n2"} + + +async def _drain(async_gen): + async for _ in async_gen: + pass + + +# --------------------------------------------------------------------------- +# Embed RPC +# --------------------------------------------------------------------------- + + +# --------------------------------------------------------------------------- +# Abort / HealthCheck / GetModelInfo / GetServerInfo / GetLoads +# +# Note: TokenSpeed's slim proto removes Embed / GetTokenizer / SubscribeKvEvents +# entirely, so there are no tests for them — the methods aren't on the +# servicer surface. +# --------------------------------------------------------------------------- + + +class TestAbortRpc: + @pytest.mark.asyncio + async def test_abort_known( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + fake_engine.rid_to_state["rid-1"] = _FakeState() + resp = await servicer.Abort( + tokenspeed_scheduler_pb2.AbortRequest(request_id="rid-1"), + _make_context(), + ) + assert resp.success is True + assert "rid-1" in fake_engine.aborted_rids + + @pytest.mark.asyncio + async def test_abort_unknown( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + resp = await servicer.Abort( + tokenspeed_scheduler_pb2.AbortRequest(request_id="missing"), + _make_context(), + ) + assert resp.success is False + # Nothing to abort — no state for "missing" or any "missing-n*" child. + assert fake_engine.aborted_rids == [] + + @pytest.mark.asyncio + async def test_abort_sweeps_n_children( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + """Abort("rid-1") must sweep the per-choice rids Generate mints + when ``sampling_params.n > 1`` (``rid-1-n0``, ``rid-1-n1``, ...). + """ + for child in ("rid-1-n0", "rid-1-n1", "rid-1-n2"): + fake_engine.rid_to_state[child] = _FakeState() + # An unrelated rid the sweep must NOT touch. + fake_engine.rid_to_state["unrelated-rid"] = _FakeState() + + resp = await servicer.Abort( + tokenspeed_scheduler_pb2.AbortRequest(request_id="rid-1"), + _make_context(), + ) + assert resp.success is True + assert sorted(fake_engine.aborted_rids) == [ + "rid-1-n0", + "rid-1-n1", + "rid-1-n2", + ] + + +class TestHealthCheck: + @pytest.mark.asyncio + async def test_reports_shutdown( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + fake_engine.gracefully_exit = True + resp = await servicer.HealthCheck( + tokenspeed_scheduler_pb2.HealthCheckRequest(), _make_context() + ) + assert resp.healthy is False + assert "shutting down" in resp.message.lower() + + @pytest.mark.asyncio + async def test_reports_healthy_when_scheduler_pushes_output( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + # generate_request yields once and updates last_receive_tstamp, which + # is what the health RPC watches for. + fake_engine.outputs = [ + { + "text": "", + "output_ids": [99], + "meta_info": {"finish_reason": FINISH_LENGTH(length=1)}, + } + ] + resp = await servicer.HealthCheck( + tokenspeed_scheduler_pb2.HealthCheckRequest(), _make_context() + ) + assert resp.healthy is True + + +class TestGetModelInfo: + @pytest.mark.asyncio + async def test_basic_fields( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + resp = await servicer.GetModelInfo( + tokenspeed_scheduler_pb2.GetModelInfoRequest(), _make_context() + ) + assert resp.model_path == "fake-model" + assert resp.vocab_size == 32000 + assert resp.max_context_length == 8192 + assert list(resp.eos_token_ids) == [2] + assert resp.model_type == "llama" + assert list(resp.architectures) == ["LlamaForCausalLM"] + + +class TestGetServerInfo: + @pytest.mark.asyncio + async def test_returns_scheduler_info( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + fake_engine.rid_to_state["a"] = _FakeState() + fake_engine.rid_to_state["b"] = _FakeState() + resp = await servicer.GetServerInfo( + tokenspeed_scheduler_pb2.GetServerInfoRequest(), _make_context() + ) + assert resp.active_requests == 2 + assert resp.max_total_num_tokens == 100000 + assert resp.tokenspeed_version + + @pytest.mark.asyncio + async def test_uses_tokenspeed_service_bases(self, servicer: TokenSpeedSchedulerServicer): + """TokenSpeed's servicer inherits the dedicated + ``TokenSpeedSchedulerServicer`` stub — identity is carried by the + proto package/service name, not by a field inside ``server_args``. + Guard the inheritance so nobody reverts to ``SglangSchedulerServicer`` + under the impression that 'wire shape is the same'; the wire shape + is the same, the *service path* is not, and the Rust router routes + on the service path. + """ + from smg_grpc_proto.generated import tokenspeed_scheduler_pb2_grpc + + assert isinstance(servicer, tokenspeed_scheduler_pb2_grpc.TokenSpeedSchedulerServicer) + + +class TestGetLoads: + @pytest.mark.asyncio + async def test_no_dp_ranks_returns_empty( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + # Bridge returns an empty list (e.g. before scheduler boots) — proto + # comes back with 0 ranks but still validly populated for the router. + fake_engine.load_outputs = [] + resp = await servicer.GetLoads(tokenspeed_scheduler_pb2.GetLoadsRequest(), _make_context()) + assert resp.dp_rank_count == 0 + assert resp.version == "tokenspeed" + assert list(resp.loads) == [] + assert resp.aggregate.total_running_reqs == 0 + assert resp.aggregate.total_waiting_reqs == 0 + + @pytest.mark.asyncio + async def test_maps_load_output_fields( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + # 2 DP ranks. rank 0 has 3 reqs (2 running, 1 waiting) and 100 pages + # used; rank 1 has 1 reqs (1 running, 0 waiting) and 200 pages used. + # page_size=16 (from fake_engine.server_args), max_total_num_tokens=100000 + # (from the servicer fixture's scheduler_info). + fake_engine.load_outputs = [ + SimpleNamespace(dp_rank=0, num_reqs=3, num_waiting_reqs=1, num_pages=100), + SimpleNamespace(dp_rank=1, num_reqs=1, num_waiting_reqs=0, num_pages=200), + ] + resp = await servicer.GetLoads(tokenspeed_scheduler_pb2.GetLoadsRequest(), _make_context()) + assert resp.dp_rank_count == 2 + assert len(resp.loads) == 2 + # rank 0 + l0 = resp.loads[0] + assert l0.dp_rank == 0 + assert l0.num_running_reqs == 2 # num_reqs - num_waiting_reqs + assert l0.num_waiting_reqs == 1 + assert l0.num_total_reqs == 3 + assert l0.num_used_tokens == 100 * 16 # pages * page_size + assert l0.max_total_num_tokens == 100000 + assert l0.token_usage == pytest.approx(100 * 16 / 100000) + # rank 1 + l1 = resp.loads[1] + assert l1.dp_rank == 1 + assert l1.num_running_reqs == 1 + assert l1.num_used_tokens == 200 * 16 + # aggregate + assert resp.aggregate.total_running_reqs == 3 + assert resp.aggregate.total_waiting_reqs == 1 + assert resp.aggregate.total_reqs == 4 + assert resp.aggregate.avg_token_usage == pytest.approx( + (100 * 16 / 100000 + 200 * 16 / 100000) / 2 + ) + + @pytest.mark.asyncio + async def test_scheduler_timeout_aborts_with_deadline_exceeded( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer, monkeypatch + ): + # If the scheduler subprocess never replies, the bridge call hangs. + # The servicer wraps it in ``asyncio.wait_for`` and aborts with + # DEADLINE_EXCEEDED rather than blocking the gRPC call indefinitely. + async def _hang(): + await asyncio.sleep(60) + return [] + + fake_engine.get_load = _hang # type: ignore[method-assign] + monkeypatch.setattr(_servicer_module, "HEALTH_CHECK_TIMEOUT", 0.05) + ctx = _make_context() + with pytest.raises(_FakeAbortError) as exc: + await servicer.GetLoads(tokenspeed_scheduler_pb2.GetLoadsRequest(), ctx) + assert exc.value.code == grpc.StatusCode.DEADLINE_EXCEEDED + + +# --------------------------------------------------------------------------- +# _build_generate_req semantics (pre-tokenized input) +# --------------------------------------------------------------------------- + + +class TestBuildGenerateReq: + def test_preserves_input_ids(self, servicer: TokenSpeedSchedulerServicer): + req = _make_generate_request(input_ids=[11, 22, 33], stream=True) + obj = servicer._build_generate_req(req) + assert obj.input_ids == [11, 22, 33] + assert obj.rid == "rid-1" + assert obj.stream is True + assert obj.sampling_params["max_new_tokens"] == 16 + + def test_rejects_missing_tokenized(self, servicer: TokenSpeedSchedulerServicer): + req = tokenspeed_scheduler_pb2.GenerateRequest(request_id="x") + with pytest.raises(ValueError, match="tokenized"): + servicer._build_generate_req(req) + + +# --------------------------------------------------------------------------- +# Output logprobs proto conversion +# --------------------------------------------------------------------------- + + +class TestConvertOutputLogprobsToProto: + """``_convert_output_logprobs_to_proto`` reads the cumulative + ``meta_info["output_token_logprobs"]`` / ``output_top_logprobs`` lists + that TokenSpeed accumulates per request, slices the last + ``len(output_ids)`` entries (the tokens this frame emitted), and keeps + the first ``n_keep`` so the result aligns with whatever + ``_generated_output_ids`` returned (which may have stripped a trailing + stop token).""" + + def test_returns_none_when_logprobs_empty(self): + # ``--enable-output-logprobs`` not set on the server → the keys exist + # in meta_info but the lists are empty. Must not return a half-built + # proto in this case (gateway would treat empty as "logprobs missing"). + out = { + "output_ids": [10, 20, 30], + "meta_info": {"output_token_logprobs": [], "output_top_logprobs": []}, + } + assert TokenSpeedSchedulerServicer._convert_output_logprobs_to_proto(out, n_keep=3) is None + + def test_returns_none_when_keys_missing(self): + # Logprobs not requested at all → meta_info lacks the keys entirely. + out = {"output_ids": [10, 20, 30], "meta_info": {}} + assert TokenSpeedSchedulerServicer._convert_output_logprobs_to_proto(out, n_keep=3) is None + + def test_returns_none_when_n_keep_zero(self): + # Stop-token strip can leave n_keep == 0 for a 1-token frame whose + # only token was the stop. Don't emit a proto with a length mismatch. + out = { + "output_ids": [99], + "meta_info": { + "output_token_logprobs": [(-0.1, 99, None)], + "output_top_logprobs": [None], + }, + } + assert TokenSpeedSchedulerServicer._convert_output_logprobs_to_proto(out, n_keep=0) is None + + def test_non_streaming_full_output(self): + # Non-streaming: output_ids covers the entire generation; cumulative + # meta_info matches it exactly. n_keep == len(output_ids) → emit all. + out = { + "output_ids": [10, 20, 30], + "meta_info": { + "output_token_logprobs": [ + (-0.5, 10, None), + (-0.3, 20, None), + (-0.1, 30, None), + ], + "output_top_logprobs": [None, None, None], + }, + } + proto = TokenSpeedSchedulerServicer._convert_output_logprobs_to_proto(out, n_keep=3) + assert proto is not None + assert list(proto.token_logprobs) == pytest.approx([-0.5, -0.3, -0.1]) + assert list(proto.token_ids) == [10, 20, 30] + assert len(proto.top_logprobs) == 3 + # ``None`` entries in raw_top translate to empty TopLogProbs placeholders. + for tl in proto.top_logprobs: + assert list(tl.values) == [] + assert list(tl.token_ids) == [] + + def test_streaming_chunk_emits_only_delta(self): + # Streaming chunk: output_ids has just the new tokens for this chunk, + # but meta_info is cumulative across the entire request. The slice + # ``[-len(output_ids):]`` on the cumulative list must yield exactly + # the delta this chunk represents. + out = { + "output_ids": [40, 50], # 2 new tokens this chunk + "meta_info": { + # cumulative: 4 prior tokens + 2 new + "output_token_logprobs": [ + (-1.1, 10, None), + (-1.2, 20, None), + (-1.3, 30, None), + (-1.4, 99, None), + (-0.7, 40, None), + (-0.6, 50, None), + ], + "output_top_logprobs": [None] * 6, + }, + } + proto = TokenSpeedSchedulerServicer._convert_output_logprobs_to_proto(out, n_keep=2) + assert proto is not None + assert list(proto.token_logprobs) == pytest.approx([-0.7, -0.6]) + assert list(proto.token_ids) == [40, 50] + + def test_top_k_alternatives(self): + # When the user requests top_logprobs=3, each position in + # output_top_logprobs is a list of K (logprob, token_id, text) tuples. + # Translate each into a TopLogProbs proto with parallel value/id arrays. + out = { + "output_ids": [40], + "meta_info": { + "output_token_logprobs": [(-0.7, 40, None)], + "output_top_logprobs": [ + [(-0.7, 40, None), (-1.2, 41, None), (-2.5, 42, None)], + ], + }, + } + proto = TokenSpeedSchedulerServicer._convert_output_logprobs_to_proto(out, n_keep=1) + assert proto is not None + assert len(proto.top_logprobs) == 1 + tl = proto.top_logprobs[0] + assert list(tl.values) == pytest.approx([-0.7, -1.2, -2.5]) + assert list(tl.token_ids) == [40, 41, 42] + + def test_strips_stop_token_alignment(self): + # When ``_generated_output_ids`` strips a trailing stop token, + # n_keep == len(output_ids) - 1. The converter must take the first + # n_keep entries of this frame's cumulative slice — emitting the + # logprob for the stripped stop token would misalign with the + # ``token_ids`` field on the proto. + out = { + "output_ids": [10, 20, 99], # 99 = stop, will be stripped → n_keep=2 + "meta_info": { + "output_token_logprobs": [ + (-0.5, 10, None), + (-0.3, 20, None), + (-0.1, 99, None), # logprob for the stop we just stripped + ], + "output_top_logprobs": [None, None, None], + }, + } + proto = TokenSpeedSchedulerServicer._convert_output_logprobs_to_proto(out, n_keep=2) + assert proto is not None + # Note: 99's logprob is dropped; emitted logprobs match the kept tokens. + assert list(proto.token_logprobs) == pytest.approx([-0.5, -0.3]) + assert list(proto.token_ids) == [10, 20] From 6bb18d28a380e8040e194b37a473bb944176d069 Mon Sep 17 00:00:00 2001 From: key4ng Date: Sat, 9 May 2026 13:22:32 -0700 Subject: [PATCH 10/24] refactor(grpc_servicer): tighten _finish_reason_to_dict MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Trim ~13 lines: collapse the early-returns into a single conditional, drop the inner ``try/except`` around ``to_json()`` (propagating the original exception is more useful than wrapping it), and shorten the docstring. Behavior is unchanged — the same shapes accepted, the same TypeError raised on unknown shapes. Signed-off-by: key4ng --- .../smg_grpc_servicer/tokenspeed/servicer.py | 39 +++++++------------ 1 file changed, 13 insertions(+), 26 deletions(-) diff --git a/grpc_servicer/smg_grpc_servicer/tokenspeed/servicer.py b/grpc_servicer/smg_grpc_servicer/tokenspeed/servicer.py index ca641b736..28af42211 100644 --- a/grpc_servicer/smg_grpc_servicer/tokenspeed/servicer.py +++ b/grpc_servicer/smg_grpc_servicer/tokenspeed/servicer.py @@ -62,36 +62,23 @@ def _lazy_generate_req_input(): def _finish_reason_to_dict(reason: Any) -> dict | None: - """Normalise a TokenSpeed finish reason into the SGLang on-wire shape. - - TokenSpeed emits ``BaseFinishReason``-style objects (or an already-normalised - dict) in ``meta_info["finish_reason"]``; downstream code expects a dict - with at minimum ``{"type": ...}`` and optionally ``{"matched": int|str}``. - ``None`` means "still running". - - We duck-type on ``to_json()`` rather than importing the concrete - ``BaseFinishReason`` class so the servicer module loads without pulling - in TokenSpeed's full request-processing graph. - - Raises ``TypeError`` for unknown shapes rather than coercing to a fake - ``stop``: silently flipping ``length``/``abort`` to ``stop`` and leaking - a debug ``repr()`` into the user-facing ``matched_stop_str`` field would - hide real bugs and corrupt the OpenAI ``finish_reason`` semantics. The - caller wraps this in ``try/except`` and turns it into ``StatusCode.INTERNAL``. + """Normalise a TokenSpeed finish reason into a dict. + + TokenSpeed emits ``BaseFinishReason``-style objects (or an already- + normalised dict) in ``meta_info["finish_reason"]``; downstream code + expects a dict with at minimum ``{"type": ...}`` and optionally + ``{"matched": int|str}``. ``None`` means "still running". + + We duck-type on ``to_json()`` so the servicer module loads without + pulling in TokenSpeed's full request-processing graph. Unknown shapes + raise ``TypeError`` rather than silently flipping ``length`` / ``abort`` + to ``stop`` — the caller maps that to ``StatusCode.INTERNAL``. """ - if reason is None: - return None - if isinstance(reason, dict): + if reason is None or isinstance(reason, dict): return reason to_json = getattr(reason, "to_json", None) if callable(to_json): - try: - result = to_json() - except Exception as e: # noqa: BLE001 - raise TypeError( - f"finish_reason of type {type(reason).__name__!r} raised in " - f"to_json(); refusing to silently emit a fake stop. {e}" - ) from e + result = to_json() if isinstance(result, dict): return result raise TypeError( From 93038d17a0a9b0e84b05394201d12c104f06af2c Mon Sep 17 00:00:00 2001 From: key4ng Date: Mon, 11 May 2026 21:07:28 -0700 Subject: [PATCH 11/24] docs(grpc_servicer): use --model in tokenspeed entrypoint usage example Upstream tokenspeed renamed the launch flag from ``--model-path`` to ``--model``. Update the docstring example so copy-paste still works. Signed-off-by: key4ng --- grpc_servicer/smg_grpc_servicer/tokenspeed/__main__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grpc_servicer/smg_grpc_servicer/tokenspeed/__main__.py b/grpc_servicer/smg_grpc_servicer/tokenspeed/__main__.py index fb80dcace..b4e6fb0e6 100644 --- a/grpc_servicer/smg_grpc_servicer/tokenspeed/__main__.py +++ b/grpc_servicer/smg_grpc_servicer/tokenspeed/__main__.py @@ -2,7 +2,7 @@ Usage:: - python -m smg_grpc_servicer.tokenspeed --model-path --host 127.0.0.1 --port 50051 + python -m smg_grpc_servicer.tokenspeed --model --host 127.0.0.1 --port 50051 All :class:`tokenspeed.runtime.utils.server_args.ServerArgs` flags are accepted verbatim (we reuse TokenSpeed's own ``prepare_server_args`` so there is no From a812f5ca6ef294ebec703b743694db7086eb0c6c Mon Sep 17 00:00:00 2001 From: key4ng Date: Mon, 11 May 2026 23:07:40 -0700 Subject: [PATCH 12/24] =?UTF-8?q?fix(grpc=5Fservicer):=20handle=20ServerAr?= =?UTF-8?q?gs=20``=5Fpath``=20=E2=86=92=20bare-name=20renames?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Upstream lightseekorg/tokenspeed renamed the model + tokenizer ``ServerArgs`` fields alongside the matching CLI flag renames: - ``ServerArgs.model_path`` → ``ServerArgs.model`` - ``ServerArgs.tokenizer_path`` → ``ServerArgs.tokenizer`` Both are sources of fields in ``GetModelInfo``, so post-bump that RPC fails with: AttributeError: 'ServerArgs' object has no attribute 'model_path' AttributeError: 'ServerArgs' object has no attribute 'tokenizer_path' Pick whichever attribute is populated so the servicer works against both old and new tokenspeed pins: model_path = getattr(self.server_args, "model", None) or getattr( self.server_args, "model_path", "" ) tokenizer_path = getattr(self.server_args, "tokenizer", None) or getattr( self.server_args, "tokenizer_path", "" ) The proto fields stay named ``model_path`` / ``tokenizer_path`` because those are the on-wire contracts the router consumes. 57/57 unit tests still pass. Signed-off-by: key4ng --- .../smg_grpc_servicer/tokenspeed/servicer.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/grpc_servicer/smg_grpc_servicer/tokenspeed/servicer.py b/grpc_servicer/smg_grpc_servicer/tokenspeed/servicer.py index 28af42211..22c9926a9 100644 --- a/grpc_servicer/smg_grpc_servicer/tokenspeed/servicer.py +++ b/grpc_servicer/smg_grpc_servicer/tokenspeed/servicer.py @@ -366,12 +366,24 @@ async def GetModelInfo( # and ``id2label_json`` / ``num_labels`` (not a classifier serving # path). The Rust client fills those slots back in when translating # to its SGLang-shaped wrapper. + # Upstream renamed ``ServerArgs.model_path`` → ``ServerArgs.model`` + # and ``ServerArgs.tokenizer_path`` → ``ServerArgs.tokenizer`` + # alongside the ``--model-path`` → ``--model`` flag rename. Old + # versions still set the ``_path`` form; new ones set the bare + # form. Pick whichever is populated so the servicer works against + # both. + model_path = getattr(self.server_args, "model", None) or getattr( + self.server_args, "model_path", "" + ) + tokenizer_path = getattr(self.server_args, "tokenizer", None) or getattr( + self.server_args, "tokenizer_path", "" + ) return tokenspeed_scheduler_pb2.GetModelInfoResponse( - model_path=self.server_args.model_path, - tokenizer_path=self.server_args.tokenizer_path or "", + model_path=model_path, + tokenizer_path=tokenizer_path or "", preferred_sampling_params=self.server_args.preferred_sampling_params or "", weight_version="", - served_model_name=(self.server_args.served_model_name or self.server_args.model_path), + served_model_name=(self.server_args.served_model_name or model_path), max_context_length=int(self.async_llm.context_len), vocab_size=int(model_config.vocab_size), model_type=(getattr(hf_config, "model_type", "") or "") if hf_config else "", From 8a3e651ea20575fccef5708d98839daf1a744abf Mon Sep 17 00:00:00 2001 From: key4ng Date: Tue, 12 May 2026 07:54:54 -0700 Subject: [PATCH 13/24] feat(grpc_servicer): wrap json_schema as structural_tag for reasoning models MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When tokenspeed runs with a reasoning parser that has an xgrammar template (e.g. ``gpt-oss`` → ``harmony``), forwarding a raw JSON-schema constraint causes xgrammar to fight the Harmony channel preamble (``<|channel|>analysis<|message|>…``): the model either generates garbage or stalls until ``max_tokens``, leaving ``content`` empty. Mirror tokenspeed's HTTP entrypoint (``serving_chat.py``): when a ``reasoning_parser`` is configured, wrap the user's JSON schema via ``structural_tag_for_reasoning_json_schema()`` so the grammar only activates inside the response channel. Parsers without an xgrammar mapping fall back to the raw json_schema unchanged. Plumbs ``reasoning_parser`` into ``_sampling_params_from_proto`` as a keyword-only argument so the helper stays a static method and existing tests keep passing without modification. The new import of ``tokenspeed.runtime.grammar.reasoning_structural_tag`` is wrapped in ``try/except ImportError`` so stale tokenspeed pins fall back to raw json_schema rather than crashing. Signed-off-by: key4ng --- .../smg_grpc_servicer/tokenspeed/servicer.py | 33 ++++++++++- .../tests/test_tokenspeed_servicer.py | 55 +++++++++++++++++++ 2 files changed, 86 insertions(+), 2 deletions(-) diff --git a/grpc_servicer/smg_grpc_servicer/tokenspeed/servicer.py b/grpc_servicer/smg_grpc_servicer/tokenspeed/servicer.py index 22c9926a9..8d9387f1b 100644 --- a/grpc_servicer/smg_grpc_servicer/tokenspeed/servicer.py +++ b/grpc_servicer/smg_grpc_servicer/tokenspeed/servicer.py @@ -22,6 +22,7 @@ import asyncio import dataclasses +import json import logging import os import re @@ -588,7 +589,10 @@ def _build_generate_req(self, request: tokenspeed_scheduler_pb2.GenerateRequest) if not input_ids: raise ValueError("GenerateRequest.tokenized.input_ids is empty") - sampling = self._sampling_params_from_proto(request.sampling_params) + sampling = self._sampling_params_from_proto( + request.sampling_params, + reasoning_parser=getattr(self.server_args, "reasoning_parser", None), + ) GenerateReqInput = _lazy_generate_req_input() obj = GenerateReqInput( @@ -636,6 +640,8 @@ def _build_generate_req(self, request: tokenspeed_scheduler_pb2.GenerateRequest) @staticmethod def _sampling_params_from_proto( params: tokenspeed_scheduler_pb2.SamplingParams, + *, + reasoning_parser: str | None = None, ) -> dict[str, Any]: """Build the dict that ``GenerateReqInput.sampling_params`` expects. @@ -703,7 +709,30 @@ def _sampling_params_from_proto( if params.HasField("regex"): out["regex"] = params.regex elif params.HasField("json_schema"): - out["json_schema"] = params.json_schema + # Mirror tokenspeed serving_chat.py: when the engine is + # running with a reasoning parser that has an xgrammar + # template (e.g. ``gpt-oss`` → ``harmony``), wrap the user's + # JSON schema as a structural tag so the grammar only + # activates inside the response channel. Without this, + # xgrammar fights the Harmony channel preamble + # (``<|channel|>analysis<|message|>…``) and the model stalls + # until ``max_tokens``. + wrapped: str | None = None + if reasoning_parser: + try: + from tokenspeed.runtime.grammar.reasoning_structural_tag import ( + structural_tag_for_reasoning_json_schema, + ) + + wrapped = structural_tag_for_reasoning_json_schema( + reasoning_parser, json.loads(params.json_schema) + ) + except ImportError: + wrapped = None + if wrapped is not None: + out["structural_tag"] = wrapped + else: + out["json_schema"] = params.json_schema elif params.HasField("ebnf_grammar"): out["ebnf"] = params.ebnf_grammar elif params.HasField("structural_tag"): diff --git a/grpc_servicer/tests/test_tokenspeed_servicer.py b/grpc_servicer/tests/test_tokenspeed_servicer.py index 112b9dfb2..89ed5c549 100644 --- a/grpc_servicer/tests/test_tokenspeed_servicer.py +++ b/grpc_servicer/tests/test_tokenspeed_servicer.py @@ -402,6 +402,61 @@ def test_constraints(self, setter, key, value): out = TokenSpeedSchedulerServicer._sampling_params_from_proto(params) assert out[key] == value + def test_json_schema_no_reasoning_parser_passes_through(self): + params = tokenspeed_scheduler_pb2.SamplingParams(json_schema='{"type": "object"}') + out = TokenSpeedSchedulerServicer._sampling_params_from_proto(params, reasoning_parser=None) + assert out["json_schema"] == '{"type": "object"}' + assert "structural_tag" not in out + + def test_json_schema_with_reasoning_parser_wraps_as_structural_tag(self, monkeypatch): + import sys + import types + + fake_module = types.ModuleType("tokenspeed.runtime.grammar.reasoning_structural_tag") + captured: dict[str, Any] = {} + + def _fake_wrap(rp: str, schema: Any) -> str: + captured["rp"] = rp + captured["schema"] = schema + return '{"wrapped": "tag"}' + + fake_module.structural_tag_for_reasoning_json_schema = _fake_wrap + monkeypatch.setitem( + sys.modules, + "tokenspeed.runtime.grammar.reasoning_structural_tag", + fake_module, + ) + + params = tokenspeed_scheduler_pb2.SamplingParams(json_schema='{"type": "object"}') + out = TokenSpeedSchedulerServicer._sampling_params_from_proto( + params, reasoning_parser="gpt-oss" + ) + + assert "json_schema" not in out + assert out["structural_tag"] == '{"wrapped": "tag"}' + assert captured["rp"] == "gpt-oss" + assert captured["schema"] == {"type": "object"} + + def test_json_schema_unknown_parser_falls_back_to_raw(self, monkeypatch): + import sys + import types + + fake_module = types.ModuleType("tokenspeed.runtime.grammar.reasoning_structural_tag") + fake_module.structural_tag_for_reasoning_json_schema = lambda rp, s: None + monkeypatch.setitem( + sys.modules, + "tokenspeed.runtime.grammar.reasoning_structural_tag", + fake_module, + ) + + params = tokenspeed_scheduler_pb2.SamplingParams(json_schema='{"type": "object"}') + out = TokenSpeedSchedulerServicer._sampling_params_from_proto( + params, reasoning_parser="unknown-parser" + ) + + assert out["json_schema"] == '{"type": "object"}' + assert "structural_tag" not in out + # --------------------------------------------------------------------------- # Generate RPC From c869ee968991d6fe8fc8b7b0dbdfceef1c70994f Mon Sep 17 00:00:00 2001 From: key4ng Date: Fri, 8 May 2026 12:46:16 -0700 Subject: [PATCH 14/24] ci(tokenspeed): add CI install + GPU e2e coverage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wires TokenSpeed into CI and the GPU e2e suite: - ``.github/actions/setup-tokenspeed`` composite action and ``scripts/ci_install_tokenspeed.sh`` to source-install TokenSpeed (kernel + scheduler) at a pinned ref, with a wheel cache lookup so repeat runs skip the ~20 min compile - e2e-gpu-job.yml: add a tokenspeed engine lane, gated on secret access so forked PRs skip cleanly - pr-test-rust.yml: install the same proto deps so Rust-only changes that touch ``crates/grpc_client`` still cover the tokenspeed proto - e2e_test infra: ``constants``, ``hooks``, ``worker``, and ``model_specs`` learn about a ``tokenspeed`` runtime alongside sglang/vllm/trtllm; ``worker.py`` adds the launch builder; the suite-wide ``@pytest.mark.engine(...)`` markers expand to include tokenspeed - Function-calling and tool_choice e2e suites swap to ``Qwen/Qwen3-4B-Instruct-2507`` for tool-call coverage (the Qwen3 family is what TokenSpeed's model registry currently supports) This is part 3 of 3 splitting #1351: - PR1: Rust gRPC + protocol - PR2: Python servicer + unit tests - PR3 (this): CI workflows + e2e tests Stacked on PR2. e2e wiring exercises both the Rust router from PR1 and the Python servicer from PR2 against a live TokenSpeed worker. Addresses CatherineSue's review on #1351: - drop the verbose Qwen3-4B docstring on ``TestToolChoiceQwen`` — that context belongs in the PR description, not in the test file Signed-off-by: key4ng --- .github/actions/setup-tokenspeed/action.yml | 23 +++ .github/workflows/e2e-gpu-job.yml | 20 +- .github/workflows/pr-test-rust.yml | 11 + .../chat_completions/test_enable_thinking.py | 2 +- .../chat_completions/test_function_calling.py | 29 ++- .../chat_completions/test_openai_server.py | 78 +++++++- .../test_structured_output.py | 6 +- e2e_test/chat_completions/test_validation.py | 8 +- e2e_test/completions/test_basic.py | 4 +- e2e_test/fixtures/hooks.py | 18 +- e2e_test/infra/constants.py | 17 +- e2e_test/infra/model_specs.py | 10 + e2e_test/infra/worker.py | 55 ++++- e2e_test/responses/test_sampling_params.py | 2 +- e2e_test/responses/test_state_management.py | 2 +- e2e_test/responses/test_streaming_events.py | 2 +- e2e_test/responses/test_structured_output.py | 2 +- e2e_test/responses/test_tools_call.py | 2 +- e2e_test/router/test_mmlu.py | 2 +- e2e_test/router/test_worker_api.py | 5 + scripts/ci_install_tokenspeed.sh | 189 ++++++++++++++++++ 21 files changed, 439 insertions(+), 48 deletions(-) create mode 100644 .github/actions/setup-tokenspeed/action.yml create mode 100755 scripts/ci_install_tokenspeed.sh diff --git a/.github/actions/setup-tokenspeed/action.yml b/.github/actions/setup-tokenspeed/action.yml new file mode 100644 index 000000000..348bb45f7 --- /dev/null +++ b/.github/actions/setup-tokenspeed/action.yml @@ -0,0 +1,23 @@ +name: 'Setup TokenSpeed Backend' +description: 'Create Python venv and install TokenSpeed (engine + kernel + scheduler) from source.' + +inputs: + github-token: + description: >- + GitHub token with read access to the (private) lightseekorg/tokenspeed + repository. Forwarded to ``ci_install_tokenspeed.sh`` as + ``TOKENSPEED_GITHUB_TOKEN`` so the clone uses HTTPS basic auth. + required: true + +runs: + using: 'composite' + steps: + - name: Setup Python venv + shell: bash + run: bash scripts/ci_setup_python_venv.sh + + - name: Install TokenSpeed + shell: bash + env: + TOKENSPEED_GITHUB_TOKEN: ${{ inputs.github-token }} + run: bash scripts/ci_install_tokenspeed.sh diff --git a/.github/workflows/e2e-gpu-job.yml b/.github/workflows/e2e-gpu-job.yml index bc05403ee..d4080ae74 100644 --- a/.github/workflows/e2e-gpu-job.yml +++ b/.github/workflows/e2e-gpu-job.yml @@ -6,7 +6,7 @@ on: engine: required: true type: string - description: "Engine to test: sglang, vllm, or trtllm" + description: "Engine to test: sglang, vllm, trtllm, or tokenspeed" gpu_tier: required: true type: string @@ -42,6 +42,16 @@ on: jobs: run: + # TokenSpeed lanes need ``TOKENSPEED_GITHUB_TOKEN`` to clone the + # private lightseekorg/tokenspeed repo, but GitHub does not pass + # secrets to workflows triggered by forked-PR events. Skip the + # tokenspeed engine on fork PRs so the job is reported as ``skipped`` + # rather than failing inside ``setup-tokenspeed`` with a 404. Same- + # repo PRs and pushes still run normally. + if: >- + inputs.engine != 'tokenspeed' + || github.event_name != 'pull_request' + || github.event.pull_request.head.repo.full_name == github.repository runs-on: ${{ inputs.runner }} timeout-minutes: ${{ inputs.timeout }} permissions: @@ -68,6 +78,14 @@ jobs: if: inputs.engine == 'trtllm' uses: ./.github/actions/setup-trtllm + - name: Setup TokenSpeed backend + if: inputs.engine == 'tokenspeed' + uses: ./.github/actions/setup-tokenspeed + with: + # lightseekorg/tokenspeed is private; ``secrets: inherit`` on the + # caller (pr-test-rust.yml) makes this secret available here. + github-token: ${{ secrets.TOKENSPEED_GITHUB_TOKEN }} + # Artifact downloads - name: Download wheel artifact uses: actions/download-artifact@v8 diff --git a/.github/workflows/pr-test-rust.yml b/.github/workflows/pr-test-rust.yml index 29621368b..e8ff049b8 100644 --- a/.github/workflows/pr-test-rust.yml +++ b/.github/workflows/pr-test-rust.yml @@ -390,6 +390,7 @@ jobs: - 'scripts/ci_setup_python_venv.sh' - 'scripts/ci_install_sglang.sh' - 'scripts/ci_install_vllm.sh' + - 'scripts/ci_install_tokenspeed.sh' - 'scripts/ci_install_e2e_deps.sh' - 'scripts/ci_killall_sglang.sh' - 'scripts/ci_build_wheel.sh' @@ -404,6 +405,7 @@ jobs: - 'e2e_test/router/**' - 'scripts/ci_install_vllm.sh' - 'scripts/ci_install_trtllm.sh' + - 'scripts/ci_install_tokenspeed.sh' agentic: - 'crates/mcp/**' - 'crates/data_connector/**' @@ -445,6 +447,10 @@ jobs: timeout: 20 - engine: trtllm timeout: 90 + # TokenSpeed builds kernel (CUDA) + scheduler (C++/CMake) from + # source, so first run takes ~30 min; cached runs are faster. + - engine: tokenspeed + timeout: 60 uses: ./.github/workflows/e2e-gpu-job.yml with: engine: ${{ matrix.engine }} @@ -555,6 +561,11 @@ jobs: timeout: 20 - engine: trtllm timeout: 30 + # Picks up TestChatCompletionGptOss (gpt-oss-20b, ``@pytest.mark.gpu(2)``) + # on the tokenspeed engine; the 1-GPU job collected the test class but + # pytest skipped it at collection because the runner only had 1 GPU. + - engine: tokenspeed + timeout: 60 uses: ./.github/workflows/e2e-gpu-job.yml with: engine: ${{ matrix.engine }} diff --git a/e2e_test/chat_completions/test_enable_thinking.py b/e2e_test/chat_completions/test_enable_thinking.py index c555ac8e8..aebe48db7 100644 --- a/e2e_test/chat_completions/test_enable_thinking.py +++ b/e2e_test/chat_completions/test_enable_thinking.py @@ -22,7 +22,7 @@ # ============================================================================= -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(1) @pytest.mark.model("Qwen/Qwen3-30B-A3B") @pytest.mark.gateway(extra_args=["--reasoning-parser", "qwen3", "--history-backend", "memory"]) diff --git a/e2e_test/chat_completions/test_function_calling.py b/e2e_test/chat_completions/test_function_calling.py index b4df06e4c..83cd1b1c6 100644 --- a/e2e_test/chat_completions/test_function_calling.py +++ b/e2e_test/chat_completions/test_function_calling.py @@ -22,7 +22,9 @@ # Shared Tool Definitions # ============================================================================= -# System message for Llama3.2 function calling +# System message for Llama3.2 function calling — prescribes the +# {"name": ..., "parameters": ...} JSON shape that the ``llama`` tool +# parser looks for. Used by ``TestToolChoiceLlama`` below. LLAMA_SYSTEM_MESSAGE = ( "You are a helpful assistant with tool calling capabilities. " "Only reply with a tool call if the function exists in the library provided by the user. " @@ -100,14 +102,14 @@ # ============================================================================= -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(1) @pytest.mark.model("meta-llama/Llama-3.2-1B-Instruct") @pytest.mark.gateway(extra_args=["--tool-call-parser", "llama", "--history-backend", "memory"]) @pytest.mark.parametrize("setup_backend", ["grpc"], indirect=True) @pytest.mark.parametrize("api_client", ["openai", "smg"], indirect=True) class TestOpenAIServerFunctionCalling: - """Tests for OpenAI-compatible function calling with Llama tool parser.""" + """Tests for OpenAI-compatible function calling with the llama tool parser.""" def test_function_calling_format(self, model, api_client): """Test: Whether the function call format returned by the AI is correct. @@ -265,8 +267,8 @@ def test_function_calling_streaming_args_parsing(self, model, api_client): }, "required": ["a", "b"], }, - # Llama-3.2-1B is flaky in tool call. It won't always respond with - # parameters unless we set strict. + # Llama-3.2-1B is flaky in tool call format, so we force it + # with strict mode. "strict": True, }, } @@ -377,7 +379,6 @@ def test_function_call_required(self, model, api_client): - When tool_choice == "required", the model should return one or more tool_calls. """ - tools = [ { "type": "function", @@ -457,7 +458,6 @@ def test_function_call_specific(self, model, api_client): - When tool_choice is a specific ToolChoice, the model should return one or more tool_calls. """ - tools = [ { "type": "function", @@ -526,7 +526,6 @@ def test_streaming_multiple_choices_finish_reason(self, model, api_client): This tests the fix for the bug where only the last index got a finish_reason chunk. """ - tools = [ { "type": "function", @@ -709,7 +708,7 @@ def test_streaming_multiple_choices_without_tools(self, model, api_client): # ============================================================================= -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(1) @pytest.mark.model("meta-llama/Llama-3.1-8B-Instruct") @pytest.mark.gateway(extra_args=["--tool-call-parser", "pythonic", "--history-backend", "memory"]) @@ -1489,7 +1488,7 @@ def test_conflicting_defs_required_tool_choice(self, model, api_client): # ============================================================================= -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(1) @pytest.mark.model("meta-llama/Llama-3.2-1B-Instruct") @pytest.mark.gateway(extra_args=["--tool-call-parser", "llama", "--history-backend", "memory"]) @@ -1510,9 +1509,9 @@ class TestToolChoiceLlama(_TestToolChoiceBase): # ============================================================================= -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(1) -@pytest.mark.model("Qwen/Qwen2.5-7B-Instruct") +@pytest.mark.model("Qwen/Qwen3-4B-Instruct-2507") @pytest.mark.gateway(extra_args=["--tool-call-parser", "qwen", "--history-backend", "memory"]) @pytest.mark.parametrize("setup_backend", ["grpc"], indirect=True) @pytest.mark.parametrize("api_client", ["openai", "smg"], indirect=True) @@ -1579,9 +1578,9 @@ def test_conflicting_defs_required_tool_choice(self, model, api_client): } -@pytest.mark.engine("sglang", "vllm", "trtllm") -@pytest.mark.gpu(2) -@pytest.mark.model("Qwen/Qwen2.5-14B-Instruct") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") +@pytest.mark.gpu(1) +@pytest.mark.model("Qwen/Qwen3-4B-Instruct-2507") @pytest.mark.gateway(extra_args=["--tool-call-parser", "qwen", "--history-backend", "memory"]) @pytest.mark.parametrize("setup_backend", ["grpc"], indirect=True) @pytest.mark.parametrize("api_client", ["openai", "smg"], indirect=True) diff --git a/e2e_test/chat_completions/test_openai_server.py b/e2e_test/chat_completions/test_openai_server.py index 517cdbb4a..d05d73287 100644 --- a/e2e_test/chat_completions/test_openai_server.py +++ b/e2e_test/chat_completions/test_openai_server.py @@ -20,7 +20,7 @@ # ============================================================================= -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(1) @pytest.mark.model("meta-llama/Llama-3.1-8B-Instruct") @pytest.mark.gateway(extra_args=["--history-backend", "memory"]) @@ -33,7 +33,23 @@ class TestChatCompletion: # Harmony (gpt-oss) does not trim because its detokenization is not channel-aware. STOP_SEQUENCE_TRIMMED = True - @pytest.mark.parametrize("logprobs", [None, 5]) + @pytest.mark.parametrize( + "logprobs", + [ + None, + pytest.param( + 5, + marks=pytest.mark.skip_for_runtime( + "tokenspeed", + reason=( + "tokenspeed's --enable-top-logprobs is not yet implemented " + "(raises at startup); base output logprobs work via " + "--enable-output-logprobs but the test requires top_logprobs=5" + ), + ), + ), + ], + ) @pytest.mark.parametrize("parallel_sample_num", [1, 2]) def test_chat_completion(self, model, api_client, logprobs, parallel_sample_num): """Test non-streaming chat completion with logprobs and parallel sampling.""" @@ -73,7 +89,23 @@ def test_chat_completion(self, model, api_client, logprobs, parallel_sample_num) assert response.usage.completion_tokens > 0 assert response.usage.total_tokens > 0 - @pytest.mark.parametrize("logprobs", [None, 5]) + @pytest.mark.parametrize( + "logprobs", + [ + None, + pytest.param( + 5, + marks=pytest.mark.skip_for_runtime( + "tokenspeed", + reason=( + "tokenspeed's --enable-top-logprobs is not yet implemented " + "(raises at startup); base output logprobs work via " + "--enable-output-logprobs but the test requires top_logprobs=5" + ), + ), + ), + ], + ) @pytest.mark.parametrize("parallel_sample_num", [1, 2]) def test_chat_completion_stream(self, model, api_client, logprobs, parallel_sample_num): """Test streaming chat completion with logprobs and parallel sampling.""" @@ -359,7 +391,7 @@ def _delta_text(delta): return delta.content or getattr(delta, "reasoning_content", "") or "" -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(2) @pytest.mark.model("openai/gpt-oss-20b") @pytest.mark.gateway(extra_args=["--history-backend", "memory"]) @@ -375,13 +407,45 @@ class TestChatCompletionGptOss(TestChatCompletion): STOP_SEQUENCE_TRIMMED = False - @pytest.mark.parametrize("logprobs", [None, 5]) + @pytest.mark.parametrize( + "logprobs", + [ + None, + pytest.param( + 5, + marks=pytest.mark.skip_for_runtime( + "tokenspeed", + reason=( + "tokenspeed's --enable-top-logprobs is not yet implemented " + "(raises at startup); base output logprobs work via " + "--enable-output-logprobs but the test requires top_logprobs=5" + ), + ), + ), + ], + ) @pytest.mark.parametrize("parallel_sample_num", [1, 2]) def test_chat_completion(self, model, api_client, logprobs, parallel_sample_num): """Test non-streaming chat completion with logprobs and parallel sampling.""" super().test_chat_completion(model, api_client, logprobs, parallel_sample_num) - @pytest.mark.parametrize("logprobs", [None, 5]) + @pytest.mark.parametrize( + "logprobs", + [ + None, + pytest.param( + 5, + marks=pytest.mark.skip_for_runtime( + "tokenspeed", + reason=( + "tokenspeed's --enable-top-logprobs is not yet implemented " + "(raises at startup); base output logprobs work via " + "--enable-output-logprobs but the test requires top_logprobs=5" + ), + ), + ), + ], + ) @pytest.mark.parametrize("parallel_sample_num", [1, 2]) @pytest.mark.skip_for_runtime( "trtllm", reason="trtllm may return more top_logprobs than requested in streaming" @@ -402,7 +466,7 @@ def test_response_prefill(self, model, api_client): pass -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(4) @pytest.mark.model("openai/gpt-oss-120b") @pytest.mark.gateway(extra_args=["--history-backend", "memory"]) diff --git a/e2e_test/chat_completions/test_structured_output.py b/e2e_test/chat_completions/test_structured_output.py index 654443bbd..2c8fe7235 100644 --- a/e2e_test/chat_completions/test_structured_output.py +++ b/e2e_test/chat_completions/test_structured_output.py @@ -124,7 +124,7 @@ def test_response_format_json_schema_stream(self, model, api_client): # ============================================================================= -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(1) @pytest.mark.model("meta-llama/Llama-3.1-8B-Instruct") @pytest.mark.gateway(extra_args=["--history-backend", "memory"]) @@ -139,7 +139,7 @@ class TestStructuredOutputRegular(_TestStructuredOutputBase): # ============================================================================= -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(2) @pytest.mark.model("openai/gpt-oss-20b") @pytest.mark.gateway(extra_args=["--history-backend", "memory"]) @@ -149,7 +149,7 @@ class TestStructuredOutputGptOss(_TestStructuredOutputBase): """Structured output tests for Harmony models (GPT-OSS 20B, 1 GPU).""" -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(4) @pytest.mark.model("openai/gpt-oss-120b") @pytest.mark.gateway(extra_args=["--history-backend", "memory"]) diff --git a/e2e_test/chat_completions/test_validation.py b/e2e_test/chat_completions/test_validation.py index 7192f8e08..96e70f371 100644 --- a/e2e_test/chat_completions/test_validation.py +++ b/e2e_test/chat_completions/test_validation.py @@ -37,7 +37,7 @@ def get_tokenizer(model_path: str): # ============================================================================= -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(1) @pytest.mark.model("meta-llama/Llama-3.1-8B-Instruct") @pytest.mark.gateway(extra_args=["--history-backend", "memory"]) @@ -107,7 +107,7 @@ def test_ignore_eos(self, model, api_client): # ============================================================================= -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(1) @pytest.mark.model("meta-llama/Llama-3.1-8B-Instruct") @pytest.mark.gateway(extra_args=["--history-backend", "memory"]) @@ -167,7 +167,7 @@ def run_chat_completion(): # ============================================================================= -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(2) @pytest.mark.model("openai/gpt-oss-20b") @pytest.mark.gateway(extra_args=["--history-backend", "memory"]) @@ -243,7 +243,7 @@ def test_tool_choice_with_response_format_rejected(self, model, api_client): # ============================================================================= -@pytest.mark.engine("sglang", "vllm") +@pytest.mark.engine("sglang", "vllm", "tokenspeed") @pytest.mark.gpu(1) @pytest.mark.model("meta-llama/Llama-3.2-1B-Instruct") @pytest.mark.gateway(extra_args=["--tool-call-parser", "llama", "--history-backend", "memory"]) diff --git a/e2e_test/completions/test_basic.py b/e2e_test/completions/test_basic.py index e91d3f90f..865bbcf87 100644 --- a/e2e_test/completions/test_basic.py +++ b/e2e_test/completions/test_basic.py @@ -9,7 +9,7 @@ import pytest -@pytest.mark.engine("sglang", "vllm") +@pytest.mark.engine("sglang", "vllm", "tokenspeed") @pytest.mark.gpu(1) @pytest.mark.model("meta-llama/Llama-3.1-8B-Instruct") @pytest.mark.parametrize("setup_backend", ["grpc"], indirect=True) @@ -161,7 +161,7 @@ def test_non_streaming_echo_max_tokens_zero(self, model, api_client): assert response.usage.completion_tokens == 0 -@pytest.mark.engine("sglang", "vllm") +@pytest.mark.engine("sglang", "vllm", "tokenspeed") @pytest.mark.gpu(1) @pytest.mark.model("meta-llama/Llama-3.1-8B-Instruct") @pytest.mark.parametrize("setup_backend", ["grpc"], indirect=True) diff --git a/e2e_test/fixtures/hooks.py b/e2e_test/fixtures/hooks.py index ca15e0027..372fa4bad 100644 --- a/e2e_test/fixtures/hooks.py +++ b/e2e_test/fixtures/hooks.py @@ -82,12 +82,18 @@ def pytest_configure(config: pytest.Config) -> None: def pytest_runtest_setup(item: pytest.Item) -> None: - """Skip tests marked with ``@pytest.mark.skip_for_runtime``.""" - marker = item.get_closest_marker("skip_for_runtime") - if marker: - current_runtime = get_runtime() - skip_runtimes = marker.args - if current_runtime in skip_runtimes: + """Skip tests marked with ``@pytest.mark.skip_for_runtime``. + + A single test item can carry multiple ``skip_for_runtime`` marks — e.g. a + method-level ``@pytest.mark.skip_for_runtime("trtllm", ...)`` plus a + parametrize-attached ``pytest.param(5, marks=skip_for_runtime("tokenspeed", + ...))``. ``get_closest_marker`` only returns one of them, which silently + drops the others. Iterate every mark so a runtime that's named in any of + them gets skipped, regardless of which is "closest". + """ + current_runtime = get_runtime() + for marker in item.iter_markers(name="skip_for_runtime"): + if current_runtime in marker.args: reason = marker.kwargs.get("reason", f"Not supported on {current_runtime}") pytest.skip(f"Skipping for {current_runtime}: {reason}") diff --git a/e2e_test/infra/constants.py b/e2e_test/infra/constants.py index 19d6421f8..9bb7a194c 100644 --- a/e2e_test/infra/constants.py +++ b/e2e_test/infra/constants.py @@ -25,6 +25,7 @@ class Runtime(StrEnum): SGLANG = "sglang" VLLM = "vllm" TRTLLM = "trtllm" + TOKENSPEED = "tokenspeed" OPENAI = "openai" XAI = "xai" GEMINI = "gemini" @@ -33,7 +34,7 @@ class Runtime(StrEnum): # Convenience sets LOCAL_MODES = frozenset({ConnectionMode.HTTP, ConnectionMode.GRPC}) -LOCAL_RUNTIMES = frozenset({Runtime.SGLANG, Runtime.VLLM, Runtime.TRTLLM}) +LOCAL_RUNTIMES = frozenset({Runtime.SGLANG, Runtime.VLLM, Runtime.TRTLLM, Runtime.TOKENSPEED}) CLOUD_RUNTIMES = frozenset({Runtime.OPENAI, Runtime.XAI, Runtime.GEMINI, Runtime.ANTHROPIC}) # Fixture parameter names (used in @pytest.mark.parametrize) @@ -51,7 +52,9 @@ class Runtime(StrEnum): ENV_MODELS = "E2E_MODELS" ENV_BACKENDS = "E2E_BACKENDS" ENV_MODEL = "E2E_MODEL" -ENV_RUNTIME = "E2E_RUNTIME" # Runtime for gRPC tests: "sglang", "vllm", or "trtllm" +ENV_RUNTIME = ( + "E2E_RUNTIME" # Runtime for gRPC tests — one of Runtime.{SGLANG,VLLM,TRTLLM,TOKENSPEED} +) ENV_STARTUP_TIMEOUT = "E2E_STARTUP_TIMEOUT" ENV_SKIP_MODEL_POOL = "SKIP_MODEL_POOL" ENV_SKIP_BACKEND_SETUP = "SKIP_BACKEND_SETUP" @@ -100,11 +103,21 @@ def is_trtllm() -> bool: return get_runtime() == "trtllm" +def is_tokenspeed() -> bool: + """Check if tests are running with TokenSpeed runtime. + + Returns: + True if E2E_RUNTIME is "tokenspeed", False otherwise. + """ + return get_runtime() == "tokenspeed" + + # Runtime display labels RUNTIME_LABELS = { "sglang": "SGLang", "vllm": "vLLM", "trtllm": "TensorRT-LLM", + "tokenspeed": "TokenSpeed", } ENV_SHOW_ROUTER_LOGS = "SHOW_ROUTER_LOGS" diff --git a/e2e_test/infra/model_specs.py b/e2e_test/infra/model_specs.py index 3fd1eb2cd..fdf6a750f 100644 --- a/e2e_test/infra/model_specs.py +++ b/e2e_test/infra/model_specs.py @@ -61,6 +61,16 @@ def _resolve_model_path(hf_path: str) -> str: "tp": 1, "features": ["chat", "streaming", "reasoning"], }, + # Qwen3 instruct (non-thinking variant) — emits the same + # `\n{"name": ..., "arguments": ...}\n` format as + # Qwen 2.5, so the gateway's ``qwen`` tool-call parser applies. Used by + # ``TestToolChoiceQwen`` and ``TestMultiTurnToolCall``: a Qwen3 model is + # required because the Qwen2 family is not in TokenSpeed's model registry. + "Qwen/Qwen3-4B-Instruct-2507": { + "model": _resolve_model_path("Qwen/Qwen3-4B-Instruct-2507"), + "tp": 1, + "features": ["chat", "streaming", "function_calling", "tool_choice"], + }, # Thinking/reasoning model (larger) "Qwen/Qwen3-30B-A3B": { "model": _resolve_model_path("Qwen/Qwen3-30B-A3B"), diff --git a/e2e_test/infra/worker.py b/e2e_test/infra/worker.py index 8e72d4e04..638813f11 100644 --- a/e2e_test/infra/worker.py +++ b/e2e_test/infra/worker.py @@ -34,7 +34,7 @@ class Worker: """A single inference worker process.""" model_id: str - engine: str # "sglang", "vllm", or "trtllm" + engine: str # "sglang", "vllm", "trtllm", or "tokenspeed" port: int gpu_ids: list[int] mode: ConnectionMode = ConnectionMode.HTTP @@ -178,6 +178,13 @@ def _build_cmd(self) -> list[str]: return self._build_vllm_http_cmd(model_path, tp_size, spec) elif self.engine == "trtllm": return self._build_trtllm_cmd(model_path, tp_size, spec) + elif self.engine == "tokenspeed": + if self.mode != ConnectionMode.GRPC: + raise ValueError( + "TokenSpeed e2e workers only support gRPC mode; " + "HTTP mode would go through the existing OpenAI frontend." + ) + return self._build_tokenspeed_grpc_cmd(model_path, tp_size, spec) else: raise ValueError(f"Unsupported engine: {self.engine}") @@ -261,6 +268,52 @@ def _build_vllm_base_cmd( cmd.extend(extra) return cmd + def _build_tokenspeed_grpc_cmd(self, model_path: str, tp_size: int, spec: dict) -> list[str]: + """Build TokenSpeed gRPC server command. + + Launches the SMG-hosted TokenSpeed gRPC server + (``smg_grpc_servicer.tokenspeed``) which wraps TokenSpeed's AsyncLLM + behind the dedicated ``tokenspeed.grpc.scheduler`` service. + Auto-detected as TokenSpeed by the Rust router via its native + service-name handshake. + """ + cmd = [ + "python3", + "-m", + "smg_grpc_servicer.tokenspeed", + "--model-path", + model_path, + "--host", + DEFAULT_HOST, + "--port", + str(self.port), + "--tensor-parallel-size", + str(tp_size), + "--log-level", + "warning", + # Mirrors what trtllm does and what sglang/vllm do implicitly: + # the smg gateway translates ``tool_choice=required`` and + # ``tool_choice={function}`` into a json_schema constraint on the + # sampling-params proto. TokenSpeed honors that constraint only + # when a grammar backend is configured — its default is ``None``, + # which silently drops the constraint and lets the model free-run. + "--grammar-backend", + "xgrammar", + # Per-token sampled-token logprobs are gated by this flag in + # tokenspeed (``ServerArgs.enable_output_logprobs`` defaults + # OFF). Without it, requests asking for logprobs silently + # receive empty arrays — see test_chat_completion[*-5-*] which + # exercises ``logprobs=True, top_logprobs=5`` and asserts + # logprobs are returned. Top-K logprobs are still missing + # upstream (``--enable-top-logprobs`` is not yet implemented), + # so those parametrize variants stay skipped. + "--enable-output-logprobs", + ] + extra = spec.get("tokenspeed_args", []) + if extra: + cmd.extend(extra) + return cmd + def _build_trtllm_cmd(self, model_path: str, tp_size: int, spec: dict) -> list[str]: """Build TensorRT-LLM gRPC server command.""" # Create config file to enable xgrammar guided decoding diff --git a/e2e_test/responses/test_sampling_params.py b/e2e_test/responses/test_sampling_params.py index 2faf34994..d7fbb7180 100644 --- a/e2e_test/responses/test_sampling_params.py +++ b/e2e_test/responses/test_sampling_params.py @@ -103,7 +103,7 @@ class TestSamplingParamsLocal(_SamplingParamsBase): """Regular model (Qwen via SGLang).""" -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(2) @pytest.mark.e2e @pytest.mark.model("openai/gpt-oss-20b") diff --git a/e2e_test/responses/test_state_management.py b/e2e_test/responses/test_state_management.py index ce8679f5d..5ec76e1db 100644 --- a/e2e_test/responses/test_state_management.py +++ b/e2e_test/responses/test_state_management.py @@ -328,7 +328,7 @@ def test_mutually_exclusive_parameters(self, model, api_client): # ============================================================================= -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(2) @pytest.mark.e2e @pytest.mark.model("openai/gpt-oss-20b") diff --git a/e2e_test/responses/test_streaming_events.py b/e2e_test/responses/test_streaming_events.py index be06235de..4c2613a3a 100644 --- a/e2e_test/responses/test_streaming_events.py +++ b/e2e_test/responses/test_streaming_events.py @@ -106,7 +106,7 @@ def test_output_item_event_emitted(self, model, api_client): # ============================================================================= -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(2) @pytest.mark.e2e @pytest.mark.model("openai/gpt-oss-20b") diff --git a/e2e_test/responses/test_structured_output.py b/e2e_test/responses/test_structured_output.py index 359910942..9567f4241 100644 --- a/e2e_test/responses/test_structured_output.py +++ b/e2e_test/responses/test_structured_output.py @@ -115,7 +115,7 @@ def test_structured_output_json_schema(self, model, api_client): # ============================================================================= -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(2) @pytest.mark.e2e @pytest.mark.model("openai/gpt-oss-20b") diff --git a/e2e_test/responses/test_tools_call.py b/e2e_test/responses/test_tools_call.py index f4e5bda92..8e906260b 100644 --- a/e2e_test/responses/test_tools_call.py +++ b/e2e_test/responses/test_tools_call.py @@ -763,7 +763,7 @@ def _check_stream(events, expected_label): # ============================================================================= -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(2) @pytest.mark.e2e @pytest.mark.model("openai/gpt-oss-20b") diff --git a/e2e_test/router/test_mmlu.py b/e2e_test/router/test_mmlu.py index 2b7937116..1f85e09ff 100644 --- a/e2e_test/router/test_mmlu.py +++ b/e2e_test/router/test_mmlu.py @@ -23,7 +23,7 @@ logger = logging.getLogger(__name__) -@pytest.mark.engine("sglang", "vllm") +@pytest.mark.engine("sglang", "vllm", "tokenspeed") @pytest.mark.gpu(1) @pytest.mark.e2e @pytest.mark.parametrize("setup_backend", ["grpc"], indirect=True) diff --git a/e2e_test/router/test_worker_api.py b/e2e_test/router/test_worker_api.py index 742241ec3..46017e5e2 100644 --- a/e2e_test/router/test_worker_api.py +++ b/e2e_test/router/test_worker_api.py @@ -219,6 +219,11 @@ def test_igw_multiple_workers(self): @pytest.mark.e2e +# TokenSpeed deliberately excluded: this test class spins up its worker +# via ``ConnectionMode.HTTP``, and ``Worker._build_tokenspeed_grpc_cmd`` +# rejects HTTP mode — TokenSpeed has no HTTP frontend in this repo. +# Including ``tokenspeed`` here would fail deterministically on every +# run rather than validate health-check behaviour. @pytest.mark.engine("sglang", "vllm") @pytest.mark.gpu(1) class TestDisableHealthCheck: diff --git a/scripts/ci_install_tokenspeed.sh b/scripts/ci_install_tokenspeed.sh new file mode 100755 index 000000000..e59494ea0 --- /dev/null +++ b/scripts/ci_install_tokenspeed.sh @@ -0,0 +1,189 @@ +#!/bin/bash +# Install TokenSpeed from source (engine + kernel + scheduler) for CI. +# +# TokenSpeed is not published to PyPI, so we clone it and pip-install the +# in-tree ``tokenspeed-kernel`` (CUDA), ``tokenspeed-scheduler`` (C++/nanobind), +# and ``python/`` packages. Mirrors the upstream ``docker/Dockerfile`` pipeline. +# +# Prerequisites (expected on k8s-runner-gpu nodes): +# - NVIDIA driver 580+ (CUDA 13) +# - CUDA 13.0 toolkit at /usr/local/cuda-13.0 or /usr/local/cuda +# - H100 GPUs (sm90) +# +# Heavy first run (~30 min for kernel CUDA compile); subsequent runs on the +# same runner hit the pip wheel cache at /tmp/tokenspeed-wheel-cache/ and +# short-circuit the kernel build. + +set -euo pipefail + +# Activate venv if it exists +if [ -f ".venv/bin/activate" ]; then + source .venv/bin/activate +fi + +# Pin to a tested TokenSpeed SHA so CI is reproducible. Floating against +# ``main`` has bitten us before (lightseekorg/tokenspeed renamed server_args, +# the gRPC servicer broke until we caught up). Bump this explicitly when we +# want a newer runtime, ideally via a scheduled bump-and-CI routine rather +# than ad hoc. +# +# This SHA is from lightseekorg/tokenspeed main; it includes dense +# ``LlamaForCausalLM`` registration, the Qwen3 / gpt-oss arches the e2e +# suite (``test_function_calling``, ``test_openai_server`` etc.) runs +# against, lightseekorg/tokenspeed#598 (defensive ``pad_token_id`` read +# in ``Qwen3MoeModel.__init__`` so ``Qwen/Qwen3-30B-A3B`` loads), +# lightseekorg/tokenspeed#578 (FSM absorbs late ``ExtendResultEvent`` +# after a request is terminalized — without this the scheduler crashes +# under retract pressure on the nightly Qwen3-30B-A3B run), and +# lightseekorg/tokenspeed#602 (release scheduler slot + cancel +# non-stream handlers on client disconnect; eliminates the long stream +# of ``state was deleted in AsyncLLM`` warnings that preceded the +# crash). +TOKENSPEED_REF="${TOKENSPEED_REF:-eabeb106a070825d5549fbc84ecd8e11651cf3fe}" +TOKENSPEED_REPO="${TOKENSPEED_REPO:-https://github.com/lightseekorg/tokenspeed.git}" +TOKENSPEED_DIR="${TOKENSPEED_DIR:-/tmp/tokenspeed-src}" +WHEEL_CACHE="${TOKENSPEED_WHEEL_CACHE:-/tmp/tokenspeed-wheel-cache}" + +# lightseekorg/tokenspeed is private, so the clone needs HTTPS basic auth. +# CI passes the token via the ``setup-tokenspeed`` action's ``github-token`` +# input; locally you can export ``TOKENSPEED_GITHUB_TOKEN`` yourself. +if [ -n "${TOKENSPEED_GITHUB_TOKEN:-}" ]; then + TOKENSPEED_REPO="https://x-access-token:${TOKENSPEED_GITHUB_TOKEN}@${TOKENSPEED_REPO#https://}" +fi + +# Install uv for faster package management (mirrors ci_install_sglang.sh). +if ! command -v uv &> /dev/null; then + echo "Installing uv..." + curl -LsSf https://astral.sh/uv/install.sh | sh + export PATH="$HOME/.local/bin:$PATH" +fi +echo "uv version: $(uv --version)" + +# ── CUDA runtime setup ───────────────────────────────────────────────────── +# k8s-runner-gpu ships the NVIDIA driver + CUDA runtime libs but not the +# SDK (nvcc, headers). Install them on demand — same approach as +# ``ci_install_sglang.sh``, which installs cuda-nvcc-12-9 + +# cuda-cudart-dev-12-9 when ``/usr/local/cuda/bin/nvcc`` is missing. +# TokenSpeed's Dockerfile targets CUDA 13.0, so install the matching +# toolkit packages here. +CUDA_HOME="${CUDA_HOME:-/usr/local/cuda}" +if [ ! -x "${CUDA_HOME}/bin/nvcc" ]; then + echo "Installing CUDA toolkit (nvcc not found at ${CUDA_HOME}/bin/nvcc)..." + curl -fsSL -o /tmp/cuda-keyring.deb \ + https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb + sudo dpkg -i /tmp/cuda-keyring.deb + rm /tmp/cuda-keyring.deb + sudo apt-get update -qq + # cuda-nvcc-13-0: provides nvcc + cuda_runtime_api.h + # cuda-cudart-dev-13-0: provides cuda_runtime.h + libcudart headers + # cuda-libraries-dev-13-0: meta-package pulling in cublas / curand / + # cusolver / cusparse / cufft / nvrtc / + # nvjitlink dev headers that tokenspeed-kernel + # needs (cublas_v2.h, curand.h, cublasLt.h, ...) + sudo apt-get install -y --no-install-recommends \ + cuda-nvcc-13-0 \ + cuda-cudart-dev-13-0 \ + cuda-libraries-dev-13-0 + # apt installs under /usr/local/cuda-13.0; expose the /usr/local/cuda + # alias the job-level ``CUDA_HOME: /usr/local/cuda`` env expects. + if [ ! -d "${CUDA_HOME}/bin" ] && [ -d "/usr/local/cuda-13.0/bin" ]; then + sudo ln -sfn /usr/local/cuda-13.0 "${CUDA_HOME}" + fi + echo "nvcc installed: $(${CUDA_HOME}/bin/nvcc --version | tail -1)" +else + echo "nvcc already available: $(${CUDA_HOME}/bin/nvcc --version | tail -1)" +fi +export CUDA_HOME +export PATH="$CUDA_HOME/bin:$PATH" +export LD_LIBRARY_PATH="${CUDA_HOME}/lib64:${CUDA_HOME}/extras/CUPTI/lib64:${LD_LIBRARY_PATH:-}" +# Torch's JIT cpp_extension builder compiles some TokenSpeed runtime +# extensions (e.g. ``tokenspeed_hostfunc_ext``) with plain g++ and +# doesn't pass ``-I$CUDA_HOME/include``; expose the headers via CPATH / +# CPLUS_INCLUDE_PATH so the compile picks them up. +export CPATH="${CUDA_HOME}/include${CPATH:+:$CPATH}" +export CPLUS_INCLUDE_PATH="${CUDA_HOME}/include${CPLUS_INCLUDE_PATH:+:$CPLUS_INCLUDE_PATH}" + +# ── Clone TokenSpeed ──────────────────────────────────────────────────────── +# ``git clone --branch`` only accepts branch/tag names, not SHAs, so we +# init+fetch+checkout instead. Works for both SHAs and refs. +if [ ! -d "$TOKENSPEED_DIR" ]; then + echo "Cloning TokenSpeed ${TOKENSPEED_REF} from ${TOKENSPEED_REPO}..." + git init -q "$TOKENSPEED_DIR" + (cd "$TOKENSPEED_DIR" \ + && git remote add origin "$TOKENSPEED_REPO" \ + && git fetch --depth 1 origin "$TOKENSPEED_REF" \ + && git checkout FETCH_HEAD) +else + echo "TokenSpeed clone exists at $TOKENSPEED_DIR, reusing" + (cd "$TOKENSPEED_DIR" && git fetch --depth 1 origin "$TOKENSPEED_REF" && git checkout "$TOKENSPEED_REF") +fi + +cd "$TOKENSPEED_DIR" + +# ── System dependencies (mirrors docker/Dockerfile) ───────────────────────── +export DEBIAN_FRONTEND=noninteractive +sudo apt-get update -qq +sudo apt-get install -y --no-install-recommends libssl-dev libopenmpi-dev cmake + +# ── Kernel + scheduler + engine install ──────────────────────────────────── +# Step 1: plain Python requirements. +uv pip install -r tokenspeed-kernel/python/requirements/cuda.txt + +# Step 2: build-isolation=off so nanobind/cutlass build dependencies are shared. +uv pip install -r tokenspeed-kernel/python/requirements/cuda-thirdparty.txt \ + --no-build-isolation + +# Step 3: kernel (CUDA compile — the expensive one). Try the cached wheel first. +CACHED_KERNEL_WHEEL=$(find "$WHEEL_CACHE" -name "tokenspeed_kernel-*.whl" 2>/dev/null | head -1 || true) +if [ -n "$CACHED_KERNEL_WHEEL" ] && [ -f "$CACHED_KERNEL_WHEEL" ]; then + echo "Installing cached tokenspeed-kernel wheel: $CACHED_KERNEL_WHEEL" + uv pip install "$CACHED_KERNEL_WHEEL" --no-build-isolation +else + echo "Building tokenspeed-kernel from source (this takes ~30 min the first time)..." + MAX_JOBS="${MAX_JOBS:-16}" FLASHINFER_CUDA_ARCH_LIST="9.0a 10.0a" \ + uv pip install tokenspeed-kernel/python/ --no-build-isolation + # Cache the built wheel — uv stores wheels under its cache, copy out. + mkdir -p "$WHEEL_CACHE" + python3 -c "import tokenspeed_kernel, os, shutil, glob; \ + d = os.path.dirname(tokenspeed_kernel.__file__); \ + site = os.path.dirname(d); \ + whls = glob.glob(os.path.join(site, 'tokenspeed_kernel-*.dist-info')); \ + print('kernel install dir:', whls)" || true +fi + +# Step 4: scheduler (scikit-build-core + nanobind + CMake). +echo "Building tokenspeed-scheduler..." +uv pip install tokenspeed-scheduler/ + +# Step 5: the Python runtime (pure-Python). +uv pip install "./python" --no-build-isolation + +# ── Persist env to subsequent CI steps ───────────────────────────────────── +if [ -n "${GITHUB_ENV:-}" ]; then + echo "CUDA_HOME=$CUDA_HOME" >> "$GITHUB_ENV" + echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH" >> "$GITHUB_ENV" + # See note above: needed so torch's JIT C++ extension builder sees + # CUDA headers when it bypasses nvcc for .cpp sources. + echo "CPATH=$CPATH" >> "$GITHUB_ENV" + echo "CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH" >> "$GITHUB_ENV" +fi +if [ -n "${GITHUB_PATH:-}" ]; then + # Make ``nvcc`` discoverable to downstream steps (pytest spawns the + # worker which may trigger CUDA extension builds). + echo "$CUDA_HOME/bin" >> "$GITHUB_PATH" +fi + +# ── smg gRPC packages (same as other engines: from source so PR changes land) ─ +cd - > /dev/null +echo "Installing smg-grpc-proto and smg-grpc-servicer from source..." +uv pip install -e crates/grpc_client/python/ +uv pip install -e grpc_servicer/ + +# ── Verification ────────────────────────────────────────────────────────── +echo "=== TokenSpeed verification ===" +python3 -c "from tokenspeed.runtime.engine.async_llm import AsyncLLM; \ + print('AsyncLLM bases:', [b.__name__ for b in AsyncLLM.__bases__])" +python3 -c "from smg_grpc_servicer.tokenspeed.servicer import TokenSpeedSchedulerServicer; \ + print('gRPC servicer: importable')" + +echo "TokenSpeed installation complete" From 7b152c02e859ac1aec0d587de803b23b9005fa84 Mon Sep 17 00:00:00 2001 From: key4ng Date: Mon, 11 May 2026 20:28:05 -0700 Subject: [PATCH 15/24] ci(tokenspeed): drop private-repo auth now that tokenspeed is open-source MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit lightseekorg/tokenspeed is now public, so the clone no longer needs HTTPS basic auth and the fork-PR skip-on-missing-secret guard is no longer necessary. - scripts/ci_install_tokenspeed.sh: drop the ``TOKENSPEED_GITHUB_TOKEN`` rewrite block; the default ``https://github.com/...`` URL clones anonymously. - .github/actions/setup-tokenspeed/action.yml: drop the ``github-token`` input and the env-forwarding step. - .github/workflows/e2e-gpu-job.yml: drop the job-level ``if`` guard that skipped tokenspeed on forked PRs, and the ``with: github-token: ${{ secrets.TOKENSPEED_GITHUB_TOKEN }}`` plumbing on the ``Setup TokenSpeed backend`` step. TOKENSPEED_REF stays pinned to a tested SHA — bumping that is a separate decision. Signed-off-by: key4ng --- .github/actions/setup-tokenspeed/action.yml | 10 ---------- .github/workflows/e2e-gpu-job.yml | 14 -------------- scripts/ci_install_tokenspeed.sh | 7 ------- 3 files changed, 31 deletions(-) diff --git a/.github/actions/setup-tokenspeed/action.yml b/.github/actions/setup-tokenspeed/action.yml index 348bb45f7..547ea35fd 100644 --- a/.github/actions/setup-tokenspeed/action.yml +++ b/.github/actions/setup-tokenspeed/action.yml @@ -1,14 +1,6 @@ name: 'Setup TokenSpeed Backend' description: 'Create Python venv and install TokenSpeed (engine + kernel + scheduler) from source.' -inputs: - github-token: - description: >- - GitHub token with read access to the (private) lightseekorg/tokenspeed - repository. Forwarded to ``ci_install_tokenspeed.sh`` as - ``TOKENSPEED_GITHUB_TOKEN`` so the clone uses HTTPS basic auth. - required: true - runs: using: 'composite' steps: @@ -18,6 +10,4 @@ runs: - name: Install TokenSpeed shell: bash - env: - TOKENSPEED_GITHUB_TOKEN: ${{ inputs.github-token }} run: bash scripts/ci_install_tokenspeed.sh diff --git a/.github/workflows/e2e-gpu-job.yml b/.github/workflows/e2e-gpu-job.yml index d4080ae74..29b693899 100644 --- a/.github/workflows/e2e-gpu-job.yml +++ b/.github/workflows/e2e-gpu-job.yml @@ -42,16 +42,6 @@ on: jobs: run: - # TokenSpeed lanes need ``TOKENSPEED_GITHUB_TOKEN`` to clone the - # private lightseekorg/tokenspeed repo, but GitHub does not pass - # secrets to workflows triggered by forked-PR events. Skip the - # tokenspeed engine on fork PRs so the job is reported as ``skipped`` - # rather than failing inside ``setup-tokenspeed`` with a 404. Same- - # repo PRs and pushes still run normally. - if: >- - inputs.engine != 'tokenspeed' - || github.event_name != 'pull_request' - || github.event.pull_request.head.repo.full_name == github.repository runs-on: ${{ inputs.runner }} timeout-minutes: ${{ inputs.timeout }} permissions: @@ -81,10 +71,6 @@ jobs: - name: Setup TokenSpeed backend if: inputs.engine == 'tokenspeed' uses: ./.github/actions/setup-tokenspeed - with: - # lightseekorg/tokenspeed is private; ``secrets: inherit`` on the - # caller (pr-test-rust.yml) makes this secret available here. - github-token: ${{ secrets.TOKENSPEED_GITHUB_TOKEN }} # Artifact downloads - name: Download wheel artifact diff --git a/scripts/ci_install_tokenspeed.sh b/scripts/ci_install_tokenspeed.sh index e59494ea0..efa2eabe6 100755 --- a/scripts/ci_install_tokenspeed.sh +++ b/scripts/ci_install_tokenspeed.sh @@ -44,13 +44,6 @@ TOKENSPEED_REPO="${TOKENSPEED_REPO:-https://github.com/lightseekorg/tokenspeed.g TOKENSPEED_DIR="${TOKENSPEED_DIR:-/tmp/tokenspeed-src}" WHEEL_CACHE="${TOKENSPEED_WHEEL_CACHE:-/tmp/tokenspeed-wheel-cache}" -# lightseekorg/tokenspeed is private, so the clone needs HTTPS basic auth. -# CI passes the token via the ``setup-tokenspeed`` action's ``github-token`` -# input; locally you can export ``TOKENSPEED_GITHUB_TOKEN`` yourself. -if [ -n "${TOKENSPEED_GITHUB_TOKEN:-}" ]; then - TOKENSPEED_REPO="https://x-access-token:${TOKENSPEED_GITHUB_TOKEN}@${TOKENSPEED_REPO#https://}" -fi - # Install uv for faster package management (mirrors ci_install_sglang.sh). if ! command -v uv &> /dev/null; then echo "Installing uv..." From 34747d2459242e261ce2924c0af1391e6ec997db Mon Sep 17 00:00:00 2001 From: key4ng Date: Mon, 11 May 2026 20:45:09 -0700 Subject: [PATCH 16/24] ci(tokenspeed): run install inside the official build-env container MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Switches the TokenSpeed e2e lane to run inside the upstream-published ``lightseekorg/tokenspeed-runner:cu130-torch-2.11.0`` image instead of on the bare runner host. That image already ships CUDA 13.0, Torch 2.11.0, nanobind, cmake, libopenmpi-dev, and libssl-dev — everything the install script previously apt-installed at job time. Changes: - ``.github/workflows/e2e-gpu-job.yml``: add a job-level ``container:`` expression that evaluates to the runner image only when ``inputs.engine == 'tokenspeed'``, otherwise ``null`` (bare host, no change for the SGLang/vLLM/TRT-LLM/MLX lanes). Bind-mounts ``/models`` and ``/tmp/tokenspeed-wheel-cache`` so the existing model cache and kernel-wheel cache work unchanged inside the container. - ``scripts/ci_install_tokenspeed.sh``: drop the CUDA-toolkit apt-get block (~45 lines), the ``libssl-dev``/``libopenmpi-dev``/``cmake`` install, and the ``GITHUB_ENV``/``GITHUB_PATH`` persistence — all redundant now that the container baseline already exports ``CUDA_HOME``, ``LD_LIBRARY_PATH``, etc. Script shrinks from 189 → 105 lines. - Bump ``TOKENSPEED_REF`` to ``70030b29`` (latest lightseekorg/tokenspeed main, "Add KV cache events to scheduler"). The unavoidable ~30 min CUDA kernel compile still happens on first run because upstream doesn't publish pre-built engine wheels yet — the wheel cache at ``/tmp/tokenspeed-wheel-cache`` short-circuits it on every subsequent run. Signed-off-by: key4ng --- .github/workflows/e2e-gpu-job.yml | 5 ++ scripts/ci_install_tokenspeed.sh | 101 ++++-------------------------- 2 files changed, 17 insertions(+), 89 deletions(-) diff --git a/.github/workflows/e2e-gpu-job.yml b/.github/workflows/e2e-gpu-job.yml index 29b693899..9edb7f93a 100644 --- a/.github/workflows/e2e-gpu-job.yml +++ b/.github/workflows/e2e-gpu-job.yml @@ -46,6 +46,11 @@ jobs: timeout-minutes: ${{ inputs.timeout }} permissions: contents: read + # TokenSpeed lane runs inside the official build-environment image so + # the install script doesn't have to apt-install the CUDA toolkit or + # the build deps. The expression evaluates to `null` for other engines, + # which means "no container" (bare host, as before). + container: ${{ inputs.engine == 'tokenspeed' && fromJSON('{"image":"lightseekorg/tokenspeed-runner:cu130-torch-2.11.0","options":"--gpus all --shm-size 32g --ipc=host --volume /models:/models --volume /tmp/tokenspeed-wheel-cache:/tmp/tokenspeed-wheel-cache"}') || null }} env: E2E_ENGINE: ${{ inputs.engine }} E2E_RUNTIME: ${{ inputs.engine }} diff --git a/scripts/ci_install_tokenspeed.sh b/scripts/ci_install_tokenspeed.sh index efa2eabe6..c3bd2410e 100755 --- a/scripts/ci_install_tokenspeed.sh +++ b/scripts/ci_install_tokenspeed.sh @@ -5,14 +5,15 @@ # in-tree ``tokenspeed-kernel`` (CUDA), ``tokenspeed-scheduler`` (C++/nanobind), # and ``python/`` packages. Mirrors the upstream ``docker/Dockerfile`` pipeline. # -# Prerequisites (expected on k8s-runner-gpu nodes): -# - NVIDIA driver 580+ (CUDA 13) -# - CUDA 13.0 toolkit at /usr/local/cuda-13.0 or /usr/local/cuda -# - H100 GPUs (sm90) +# Designed to run inside the official +# ``lightseekorg/tokenspeed-runner:cu130-torch-2.11.0`` container (set by +# e2e-gpu-job.yml). The image ships CUDA 13.0, Torch 2.11.0, nanobind, +# cmake, libopenmpi-dev and libssl-dev, so this script skips the +# host-side toolkit install entirely. # # Heavy first run (~30 min for kernel CUDA compile); subsequent runs on the -# same runner hit the pip wheel cache at /tmp/tokenspeed-wheel-cache/ and -# short-circuit the kernel build. +# same runner hit the pip wheel cache at /tmp/tokenspeed-wheel-cache/ (host +# volume-mounted into the container) and short-circuit the kernel build. set -euo pipefail @@ -21,25 +22,11 @@ if [ -f ".venv/bin/activate" ]; then source .venv/bin/activate fi -# Pin to a tested TokenSpeed SHA so CI is reproducible. Floating against -# ``main`` has bitten us before (lightseekorg/tokenspeed renamed server_args, -# the gRPC servicer broke until we caught up). Bump this explicitly when we -# want a newer runtime, ideally via a scheduled bump-and-CI routine rather -# than ad hoc. -# -# This SHA is from lightseekorg/tokenspeed main; it includes dense -# ``LlamaForCausalLM`` registration, the Qwen3 / gpt-oss arches the e2e -# suite (``test_function_calling``, ``test_openai_server`` etc.) runs -# against, lightseekorg/tokenspeed#598 (defensive ``pad_token_id`` read -# in ``Qwen3MoeModel.__init__`` so ``Qwen/Qwen3-30B-A3B`` loads), -# lightseekorg/tokenspeed#578 (FSM absorbs late ``ExtendResultEvent`` -# after a request is terminalized — without this the scheduler crashes -# under retract pressure on the nightly Qwen3-30B-A3B run), and -# lightseekorg/tokenspeed#602 (release scheduler slot + cancel -# non-stream handlers on client disconnect; eliminates the long stream -# of ``state was deleted in AsyncLLM`` warnings that preceded the -# crash). -TOKENSPEED_REF="${TOKENSPEED_REF:-eabeb106a070825d5549fbc84ecd8e11651cf3fe}" +# Pinned SHA from lightseekorg/tokenspeed main. Bump explicitly (ideally via +# a scheduled bump-and-CI routine) rather than floating against ``main`` — +# upstream has renamed APIs before and the gRPC servicer broke until we +# caught up. +TOKENSPEED_REF="${TOKENSPEED_REF:-70030b298bc6abf6903348057605cc083bf70746}" TOKENSPEED_REPO="${TOKENSPEED_REPO:-https://github.com/lightseekorg/tokenspeed.git}" TOKENSPEED_DIR="${TOKENSPEED_DIR:-/tmp/tokenspeed-src}" WHEEL_CACHE="${TOKENSPEED_WHEEL_CACHE:-/tmp/tokenspeed-wheel-cache}" @@ -52,50 +39,6 @@ if ! command -v uv &> /dev/null; then fi echo "uv version: $(uv --version)" -# ── CUDA runtime setup ───────────────────────────────────────────────────── -# k8s-runner-gpu ships the NVIDIA driver + CUDA runtime libs but not the -# SDK (nvcc, headers). Install them on demand — same approach as -# ``ci_install_sglang.sh``, which installs cuda-nvcc-12-9 + -# cuda-cudart-dev-12-9 when ``/usr/local/cuda/bin/nvcc`` is missing. -# TokenSpeed's Dockerfile targets CUDA 13.0, so install the matching -# toolkit packages here. -CUDA_HOME="${CUDA_HOME:-/usr/local/cuda}" -if [ ! -x "${CUDA_HOME}/bin/nvcc" ]; then - echo "Installing CUDA toolkit (nvcc not found at ${CUDA_HOME}/bin/nvcc)..." - curl -fsSL -o /tmp/cuda-keyring.deb \ - https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb - sudo dpkg -i /tmp/cuda-keyring.deb - rm /tmp/cuda-keyring.deb - sudo apt-get update -qq - # cuda-nvcc-13-0: provides nvcc + cuda_runtime_api.h - # cuda-cudart-dev-13-0: provides cuda_runtime.h + libcudart headers - # cuda-libraries-dev-13-0: meta-package pulling in cublas / curand / - # cusolver / cusparse / cufft / nvrtc / - # nvjitlink dev headers that tokenspeed-kernel - # needs (cublas_v2.h, curand.h, cublasLt.h, ...) - sudo apt-get install -y --no-install-recommends \ - cuda-nvcc-13-0 \ - cuda-cudart-dev-13-0 \ - cuda-libraries-dev-13-0 - # apt installs under /usr/local/cuda-13.0; expose the /usr/local/cuda - # alias the job-level ``CUDA_HOME: /usr/local/cuda`` env expects. - if [ ! -d "${CUDA_HOME}/bin" ] && [ -d "/usr/local/cuda-13.0/bin" ]; then - sudo ln -sfn /usr/local/cuda-13.0 "${CUDA_HOME}" - fi - echo "nvcc installed: $(${CUDA_HOME}/bin/nvcc --version | tail -1)" -else - echo "nvcc already available: $(${CUDA_HOME}/bin/nvcc --version | tail -1)" -fi -export CUDA_HOME -export PATH="$CUDA_HOME/bin:$PATH" -export LD_LIBRARY_PATH="${CUDA_HOME}/lib64:${CUDA_HOME}/extras/CUPTI/lib64:${LD_LIBRARY_PATH:-}" -# Torch's JIT cpp_extension builder compiles some TokenSpeed runtime -# extensions (e.g. ``tokenspeed_hostfunc_ext``) with plain g++ and -# doesn't pass ``-I$CUDA_HOME/include``; expose the headers via CPATH / -# CPLUS_INCLUDE_PATH so the compile picks them up. -export CPATH="${CUDA_HOME}/include${CPATH:+:$CPATH}" -export CPLUS_INCLUDE_PATH="${CUDA_HOME}/include${CPLUS_INCLUDE_PATH:+:$CPLUS_INCLUDE_PATH}" - # ── Clone TokenSpeed ──────────────────────────────────────────────────────── # ``git clone --branch`` only accepts branch/tag names, not SHAs, so we # init+fetch+checkout instead. Works for both SHAs and refs. @@ -113,11 +56,6 @@ fi cd "$TOKENSPEED_DIR" -# ── System dependencies (mirrors docker/Dockerfile) ───────────────────────── -export DEBIAN_FRONTEND=noninteractive -sudo apt-get update -qq -sudo apt-get install -y --no-install-recommends libssl-dev libopenmpi-dev cmake - # ── Kernel + scheduler + engine install ──────────────────────────────────── # Step 1: plain Python requirements. uv pip install -r tokenspeed-kernel/python/requirements/cuda.txt @@ -151,21 +89,6 @@ uv pip install tokenspeed-scheduler/ # Step 5: the Python runtime (pure-Python). uv pip install "./python" --no-build-isolation -# ── Persist env to subsequent CI steps ───────────────────────────────────── -if [ -n "${GITHUB_ENV:-}" ]; then - echo "CUDA_HOME=$CUDA_HOME" >> "$GITHUB_ENV" - echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH" >> "$GITHUB_ENV" - # See note above: needed so torch's JIT C++ extension builder sees - # CUDA headers when it bypasses nvcc for .cpp sources. - echo "CPATH=$CPATH" >> "$GITHUB_ENV" - echo "CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH" >> "$GITHUB_ENV" -fi -if [ -n "${GITHUB_PATH:-}" ]; then - # Make ``nvcc`` discoverable to downstream steps (pytest spawns the - # worker which may trigger CUDA extension builds). - echo "$CUDA_HOME/bin" >> "$GITHUB_PATH" -fi - # ── smg gRPC packages (same as other engines: from source so PR changes land) ─ cd - > /dev/null echo "Installing smg-grpc-proto and smg-grpc-servicer from source..." From 7279401f272d8592dae838aa1df62b30dc2483cc Mon Sep 17 00:00:00 2001 From: key4ng Date: Mon, 11 May 2026 20:52:07 -0700 Subject: [PATCH 17/24] ci(tokenspeed): revert job-level container; k8s runner can't use docker MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous commit (388cfc19) moved the TokenSpeed lane inside ``lightseekorg/tokenspeed-runner:cu130-torch-2.11.0`` via a job-level ``container:`` directive. On the k8s-runner-gpu pods this fails: the runner image pull hits ``unexpected EOF`` mid-download (likely ephemeral-storage pressure on the pod for a large CUDA image) and the retry can't reach the docker daemon at all (it's not stably available inside the pod). Roll back the workflow ``container:`` block and restore the host-side toolkit install in ``ci_install_tokenspeed.sh`` (CUDA-13 apt-get, ``libssl-dev``/``libopenmpi-dev``/``cmake``, ``GITHUB_ENV``/``GITHUB_PATH`` persistence). Keep the bumped ``TOKENSPEED_REF`` — that part is independent of the container experiment. The "use the upstream runner image" simplification is still the right direction in principle; revisiting after the runner infra grows stable container support, or after upstream publishes engine wheels (which would remove the source-build step entirely). Signed-off-by: key4ng --- .github/workflows/e2e-gpu-job.yml | 5 -- scripts/ci_install_tokenspeed.sh | 77 ++++++++++++++++++++++++++++--- 2 files changed, 70 insertions(+), 12 deletions(-) diff --git a/.github/workflows/e2e-gpu-job.yml b/.github/workflows/e2e-gpu-job.yml index 9edb7f93a..29b693899 100644 --- a/.github/workflows/e2e-gpu-job.yml +++ b/.github/workflows/e2e-gpu-job.yml @@ -46,11 +46,6 @@ jobs: timeout-minutes: ${{ inputs.timeout }} permissions: contents: read - # TokenSpeed lane runs inside the official build-environment image so - # the install script doesn't have to apt-install the CUDA toolkit or - # the build deps. The expression evaluates to `null` for other engines, - # which means "no container" (bare host, as before). - container: ${{ inputs.engine == 'tokenspeed' && fromJSON('{"image":"lightseekorg/tokenspeed-runner:cu130-torch-2.11.0","options":"--gpus all --shm-size 32g --ipc=host --volume /models:/models --volume /tmp/tokenspeed-wheel-cache:/tmp/tokenspeed-wheel-cache"}') || null }} env: E2E_ENGINE: ${{ inputs.engine }} E2E_RUNTIME: ${{ inputs.engine }} diff --git a/scripts/ci_install_tokenspeed.sh b/scripts/ci_install_tokenspeed.sh index c3bd2410e..094560b62 100755 --- a/scripts/ci_install_tokenspeed.sh +++ b/scripts/ci_install_tokenspeed.sh @@ -5,15 +5,14 @@ # in-tree ``tokenspeed-kernel`` (CUDA), ``tokenspeed-scheduler`` (C++/nanobind), # and ``python/`` packages. Mirrors the upstream ``docker/Dockerfile`` pipeline. # -# Designed to run inside the official -# ``lightseekorg/tokenspeed-runner:cu130-torch-2.11.0`` container (set by -# e2e-gpu-job.yml). The image ships CUDA 13.0, Torch 2.11.0, nanobind, -# cmake, libopenmpi-dev and libssl-dev, so this script skips the -# host-side toolkit install entirely. +# Prerequisites (expected on k8s-runner-gpu nodes): +# - NVIDIA driver 580+ (CUDA 13) +# - CUDA 13.0 toolkit at /usr/local/cuda-13.0 or /usr/local/cuda +# - H100 GPUs (sm90) # # Heavy first run (~30 min for kernel CUDA compile); subsequent runs on the -# same runner hit the pip wheel cache at /tmp/tokenspeed-wheel-cache/ (host -# volume-mounted into the container) and short-circuit the kernel build. +# same runner hit the pip wheel cache at /tmp/tokenspeed-wheel-cache/ and +# short-circuit the kernel build. set -euo pipefail @@ -39,6 +38,50 @@ if ! command -v uv &> /dev/null; then fi echo "uv version: $(uv --version)" +# ── CUDA runtime setup ───────────────────────────────────────────────────── +# k8s-runner-gpu ships the NVIDIA driver + CUDA runtime libs but not the +# SDK (nvcc, headers). Install them on demand — same approach as +# ``ci_install_sglang.sh``, which installs cuda-nvcc-12-9 + +# cuda-cudart-dev-12-9 when ``/usr/local/cuda/bin/nvcc`` is missing. +# TokenSpeed's Dockerfile targets CUDA 13.0, so install the matching +# toolkit packages here. +CUDA_HOME="${CUDA_HOME:-/usr/local/cuda}" +if [ ! -x "${CUDA_HOME}/bin/nvcc" ]; then + echo "Installing CUDA toolkit (nvcc not found at ${CUDA_HOME}/bin/nvcc)..." + curl -fsSL -o /tmp/cuda-keyring.deb \ + https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb + sudo dpkg -i /tmp/cuda-keyring.deb + rm /tmp/cuda-keyring.deb + sudo apt-get update -qq + # cuda-nvcc-13-0: provides nvcc + cuda_runtime_api.h + # cuda-cudart-dev-13-0: provides cuda_runtime.h + libcudart headers + # cuda-libraries-dev-13-0: meta-package pulling in cublas / curand / + # cusolver / cusparse / cufft / nvrtc / + # nvjitlink dev headers that tokenspeed-kernel + # needs (cublas_v2.h, curand.h, cublasLt.h, ...) + sudo apt-get install -y --no-install-recommends \ + cuda-nvcc-13-0 \ + cuda-cudart-dev-13-0 \ + cuda-libraries-dev-13-0 + # apt installs under /usr/local/cuda-13.0; expose the /usr/local/cuda + # alias the job-level ``CUDA_HOME: /usr/local/cuda`` env expects. + if [ ! -d "${CUDA_HOME}/bin" ] && [ -d "/usr/local/cuda-13.0/bin" ]; then + sudo ln -sfn /usr/local/cuda-13.0 "${CUDA_HOME}" + fi + echo "nvcc installed: $(${CUDA_HOME}/bin/nvcc --version | tail -1)" +else + echo "nvcc already available: $(${CUDA_HOME}/bin/nvcc --version | tail -1)" +fi +export CUDA_HOME +export PATH="$CUDA_HOME/bin:$PATH" +export LD_LIBRARY_PATH="${CUDA_HOME}/lib64:${CUDA_HOME}/extras/CUPTI/lib64:${LD_LIBRARY_PATH:-}" +# Torch's JIT cpp_extension builder compiles some TokenSpeed runtime +# extensions (e.g. ``tokenspeed_hostfunc_ext``) with plain g++ and +# doesn't pass ``-I$CUDA_HOME/include``; expose the headers via CPATH / +# CPLUS_INCLUDE_PATH so the compile picks them up. +export CPATH="${CUDA_HOME}/include${CPATH:+:$CPATH}" +export CPLUS_INCLUDE_PATH="${CUDA_HOME}/include${CPLUS_INCLUDE_PATH:+:$CPLUS_INCLUDE_PATH}" + # ── Clone TokenSpeed ──────────────────────────────────────────────────────── # ``git clone --branch`` only accepts branch/tag names, not SHAs, so we # init+fetch+checkout instead. Works for both SHAs and refs. @@ -56,6 +99,11 @@ fi cd "$TOKENSPEED_DIR" +# ── System dependencies (mirrors docker/Dockerfile) ───────────────────────── +export DEBIAN_FRONTEND=noninteractive +sudo apt-get update -qq +sudo apt-get install -y --no-install-recommends libssl-dev libopenmpi-dev cmake + # ── Kernel + scheduler + engine install ──────────────────────────────────── # Step 1: plain Python requirements. uv pip install -r tokenspeed-kernel/python/requirements/cuda.txt @@ -89,6 +137,21 @@ uv pip install tokenspeed-scheduler/ # Step 5: the Python runtime (pure-Python). uv pip install "./python" --no-build-isolation +# ── Persist env to subsequent CI steps ───────────────────────────────────── +if [ -n "${GITHUB_ENV:-}" ]; then + echo "CUDA_HOME=$CUDA_HOME" >> "$GITHUB_ENV" + echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH" >> "$GITHUB_ENV" + # See note above: needed so torch's JIT C++ extension builder sees + # CUDA headers when it bypasses nvcc for .cpp sources. + echo "CPATH=$CPATH" >> "$GITHUB_ENV" + echo "CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH" >> "$GITHUB_ENV" +fi +if [ -n "${GITHUB_PATH:-}" ]; then + # Make ``nvcc`` discoverable to downstream steps (pytest spawns the + # worker which may trigger CUDA extension builds). + echo "$CUDA_HOME/bin" >> "$GITHUB_PATH" +fi + # ── smg gRPC packages (same as other engines: from source so PR changes land) ─ cd - > /dev/null echo "Installing smg-grpc-proto and smg-grpc-servicer from source..." From 7e948c61571b717cafeb58af4c6f8cd97ac23d47 Mon Sep 17 00:00:00 2001 From: key4ng Date: Mon, 11 May 2026 21:06:59 -0700 Subject: [PATCH 18/24] fix(e2e): pass --model to tokenspeed worker (upstream renamed --model-path) Upstream lightseekorg/tokenspeed (current main) renamed the ``--model-path`` argparse flag to ``--model``. The old name remains only as a positional alias on the parser; passing ``--model-path`` fails the worker boot with: __main__.py: error: unrecognized arguments: --model-path Switch ``_build_tokenspeed_grpc_cmd`` to the new flag form. The servicer's ``server_args.model_path`` attribute is still populated because upstream's ``prepare_server_args`` resolves both the positional ``model_path`` and ``--model`` sources into the same attribute. Only the tokenspeed lane is affected; sglang / vllm / trtllm / mlx each have their own argparse and continue to accept ``--model-path``. Signed-off-by: key4ng --- e2e_test/infra/worker.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/e2e_test/infra/worker.py b/e2e_test/infra/worker.py index 638813f11..750581714 100644 --- a/e2e_test/infra/worker.py +++ b/e2e_test/infra/worker.py @@ -281,7 +281,9 @@ def _build_tokenspeed_grpc_cmd(self, model_path: str, tp_size: int, spec: dict) "python3", "-m", "smg_grpc_servicer.tokenspeed", - "--model-path", + # Upstream renamed ``--model-path`` to ``--model`` (with the old + # name kept only as a positional alias). Use the new flag form. + "--model", model_path, "--host", DEFAULT_HOST, From 24d2f14f28f7dff5cf09b588f3c2518db1f4474c Mon Sep 17 00:00:00 2001 From: key4ng Date: Mon, 11 May 2026 21:42:13 -0700 Subject: [PATCH 19/24] test(e2e): swap TestEnableThinking model to Qwen3.5-27B MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The bumped ``TOKENSPEED_REF`` dropped ``Qwen3MoeForCausalLM`` from its model registry, so the existing ``Qwen/Qwen3-30B-A3B`` worker fails to start on the tokenspeed lane with: ValueError: Model architectures ['Qwen3MoeForCausalLM'] are not supported for now. Switch ``TestEnableThinking`` to ``Qwen/Qwen3.5-27B``. The model: - is ``Qwen3_5ForConditionalGeneration`` — in the current tokenspeed registry (the family upstream is investing in); - has the ``enable_thinking`` chat-template toggle the test exercises; - emits the same ``...`` reasoning markers, so the existing ``qwen3`` SMG reasoning parser handles it without changes; - weighs ~57 GB (54 GB text + ~3 GB vision encoder), which fits on the 1×H100 80GB lane with ~20 GB headroom for KV cache and activations. No SMG-side parser changes needed for this test. For future Qwen3.5 function-calling tests, the SMG tool-parser factory already auto-maps ``Qwen/Qwen3.5*`` → ``qwen_xml`` at ``crates/tool_parser/src/factory.rs:353``. Changes: - ``e2e_test/infra/model_specs.py``: add a ``Qwen/Qwen3.5-27B`` spec entry (``tp: 1``, thinking + reasoning features) and point ``DEFAULT_ENABLE_THINKING_MODEL_PATH`` at it. The existing ``Qwen/Qwen3-30B-A3B`` entry stays for the nightly perf benchmark. - ``e2e_test/chat_completions/test_enable_thinking.py``: update the ``@pytest.mark.model`` to ``Qwen/Qwen3.5-27B``. Signed-off-by: key4ng --- e2e_test/chat_completions/test_enable_thinking.py | 4 ++-- e2e_test/infra/model_specs.py | 11 ++++++++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/e2e_test/chat_completions/test_enable_thinking.py b/e2e_test/chat_completions/test_enable_thinking.py index aebe48db7..85bcb29e0 100644 --- a/e2e_test/chat_completions/test_enable_thinking.py +++ b/e2e_test/chat_completions/test_enable_thinking.py @@ -18,13 +18,13 @@ # ============================================================================= -# Enable Thinking Tests (Qwen 30B) +# Enable Thinking Tests (Qwen3) # ============================================================================= @pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(1) -@pytest.mark.model("Qwen/Qwen3-30B-A3B") +@pytest.mark.model("Qwen/Qwen3.5-27B") @pytest.mark.gateway(extra_args=["--reasoning-parser", "qwen3", "--history-backend", "memory"]) @pytest.mark.parametrize("setup_backend", ["grpc"], indirect=True) @pytest.mark.parametrize("api_client", ["openai", "smg"], indirect=True) diff --git a/e2e_test/infra/model_specs.py b/e2e_test/infra/model_specs.py index fdf6a750f..c02df807c 100644 --- a/e2e_test/infra/model_specs.py +++ b/e2e_test/infra/model_specs.py @@ -72,6 +72,15 @@ def _resolve_model_path(hf_path: str) -> str: "features": ["chat", "streaming", "function_calling", "tool_choice"], }, # Thinking/reasoning model (larger) + # Dense Qwen3.5 with the ``enable_thinking`` chat-template toggle. Used + # by ``TestEnableThinking``. ``Qwen3_5ForConditionalGeneration`` + # architecture, supported by tokenspeed's current registry where the + # older ``Qwen3-30B-A3B`` (``Qwen3MoeForCausalLM``) is not. + "Qwen/Qwen3.5-27B": { + "model": _resolve_model_path("Qwen/Qwen3.5-27B"), + "tp": 1, + "features": ["chat", "streaming", "thinking", "reasoning"], + }, "Qwen/Qwen3-30B-A3B": { "model": _resolve_model_path("Qwen/Qwen3-30B-A3B"), "tp": 1, @@ -246,7 +255,7 @@ def get_model_spec(model_id: str) -> dict: DEFAULT_MODEL_PATH = MODEL_SPECS["meta-llama/Llama-3.1-8B-Instruct"]["model"] DEFAULT_SMALL_MODEL_PATH = MODEL_SPECS["meta-llama/Llama-3.2-1B-Instruct"]["model"] DEFAULT_REASONING_MODEL_PATH = MODEL_SPECS["deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"]["model"] -DEFAULT_ENABLE_THINKING_MODEL_PATH = MODEL_SPECS["Qwen/Qwen3-30B-A3B"]["model"] +DEFAULT_ENABLE_THINKING_MODEL_PATH = MODEL_SPECS["Qwen/Qwen3.5-27B"]["model"] DEFAULT_QWEN_FUNCTION_CALLING_MODEL_PATH = MODEL_SPECS["Qwen/Qwen2.5-7B-Instruct"]["model"] DEFAULT_MISTRAL_FUNCTION_CALLING_MODEL_PATH = MODEL_SPECS["mistralai/Mistral-7B-Instruct-v0.3"][ "model" From 4f58c5c12e17506136e8fd1d876c3ab496020f5b Mon Sep 17 00:00:00 2001 From: key4ng Date: Mon, 11 May 2026 22:27:13 -0700 Subject: [PATCH 20/24] =?UTF-8?q?fix(e2e):=20cap=20Qwen3.5-27B=20context?= =?UTF-8?q?=20at=2016K=20for=20the=201=C3=97H100=20KV-cache=20budget?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit After model load (~57 GB on 1×H100 80GB) the remaining ~16 GB of GPU memory can't pre-allocate KV cache for the model's 256K native context. vLLM fails the worker boot with: ValueError: To serve at least one request with the models's max seq len (262144), (16.17 GiB KV cache is needed, which is larger than the available KV cache memory (16.0 GiB). The TestEnableThinking test sends short single-turn chats ("Hello"), so cap the context at 16K — same value the ``Qwen/Qwen2.5-14B-Instruct`` spec uses in this file. Apply across all engines that pre-allocate KV cache by max-seq-len (sglang, vllm, tokenspeed); TRT-LLM allocates dynamically and keeps the existing ``free_gpu_memory_fraction: 0.8`` knob. Also pass ``--enforce-eager`` on the engines that respect it — the hybrid Mamba (Gated DeltaNet) + attention layout makes CUDA-graph capture finicky and is unnecessary for a smoke test. Signed-off-by: key4ng --- e2e_test/infra/model_specs.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/e2e_test/infra/model_specs.py b/e2e_test/infra/model_specs.py index c02df807c..272fecd50 100644 --- a/e2e_test/infra/model_specs.py +++ b/e2e_test/infra/model_specs.py @@ -75,11 +75,20 @@ def _resolve_model_path(hf_path: str) -> str: # Dense Qwen3.5 with the ``enable_thinking`` chat-template toggle. Used # by ``TestEnableThinking``. ``Qwen3_5ForConditionalGeneration`` # architecture, supported by tokenspeed's current registry where the - # older ``Qwen3-30B-A3B`` (``Qwen3MoeForCausalLM``) is not. + # older ``Qwen3-30B-A3B`` (``Qwen3MoeForCausalLM``) is not. The 256K + # native context overflows the KV-cache budget on 1×H100 80GB after the + # ~57 GB model load, so cap engines at 16K (matches the ``Qwen2.5-14B`` + # spec) — the test sends short single-turn chats. Hybrid Mamba + + # attention also makes CUDA-graph capture finicky, so force eager mode + # on the engines that respect it. "Qwen/Qwen3.5-27B": { "model": _resolve_model_path("Qwen/Qwen3.5-27B"), "tp": 1, "features": ["chat", "streaming", "thinking", "reasoning"], + "sglang_args": ["--context-length=16384"], + "vllm_args": ["--max-model-len=16384", "--enforce-eager"], + "tokenspeed_args": ["--max-model-len=16384", "--enforce-eager"], + "trtllm_extra_config": {"kv_cache_config": {"free_gpu_memory_fraction": 0.8}}, }, "Qwen/Qwen3-30B-A3B": { "model": _resolve_model_path("Qwen/Qwen3-30B-A3B"), From 779445a969134c108b4997ab310e969d77dbc9b2 Mon Sep 17 00:00:00 2001 From: key4ng Date: Mon, 11 May 2026 23:05:59 -0700 Subject: [PATCH 21/24] test(e2e): retarget TestEnableThinking at Qwen/Qwen3-4B MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Dense Qwen3 (``Qwen3ForCausalLM``) is in the bumped tokenspeed pin's model registry where ``Qwen3MoeForCausalLM`` (``Qwen3-30B-A3B``) and ``Qwen3_5ForConditionalGeneration`` (``Qwen3.5-27B``, which also overflowed vLLM's CUDA-13 ``nixl_ep`` dependency) are not. Hybrid Qwen3 keeps the ``enable_thinking`` chat-template toggle the test needs and emits the same ``...`` markers the existing ``qwen3`` SMG reasoning parser handles. 4B fits trivially on the 1×H100 lane (~8 GB weights), so the engine-specific KV cap and ``--enforce-eager`` args the Qwen3.5-27B spec needed are no longer required. Changes: - ``e2e_test/infra/model_specs.py``: replace the ``Qwen/Qwen3.5-27B`` entry with ``Qwen/Qwen3-4B``; point ``DEFAULT_ENABLE_THINKING_MODEL_PATH`` at it. Keep the ``Qwen/Qwen3-30B-A3B`` entry for the nightly perf benchmark. - ``e2e_test/chat_completions/test_enable_thinking.py``: update ``@pytest.mark.model`` to ``Qwen/Qwen3-4B``. Signed-off-by: key4ng --- .../chat_completions/test_enable_thinking.py | 2 +- e2e_test/infra/model_specs.py | 25 ++++++------------- 2 files changed, 9 insertions(+), 18 deletions(-) diff --git a/e2e_test/chat_completions/test_enable_thinking.py b/e2e_test/chat_completions/test_enable_thinking.py index 85bcb29e0..a896bd54a 100644 --- a/e2e_test/chat_completions/test_enable_thinking.py +++ b/e2e_test/chat_completions/test_enable_thinking.py @@ -24,7 +24,7 @@ @pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(1) -@pytest.mark.model("Qwen/Qwen3.5-27B") +@pytest.mark.model("Qwen/Qwen3-4B") @pytest.mark.gateway(extra_args=["--reasoning-parser", "qwen3", "--history-backend", "memory"]) @pytest.mark.parametrize("setup_backend", ["grpc"], indirect=True) @pytest.mark.parametrize("api_client", ["openai", "smg"], indirect=True) diff --git a/e2e_test/infra/model_specs.py b/e2e_test/infra/model_specs.py index 272fecd50..90d33fd56 100644 --- a/e2e_test/infra/model_specs.py +++ b/e2e_test/infra/model_specs.py @@ -71,24 +71,15 @@ def _resolve_model_path(hf_path: str) -> str: "tp": 1, "features": ["chat", "streaming", "function_calling", "tool_choice"], }, - # Thinking/reasoning model (larger) - # Dense Qwen3.5 with the ``enable_thinking`` chat-template toggle. Used - # by ``TestEnableThinking``. ``Qwen3_5ForConditionalGeneration`` - # architecture, supported by tokenspeed's current registry where the - # older ``Qwen3-30B-A3B`` (``Qwen3MoeForCausalLM``) is not. The 256K - # native context overflows the KV-cache budget on 1×H100 80GB after the - # ~57 GB model load, so cap engines at 16K (matches the ``Qwen2.5-14B`` - # spec) — the test sends short single-turn chats. Hybrid Mamba + - # attention also makes CUDA-graph capture finicky, so force eager mode - # on the engines that respect it. - "Qwen/Qwen3.5-27B": { - "model": _resolve_model_path("Qwen/Qwen3.5-27B"), + # Hybrid Qwen3 with the ``enable_thinking`` chat-template toggle. Used + # by ``TestEnableThinking``. Dense (``Qwen3ForCausalLM``), so it lands + # on tokenspeed's current model registry where the larger + # ``Qwen3-30B-A3B`` (``Qwen3MoeForCausalLM``) does not. Uses the + # existing ``qwen3`` reasoning parser. + "Qwen/Qwen3-4B": { + "model": _resolve_model_path("Qwen/Qwen3-4B"), "tp": 1, "features": ["chat", "streaming", "thinking", "reasoning"], - "sglang_args": ["--context-length=16384"], - "vllm_args": ["--max-model-len=16384", "--enforce-eager"], - "tokenspeed_args": ["--max-model-len=16384", "--enforce-eager"], - "trtllm_extra_config": {"kv_cache_config": {"free_gpu_memory_fraction": 0.8}}, }, "Qwen/Qwen3-30B-A3B": { "model": _resolve_model_path("Qwen/Qwen3-30B-A3B"), @@ -264,7 +255,7 @@ def get_model_spec(model_id: str) -> dict: DEFAULT_MODEL_PATH = MODEL_SPECS["meta-llama/Llama-3.1-8B-Instruct"]["model"] DEFAULT_SMALL_MODEL_PATH = MODEL_SPECS["meta-llama/Llama-3.2-1B-Instruct"]["model"] DEFAULT_REASONING_MODEL_PATH = MODEL_SPECS["deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"]["model"] -DEFAULT_ENABLE_THINKING_MODEL_PATH = MODEL_SPECS["Qwen/Qwen3.5-27B"]["model"] +DEFAULT_ENABLE_THINKING_MODEL_PATH = MODEL_SPECS["Qwen/Qwen3-4B"]["model"] DEFAULT_QWEN_FUNCTION_CALLING_MODEL_PATH = MODEL_SPECS["Qwen/Qwen2.5-7B-Instruct"]["model"] DEFAULT_MISTRAL_FUNCTION_CALLING_MODEL_PATH = MODEL_SPECS["mistralai/Mistral-7B-Instruct-v0.3"][ "model" From 5ab36713b84ef1c49febf3ee84eca705c63bc063 Mon Sep 17 00:00:00 2001 From: Jue Wang Date: Tue, 12 May 2026 23:27:40 +0000 Subject: [PATCH 22/24] add none reasoning parser --- crates/reasoning_parser/src/factory.rs | 7 +- crates/reasoning_parser/src/lib.rs | 2 +- crates/reasoning_parser/src/parsers/mod.rs | 2 + crates/reasoning_parser/src/parsers/none.rs | 123 ++++++++++++++++++++ 4 files changed, 132 insertions(+), 2 deletions(-) create mode 100644 crates/reasoning_parser/src/parsers/none.rs diff --git a/crates/reasoning_parser/src/factory.rs b/crates/reasoning_parser/src/factory.rs index eeda6cbd1..4eb9e2c18 100644 --- a/crates/reasoning_parser/src/factory.rs +++ b/crates/reasoning_parser/src/factory.rs @@ -9,7 +9,7 @@ use tokio::sync::Mutex; use crate::{ parsers::{ BaseReasoningParser, CohereCmdParser, DeepSeekR1Parser, Glm45Parser, KimiParser, - MiniMaxParser, NanoV3Parser, Qwen3Parser, QwenThinkingParser, Step3Parser, + MiniMaxParser, NanoV3Parser, NoneParser, Qwen3Parser, QwenThinkingParser, Step3Parser, }, traits::{ParserConfig, ReasoningParser, DEFAULT_MAX_BUFFER_SIZE}, }; @@ -176,6 +176,11 @@ impl ParserFactory { Box::new(BaseReasoningParser::new(ParserConfig::default())) }); + // Register no-op parser: returns all text as normal content, + // never produces reasoning_content. Selectable via + // `--reasoning-parser none`. + registry.register_parser("none", || Box::new(NoneParser::new())); + // Register DeepSeek-R1 parser (starts with in_reasoning=true) registry.register_parser("deepseek_r1", || Box::new(DeepSeekR1Parser::new())); diff --git a/crates/reasoning_parser/src/lib.rs b/crates/reasoning_parser/src/lib.rs index 83d49fbd4..97d51d998 100644 --- a/crates/reasoning_parser/src/lib.rs +++ b/crates/reasoning_parser/src/lib.rs @@ -5,7 +5,7 @@ pub mod traits; pub use factory::{ParserFactory, ParserRegistry, PooledParser}; pub use parsers::{ BaseReasoningParser, CohereCmdParser, DeepSeekR1Parser, Glm45Parser, KimiParser, MiniMaxParser, - NanoV3Parser, Qwen3Parser, QwenThinkingParser, Step3Parser, + NanoV3Parser, NoneParser, Qwen3Parser, QwenThinkingParser, Step3Parser, }; pub use traits::{ ParseError, ParserConfig, ParserResult, ReasoningParser, DEFAULT_MAX_BUFFER_SIZE, diff --git a/crates/reasoning_parser/src/parsers/mod.rs b/crates/reasoning_parser/src/parsers/mod.rs index 81757d04c..baa7b795f 100644 --- a/crates/reasoning_parser/src/parsers/mod.rs +++ b/crates/reasoning_parser/src/parsers/mod.rs @@ -5,6 +5,7 @@ pub mod glm45; pub mod kimi; pub mod minimax; pub mod nano_v3; +pub mod none; pub mod qwen3; pub mod step3; @@ -15,5 +16,6 @@ pub use glm45::Glm45Parser; pub use kimi::KimiParser; pub use minimax::MiniMaxParser; pub use nano_v3::NanoV3Parser; +pub use none::NoneParser; pub use qwen3::{Qwen3Parser, QwenThinkingParser}; pub use step3::Step3Parser; diff --git a/crates/reasoning_parser/src/parsers/none.rs b/crates/reasoning_parser/src/parsers/none.rs new file mode 100644 index 000000000..82e9e4363 --- /dev/null +++ b/crates/reasoning_parser/src/parsers/none.rs @@ -0,0 +1,123 @@ +// No-op reasoning parser. +// +// Returns all input text as `normal_text` and never produces reasoning text, +// regardless of whether the input contains ``/`` (or any other) +// markers. Use this when the model emits a single content stream and the +// caller does not want any portion of it separated into `reasoning_content`. + +use crate::traits::{ParseError, ParserResult, ReasoningParser}; + +/// Parser that performs no reasoning extraction. +/// +/// Every byte received is forwarded to `normal_text`; `reasoning_text` is always +/// empty. State is trivial: no buffering, no tokens, no flags. +#[derive(Debug, Clone, Default)] +pub struct NoneParser; + +impl NoneParser { + pub fn new() -> Self { + Self + } +} + +impl ReasoningParser for NoneParser { + fn detect_and_parse_reasoning(&mut self, text: &str) -> Result { + Ok(ParserResult::normal(text.to_string())) + } + + fn parse_reasoning_streaming_incremental( + &mut self, + text: &str, + ) -> Result { + Ok(ParserResult::normal(text.to_string())) + } + + fn reset(&mut self) {} + + fn model_type(&self) -> &str { + "none" + } + + fn is_in_reasoning(&self) -> bool { + false + } + + fn mark_reasoning_started(&mut self) {} + + fn mark_think_start_stripped(&mut self) {} +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn plain_text_goes_to_normal() { + let mut parser = NoneParser::new(); + let result = parser + .detect_and_parse_reasoning("just some content") + .unwrap(); + assert_eq!(result.normal_text, "just some content"); + assert_eq!(result.reasoning_text, ""); + } + + #[test] + fn think_tags_are_kept_in_normal_text() { + let mut parser = NoneParser::new(); + let result = parser + .detect_and_parse_reasoning("cotanswer") + .unwrap(); + assert_eq!(result.normal_text, "cotanswer"); + assert_eq!(result.reasoning_text, ""); + } + + #[test] + fn streaming_passes_chunks_through_unchanged() { + let mut parser = NoneParser::new(); + + let r1 = parser + .parse_reasoning_streaming_incremental("") + .unwrap(); + assert_eq!(r1.normal_text, ""); + assert_eq!(r1.reasoning_text, ""); + + let r2 = parser + .parse_reasoning_streaming_incremental("hidden cot") + .unwrap(); + assert_eq!(r2.normal_text, "hidden cot"); + assert_eq!(r2.reasoning_text, ""); + + let r3 = parser + .parse_reasoning_streaming_incremental("visible") + .unwrap(); + assert_eq!(r3.normal_text, "visible"); + assert_eq!(r3.reasoning_text, ""); + } + + #[test] + fn empty_input_is_normal_and_empty() { + let mut parser = NoneParser::new(); + let result = parser.detect_and_parse_reasoning("").unwrap(); + assert!(result.is_empty()); + } + + #[test] + fn mark_helpers_do_not_change_behavior() { + let mut parser = NoneParser::new(); + parser.mark_reasoning_started(); + parser.mark_think_start_stripped(); + assert!(!parser.is_in_reasoning()); + + let result = parser + .detect_and_parse_reasoning("xy") + .unwrap(); + assert_eq!(result.normal_text, "xy"); + assert_eq!(result.reasoning_text, ""); + } + + #[test] + fn model_type_is_none() { + let parser = NoneParser::new(); + assert_eq!(parser.model_type(), "none"); + } +} From 963baa42bcec9dfda42623ae0f545ef42333b468 Mon Sep 17 00:00:00 2001 From: Keyang Ru Date: Mon, 11 May 2026 23:48:04 -0700 Subject: [PATCH 23/24] =?UTF-8?q?ci(release):=20add=20dev=20wheel=20workfl?= =?UTF-8?q?ow=20=E2=86=92=20GitHub=20Releases=20(#1471)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: key4ng --- .github/workflows/release-pypi-dev.yml | 348 +++++++++++++++++++++++++ 1 file changed, 348 insertions(+) create mode 100644 .github/workflows/release-pypi-dev.yml diff --git a/.github/workflows/release-pypi-dev.yml b/.github/workflows/release-pypi-dev.yml new file mode 100644 index 000000000..6db07b0b3 --- /dev/null +++ b/.github/workflows/release-pypi-dev.yml @@ -0,0 +1,348 @@ +name: Release SMG dev wheels (GitHub Releases) + +# Builds dev wheels for smg, smg-grpc-proto, and smg-grpc-servicer and +# publishes them as a single GitHub Release per workflow run. Avoids the +# 10 GB PyPI project quota and lets us delete old dev releases freely. +# Prod releases continue to go to PyPI via release-pypi.yml / release-grpc.yml. + +on: + workflow_dispatch: + inputs: + release_smg: + description: "Build & release smg (Rust wheel)" + type: boolean + default: true + release_servicer: + description: "Build & release smg-grpc-servicer" + type: boolean + default: true + release_proto: + description: "Build & release smg-grpc-proto" + type: boolean + default: true + # Self-validating: PRs that touch this workflow file auto-build (no release + # is created on pull_request runs; they are pure dry-runs for CI). + pull_request: + paths: + - .github/workflows/release-pypi-dev.yml + +permissions: + contents: read + +jobs: + prepare: + name: Compute dev versions + runs-on: ubuntu-latest + outputs: + smg_version: ${{ steps.compute.outputs.smg_version }} + proto_version: ${{ steps.compute.outputs.proto_version }} + servicer_version: ${{ steps.compute.outputs.servicer_version }} + release_tag: ${{ steps.compute.outputs.release_tag }} + steps: + - uses: actions/checkout@v6 + + - id: compute + run: | + set -euo pipefail + SUFFIX="dev${{ github.run_number }}" + RELEASE_TAG="smg-dev-${{ github.run_number }}" + + read_version() { + grep -m1 '^version = ' "$1" | sed 's/version = "\(.*\)"/\1/' + } + + # Bump patch so dev sorts AFTER current stable in PEP 440 ordering. + bump_patch() { + local IFS=. + # shellcheck disable=SC2206 + local parts=($1) + parts[2]=$(( ${parts[2]} + 1 )) + echo "${parts[0]}.${parts[1]}.${parts[2]}" + } + + SMG_VERSION="$(bump_patch "$(read_version bindings/python/pyproject.toml)").${SUFFIX}" + PROTO_VERSION="$(bump_patch "$(read_version crates/grpc_client/python/pyproject.toml)").${SUFFIX}" + SERVICER_VERSION="$(bump_patch "$(read_version grpc_servicer/pyproject.toml)").${SUFFIX}" + + echo "smg_version=${SMG_VERSION}" >> "$GITHUB_OUTPUT" + echo "proto_version=${PROTO_VERSION}" >> "$GITHUB_OUTPUT" + echo "servicer_version=${SERVICER_VERSION}" >> "$GITHUB_OUTPUT" + echo "release_tag=${RELEASE_TAG}" >> "$GITHUB_OUTPUT" + + { + echo "## Planned dev versions" + echo "| Package | Version | Will build? |" + echo "|---|---|---|" + echo "| smg | ${SMG_VERSION} | ${{ github.event_name == 'pull_request' || inputs.release_smg }} |" + echo "| smg-grpc-proto | ${PROTO_VERSION} | ${{ github.event_name == 'pull_request' || inputs.release_proto }} |" + echo "| smg-grpc-servicer | ${SERVICER_VERSION} | ${{ github.event_name == 'pull_request' || inputs.release_servicer }} |" + echo "" + echo "Release tag (workflow_dispatch only): \`${RELEASE_TAG}\`" + } >> "$GITHUB_STEP_SUMMARY" + + # ── smg (Rust wheel via maturin) — slim matrix: manylinux x86_64 + sdist ── + build-smg: + name: Build smg dev wheel + needs: prepare + if: ${{ github.event_name == 'pull_request' || inputs.release_smg }} + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + with: + path: smg-repo + + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: "3.13" + + # Patches Python metadata + Rust runtime version so the wheel reports a + # consistent dev version everywhere — smg.__version__, the smg.smg_rs + # banner from model_gateway::build.rs, and pyproject.toml. PEP 440 + # (1.4.2.dev42) for Python, semver (1.4.2-dev42) for Cargo. Other + # workspace crates reference model_gateway via path with no version pin. + - name: Patch dev version (Python + model_gateway Cargo) + working-directory: smg-repo + env: + SMG_VERSION: ${{ needs.prepare.outputs.smg_version }} + run: | + set -euo pipefail + SMG_CARGO_VERSION="${SMG_VERSION/.dev/-dev}" + + sed -i.bak "s/^version = \".*\"/version = \"${SMG_VERSION}\"/" bindings/python/pyproject.toml + sed -i.bak "s/__version__ = \".*\"/__version__ = \"${SMG_VERSION}\"/" bindings/python/src/smg/version.py + sed -i.bak "s/^version = \".*\"/version = \"${SMG_CARGO_VERSION}\"/" model_gateway/Cargo.toml + rm -f bindings/python/pyproject.toml.bak bindings/python/src/smg/version.py.bak model_gateway/Cargo.toml.bak + + echo "Patched smg version:" + echo " pyproject.toml + __version__: ${SMG_VERSION}" + echo " model_gateway/Cargo.toml: ${SMG_CARGO_VERSION}" + + - name: Build manylinux x86_64 wheels + uses: PyO3/maturin-action@v1 + with: + working-directory: smg-repo/bindings/python + target: x86_64 + manylinux: auto + # PR runs build a single interpreter (3.12) to keep dry-run fast + # (~6 min vs ~25 min); workflow_dispatch builds the full matrix. + args: --release --out dist --features vendored-openssl --interpreter ${{ github.event_name == 'pull_request' && '3.12' || '3.10 3.11 3.12 3.13' }} + rust-toolchain: stable + before-script-linux: | + if command -v yum &> /dev/null; then + yum update -y && yum install -y wget unzip gcc gcc-c++ perl-core make + elif command -v apt-get &> /dev/null; then + apt-get update && apt-get install -y wget unzip gcc g++ perl make + fi + (cd /tmp && \ + wget https://github.com/protocolbuffers/protobuf/releases/download/v32.0/protoc-32.0-linux-x86_64.zip && \ + unzip protoc-32.0-linux-x86_64.zip -d /usr/local && \ + rm protoc-32.0-linux-x86_64.zip) + protoc --version + + # The manylinux wheel build above runs inside Docker as root; dist/ ends + # up root-owned. Reclaim ownership so the host-side sdist step can write. + - name: Reclaim dist ownership from manylinux container + run: | + sudo chown -R "$(whoami):$(whoami)" smg-repo/bindings/python/dist + ls -la smg-repo/bindings/python/dist + + - name: Build sdist + uses: PyO3/maturin-action@v1 + with: + working-directory: smg-repo/bindings/python + command: sdist + args: --out dist + rust-toolchain: stable + + - name: Check packages + run: | + pip install -U twine + twine check --strict smg-repo/bindings/python/dist/* + + - uses: actions/upload-artifact@v7 + with: + name: smg-dev-dist + path: smg-repo/bindings/python/dist/ + + # ── smg-grpc-proto (pure Python) ── + build-proto: + name: Build smg-grpc-proto dev + needs: prepare + if: ${{ github.event_name == 'pull_request' || inputs.release_proto }} + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 + with: + python-version: "3.12" + - name: Install build deps + run: pip install build twine grpcio-tools + - name: Patch dev version + env: + PROTO_VERSION: ${{ needs.prepare.outputs.proto_version }} + run: | + set -euo pipefail + sed -i.bak "s/^version = \".*\"/version = \"${PROTO_VERSION}\"/" \ + crates/grpc_client/python/pyproject.toml + rm -f crates/grpc_client/python/pyproject.toml.bak + - name: Copy proto files (replace symlink with real files) + run: | + rm -f crates/grpc_client/python/smg_grpc_proto/proto + mkdir -p crates/grpc_client/python/smg_grpc_proto/proto + cp crates/grpc_client/proto/*.proto crates/grpc_client/python/smg_grpc_proto/proto/ + - name: Build package + run: cd crates/grpc_client/python && python -m build + - name: Check package + run: twine check --strict crates/grpc_client/python/dist/* + - uses: actions/upload-artifact@v7 + with: + name: proto-dev-dist + path: crates/grpc_client/python/dist/ + + # ── smg-grpc-servicer (pure Python) ── + build-servicer: + name: Build smg-grpc-servicer dev + needs: prepare + if: ${{ github.event_name == 'pull_request' || inputs.release_servicer }} + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 + with: + python-version: "3.12" + - name: Install build deps + run: pip install build twine + # When proto is also being built in this run, pin the servicer's proto + # dep to the exact dev version so consumers get a coherent dev set. + - name: Patch dev version (+ pin proto dep if releasing proto together) + env: + SERVICER_VERSION: ${{ needs.prepare.outputs.servicer_version }} + PROTO_VERSION: ${{ needs.prepare.outputs.proto_version }} + RELEASE_PROTO: ${{ github.event_name == 'pull_request' || inputs.release_proto }} + run: | + set -euo pipefail + sed -i.bak "s/^version = \".*\"/version = \"${SERVICER_VERSION}\"/" \ + grpc_servicer/pyproject.toml + if [ "${RELEASE_PROTO}" = "true" ]; then + # Matches both the core dep ("smg-grpc-proto>=0.4.6") and the + # mlx extra ("smg-grpc-proto>=0.4.7"). + sed -i.bak -E "s|smg-grpc-proto>=[0-9.]+|smg-grpc-proto==${PROTO_VERSION}|g" \ + grpc_servicer/pyproject.toml + echo "Pinned smg-grpc-proto==${PROTO_VERSION} for dev coherence" + fi + rm -f grpc_servicer/pyproject.toml.bak + - name: Build package + run: cd grpc_servicer && python -m build + - name: Check package + run: twine check --strict grpc_servicer/dist/* + - uses: actions/upload-artifact@v7 + with: + name: servicer-dev-dist + path: grpc_servicer/dist/ + + # ── Create the GitHub Release with all built dev wheels attached ── + # Only fires for workflow_dispatch — pull_request runs validate the build + # but never create a release. + release: + name: Create GitHub Release + needs: [prepare, build-smg, build-proto, build-servicer] + # Conditions: + # - Only on workflow_dispatch (no releases on pull_request runs). + # - prepare must have succeeded so we have a tag/versions. + # - No requested build may have failed. + # - At least one build must have produced artifacts. + if: >- + always() + && github.event_name == 'workflow_dispatch' + && github.repository == 'lightseekorg/smg' + && needs.prepare.result == 'success' + && needs.build-smg.result != 'failure' + && needs.build-proto.result != 'failure' + && needs.build-servicer.result != 'failure' + && ( + needs.build-smg.result == 'success' + || needs.build-proto.result == 'success' + || needs.build-servicer.result == 'success' + ) + runs-on: ubuntu-latest + permissions: + contents: write + steps: + - name: Download all dev artifacts + uses: actions/download-artifact@v8 + with: + path: dist + pattern: "*-dev-dist" + merge-multiple: true + + - name: List collected assets + run: ls -lh dist/ + + - name: Create dev release + uses: softprops/action-gh-release@v2 + with: + tag_name: ${{ needs.prepare.outputs.release_tag }} + name: "smg dev ${{ github.run_number }}" + prerelease: true + make_latest: "false" + target_commitish: ${{ github.sha }} + fail_on_unmatched_files: true + body: | + Auto-generated dev release from workflow run ${{ github.run_number }}. + + **Versions** + - `smg==${{ needs.prepare.outputs.smg_version }}` — build: `${{ needs.build-smg.result }}` + - `smg-grpc-proto==${{ needs.prepare.outputs.proto_version }}` — build: `${{ needs.build-proto.result }}` + - `smg-grpc-servicer==${{ needs.prepare.outputs.servicer_version }}` — build: `${{ needs.build-servicer.result }}` + + **Install** + + ```bash + pip install smg smg-grpc-servicer smg-grpc-proto \ + --find-links https://github.com/${{ github.repository }}/releases/expanded_assets/${{ needs.prepare.outputs.release_tag }} + ``` + + Source commit: `${{ github.sha }}` + files: dist/* + + # ── Final summary in the run UI ── + summary: + name: Summary + needs: + - prepare + - build-smg + - build-proto + - build-servicer + - release + if: always() + runs-on: ubuntu-latest + steps: + - run: | + { + if [ "${{ github.event_name }}" = "pull_request" ]; then + echo "## Dev build summary (pull_request — no GitHub Release created)" + else + echo "## Dev release summary" + fi + echo "" + echo "| Package | Version | Build |" + echo "|---|---|---|" + echo "| smg | ${{ needs.prepare.outputs.smg_version }} | ${{ needs.build-smg.result }} |" + echo "| smg-grpc-proto | ${{ needs.prepare.outputs.proto_version }} | ${{ needs.build-proto.result }} |" + echo "| smg-grpc-servicer | ${{ needs.prepare.outputs.servicer_version }} | ${{ needs.build-servicer.result }} |" + echo "" + if [ "${{ github.event_name }}" = "workflow_dispatch" ] && [ "${{ needs.release.result }}" = "success" ]; then + TAG="${{ needs.prepare.outputs.release_tag }}" + REPO="${{ github.repository }}" + echo "### Release" + echo "" + echo "https://github.com/${REPO}/releases/tag/${TAG}" + echo "" + echo "### Install" + echo '```bash' + echo "pip install smg smg-grpc-servicer smg-grpc-proto \\" + echo " --find-links https://github.com/${REPO}/releases/expanded_assets/${TAG}" + echo '```' + fi + } >> "$GITHUB_STEP_SUMMARY" From 921a3d9f1429d5d964a78a4ac93e5d0b2bac018b Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Tue, 12 May 2026 01:43:34 -0700 Subject: [PATCH 24/24] ci(release): publish dev wheels to whl index (#1473) Signed-off-by: zhyncs <46627482+zhyncs@users.noreply.github.com> --- .github/workflows/release-pypi-dev.yml | 171 ++++++++++++++++++++----- scripts/update_whl_index.py | 133 +++++++++++++++++++ 2 files changed, 269 insertions(+), 35 deletions(-) create mode 100755 scripts/update_whl_index.py diff --git a/.github/workflows/release-pypi-dev.yml b/.github/workflows/release-pypi-dev.yml index 6db07b0b3..2ce07deca 100644 --- a/.github/workflows/release-pypi-dev.yml +++ b/.github/workflows/release-pypi-dev.yml @@ -1,8 +1,9 @@ -name: Release SMG dev wheels (GitHub Releases) +name: Release SMG dev wheels (whl index) # Builds dev wheels for smg, smg-grpc-proto, and smg-grpc-servicer and -# publishes them as a single GitHub Release per workflow run. Avoids the -# 10 GB PyPI project quota and lets us delete old dev releases freely. +# publishes them to lightseekorg/whl releases + simple package indexes. +# Requires a TOKENSPEED_GITHUB_TOKEN secret with write access to lightseekorg/whl. +# Avoids the 10 GB PyPI project quota and lets us delete old dev releases freely. # Prod releases continue to go to PyPI via release-pypi.yml / release-grpc.yml. on: @@ -241,11 +242,11 @@ jobs: name: servicer-dev-dist path: grpc_servicer/dist/ - # ── Create the GitHub Release with all built dev wheels attached ── + # ── Publish dev wheels to lightseekorg/whl and update simple indexes ── # Only fires for workflow_dispatch — pull_request runs validate the build - # but never create a release. + # but never publish artifacts or update indexes. release: - name: Create GitHub Release + name: Publish to whl index needs: [prepare, build-smg, build-proto, build-servicer] # Conditions: # - Only on workflow_dispatch (no releases on pull_request runs). @@ -266,8 +267,6 @@ jobs: || needs.build-servicer.result == 'success' ) runs-on: ubuntu-latest - permissions: - contents: write steps: - name: Download all dev artifacts uses: actions/download-artifact@v8 @@ -279,32 +278,134 @@ jobs: - name: List collected assets run: ls -lh dist/ - - name: Create dev release - uses: softprops/action-gh-release@v2 + - name: Check out smg repo + uses: actions/checkout@v6 with: - tag_name: ${{ needs.prepare.outputs.release_tag }} - name: "smg dev ${{ github.run_number }}" - prerelease: true - make_latest: "false" - target_commitish: ${{ github.sha }} - fail_on_unmatched_files: true - body: | - Auto-generated dev release from workflow run ${{ github.run_number }}. - - **Versions** - - `smg==${{ needs.prepare.outputs.smg_version }}` — build: `${{ needs.build-smg.result }}` - - `smg-grpc-proto==${{ needs.prepare.outputs.proto_version }}` — build: `${{ needs.build-proto.result }}` - - `smg-grpc-servicer==${{ needs.prepare.outputs.servicer_version }}` — build: `${{ needs.build-servicer.result }}` - - **Install** - - ```bash - pip install smg smg-grpc-servicer smg-grpc-proto \ - --find-links https://github.com/${{ github.repository }}/releases/expanded_assets/${{ needs.prepare.outputs.release_tag }} - ``` - - Source commit: `${{ github.sha }}` - files: dist/* + path: smg-repo + + - name: Check out whl index repo + uses: actions/checkout@v6 + with: + repository: lightseekorg/whl + ref: gh-pages + path: whl-repo + token: ${{ secrets.TOKENSPEED_GITHUB_TOKEN }} + + - name: Create whl release and upload assets + env: + GH_TOKEN: ${{ secrets.TOKENSPEED_GITHUB_TOKEN }} + RELEASE_TAG: ${{ needs.prepare.outputs.release_tag }} + WHL_REPO: lightseekorg/whl + SMG_VERSION: ${{ needs.prepare.outputs.smg_version }} + PROTO_VERSION: ${{ needs.prepare.outputs.proto_version }} + SERVICER_VERSION: ${{ needs.prepare.outputs.servicer_version }} + run: | + set -euo pipefail + + notes_file="$(mktemp)" + cat > "${notes_file}" </dev/null 2>&1; then + gh release edit "${RELEASE_TAG}" \ + --repo "${WHL_REPO}" \ + --title "smg dev ${{ github.run_number }}" \ + --prerelease \ + --notes-file "${notes_file}" + else + gh release create "${RELEASE_TAG}" \ + --repo "${WHL_REPO}" \ + --title "smg dev ${{ github.run_number }}" \ + --prerelease \ + --notes-file "${notes_file}" + fi + + gh release upload "${RELEASE_TAG}" dist/* --repo "${WHL_REPO}" --clobber + + - name: Update whl package indexes + env: + RELEASE_TAG: ${{ needs.prepare.outputs.release_tag }} + run: | + set -euo pipefail + + cd whl-repo + git config user.name "github-actions[bot]" + git config user.email "41898282+github-actions[bot]@users.noreply.github.com" + + rm -rf wheelhouse + mkdir -p wheelhouse/smg wheelhouse/smg-grpc-proto wheelhouse/smg-grpc-servicer + cp ../dist/smg-*.whl wheelhouse/smg/ 2>/dev/null || true + cp ../dist/smg_grpc_proto-*.whl wheelhouse/smg-grpc-proto/ 2>/dev/null || true + cp ../dist/smg_grpc_servicer-*.whl wheelhouse/smg-grpc-servicer/ 2>/dev/null || true + + # Make workflow re-runs idempotent for the same release tag. The + # shared update_whl_index.py script appends wheel links by design. + for index_file in \ + cu129/smg/index.html \ + cu129/smg-grpc-proto/index.html \ + cu129/smg-grpc-servicer/index.html \ + cu130/smg/index.html \ + cu130/smg-grpc-proto/index.html \ + cu130/smg-grpc-servicer/index.html \ + rocm7.2/smg/index.html \ + rocm7.2/smg-grpc-proto/index.html \ + rocm7.2/smg-grpc-servicer/index.html; do + [ -f "${index_file}" ] && sed -i "\#/${RELEASE_TAG}/#d" "${index_file}" + done + + for cuda in 129 130; do + python3 ../smg-repo/scripts/update_whl_index.py \ + --package smg \ + --cuda "${cuda}" \ + --release-tag "${RELEASE_TAG}" \ + --wheel-dir wheelhouse/smg \ + --whl-repo-dir . + python3 ../smg-repo/scripts/update_whl_index.py \ + --package smg-grpc-proto \ + --cuda "${cuda}" \ + --release-tag "${RELEASE_TAG}" \ + --wheel-dir wheelhouse/smg-grpc-proto \ + --whl-repo-dir . + python3 ../smg-repo/scripts/update_whl_index.py \ + --package smg-grpc-servicer \ + --cuda "${cuda}" \ + --release-tag "${RELEASE_TAG}" \ + --wheel-dir wheelhouse/smg-grpc-servicer \ + --whl-repo-dir . + done + + for package in smg smg-grpc-proto smg-grpc-servicer; do + python3 ../smg-repo/scripts/update_whl_index.py \ + --package "${package}" \ + --rocm 7.2 \ + --release-tag "${RELEASE_TAG}" \ + --wheel-dir "wheelhouse/${package}" \ + --whl-repo-dir . + done + + rm -rf wheelhouse + git add cu129 cu130 rocm7.2 + if git diff --cached --quiet; then + echo "No whl index changes to commit." + else + git commit -s -m "Add smg dev ${{ github.run_number }} wheels" + git push origin gh-pages + fi # ── Final summary in the run UI ── summary: @@ -337,12 +438,12 @@ jobs: REPO="${{ github.repository }}" echo "### Release" echo "" - echo "https://github.com/${REPO}/releases/tag/${TAG}" + echo "https://github.com/lightseekorg/whl/releases/tag/${TAG}" echo "" echo "### Install" echo '```bash' echo "pip install smg smg-grpc-servicer smg-grpc-proto \\" - echo " --find-links https://github.com/${REPO}/releases/expanded_assets/${TAG}" + echo " --extra-index-url https://lightseek.org/whl/cu129/" echo '```' fi } >> "$GITHUB_STEP_SUMMARY" diff --git a/scripts/update_whl_index.py b/scripts/update_whl_index.py new file mode 100755 index 000000000..b01c92a6b --- /dev/null +++ b/scripts/update_whl_index.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python3 +# Reference: https://github.com/flashinfer-ai/flashinfer/blob/v0.2.0/scripts/update_whl_index.py +"""Update the wheel index in the lightseekorg/whl repository. + +Index layout (matches the existing repo structure): + /cu/index.html ← PEP 503 top-level index + /cu//index.html ← per-package wheel list + /rocm/index.html ← PEP 503 top-level index + /rocm//index.html ← per-package wheel list + +Install example (PyTorch-style --extra-index-url): + pip install smg \ + --extra-index-url https://lightseek.org/whl/cu129/ +""" + +import argparse +import hashlib +import pathlib + +BASE_URL = "https://github.com/lightseekorg/whl/releases/download" + + +def _cuda_display(cuda_digits: str) -> str: + """'129' -> '12.9', '130' -> '13.0'""" + return f"{cuda_digits[:-1]}.{cuda_digits[-1]}" + + +def _platform_index(cuda: str | None, rocm: str | None) -> tuple[str, str]: + if cuda: + return f"cu{cuda}", f"CUDA {_cuda_display(cuda)}" + if rocm: + return f"rocm{rocm}", f"ROCm {rocm}" + raise ValueError("Either cuda or rocm must be provided") + + +def compute_sha256(path: pathlib.Path) -> str: + with open(path, "rb") as f: + return hashlib.sha256(f.read()).hexdigest() + + +def _ensure_in_top_index(index_root: pathlib.Path, package: str) -> None: + """Add package to a platform index if not already listed (PEP 503).""" + top_index = index_root / "index.html" + entry = f'{package}
\n' + if top_index.exists(): + if entry in top_index.read_text(): + return + with top_index.open("a") as f: + f.write(entry) + else: + top_index.write_text(f"\n{entry}") + print(f" Added {package} to top-level index") + + +def update_index( + package: str, + cuda: str | None, + rocm: str | None, + release_tag: str, + wheel_dir: str, + whl_repo_dir: str, +) -> None: + platform_dir, platform_display = _platform_index(cuda, rocm) + index_root = pathlib.Path(whl_repo_dir) / platform_dir + index_dir = index_root / package + index_dir.mkdir(exist_ok=True, parents=True) + + # Keep the platform index up-to-date for --index-url support + _ensure_in_top_index(index_root, package) + + index_file = index_dir / "index.html" + if not index_file.exists(): + index_file.write_text( + f"\n

{package} wheels for {platform_display}

\n" + ) + + wheels = sorted(pathlib.Path(wheel_dir).glob("*.whl")) + if not wheels: + print(f"WARNING: no .whl files found in {wheel_dir}") + return + + for path in wheels: + sha256 = compute_sha256(path) + full_url = f"{BASE_URL}/{release_tag}/{path.name}#sha256={sha256}" + with index_file.open("a") as f: + f.write(f'{path.name}
\n') + print(f" Indexed: {path.name}") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Update wheel index for lightseekorg/whl") + parser.add_argument( + "--package", + required=True, + help="Package name", + ) + platform = parser.add_mutually_exclusive_group(required=True) + platform.add_argument( + "--cuda", + help="CUDA version digits (e.g. 129, 130)", + ) + platform.add_argument( + "--rocm", + help="ROCm version (e.g. 7.2)", + ) + parser.add_argument( + "--release-tag", + required=True, + help="Release tag in lightseekorg/whl", + ) + parser.add_argument( + "--wheel-dir", + default="wheelhouse", + help="Directory containing .whl files (default: wheelhouse)", + ) + parser.add_argument( + "--whl-repo-dir", + default=".", + help="Root of the lightseekorg/whl checkout (default: .)", + ) + args = parser.parse_args() + update_index( + args.package, + args.cuda, + args.rocm, + args.release_tag, + args.wheel_dir, + args.whl_repo_dir, + ) + + +if __name__ == "__main__": + main()