From 872e5d2b78c0fe31f31dfbccbca0b3a7716ed044 Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Tue, 5 May 2026 19:52:03 +0800 Subject: [PATCH 1/2] chore: reformat codebase Signed-off-by: Bugen Zhao --- rustfmt.toml | 2 + rustfmt.unstable.toml | 3 +- .../examples/external_engine_chat_qwen.rs | 9 +- src/chat/src/backend/mod.rs | 19 +- src/chat/src/event.rs | 15 +- src/chat/src/lib.rs | 34 +- src/chat/src/output/default/mod.rs | 50 +-- src/chat/src/output/default/reasoning.rs | 3 +- src/chat/src/output/default/tool.rs | 7 +- src/chat/src/output/harmony/mod.rs | 61 ++-- src/chat/src/output/mod.rs | 22 +- src/chat/src/output/structured.rs | 13 +- src/chat/src/parser/mod.rs | 6 +- src/chat/src/parser/reasoning/cohere_cmd.rs | 3 +- src/chat/src/parser/reasoning/deepseek_r1.rs | 3 +- src/chat/src/parser/reasoning/delimited.rs | 19 +- src/chat/src/parser/reasoning/gemma4.rs | 9 +- src/chat/src/parser/reasoning/mod.rs | 28 +- src/chat/src/parser/tool/deepseek_v32.rs | 14 +- src/chat/src/parser/tool/deepseek_v4.rs | 4 +- src/chat/src/parser/tool/external.rs | 38 +- src/chat/src/parser/tool/gemma4.rs | 33 +- src/chat/src/parser/tool/minimax_m2.rs | 20 +- src/chat/src/parser/tool/mod.rs | 36 +- src/chat/src/parser/tool/parameters.rs | 38 +- src/chat/src/parser/tool/qwen_coder.rs | 12 +- src/chat/src/parser/tool/streaming.rs | 7 +- src/chat/src/parser/tool/test_utils.rs | 3 +- src/chat/src/parser/tool/utils.rs | 7 +- .../src/renderer/deepseek_v32/encoding.rs | 5 +- src/chat/src/renderer/deepseek_v32/tests.rs | 14 +- src/chat/src/renderer/deepseek_v4/encoding.rs | 19 +- src/chat/src/renderer/hf/format.rs | 16 +- src/chat/src/renderer/hf/mod.rs | 25 +- src/chat/src/renderer/hf/template.rs | 6 +- src/chat/src/renderer/hf/tojson.rs | 10 +- src/chat/src/renderer/mod.rs | 3 +- src/chat/src/renderer/selection.rs | 3 +- src/chat/src/request.rs | 76 ++-- src/chat/src/stream.rs | 7 +- src/chat/tests/chat.rs | 38 +- src/cmd/src/cli.rs | 112 +++--- src/cmd/src/cli/serve_validate.rs | 26 +- src/cmd/src/cli/unsupported.rs | 124 ++++--- src/cmd/src/logging.rs | 22 +- src/cmd/src/main.rs | 18 +- src/cmd/src/managed_engine.rs | 31 +- .../examples/external_engine_logprobs.rs | 10 +- .../examples/external_engine_utility_call.rs | 48 +-- src/engine-core-client/src/client.rs | 155 ++++---- src/engine-core-client/src/client/imp.rs | 83 ++--- src/engine-core-client/src/client/state.rs | 71 ++-- src/engine-core-client/src/client/stream.rs | 21 +- .../src/coordinator/bootstrap.rs | 29 +- .../src/coordinator/external.rs | 23 +- .../src/coordinator/handle.rs | 18 +- .../src/coordinator/inproc.rs | 19 +- src/engine-core-client/src/metrics.rs | 25 +- .../src/protocol/handshake.rs | 5 +- .../src/protocol/logprobs.rs | 80 +++-- .../src/protocol/logprobs/tests.rs | 21 +- .../src/protocol/logprobs/wire.rs | 40 +-- src/engine-core-client/src/protocol/mod.rs | 51 +-- src/engine-core-client/src/test_utils.rs | 49 ++- src/engine-core-client/src/tests/client.rs | 245 +++---------- src/engine-core-client/src/transport.rs | 159 ++++----- src/llm/examples/external_engine_smoke.rs | 13 +- src/llm/src/lib.rs | 17 +- src/llm/src/log_stats.rs | 23 +- src/llm/src/output.rs | 64 ++-- src/llm/src/request.rs | 7 +- src/llm/src/request_metrics.rs | 27 +- src/llm/tests/generate.rs | 45 +-- src/metrics/src/lib.rs | 9 +- src/metrics/src/scheduler.rs | 21 +- .../examples/external_engine_openai_qwen.rs | 25 +- src/server/src/config.rs | 22 +- src/server/src/error.rs | 6 +- src/server/src/grpc/convert.rs | 85 ++--- src/server/src/grpc/tests.rs | 19 +- src/server/src/lib.rs | 42 +-- src/server/src/listener.rs | 38 +- src/server/src/middleware/load.rs | 10 +- src/server/src/middleware/metrics.rs | 4 +- src/server/src/routes/cache.rs | 3 +- src/server/src/routes/http_client_tests.rs | 24 +- src/server/src/routes/inference/generate.rs | 9 +- .../src/routes/inference/generate/convert.rs | 3 +- .../src/routes/inference/generate/validate.rs | 3 +- .../src/routes/openai/chat_completions.rs | 91 +++-- .../routes/openai/chat_completions/convert.rs | 26 +- .../routes/openai/chat_completions/types.rs | 42 ++- .../openai/chat_completions/validate.rs | 3 +- src/server/src/routes/openai/completions.rs | 31 +- .../src/routes/openai/completions/convert.rs | 26 +- .../src/routes/openai/completions/types.rs | 34 +- .../src/routes/openai/completions/validate.rs | 8 +- .../src/routes/openai/utils/logprobs.rs | 27 +- .../routes/openai/utils/structured_outputs.rs | 32 +- .../src/routes/openai/utils/validated_json.rs | 10 +- src/server/src/routes/tests.rs | 235 ++++-------- src/server/src/state.rs | 20 +- src/server/src/utils.rs | 30 +- src/text/src/backend/hf/config.rs | 34 +- src/text/src/backend/hf/mod.rs | 7 +- src/text/src/backend/hf/model_files.rs | 65 ++-- src/text/src/backend/mod.rs | 10 +- src/text/src/incremental.rs | 21 +- src/text/src/lib.rs | 32 +- src/text/src/lower.rs | 53 +-- src/text/src/output/decoded.rs | 51 +-- src/text/src/output/logprobs.rs | 42 +-- src/text/src/output/mod.rs | 7 +- src/text/src/request.rs | 54 +-- src/text/src/tokenizer/byte_level_decode.rs | 6 +- src/text/src/tokenizer/hf.rs | 22 +- src/text/src/tokenizer/mod.rs | 11 +- src/text/src/tokenizer/tekken.rs | 4 +- src/text/src/tokenizer/tiktoken.rs | 334 ++++++++++-------- 119 files changed, 1908 insertions(+), 2204 deletions(-) diff --git a/rustfmt.toml b/rustfmt.toml index 35011368..e619a753 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1 +1,3 @@ style_edition = "2024" +chain_width = 80 +use_field_init_shorthand = true diff --git a/rustfmt.unstable.toml b/rustfmt.unstable.toml index 5fa0fe10..8585f04f 100644 --- a/rustfmt.unstable.toml +++ b/rustfmt.unstable.toml @@ -4,7 +4,7 @@ unstable_features = true style_edition = "2024" -comment_width = 100 +chain_width = 80 format_code_in_doc_comments = true format_macro_matchers = true normalize_comments = true @@ -13,3 +13,4 @@ imports_granularity = "Module" group_imports = "StdExternalCrate" reorder_impl_items = true wrap_comments = true +use_field_init_shorthand = true diff --git a/src/chat/examples/external_engine_chat_qwen.rs b/src/chat/examples/external_engine_chat_qwen.rs index 4a5b75de..d99d672d 100644 --- a/src/chat/examples/external_engine_chat_qwen.rs +++ b/src/chat/examples/external_engine_chat_qwen.rs @@ -96,10 +96,7 @@ async fn main() -> Result<()> { println!("request_id={request_id}"); println!("prompt={}", args.prompt); - let mut stream = chat - .chat(request) - .await - .context("failed to submit chat request")?; + let mut stream = chat.chat(request).await.context("failed to submit chat request")?; let output = tokio::time::timeout(output_timeout, async { let mut final_reasoning = String::new(); let mut final_text = String::new(); @@ -170,9 +167,7 @@ async fn main() -> Result<()> { .await .context("timed out waiting for chat output")??; - chat.shutdown() - .await - .context("failed to shut down chat client")?; + chat.shutdown().await.context("failed to shut down chat client")?; println!("final_reasoning={:?}", output.0); println!("final_text={:?}", output.1); diff --git a/src/chat/src/backend/mod.rs b/src/chat/src/backend/mod.rs index 5e94daac..aa9eab06 100644 --- a/src/chat/src/backend/mod.rs +++ b/src/chat/src/backend/mod.rs @@ -25,7 +25,8 @@ pub trait ChatBackend: Send + Sync { /// Return the renderer used for chat-prompt construction. fn chat_renderer(&self) -> DynChatRenderer; - /// Create a request-scoped output processor after request-level adjustments are applied. + /// Create a request-scoped output processor after request-level adjustments + /// are applied. fn new_chat_output_processor( &self, request: &mut ChatRequest, @@ -36,10 +37,12 @@ pub trait ChatBackend: Send + Sync { /// Shared trait-object form of [`ChatBackend`]. pub type DynChatBackend = Arc; -/// Convenience trait for backends that can serve both raw text generation and chat templating. +/// Convenience trait for backends that can serve both raw text generation and +/// chat templating. /// -/// This is mainly useful in tests and small examples, where one mock/backend often implements -/// both sides and callers want `ChatLlm` to wire the shared object into `TextLlm` automatically. +/// This is mainly useful in tests and small examples, where one mock/backend +/// often implements both sides and callers want `ChatLlm` to wire the shared +/// object into `TextLlm` automatically. pub trait ChatTextBackend: ChatBackend + TextBackend {} impl ChatTextBackend for T where T: ChatBackend + TextBackend + ?Sized {} @@ -54,11 +57,11 @@ pub struct LoadModelBackendsOptions { pub renderer: RendererSelection, /// How to serialize `message.content` when rendering the chat template. pub chat_template_content_format: ChatTemplateContentFormatOption, - /// Optional server-default chat template override, provided either as an inline template or - /// as a path to a template file. + /// Optional server-default chat template override, provided either as an + /// inline template or as a path to a template file. pub chat_template: Option, - /// Optional server-default keyword arguments merged into every chat-template render before - /// request-level `chat_template_kwargs`. + /// Optional server-default keyword arguments merged into every + /// chat-template render before request-level `chat_template_kwargs`. pub default_chat_template_kwargs: HashMap, } diff --git a/src/chat/src/event.rs b/src/chat/src/event.rs index 3f9fb4c0..08603b43 100644 --- a/src/chat/src/event.rs +++ b/src/chat/src/event.rs @@ -80,7 +80,8 @@ impl [AssistantContentBlock] { .filter(|s: &String| !s.is_empty()) } - /// Return whether this assistant message contains any non-empty reasoning text blocks. + /// Return whether this assistant message contains any non-empty reasoning + /// text blocks. pub fn has_reasoning(&self) -> bool { self.iter().any(|block| match block { AssistantContentBlock::Reasoning { text } => !text.is_empty(), @@ -95,8 +96,7 @@ impl [AssistantContentBlock] { /// Return whether this assistant message contains any tool-call blocks. pub fn has_tool_calls(&self) -> bool { - self.iter() - .any(|block| matches!(block, AssistantContentBlock::ToolCall(_))) + self.iter().any(|block| matches!(block, AssistantContentBlock::ToolCall(_))) } } @@ -124,7 +124,8 @@ impl AssistantMessage { /// Streamed chat event emitted by [`crate::ChatEventStream`]. #[derive(Debug, Clone, PartialEq)] pub enum ChatEvent { - /// The request was accepted, streaming has started, and prompt metadata is ready. + /// The request was accepted, streaming has started, and prompt metadata is + /// ready. Start { /// The actual prompt token IDs for this request. prompt_token_ids: Arc<[u32]>, @@ -158,14 +159,16 @@ pub enum ChatEvent { id: String, name: String, }, - /// One incremental tool-call arguments delta for the currently open tool call. + /// One incremental tool-call arguments delta for the currently open tool + /// call. ToolCallArgumentsDelta { index: usize, delta: String }, /// One tool call has ended. ToolCallEnd { index: usize, call: AssistantToolCall, }, - /// Terminal event carrying the final assembled assistant message and finish metadata. + /// Terminal event carrying the final assembled assistant message and finish + /// metadata. Done { message: AssistantMessage, /// Number of prompt tokens actually sent to the engine after chat diff --git a/src/chat/src/lib.rs b/src/chat/src/lib.rs index 00ab7a48..d57e09b1 100644 --- a/src/chat/src/lib.rs +++ b/src/chat/src/lib.rs @@ -1,10 +1,11 @@ //! Minimal chat facade above [`vllm_text`]. //! //! This crate keeps the northbound boundary intentionally small: -//! `messages -> rendered prompt -> tokenized prompt -> engine request -> streamed structured -//! assistant events`. The request side remains text-first, while the response side can emit -//! structured reasoning and final-answer blocks. It is closer to vLLM's internal chat-rendering -//! flow than to a full OpenAI-compatible surface. +//! `messages -> rendered prompt -> tokenized prompt -> engine request -> +//! streamed structured assistant events`. The request side remains text-first, +//! while the response side can emit structured reasoning and final-answer +//! blocks. It is closer to vLLM's internal chat-rendering flow than to a full +//! OpenAI-compatible surface. pub use backend::hf::HfChatBackend; pub use backend::{ @@ -90,8 +91,9 @@ pub fn validate_parser_overrides( /// Structured chat facade above [`TextLlm`]. /// -/// This layer stays above raw text semantics: it takes care of chat-template rendering, exposes -/// structured assistant events, and adds chat-specific request semantics such as tool calls. +/// This layer stays above raw text semantics: it takes care of chat-template +/// rendering, exposes structured assistant events, and adds chat-specific +/// request semantics such as tool calls. pub struct ChatLlm { text: TextLlm, backend: DynChatBackend, @@ -102,7 +104,8 @@ pub struct ChatLlm { } impl ChatLlm { - /// Create a new chat facade from a text-generation facade plus a chat backend. + /// Create a new chat facade from a text-generation facade plus a chat + /// backend. pub fn new(text: TextLlm, backend: DynChatBackend) -> Self { Self { text, @@ -112,8 +115,8 @@ impl ChatLlm { } } - /// Convenience constructor for one shared backend object that implements both text and chat - /// responsibilities. + /// Convenience constructor for one shared backend object that implements + /// both text and chat responsibilities. pub fn from_shared_backend(llm: Llm, backend: DynChatTextBackend) -> Self { let text = TextLlm::new(llm, backend.clone()); Self::new(text, backend) @@ -131,7 +134,8 @@ impl ChatLlm { self } - /// Expose the underlying text facade for raw text-generation routes such as `/v1/completions`. + /// Expose the underlying text facade for raw text-generation routes such as + /// `/v1/completions`. pub fn text(&self) -> &TextLlm { &self.text } @@ -141,7 +145,8 @@ impl ChatLlm { self.text.model_id() } - /// Expose the underlying engine-core client for low-level utility/admin calls. + /// Expose the underlying engine-core client for low-level utility/admin + /// calls. pub fn engine_core_client(&self) -> &EngineCoreClient { self.text.engine_core_client() } @@ -171,12 +176,7 @@ impl ChatLlm { add_special_tokens: request.add_special_tokens, data_parallel_rank: request.data_parallel_rank, }; - let decoded_stream = self - .text - .generate(text_request) - .await? - .map_err(Error::from) - .boxed(); + let decoded_stream = self.text.generate(text_request).await?.map_err(Error::from).boxed(); let structured_stream = output_processor.process(decoded_stream)?; diff --git a/src/chat/src/output/default/mod.rs b/src/chat/src/output/default/mod.rs index 2f716499..df2bed5b 100644 --- a/src/chat/src/output/default/mod.rs +++ b/src/chat/src/output/default/mod.rs @@ -28,10 +28,12 @@ trait_set! { trait ContentEventStream = Stream> + Send + 'static; } -/// Default request-scoped output processor used by Hugging Face style chat backends. +/// Default request-scoped output processor used by Hugging Face style chat +/// backends. /// -/// This implementation assumes the backend already emitted decoded text deltas, then optionally -/// layers reasoning parsing and tool-call parsing before assembling final structured chat events. +/// This implementation assumes the backend already emitted decoded text deltas, +/// then optionally layers reasoning parsing and tool-call parsing before +/// assembling final structured chat events. pub struct DefaultChatOutputProcessor { intermediate: bool, reasoning_parser: Option>, @@ -39,10 +41,12 @@ pub struct DefaultChatOutputProcessor { } impl DefaultChatOutputProcessor { - /// Build the default output processor and apply any parser-specific request adjustments. + /// Build the default output processor and apply any parser-specific request + /// adjustments. /// - /// Parser resolution happens here so that request validation, prompt rendering, and streaming - /// all observe the same parser-adjusted request state. + /// Parser resolution happens here so that request validation, prompt + /// rendering, and streaming all observe the same parser-adjusted + /// request state. pub fn new( request: &mut ChatRequest, model_id: &str, @@ -77,8 +81,9 @@ impl DefaultChatOutputProcessor { /// Build the plain-text-only default output processor. /// - /// This keeps the default structured chat-event assembly but disables both reasoning parsing - /// and tool-call parsing completely, so that all content is treated as opaque text. + /// This keeps the default structured chat-event assembly but disables both + /// reasoning parsing and tool-call parsing completely, so that all + /// content is treated as opaque text. pub fn plain_text_only() -> Self { Self { intermediate: false, @@ -106,13 +111,11 @@ impl DefaultChatOutputProcessor { let parser = factory.create(parser_name, &request.tools)?; - parser - .adjust_request(request) - .map_err(|error| Error::ParserInitialization { - kind: "tool", - name: parser_name.to_string(), - error: error.into(), - })?; + parser.adjust_request(request).map_err(|error| Error::ParserInitialization { + kind: "tool", + name: parser_name.to_string(), + error: error.into(), + })?; TOOL_PARSER_LOG_ONCE.call_once(|| info!(parser_name, "using tool parser")); Ok(parser) @@ -138,13 +141,11 @@ impl DefaultChatOutputProcessor { let parser = factory.create(parser_name, tokenizer)?; - parser - .adjust_request(request) - .map_err(|error| Error::ParserInitialization { - kind: "reasoning", - name: parser_name.to_string(), - error: error.into(), - })?; + parser.adjust_request(request).map_err(|error| Error::ParserInitialization { + kind: "reasoning", + name: parser_name.to_string(), + error: error.into(), + })?; REASONING_PARSER_LOG_ONCE.call_once(|| info!(parser_name, "using reasoning parser")); Ok(Some(parser)) @@ -155,8 +156,9 @@ static TOOL_PARSER_LOG_ONCE: Once = Once::new(); static REASONING_PARSER_LOG_ONCE: Once = Once::new(); impl ChatOutputProcessor for DefaultChatOutputProcessor { - /// Transforms a raw generate-output token stream into structured chat events - /// through three sequential stages once text decoding has already happened: + /// Transforms a raw generate-output token stream into structured chat + /// events through three sequential stages once text decoding has + /// already happened: /// /// 1. [`reasoning_event_stream`] — reasoning/content separation /// 2. [`tool_event_stream`] — tool-call parsing diff --git a/src/chat/src/output/default/reasoning.rs b/src/chat/src/output/default/reasoning.rs index 7028be23..966e8ffb 100644 --- a/src/chat/src/output/default/reasoning.rs +++ b/src/chat/src/output/default/reasoning.rs @@ -35,7 +35,8 @@ impl ReasoningState { } } - /// Convert one decoded text delta into zero or more semantic assistant deltas. + /// Convert one decoded text delta into zero or more semantic assistant + /// deltas. fn process_delta(&mut self, delta: String) -> Vec { // If the parser has already failed, skip parsing and return plain text deltas. if self.parser_failed { diff --git a/src/chat/src/output/default/tool.rs b/src/chat/src/output/default/tool.rs index 4b5e7a7e..cb8a7232 100644 --- a/src/chat/src/output/default/tool.rs +++ b/src/chat/src/output/default/tool.rs @@ -184,9 +184,10 @@ fn push_text_delta(events: &mut Vec, kind: AssistantBlockKind, d /// Tool parsing when `intermediate=false` (`FinalOnly` mode). /// -/// We keep this separate because some adaptor-backed parsers may not correctly handle the full text -/// passed to incremental `push` interface, but override `parse_complete()` with a dedicated -/// one-shot implementation to ensure correctness. +/// We keep this separate because some adaptor-backed parsers may not correctly +/// handle the full text passed to incremental `push` interface, but override +/// `parse_complete()` with a dedicated one-shot implementation to ensure +/// correctness. #[try_stream] async fn final_only_tool_event_stream( stream: impl ContentEventStream, diff --git a/src/chat/src/output/harmony/mod.rs b/src/chat/src/output/harmony/mod.rs index 66b7b113..5dc6bc31 100644 --- a/src/chat/src/output/harmony/mod.rs +++ b/src/chat/src/output/harmony/mod.rs @@ -1,8 +1,8 @@ //! Native Harmony output processing for `gpt_oss`. //! -//! Unlike the default text-first pipeline, this processor consumes `DecodedTextEvent` -//! token IDs directly and lets the official `openai-harmony` parser recover the -//! structured assistant message shape at token granularity. +//! Unlike the default text-first pipeline, this processor consumes +//! `DecodedTextEvent` token IDs directly and lets the official `openai-harmony` +//! parser recover the structured assistant message shape at token granularity. use std::sync::LazyLock; @@ -28,8 +28,8 @@ use crate::request::ChatRequest; /// Request-scoped Harmony output processor used for `model_type == "gpt_oss"`. /// -/// This processor keeps the existing northbound `ChatEvent` shape, but swaps the -/// parsed-assistant backend from generic text/reasoning/tool parsers to the +/// This processor keeps the existing northbound `ChatEvent` shape, but swaps +/// the parsed-assistant backend from generic text/reasoning/tool parsers to the /// official Harmony token parser. #[derive(Debug)] pub struct HarmonyChatOutputProcessor { @@ -80,10 +80,12 @@ impl HarmonyChatOutputProcessor { } } -/// Validate that the generic parser selections are compatible with native Harmony output parsing. +/// Validate that the generic parser selections are compatible with native +/// Harmony output parsing. /// -/// `gpt_oss` uses a model-specific token-level parser, so any generic reasoning/tool parser -/// override is rejected instead of being silently ignored. +/// `gpt_oss` uses a model-specific token-level parser, so any generic +/// reasoning/tool parser override is rejected instead of being silently +/// ignored. pub(crate) fn validate_harmony_parser_overrides( tool_call_parser: &ParserSelection, reasoning_parser: &ParserSelection, @@ -132,9 +134,7 @@ impl HarmonyState { for &token_id in token_ids { let completed_before = self.parser.messages().len(); - self.parser - .process(token_id) - .map_err(harmony_output_parsing_error)?; + self.parser.process(token_id).map_err(harmony_output_parsing_error)?; let completed_after = self.parser.messages().len(); if let Some(delta) = self @@ -186,7 +186,8 @@ impl HarmonyState { Ok(events) } - /// Flush Harmony parser state at EOS and emit any newly finalized assistant events. + /// Flush Harmony parser state at EOS and emit any newly finalized assistant + /// events. fn process_eos(&mut self) -> Result> { let completed_before = self.parser.messages().len(); let pending_key = HarmonyGroupKey { @@ -194,14 +195,10 @@ impl HarmonyState { channel: self.parser.current_channel(), recipient: self.parser.current_recipient(), }; - let pending_content = self - .parser - .current_content() - .map_err(harmony_output_parsing_error)?; + let pending_content = + self.parser.current_content().map_err(harmony_output_parsing_error)?; - self.parser - .process_eos() - .map_err(harmony_output_parsing_error)?; + self.parser.process_eos().map_err(harmony_output_parsing_error)?; let completed_after = self.parser.messages().len(); let mut events = Vec::new(); @@ -212,10 +209,7 @@ impl HarmonyState { let final_message = &self.parser.messages()[completed_before]; let final_text = harmony_message_text(final_message); - let tail = final_text - .strip_prefix(&pending_content) - .unwrap_or(final_text) - .to_string(); + let tail = final_text.strip_prefix(&pending_content).unwrap_or(final_text).to_string(); if !tail.is_empty() { self.emit_group( HarmonyGroup { @@ -241,7 +235,8 @@ impl HarmonyState { Ok(events) } - /// Flush one coalesced Harmony content group into internal assistant events. + /// Flush one coalesced Harmony content group into internal assistant + /// events. fn emit_group(&mut self, group: HarmonyGroup, events: &mut Vec) { let channel = group.key.channel.as_deref(); let recipient = group.key.recipient.as_deref(); @@ -282,9 +277,7 @@ impl HarmonyState { return; } - let recipient = recipient - .expect("tool groups always have recipient") - .to_string(); + let recipient = recipient.expect("tool groups always have recipient").to_string(); let opens_same_call = match self.open_tool_call.as_ref() { Some(open_call) => open_call.recipient == recipient, None => false, @@ -323,7 +316,8 @@ impl HarmonyState { } } -/// Convert decoded token updates into internal assistant events with Harmony parsing. +/// Convert decoded token updates into internal assistant events with Harmony +/// parsing. #[try_stream] async fn harmony_assistant_event_stream( decoded: DynDecodedTextEventStream, @@ -392,11 +386,9 @@ fn harmony_encoding() -> Result<&'static HarmonyEncoding> { .context("failed to load harmony encoding for gpt-oss") }); - ENCODING - .as_ref() - .map_err(|error| Error::HarmonyOutputParsing { - error: error.to_report_string().into(), - }) + ENCODING.as_ref().map_err(|error| Error::HarmonyOutputParsing { + error: error.to_report_string().into(), + }) } fn harmony_output_parsing_error( @@ -415,7 +407,8 @@ fn harmony_message_text(message: &HarmonyMessage) -> &str { &text.text } -/// Map one Harmony `(channel, recipient)` pair to a visible assistant block kind. +/// Map one Harmony `(channel, recipient)` pair to a visible assistant block +/// kind. fn text_block_kind(channel: Option<&str>, recipient: Option<&str>) -> Option { match (channel, recipient) { (Some("final"), _) => Some(AssistantBlockKind::Text), diff --git a/src/chat/src/output/mod.rs b/src/chat/src/output/mod.rs index 50a53881..93387c5a 100644 --- a/src/chat/src/output/mod.rs +++ b/src/chat/src/output/mod.rs @@ -21,8 +21,10 @@ pub(crate) use harmony::validate_harmony_parser_overrides; /// Internal assistant event before final assembly. /// -/// - [`ContentEvent`]: subenum after reasoning parsing, carries only text content. -/// - [`AssistantEvent`]: full event after tool parsing, adds tool-call variants. +/// - [`ContentEvent`]: subenum after reasoning parsing, carries only text +/// content. +/// - [`AssistantEvent`]: full event after tool parsing, adds tool-call +/// variants. #[subenum(ContentEvent)] #[derive(Debug, Clone, PartialEq)] pub(crate) enum AssistantEvent { @@ -44,7 +46,8 @@ pub(crate) enum AssistantEvent { }, /// The start of a new tool call, with its declared name and generated ID. ToolCallStart { id: String, name: String }, - /// A delta for the arguments of the currently open tool call. Must follow a `ToolCallStart`. + /// A delta for the arguments of the currently open tool call. Must follow a + /// `ToolCallStart`. ToolCallArgumentsDelta { delta: String }, #[subenum(ContentEvent)] Done { @@ -57,8 +60,8 @@ pub(crate) enum AssistantEvent { } impl ContentEvent { - /// Convert a [`DecodedTextEvent`] into one or more [`ContentEvent`] values by treating all text - /// as plain (non-reasoning) content. + /// Convert a [`DecodedTextEvent`] into one or more [`ContentEvent`] values + /// by treating all text as plain (non-reasoning) content. fn from_decoded_plain_text(event: DecodedTextEvent) -> Vec { match event { DecodedTextEvent::Start { @@ -106,7 +109,8 @@ pub type DynDecodedTextEventStream = Pin> + Send>>; -/// Request-scoped output processor from decoded text events into structured chat events. +/// Request-scoped output processor from decoded text events into structured +/// chat events. pub trait ChatOutputProcessor: Send { /// Consume decoded text stream and return the structured chat-event stream. fn process(self: Box, decoded: DynDecodedTextEventStream) -> Result; @@ -124,8 +128,10 @@ trait_set! { pub(crate) trait ChatEventStream = Stream> + Send + 'static; } -/// Generate the northbound tool-call ID using the OpenAI-style `call_` format. -// TODO: support other ID scheme like Kimi-K2's `functions.{name}:{global_index}`. +/// Generate the northbound tool-call ID using the OpenAI-style `call_` +/// format. +// TODO: support other ID scheme like Kimi-K2's +// `functions.{name}:{global_index}`. pub(crate) fn generate_tool_call_id() -> String { format!("call_{}", &Uuid::new_v4().simple().to_string()[..24]) } diff --git a/src/chat/src/output/structured.rs b/src/chat/src/output/structured.rs index 5559c8e1..ed6e3a51 100644 --- a/src/chat/src/output/structured.rs +++ b/src/chat/src/output/structured.rs @@ -16,7 +16,8 @@ use crate::event::{ }; use crate::{FinishReason, Result}; -/// One currently open assistant text-like block being assembled from streamed deltas. +/// One currently open assistant text-like block being assembled from streamed +/// deltas. struct OpenTextBlock { /// Stable position of this block in the final assistant message. index: usize, @@ -144,8 +145,8 @@ impl StructuredEventState { Ok(events) } - /// Append one semantic text delta to the current block, or open a new block when the semantic - /// kind changes. + /// Append one semantic text delta to the current block, or open a new block + /// when the semantic kind changes. fn push_text_delta( &mut self, kind: AssistantBlockKind, @@ -217,8 +218,7 @@ impl StructuredEventState { name: open_tool_call.name, arguments: open_tool_call.arguments, }; - self.message - .push_block(AssistantContentBlock::ToolCall(call.clone())); + self.message.push_block(AssistantContentBlock::ToolCall(call.clone())); events.push(ChatEvent::ToolCallEnd { index: open_tool_call.index, call, @@ -226,7 +226,8 @@ impl StructuredEventState { } } -/// Wrap one parsed assistant stream into the public structured chat event stream. +/// Wrap one parsed assistant stream into the public structured chat event +/// stream. #[try_stream] pub(crate) async fn structured_chat_event_stream( stream: impl AssistantEventStream, diff --git a/src/chat/src/parser/mod.rs b/src/chat/src/parser/mod.rs index 078ce372..52e83b3e 100644 --- a/src/chat/src/parser/mod.rs +++ b/src/chat/src/parser/mod.rs @@ -74,8 +74,7 @@ impl ParserFactory { /// Add a case-insensitive substring match from model ID to parser name. pub fn register_pattern(&mut self, pattern: &str, parser_name: &str) -> &mut Self { - self.patterns - .push((pattern.to_lowercase(), parser_name.to_string())); + self.patterns.push((pattern.to_lowercase(), parser_name.to_string())); self } @@ -100,7 +99,8 @@ impl ParserFactory { names } - /// Get the constructor for a parser by its exact registered name, or return None if not found. + /// Get the constructor for a parser by its exact registered name, or return + /// None if not found. pub fn creator(&self, name: &str) -> Option<&C> { self.creators.get(name) } diff --git a/src/chat/src/parser/reasoning/cohere_cmd.rs b/src/chat/src/parser/reasoning/cohere_cmd.rs index d949ff49..16f4d21e 100644 --- a/src/chat/src/parser/reasoning/cohere_cmd.rs +++ b/src/chat/src/parser/reasoning/cohere_cmd.rs @@ -8,7 +8,8 @@ pub struct CohereCmdReasoningParser { } impl CohereCmdReasoningParser { - /// Create a Cohere Command parser backed by the shared delimited state machine. + /// Create a Cohere Command parser backed by the shared delimited state + /// machine. pub fn new(tokenizer: DynTokenizer) -> Result { Ok(Self { inner: DelimitedReasoningParser::new( diff --git a/src/chat/src/parser/reasoning/deepseek_r1.rs b/src/chat/src/parser/reasoning/deepseek_r1.rs index 343651d0..ce5aeb7a 100644 --- a/src/chat/src/parser/reasoning/deepseek_r1.rs +++ b/src/chat/src/parser/reasoning/deepseek_r1.rs @@ -12,7 +12,8 @@ pub struct DeepSeekR1ReasoningParser { } impl DeepSeekR1ReasoningParser { - /// Create a DeepSeek R1 parser backed by the shared delimited state machine. + /// Create a DeepSeek R1 parser backed by the shared delimited state + /// machine. pub fn new(tokenizer: DynTokenizer) -> Result { Ok(Self { inner: DelimitedReasoningParser::new(tokenizer, "", "", true)?, diff --git a/src/chat/src/parser/reasoning/delimited.rs b/src/chat/src/parser/reasoning/delimited.rs index bce8baf5..ad6e71df 100644 --- a/src/chat/src/parser/reasoning/delimited.rs +++ b/src/chat/src/parser/reasoning/delimited.rs @@ -37,17 +37,13 @@ impl DelimitedReasoningParser { default_in_reasoning: bool, ) -> Result { let start_token_id = - tokenizer - .token_to_id(start_token) - .ok_or_else(|| ReasoningError::MissingToken { - token: start_token.to_string(), - })?; + tokenizer.token_to_id(start_token).ok_or_else(|| ReasoningError::MissingToken { + token: start_token.to_string(), + })?; let end_token_id = - tokenizer - .token_to_id(end_token) - .ok_or_else(|| ReasoningError::MissingToken { - token: end_token.to_string(), - })?; + tokenizer.token_to_id(end_token).ok_or_else(|| ReasoningError::MissingToken { + token: end_token.to_string(), + })?; Ok(Self { tokenizer, @@ -117,7 +113,8 @@ impl DelimitedReasoningParser { delta } - /// Return the longest trailing suffix that could still complete a delimiter. + /// Return the longest trailing suffix that could still complete a + /// delimiter. fn partial_suffix_len(&self, text: &str) -> usize { let mut best = 0; for idx in text.char_indices().map(|(idx, _)| idx).skip(1) { diff --git a/src/chat/src/parser/reasoning/gemma4.rs b/src/chat/src/parser/reasoning/gemma4.rs index 9271ccc4..ca60645d 100644 --- a/src/chat/src/parser/reasoning/gemma4.rs +++ b/src/chat/src/parser/reasoning/gemma4.rs @@ -76,9 +76,8 @@ impl Gemma4ReasoningParser { /// Apply Gemma4-specific reasoning post-processing to one parsed delta. fn post_process(&mut self, mut result: ReasoningDelta) -> ReasoningDelta { if let Some(reasoning) = result.reasoning.take() { - result.reasoning = self - .strip_thought_prefix(&reasoning) - .filter(|text| !text.is_empty()); + result.reasoning = + self.strip_thought_prefix(&reasoning).filter(|text| !text.is_empty()); } result } @@ -93,8 +92,8 @@ impl ReasoningParser for Gemma4ReasoningParser { } fn adjust_request(&self, request: &mut ChatRequest) -> Result<()> { - // Gemma4's reasoning delimiters are marked as special tokens, so we need to ensure they are - // not stripped during decoding. + // Gemma4's reasoning delimiters are marked as special tokens, so we need to + // ensure they are not stripped during decoding. request.decode_options.skip_special_tokens = false; Ok(()) } diff --git a/src/chat/src/parser/reasoning/mod.rs b/src/chat/src/parser/reasoning/mod.rs index b84334a1..e61697a4 100644 --- a/src/chat/src/parser/reasoning/mod.rs +++ b/src/chat/src/parser/reasoning/mod.rs @@ -58,8 +58,8 @@ pub type DeepSeekV4ReasoningParser = Qwen3ReasoningParser; /// GLM45 currently shares the standard `...` parser. pub type Glm45ReasoningParser = Qwen3ReasoningParser; /// Kimi K2 currently shares the standard `...` parser. -// TODO: kimi k2 may implicitly end reasoning by starting a tool call section using -// <|tool_calls_section_begin|>, we should support that. +// TODO: kimi k2 may implicitly end reasoning by starting a tool call section +// using <|tool_calls_section_begin|>, we should support that. pub type KimiK2ReasoningParser = Qwen3ReasoningParser; /// MiniMax M2 currently shares the standard `...` parser. pub type MiniMaxM2ReasoningParser = Qwen3ReasoningParser; @@ -107,14 +107,16 @@ impl ReasoningDelta { } } -/// Incremental parser that splits decoded text deltas into reasoning and content. +/// Incremental parser that splits decoded text deltas into reasoning and +/// content. pub trait ReasoningParser: Send { /// Construct a boxed parser instance for one request stream. fn create(tokenizer: DynTokenizer) -> Result> where Self: Sized + 'static; - /// Initialize parser state from prompt token IDs before output deltas arrive. + /// Initialize parser state from prompt token IDs before output deltas + /// arrive. fn initialize(&mut self, _prompt_token_ids: &[u32]) -> Result<()> { Ok(()) } @@ -150,14 +152,16 @@ type ReasoningParserCreator = fn(DynTokenizer) -> Result; impl ReasoningParserFactory { - /// Get the global reasoning parser factory with built-in registrations and model mappings. + /// Get the global reasoning parser factory with built-in registrations and + /// model mappings. pub fn global() -> &'static Self { static INSTANCE: LazyLock = LazyLock::new(ReasoningParserFactory::new); &INSTANCE } - /// Create the default registry with built-in parser names and model mappings. + /// Create the default registry with built-in parser names and model + /// mappings. pub fn new() -> Self { let mut factory = Self::default(); @@ -214,13 +218,11 @@ impl ReasoningParserFactory { name: &str, tokenizer: DynTokenizer, ) -> crate::Result> { - let creator = self - .creator(name) - .ok_or_else(|| crate::Error::ParserUnavailableByName { - kind: "reasoning", - name: name.to_string(), - available_names: self.list(), - })?; + let creator = self.creator(name).ok_or_else(|| crate::Error::ParserUnavailableByName { + kind: "reasoning", + name: name.to_string(), + available_names: self.list(), + })?; creator(tokenizer).map_err(|error| crate::Error::ParserInitialization { kind: "reasoning", diff --git a/src/chat/src/parser/tool/deepseek_v32.rs b/src/chat/src/parser/tool/deepseek_v32.rs index b13b962a..87fbc55a 100644 --- a/src/chat/src/parser/tool/deepseek_v32.rs +++ b/src/chat/src/parser/tool/deepseek_v32.rs @@ -76,9 +76,9 @@ struct DsmlParameter { /// /// Arguments are emitted only after a full `invoke` block is parsed. /// -/// DeepSeek V3.2 relies on DSML markers such as `|DSML|`, which are represented -/// as special tokens in the tokenizer and therefore must be preserved during -/// decode for parsing to work. +/// DeepSeek V3.2 relies on DSML markers such as `|DSML|`, which are +/// represented as special tokens in the tokenizer and therefore must be +/// preserved during decode for parsing to work. pub struct DeepSeekV32ToolParser { buffer: String, mode: DsmlMode, @@ -93,8 +93,8 @@ impl DeepSeekV32ToolParser { Self::with_tokens(tools, DsmlTokens::V32) } - /// Create a parser with custom DSML tokens, for reuse by DeepSeek V4 which has different - /// markers but mostly shared logic. + /// Create a parser with custom DSML tokens, for reuse by DeepSeek V4 which + /// has different markers but mostly shared logic. pub(super) fn with_tokens(tools: &[ChatTool], tokens: DsmlTokens) -> Self { Self { buffer: String::new(), @@ -239,9 +239,7 @@ fn tool_calls_start_event(input: &mut DsmlInput<'_>, tokens: DsmlTokens) -> Moda /// Parse a DSML function-calls end marker. fn tool_calls_end_event(input: &mut DsmlInput<'_>, tokens: DsmlTokens) -> ModalResult { - literal(tokens.tool_calls_end) - .value(DsmlEvent::ToolCallsEnd) - .parse_next(input) + literal(tokens.tool_calls_end).value(DsmlEvent::ToolCallsEnd).parse_next(input) } /// Parse a trailing rest after DSML function calls. diff --git a/src/chat/src/parser/tool/deepseek_v4.rs b/src/chat/src/parser/tool/deepseek_v4.rs index 8869e508..6a1f8ded 100644 --- a/src/chat/src/parser/tool/deepseek_v4.rs +++ b/src/chat/src/parser/tool/deepseek_v4.rs @@ -21,8 +21,8 @@ use crate::request::ChatTool; /// /// Arguments are emitted only after a full `invoke` block is parsed. /// -/// V4 reuses the V3.2 DSML invoke/parameter grammar but wraps calls in `<|DSML|tool_calls>` -/// instead of `<|DSML|function_calls>`. +/// V4 reuses the V3.2 DSML invoke/parameter grammar but wraps calls in +/// `<|DSML|tool_calls>` instead of `<|DSML|function_calls>`. pub struct DeepSeekV4ToolParser(DeepSeekV32ToolParser); impl DsmlTokens { diff --git a/src/chat/src/parser/tool/external.rs b/src/chat/src/parser/tool/external.rs index 3fe0cb3d..44def981 100644 --- a/src/chat/src/parser/tool/external.rs +++ b/src/chat/src/parser/tool/external.rs @@ -7,7 +7,8 @@ use super::{Result, ToolCallDelta, ToolParseResult}; use crate::ToolParser; use crate::request::ChatTool; -/// Adaptor that exposes the external `tool-parser` through the local [`ToolParser`] interface. +/// Adaptor that exposes the external `tool-parser` through the local +/// [`ToolParser`] interface. pub(crate) struct ExternalToolParserAdaptor

{ pub(crate) inner: P, tools: Vec, @@ -26,23 +27,23 @@ where { /// Delagating to the external `parse_complete()`. /// - /// We don't rely on the default `push()+finish()` lifecycle, because some external parsers may - /// not correctly handle the full text passed to incremental `push()` interface. - // TODO: instead of working around like this, we should make incremental `push()` robust enough - // to handle decoded text in arbitrary chunk sizes, as optimizations like speculative decoding - // or batching may still make the chunk "too long" to be correctly parsed in one `push()` call. + /// We don't rely on the default `push()+finish()` lifecycle, because some + /// external parsers may not correctly handle the full text passed to + /// incremental `push()` interface. + // TODO: instead of working around like this, we should make incremental + // `push()` robust enough to handle decoded text in arbitrary chunk sizes, + // as optimizations like speculative decoding or batching may still make the + // chunk "too long" to be correctly parsed in one `push()` call. fn parse_complete(&mut self, output: &str) -> Result { let (normal_text, calls) = poll_external(self.inner.parse_complete(output))?; // The external `parse_complete()` path does not receive tools and may therefore - // return calls with invalid names. Filter them here against the request-scoped tool - // set captured at parser creation time. + // return calls with invalid names. Filter them here against the request-scoped + // tool set captured at parser creation time. let calls = calls .into_iter() .filter(|tool_call| { - self.tools - .iter() - .any(|tool| tool.function.name == tool_call.function.name) + self.tools.iter().any(|tool| tool.function.name == tool_call.function.name) }) .enumerate() .map(|(tool_index, tool_call)| ToolCallDelta { @@ -77,10 +78,11 @@ where /// Bridge the external async trait into our synchronous local trait. /// -/// This is intentionally a temporary compatibility layer: the current external parser -/// implementations are CPU-only and their async fns do not actually suspend. As long as that -/// dependency behavior stays unchanged, `now_or_never()` is a robust adaptation strategy and we -/// don't have to spawn a thread to `block_on()` the future. +/// This is intentionally a temporary compatibility layer: the current external +/// parser implementations are CPU-only and their async fns do not actually +/// suspend. As long as that dependency behavior stays unchanged, +/// `now_or_never()` is a robust adaptation strategy and we don't have to spawn +/// a thread to `block_on()` the future. fn poll_external( future: impl Future>, ) -> Result { @@ -105,11 +107,7 @@ fn convert_tool_call_item(item: tool_parser::types::ToolCallItem) -> ToolCallDel fn convert_parse_result(result: tool_parser::types::StreamingParseResult) -> ToolParseResult { ToolParseResult { normal_text: result.normal_text, - calls: result - .calls - .into_iter() - .map(convert_tool_call_item) - .collect(), + calls: result.calls.into_iter().map(convert_tool_call_item).collect(), } } diff --git a/src/chat/src/parser/tool/gemma4.rs b/src/chat/src/parser/tool/gemma4.rs index 0f19e090..3f45a6a6 100644 --- a/src/chat/src/parser/tool/gemma4.rs +++ b/src/chat/src/parser/tool/gemma4.rs @@ -121,7 +121,8 @@ impl Gemma4ToolParser { /// This is the core of the accumulate-then-parse-then-diff strategy: /// 1. Parse `raw_args` with `parse_gemma4_args()` /// 2. Convert to JSON string - /// 3. Withhold trailing closing characters (`"}`) that may move as more tokens arrive + /// 3. Withhold trailing closing characters (`"}`) that may move as more + /// tokens arrive /// 4. Diff against previously streamed JSON and emit only new chars /// /// Why withholding is necessary: @@ -300,10 +301,7 @@ fn scan_tool_tail(input: &str) -> ToolTailState { continue; } - let next = input[i..] - .chars() - .next() - .expect("scan index must stay in bounds"); + let next = input[i..].chars().next().expect("scan index must stay in bounds"); match next { '{' => object_depth += 1, '[' => array_depth += 1, @@ -538,10 +536,7 @@ fn parse_gemma4_array(array: &str, partial: bool) -> Result> { i = skip_over_string_delim(array, i).unwrap_or(array.len()); continue; } - let next = array[i..] - .chars() - .next() - .expect("index must stay in bounds"); + let next = array[i..].chars().next().expect("index must stay in bounds"); match next { '{' => depth += 1, '}' => depth -= 1, @@ -565,10 +560,7 @@ fn parse_gemma4_array(array: &str, partial: bool) -> Result> { i = skip_over_string_delim(array, i).unwrap_or(array.len()); continue; } - let next = array[i..] - .chars() - .next() - .expect("index must stay in bounds"); + let next = array[i..].chars().next().expect("index must stay in bounds"); match next { '[' => depth += 1, ']' => depth -= 1, @@ -586,10 +578,7 @@ fn parse_gemma4_array(array: &str, partial: bool) -> Result> { // Bare value let value_start = i; while i < array.len() { - let next = array[i..] - .chars() - .next() - .expect("index must stay in bounds"); + let next = array[i..].chars().next().expect("index must stay in bounds"); if matches!(next, ',' | ']') { break; } @@ -616,10 +605,7 @@ fn skip_over_string_delim(input: &str, start: usize) -> Option { fn skip_separators(input: &str, index: &mut usize) { while *index < input.len() { - let next = input[*index..] - .chars() - .next() - .expect("index must stay in bounds"); + let next = input[*index..].chars().next().expect("index must stay in bounds"); if !matches!(next, ' ' | ',' | '\n' | '\t') { break; } @@ -629,10 +615,7 @@ fn skip_separators(input: &str, index: &mut usize) { fn skip_value_whitespace(input: &str, index: &mut usize) { while *index < input.len() { - let next = input[*index..] - .chars() - .next() - .expect("index must stay in bounds"); + let next = input[*index..].chars().next().expect("index must stay in bounds"); if !matches!(next, ' ' | '\n' | '\t') { break; } diff --git a/src/chat/src/parser/tool/minimax_m2.rs b/src/chat/src/parser/tool/minimax_m2.rs index 2a6f4663..e89879d9 100644 --- a/src/chat/src/parser/tool/minimax_m2.rs +++ b/src/chat/src/parser/tool/minimax_m2.rs @@ -76,9 +76,7 @@ impl MinimaxM2ToolParser { } MinimaxM2Event::ToolBlockStart => self.mode = MinimaxM2Mode::ToolBlock, MinimaxM2Event::Invoke { name, raw_params } => { - let arguments = self - .tool_parameters - .convert_params_with_schema(&name, raw_params); + let arguments = self.tool_parameters.convert_params_with_schema(&name, raw_params); let arguments = serde_json::to_string(&arguments) .map_err(|error| parsing_failed!("failed to serialize arguments: {}", error))?; @@ -163,9 +161,7 @@ fn parse_text_event(input: &mut MinimaxM2Input<'_>) -> ModalResult) -> ModalResult { - literal(TOOL_CALL_START) - .value(MinimaxM2Event::ToolBlockStart) - .parse_next(input) + literal(TOOL_CALL_START).value(MinimaxM2Event::ToolBlockStart).parse_next(input) } /// Parse a safe text run before the next MiniMax M2 marker. @@ -466,9 +462,7 @@ mod tests { #[test] fn minimax_m2_streaming_does_not_emit_incomplete_tool_call() { let mut parser = MinimaxM2ToolParser::new(&test_tools()); - let result = parser - .push(r#""#) - .unwrap(); + let result = parser.push(r#""#).unwrap(); assert!(result.normal_text.is_empty()); assert!(result.calls.is_empty()); @@ -477,9 +471,7 @@ mod tests { #[test] fn minimax_m2_finish_fails_incomplete_tool_call() { let mut parser = MinimaxM2ToolParser::new(&test_tools()); - parser - .push(r#""#) - .unwrap(); + parser.push(r#""#).unwrap(); assert!(parser.finish().is_err()); } @@ -495,9 +487,7 @@ mod tests { #[test] fn minimax_m2_malformed_tool_call_fails_fast() { let mut parser = MinimaxM2ToolParser::new(&test_tools()); - let error = parser - .push("") - .unwrap_err(); + let error = parser.push("").unwrap_err(); expect!["tool parser parsing failed: "].assert_eq(&error.to_report_string()); } diff --git a/src/chat/src/parser/tool/mod.rs b/src/chat/src/parser/tool/mod.rs index 88838def..c7f7fae6 100644 --- a/src/chat/src/parser/tool/mod.rs +++ b/src/chat/src/parser/tool/mod.rs @@ -84,8 +84,9 @@ pub struct ToolParseResult { impl ToolParseResult { /// Append another parser result onto this one. /// - /// Note that this does not attempt to merge multiple deltas for the same tool call into one - /// complete item. Call `coalesce_calls()` after if that behavior is desired. + /// Note that this does not attempt to merge multiple deltas for the same + /// tool call into one complete item. Call `coalesce_calls()` after if + /// that behavior is desired. pub(crate) fn append(&mut self, mut other: Self) { self.normal_text.push_str(&other.normal_text); self.calls.append(&mut other.calls); @@ -93,9 +94,10 @@ impl ToolParseResult { /// Merge multiple deltas for the same tool call into one complete item. /// - /// This is primarily used by the default `parse_complete()` implementation, which delegates - /// through the incremental parser lifecycle and then needs to collapse streaming-style argument - /// fragments into one final tool call. + /// This is primarily used by the default `parse_complete()` implementation, + /// which delegates through the incremental parser lifecycle and then + /// needs to collapse streaming-style argument fragments into one final + /// tool call. pub(crate) fn coalesce_calls(mut self) -> Self { let mut merged = BTreeMap::::new(); let mut order = Vec::new(); @@ -116,10 +118,8 @@ impl ToolParseResult { } } - self.calls = order - .into_iter() - .filter_map(|tool_index| merged.remove(&tool_index)) - .collect(); + self.calls = + order.into_iter().filter_map(|tool_index| merged.remove(&tool_index)).collect(); self } } @@ -176,13 +176,15 @@ type ToolParserCreator = fn(&[ChatTool]) -> Result>; pub type ToolParserFactory = ParserFactory; impl ToolParserFactory { - /// Get the global tool parser factory with built-in registrations and model mappings. + /// Get the global tool parser factory with built-in registrations and model + /// mappings. pub fn global() -> &'static Self { static INSTANCE: LazyLock = LazyLock::new(ToolParserFactory::new); &INSTANCE } - /// Create the default registry with built-in parser names and model mappings. + /// Create the default registry with built-in parser names and model + /// mappings. pub fn new() -> Self { let mut factory = Self::default(); @@ -254,13 +256,11 @@ impl ToolParserFactory { /// Construct a parser from an exact name. pub fn create(&self, name: &str, tools: &[ChatTool]) -> crate::Result> { - let creator = self - .creator(name) - .ok_or_else(|| crate::Error::ParserUnavailableByName { - kind: "tool", - name: name.to_string(), - available_names: self.list(), - })?; + let creator = self.creator(name).ok_or_else(|| crate::Error::ParserUnavailableByName { + kind: "tool", + name: name.to_string(), + available_names: self.list(), + })?; creator(tools).map_err(|error| crate::Error::ParserInitialization { kind: "tool", diff --git a/src/chat/src/parser/tool/parameters.rs b/src/chat/src/parser/tool/parameters.rs index 3be38453..2d3ba0fd 100644 --- a/src/chat/src/parser/tool/parameters.rs +++ b/src/chat/src/parser/tool/parameters.rs @@ -12,9 +12,10 @@ pub(super) struct ToolSchemas { /// Normalized parameter schema for one tool. /// -/// This is a minimal subset of JSON Schema with some normalization heuristics to support common -/// schema patterns and upstream schema variations, focused on coercing raw string parameter values -/// into more specific JSON types for downstream tool call execution. +/// This is a minimal subset of JSON Schema with some normalization heuristics +/// to support common schema patterns and upstream schema variations, focused on +/// coercing raw string parameter values into more specific JSON types for +/// downstream tool call execution. #[derive(Debug, Clone, Default, PartialEq, Eq)] pub(super) struct ToolSchema { params: BTreeMap, @@ -46,7 +47,8 @@ impl ToolSchemas { /// Convert raw string parameter values for one named tool. /// - /// Unknown tool names use an empty schema, so all parameters fall back to strings. + /// Unknown tool names use an empty schema, so all parameters fall back to + /// strings. pub(super) fn convert_params_with_schema( &self, function_name: &str, @@ -74,8 +76,8 @@ impl ToolSchemas { } impl ToolSchema { - /// Return an empty schema with no parameter information, which causes all parameters to be - /// treated as strings. + /// Return an empty schema with no parameter information, which causes all + /// parameters to be treated as strings. const fn empty() -> &'static Self { static EMPTY: ToolSchema = ToolSchema { params: BTreeMap::new(), @@ -101,8 +103,9 @@ impl ToolSchema { /// Convert one raw parameter value using its normalized schema type. /// - /// If the parameter name is unknown, or we don't have a schema for it, or the value fails to - /// convert, this falls back to returning the raw string as a JSON string value. + /// If the parameter name is unknown, or we don't have a schema for it, or + /// the value fails to convert, this falls back to returning the raw + /// string as a JSON string value. fn convert(&self, name: &str, value: &str) -> Value { if value.eq_ignore_ascii_case("null") { return Value::Null; @@ -128,12 +131,7 @@ impl JsonParamType { if let Some(composite) = schema.get("anyOf").or_else(|| schema.get("oneOf")) { let param_type = composite .as_array() - .map(|schemas| { - schemas - .iter() - .filter_map(Self::from_schema) - .collect::>() - }) + .map(|schemas| schemas.iter().filter_map(Self::from_schema).collect::>()) .filter(|types| !types.is_empty()) .map(Self::one_of) .unwrap_or(Self::Object); @@ -213,18 +211,14 @@ impl JsonParamType { fn convert_value(param_type: &JsonParamType, value: &str) -> Option { match param_type { JsonParamType::String => Some(Value::String(value.to_string())), - JsonParamType::Integer => value - .parse::() - .ok() - .map(Number::from) - .map(Value::Number), + JsonParamType::Integer => value.parse::().ok().map(Number::from).map(Value::Number), JsonParamType::Number => convert_number(value), JsonParamType::Boolean => convert_boolean(value), JsonParamType::Object | JsonParamType::Array => serde_json::from_str(value).ok(), JsonParamType::Null => value.eq_ignore_ascii_case("null").then_some(Value::Null), - JsonParamType::OneOf(types) => types - .iter() - .find_map(|param_type| convert_value(param_type, value)), + JsonParamType::OneOf(types) => { + types.iter().find_map(|param_type| convert_value(param_type, value)) + } } } diff --git a/src/chat/src/parser/tool/qwen_coder.rs b/src/chat/src/parser/tool/qwen_coder.rs index 236914ae..d5d0e5a8 100644 --- a/src/chat/src/parser/tool/qwen_coder.rs +++ b/src/chat/src/parser/tool/qwen_coder.rs @@ -76,9 +76,7 @@ impl Qwen3CoderToolParser { QwenCoderEvent::ToolCallStart => self.mode = QwenCoderMode::ToolCall, QwenCoderEvent::ToolCall { name, raw_params } => { self.mode = QwenCoderMode::Text; - let arguments = self - .tool_parameters - .convert_params_with_schema(&name, raw_params); + let arguments = self.tool_parameters.convert_params_with_schema(&name, raw_params); let arguments = serde_json::to_string(&arguments) .map_err(|error| parsing_failed!("failed to serialize arguments: {}", error))?; @@ -157,9 +155,7 @@ fn parse_text_event(input: &mut QwenCoderInput<'_>) -> ModalResult) -> ModalResult { - literal(TOOL_CALL_START) - .value(QwenCoderEvent::ToolCallStart) - .parse_next(input) + literal(TOOL_CALL_START).value(QwenCoderEvent::ToolCallStart).parse_next(input) } /// Parse a safe text run before the next Qwen Coder marker. @@ -316,9 +312,7 @@ mod tests { #[test] fn qwen_coder_parse_complete_extracts_empty_arguments() { let mut parser = Qwen3CoderToolParser::new(&test_tools()); - let result = parser - .parse_complete(&build_tool_call("get_weather", &[])) - .unwrap(); + let result = parser.parse_complete(&build_tool_call("get_weather", &[])).unwrap(); assert_eq!(result.calls.len(), 1); assert_eq!(result.calls[0].name.as_deref(), Some("get_weather")); diff --git a/src/chat/src/parser/tool/streaming.rs b/src/chat/src/parser/tool/streaming.rs index 75e767eb..eba3eb50 100644 --- a/src/chat/src/parser/tool/streaming.rs +++ b/src/chat/src/parser/tool/streaming.rs @@ -22,7 +22,8 @@ struct ActiveToolState { } impl StreamingToolState { - /// Start tracking a new active tool call and return its stable stream index. + /// Start tracking a new active tool call and return its stable stream + /// index. pub(crate) fn begin_tool_call(&mut self) -> usize { let tool_index = self.next_tool_index; self.next_tool_index += 1; @@ -52,9 +53,7 @@ impl StreamingToolState { /// Return the streamed argument prefix for the active tool. pub(crate) fn active_streamed_arguments(&self) -> Option<&str> { - self.active_tool - .as_ref() - .map(|tool| tool.streamed_arguments.as_str()) + self.active_tool.as_ref().map(|tool| tool.streamed_arguments.as_str()) } /// Replace the streamed argument prefix tracked for the active tool. diff --git a/src/chat/src/parser/tool/test_utils.rs b/src/chat/src/parser/tool/test_utils.rs index 076765dd..1057dc80 100644 --- a/src/chat/src/parser/tool/test_utils.rs +++ b/src/chat/src/parser/tool/test_utils.rs @@ -91,7 +91,8 @@ pub fn collect_stream(parser: &mut T, chunks: &[&str]) - result.coalesce_calls() } -/// Split text into chunks containing at most `chunk_chars` Unicode scalar values. +/// Split text into chunks containing at most `chunk_chars` Unicode scalar +/// values. pub fn split_by_chars(text: &str, chunk_chars: usize) -> Vec<&str> { let mut chunks = Vec::new(); let mut start = 0; diff --git a/src/chat/src/parser/tool/utils.rs b/src/chat/src/parser/tool/utils.rs index 2db2cbdc..9d172084 100644 --- a/src/chat/src/parser/tool/utils.rs +++ b/src/chat/src/parser/tool/utils.rs @@ -52,9 +52,10 @@ pub(super) fn safe_text_len(input: &mut Partial<&str>, marker: &str) -> ModalRes /// Parse one event from a buffered streaming input. /// /// Returns: -/// - `Ok(Some((event, consumed_len)))` if an event was successfully parsed, along with the number -/// of bytes consumed from the buffer. -/// - `Ok(None)` if the buffer does not contain a full event yet, and more data is needed. +/// - `Ok(Some((event, consumed_len)))` if an event was successfully parsed, +/// along with the number of bytes consumed from the buffer. +/// - `Ok(None)` if the buffer does not contain a full event yet, and more data +/// is needed. /// - `Err` if a parsing error occurred. pub(super) fn parse_buffered_event( buffer: &str, diff --git a/src/chat/src/renderer/deepseek_v32/encoding.rs b/src/chat/src/renderer/deepseek_v32/encoding.rs index efea3430..28913170 100644 --- a/src/chat/src/renderer/deepseek_v32/encoding.rs +++ b/src/chat/src/renderer/deepseek_v32/encoding.rs @@ -289,10 +289,7 @@ fn render_tool_message( ))); } - if tool_results_by_id - .insert(tool_call_id.as_str(), content) - .is_some() - { + if tool_results_by_id.insert(tool_call_id.as_str(), content).is_some() { return Err(Error::ChatTemplate(format!( "invalid DeepSeek V3.2 tool message: duplicate tool_call_id `{tool_call_id}`" ))); diff --git a/src/chat/src/renderer/deepseek_v32/tests.rs b/src/chat/src/renderer/deepseek_v32/tests.rs index 14ddf110..72f2d14e 100644 --- a/src/chat/src/renderer/deepseek_v32/tests.rs +++ b/src/chat/src/renderer/deepseek_v32/tests.rs @@ -85,14 +85,12 @@ fn render_request(request: &ChatRequest) -> String { } fn render_result(request: &ChatRequest) -> Result { - DeepSeekV32ChatRenderer::new() - .render(request) - .map(|rendered| { - rendered - .prompt - .into_text() - .expect("deepseek renderer should return text prompt") - }) + DeepSeekV32ChatRenderer::new().render(request).map(|rendered| { + rendered + .prompt + .into_text() + .expect("deepseek renderer should return text prompt") + }) } fn thinking_request(messages: Vec) -> ChatRequest { diff --git a/src/chat/src/renderer/deepseek_v4/encoding.rs b/src/chat/src/renderer/deepseek_v4/encoding.rs index 3d92e46c..ca698076 100644 --- a/src/chat/src/renderer/deepseek_v4/encoding.rs +++ b/src/chat/src/renderer/deepseek_v4/encoding.rs @@ -47,9 +47,7 @@ pub(super) fn render_request(request: &ChatRequest) -> Result { let (thinking_mode, max_reasoning_effort) = resolve_thinking_options(request)?; let request_tools = request_tools(request); let synthetic_tool_system = needs_synthetic_tool_system(request, request_tools); - let drop_thinking = request - .parse_template_bool("drop_thinking")? - .unwrap_or(true) + let drop_thinking = request.parse_template_bool("drop_thinking")?.unwrap_or(true) && !rendered_tools_present(request, request_tools); let last_user_render_index = find_last_user_render_index(request.messages.as_slice(), synthetic_tool_system); @@ -119,9 +117,10 @@ pub(super) fn render_request(request: &ChatRequest) -> Result { Ok(out) } -/// Resolve DeepSeek V4's thinking controls. Unlike the Python tokenizer wrapper, -/// the Rust renderer only consumes the typed top-level `reasoning_effort`; the -/// generic template-kwargs map is left for HF templates. +/// Resolve DeepSeek V4's thinking controls. Unlike the Python tokenizer +/// wrapper, the Rust renderer only consumes the typed top-level +/// `reasoning_effort`; the generic template-kwargs map is left for HF +/// templates. fn resolve_thinking_options(request: &ChatRequest) -> Result<(ThinkingMode, bool)> { let mut thinking_mode = match request.enable_thinking()?.unwrap_or(false) { true => ThinkingMode::Thinking, @@ -207,7 +206,8 @@ fn is_user_like_entry(message: &ChatMessage) -> bool { ) } -/// Return whether the next rendered entry is assistant, or there is no next entry. +/// Return whether the next rendered entry is assistant, or there is no next +/// entry. fn next_rendered_entry_is_assistant_or_end(messages: &[ChatMessage], message_index: usize) -> bool { let mut next_index = message_index + 1; if matches!(messages[message_index], ChatMessage::ToolResponse { .. }) { @@ -374,10 +374,7 @@ fn sorted_tool_response_indices( let ChatMessage::ToolResponse { tool_call_id, .. } = &messages[*index] else { unreachable!("tool response block should only contain tool messages"); }; - tool_call_order - .get(tool_call_id.as_str()) - .copied() - .unwrap_or(0) + tool_call_order.get(tool_call_id.as_str()).copied().unwrap_or(0) }); indices } diff --git a/src/chat/src/renderer/hf/format.rs b/src/chat/src/renderer/hf/format.rs index 58382b91..a9b35d0f 100644 --- a/src/chat/src/renderer/hf/format.rs +++ b/src/chat/src/renderer/hf/format.rs @@ -84,10 +84,9 @@ fn is_attr_access(expr: &Expr, varname: &str, key: &str) -> bool { fn is_var_or_elems_access(expr: &Expr, varname: &str, key: Option<&str>) -> bool { match expr { - Expr::Filter(f) => f - .expr - .as_ref() - .is_some_and(|inner| is_var_or_elems_access(inner, varname, key)), + Expr::Filter(f) => { + f.expr.as_ref().is_some_and(|inner| is_var_or_elems_access(inner, varname, key)) + } Expr::Test(t) => is_var_or_elems_access(&t.expr, varname, key), Expr::Slice(s) => is_var_or_elems_access(&s.expr, varname, key), _ => key.map_or_else( @@ -240,7 +239,8 @@ fn has_content_item_loop(root: &Stmt<'_>) -> bool { }) } -/// Detect the content format expected by a Jinja2 chat template based on AST analysis. +/// Detect the content format expected by a Jinja2 chat template based on AST +/// analysis. pub fn detect_chat_template_content_format(template: &str) -> ChatTemplateContentFormat { let ast = match parse( template, @@ -287,11 +287,7 @@ mod tests { fn iter_vllm_example_template_paths() -> impl Iterator { let mut paths = fs::read_dir(vllm_examples_dir()) .expect("failed to read vLLM example template directory") - .map(|entry| { - entry - .expect("failed to read vLLM example template dir entry") - .path() - }) + .map(|entry| entry.expect("failed to read vLLM example template dir entry").path()) .filter(|path| path.extension().is_some_and(|ext| ext == "jinja")) .collect::>(); paths.sort(); diff --git a/src/chat/src/renderer/hf/mod.rs b/src/chat/src/renderer/hf/mod.rs index 2db489bd..33edaa23 100644 --- a/src/chat/src/renderer/hf/mod.rs +++ b/src/chat/src/renderer/hf/mod.rs @@ -30,7 +30,8 @@ pub use template::{load_chat_template, resolve_chat_template}; pub use self::format::ChatTemplateContentFormatOption; -/// Hugging Face chat-template renderer backed by the local Jinja chat-template state. +/// Hugging Face chat-template renderer backed by the local Jinja chat-template +/// state. pub struct HfChatRenderer { default_template: Option, default_template_kwargs: HashMap, @@ -76,8 +77,8 @@ impl HfChatRenderer { ); info!("using configured chat template override"); } else if let Some(chat_template_path) = files.chat_template_path.as_deref() { - // If independent chat template file(s) exist and contain non-empty content, they take - // priority over template entries in the tokenizer config + // If independent chat template file(s) exist and contain non-empty content, + // they take priority over template entries in the tokenizer config let file_template = load_chat_template(chat_template_path) .map_err(|error| Error::ChatTemplate(error.to_report_string()))?; @@ -103,11 +104,12 @@ impl HfChatRenderer { ) } - /// Apply the chat template to one chat request, rendering the prompt string to be tokenized - /// and submitted to the model. + /// Apply the chat template to one chat request, rendering the prompt string + /// to be tokenized and submitted to the model. /// - /// If the request carries a per-request `chat_template` override, a temporary template is - /// compiled from that string and used instead of the model's default. + /// If the request carries a per-request `chat_template` override, a + /// temporary template is compiled from that string and used instead of + /// the model's default. fn apply_chat_template(&self, request: &ChatRequest) -> Result { let override_template = request .chat_options @@ -133,9 +135,7 @@ impl HfChatRenderer { ) -> Result { let messages = to_template_messages(&request.messages, effective_template.content_format())?; - let tools = request - .tool_parsing_enabled() - .then(|| to_template_tools(&request.tools)); + let tools = request.tool_parsing_enabled().then(|| to_template_tools(&request.tools)); trace!( message_count = messages.len(), content_format = ?effective_template.content_format(), @@ -334,10 +334,7 @@ fn to_template_content( } fn to_template_tools(tools: &[ChatTool]) -> Vec { - tools - .iter() - .map(|tool| TemplateTool(tool.to_openai_tool())) - .collect() + tools.iter().map(|tool| TemplateTool(tool.to_openai_tool())).collect() } #[cfg(test)] diff --git a/src/chat/src/renderer/hf/template.rs b/src/chat/src/renderer/hf/template.rs index 2c468bae..b71efc1e 100644 --- a/src/chat/src/renderer/hf/template.rs +++ b/src/chat/src/renderer/hf/template.rs @@ -97,7 +97,8 @@ pub fn resolve_chat_template(chat_template: &str) -> Result { Err(TemplateError::MissingTemplatePath) } -/// One compiled chat template with its Jinja environment and detected content format. +/// One compiled chat template with its Jinja environment and detected content +/// format. pub(super) struct CompiledChatTemplate { /// Cached, fully-configured environment for one compiled template. env: Environment<'static>, @@ -119,7 +120,8 @@ impl CompiledChatTemplate { }) } - /// Apply the compiled template to the given context and return the rendered prompt. + /// Apply the compiled template to the given context and return the rendered + /// prompt. pub fn apply(&self, ctx: TemplateContext<'_>) -> Result { let tmpl = self.env.get_template("chat")?; tmpl.render(ctx).map_err(TemplateError::from) diff --git a/src/chat/src/renderer/hf/tojson.rs b/src/chat/src/renderer/hf/tojson.rs index bcf44e8d..1c5c20f4 100644 --- a/src/chat/src/renderer/hf/tojson.rs +++ b/src/chat/src/renderer/hf/tojson.rs @@ -18,9 +18,7 @@ pub(super) fn hf_tojson_filter( ) -> std::result::Result { let ensure_ascii = kwargs.get::>("ensure_ascii")?.unwrap_or(false); let indent = parse_indent( - kwargs - .get::>>("indent")? - .map(|value| value.0), + kwargs.get::>>("indent")?.map(|value| value.0), ); let separators = parse_separators( kwargs @@ -157,15 +155,13 @@ mod tests { fn render(template: &str, payload: serde_json::Value) -> String { let mut env = Environment::new(); env.add_filter("tojson", hf_tojson_filter); - env.render_str(template, json!({ "payload": payload })) - .unwrap() + env.render_str(template, json!({ "payload": payload })).unwrap() } fn render_error(template: &str, payload: serde_json::Value) -> minijinja::Error { let mut env = Environment::new(); env.add_filter("tojson", hf_tojson_filter); - env.render_str(template, json!({ "payload": payload })) - .unwrap_err() + env.render_str(template, json!({ "payload": payload })).unwrap_err() } #[test] diff --git a/src/chat/src/renderer/mod.rs b/src/chat/src/renderer/mod.rs index 4cd1cd42..07ff5d0b 100644 --- a/src/chat/src/renderer/mod.rs +++ b/src/chat/src/renderer/mod.rs @@ -22,7 +22,8 @@ pub struct RenderedPrompt { /// Minimal chat-prompt renderer used by `vllm-chat`. pub trait ChatRenderer: Send + Sync { - /// Render one chat request into the text prompt submitted to the text backend. + /// Render one chat request into the text prompt submitted to the text + /// backend. fn render(&self, request: &ChatRequest) -> Result; } diff --git a/src/chat/src/renderer/selection.rs b/src/chat/src/renderer/selection.rs index 1764aaf3..f4bd565b 100644 --- a/src/chat/src/renderer/selection.rs +++ b/src/chat/src/renderer/selection.rs @@ -23,7 +23,8 @@ impl RendererSelection { pub const DEEPSEEK_V4_LITERAL: &str = "deepseek_v4"; pub const HF_LITERAL: &str = "hf"; - /// Resolve the renderer selection using the given model type string, if it's `Auto`. + /// Resolve the renderer selection using the given model type string, if + /// it's `Auto`. pub fn resolve(self, model_type: &str) -> Self { match self { Self::Auto => match model_type { diff --git a/src/chat/src/request.rs b/src/chat/src/request.rs index c3dea401..284a6819 100644 --- a/src/chat/src/request.rs +++ b/src/chat/src/request.rs @@ -55,7 +55,8 @@ pub enum ChatContent { } impl ChatContent { - /// Flatten the text content into one plain string without adding separators. + /// Flatten the text content into one plain string without adding + /// separators. // TODO: this method will be truly fallible once we add non-text content parts. pub fn try_flatten_to_text(&self) -> Result { Ok(match self { @@ -64,7 +65,8 @@ impl ChatContent { }) } - /// Return whether flattening this chat content would produce an empty string. + /// Return whether flattening this chat content would produce an empty + /// string. pub fn is_empty(&self) -> bool { match self { Self::Text(text) => text.is_empty(), @@ -164,7 +166,8 @@ impl ChatMessage { } } - /// Construct one chat message with assistant role and structured content blocks. + /// Construct one chat message with assistant role and structured content + /// blocks. pub fn assistant_blocks(content: Vec) -> Self { Self::Assistant { content } } @@ -225,16 +228,20 @@ impl From for ChatMessage { pub enum GenerationPromptMode { /// Append a generation prompt for a new assistant turn. /// - /// Equivalent to `add_generation_prompt = true` and `continue_final_message = false`. + /// Equivalent to `add_generation_prompt = true` and `continue_final_message + /// = false`. #[default] StartNewAssistant, /// Leave the final assistant message open so generation continues it. /// - /// Equivalent to `add_generation_prompt = false` and `continue_final_message = true`. + /// Equivalent to `add_generation_prompt = false` and + /// `continue_final_message = true`. ContinueFinalAssistant, - /// Render the existing chat history without adding any trailing generation prompt. + /// Render the existing chat history without adding any trailing generation + /// prompt. /// - /// Equivalent to `add_generation_prompt = false` and `continue_final_message = false`. + /// Equivalent to `add_generation_prompt = false` and + /// `continue_final_message = false`. NoGenerationPrompt, } @@ -267,16 +274,17 @@ impl ReasoningEffort { /// Chat-template-related request options. /// -/// These are the small subset of chat controls that currently affect prompt rendering in -/// `vllm-chat`. +/// These are the small subset of chat controls that currently affect prompt +/// rendering in `vllm-chat`. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct ChatOptions { - /// Controls whether rendering starts a new assistant turn, continues the final assistant - /// message, or emits no trailing generation prompt at all. + /// Controls whether rendering starts a new assistant turn, continues the + /// final assistant message, or emits no trailing generation prompt at + /// all. pub generation_prompt_mode: GenerationPromptMode, - /// Per-request Jinja chat template override. When set, this template is used instead of the - /// model's default chat template. + /// Per-request Jinja chat template override. When set, this template is + /// used instead of the model's default chat template. pub chat_template: Option, /// Effort level exposed to chat templates for reasoning models. @@ -298,7 +306,8 @@ impl Default for ChatOptions { } impl ChatOptions { - /// Whether to add a generation prompt for a new assistant turn after the existing chat history. + /// Whether to add a generation prompt for a new assistant turn after the + /// existing chat history. pub fn add_generation_prompt(&self) -> bool { matches!( self.generation_prompt_mode, @@ -306,7 +315,8 @@ impl ChatOptions { ) } - /// Whether to leave the final assistant message open so generation continues it. + /// Whether to leave the final assistant message open so generation + /// continues it. pub fn continue_final_message(&self) -> bool { matches!( self.generation_prompt_mode, @@ -325,7 +335,8 @@ pub struct ChatTool { } impl ChatTool { - /// Used internally for template rendering and passed to `tool-parser` crate. + /// Used internally for template rendering and passed to `tool-parser` + /// crate. pub(crate) fn to_openai_tool(&self) -> OpenAiTool { OpenAiTool { tool_type: "function".to_string(), @@ -348,7 +359,8 @@ pub enum ChatToolChoice { None, } -/// One chat request ready to be rendered into a prompt and lowered into a generate request. +/// One chat request ready to be rendered into a prompt and lowered into a +/// generate request. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct ChatRequest { /// Stable caller-supplied request ID. @@ -365,14 +377,17 @@ pub struct ChatRequest { pub tool_choice: ChatToolChoice, /// Text decode options for incremental detokenization. pub decode_options: TextDecodeOptions, - /// Whether to emit intermediate northbound content deltas before the terminal result. + /// Whether to emit intermediate northbound content deltas before the + /// terminal result. /// - /// If `false`, callers only observe the terminal accumulated assistant output. If `true`, - /// callers may receive zero or more incremental content events before the final terminal one. + /// If `false`, callers only observe the terminal accumulated assistant + /// output. If `true`, callers may receive zero or more incremental + /// content events before the final terminal one. pub intermediate: bool, /// Request scheduling priority (lower means earlier handling; default 0). pub priority: i32, - /// Documents for RAG (retrieval-augmented generation), passed to the chat template. + /// Documents for RAG (retrieval-augmented generation), passed to the chat + /// template. pub documents: Option>, /// Salt for prefix cache isolation in multi-user environments. pub cache_salt: Option, @@ -422,17 +437,18 @@ impl ChatRequest { Ok(()) } - /// Return true if this request should enable tool parsing based on the tool choice and tool - /// list. + /// Return true if this request should enable tool parsing based on the tool + /// choice and tool list. pub(crate) fn tool_parsing_enabled(&self) -> bool { matches!(self.tool_choice, ChatToolChoice::Auto) && !self.tools.is_empty() } /// Return the request-level thinking toggle when explicitly requested. /// - /// We currently accept the two request kwargs `thinking` and `enable_thinking`. Both must be - /// booleans when present. If both are present, they must have the same value. If neither key - /// is provided, return `None`. + /// We currently accept the two request kwargs `thinking` and + /// `enable_thinking`. Both must be booleans when present. If both are + /// present, they must have the same value. If neither key is provided, + /// return `None`. pub(crate) fn enable_thinking(&self) -> Result> { let thinking = self.parse_template_bool("thinking")?; let enable_thinking = self.parse_template_bool("enable_thinking")?; @@ -462,7 +478,8 @@ impl ChatRequest { } impl ChatRole { - /// Return the chat-template role string used by the current text-only chat backend. + /// Return the chat-template role string used by the current text-only chat + /// backend. pub fn as_str(&self) -> &'static str { match self { Self::System => "system", @@ -564,10 +581,7 @@ mod tests { #[test] fn enable_thinking_accepts_matching_duplicate_kwargs() { let mut request = ChatRequest::for_test(); - request - .chat_options - .template_kwargs - .insert("thinking".to_string(), json!(true)); + request.chat_options.template_kwargs.insert("thinking".to_string(), json!(true)); request .chat_options .template_kwargs diff --git a/src/chat/src/stream.rs b/src/chat/src/stream.rs index a9ef663e..8a8dea46 100644 --- a/src/chat/src/stream.rs +++ b/src/chat/src/stream.rs @@ -44,7 +44,8 @@ impl ChatEventStream { &self.request_id } - /// Collect the stream to completion and return the final assembled assistant message. + /// Collect the stream to completion and return the final assembled + /// assistant message. pub async fn collect_message(mut self) -> Result { use futures::StreamExt as _; @@ -103,8 +104,8 @@ impl ChatEventStream { } } - // Note: this is actually unreachable, as the underlying stream always emit an error on - // unexpected close. + // Note: this is actually unreachable, as the underlying stream always emit an + // error on unexpected close. Err(Error::StreamClosedBeforeTerminalOutput { request_id: self.request_id, }) diff --git a/src/chat/tests/chat.rs b/src/chat/tests/chat.rs index c9079c55..a72d4168 100644 --- a/src/chat/tests/chat.rs +++ b/src/chat/tests/chat.rs @@ -761,10 +761,7 @@ async fn chat_stream_reports_decode_failure_as_error_event() { })) )); - match timeout(Duration::from_secs(2), stream.next()) - .await - .unwrap() - { + match timeout(Duration::from_secs(2), stream.next()).await.unwrap() { Some(Err(vllm_chat::Error::Text(vllm_text::Error::Tokenizer(message)))) => { assert_eq!(message, "decode failed"); } @@ -1318,13 +1315,7 @@ async fn chat_collect_message_preserves_tool_call_arguments_in_final_only_mode() let mut request = sample_tool_request("chat-final-only-tool"); request.intermediate = false; - let message = chat - .chat(request) - .await - .unwrap() - .collect_message() - .await - .unwrap(); + let message = chat.chat(request).await.unwrap().collect_message().await.unwrap(); assert_eq!( message.finish_reason, @@ -1480,13 +1471,7 @@ async fn chat_stream_and_collect_preserve_prompt_and_sample_logprobs() { ) {} request.request_id = "chat-logprobs-collect".to_string(); - let collected = chat - .chat(request) - .await - .unwrap() - .collect_message() - .await - .unwrap(); + let collected = chat.chat(request).await.unwrap().collect_message().await.unwrap(); assert_eq!(collected.message.text(), "Hi"); assert_eq!( collected.prompt_logprobs, @@ -1565,9 +1550,7 @@ async fn chat_rejects_unknown_tool_parser_before_engine_request() { spawn_mock_engine_task(handshake_address.clone(), engine_id, |dealer, _| { Box::pin(async move { assert!( - timeout(Duration::from_millis(100), recv_engine_message(dealer)) - .await - .is_err(), + timeout(Duration::from_millis(100), recv_engine_message(dealer)).await.is_err(), "chat request should fail before any engine request is sent" ); }) @@ -1608,9 +1591,7 @@ async fn chat_rejects_unknown_reasoning_parser_before_engine_request() { spawn_mock_engine_task(handshake_address.clone(), engine_id, |dealer, _| { Box::pin(async move { assert!( - timeout(Duration::from_millis(100), recv_engine_message(dealer)) - .await - .is_err(), + timeout(Duration::from_millis(100), recv_engine_message(dealer)).await.is_err(), "chat request should fail before any engine request is sent" ); }) @@ -1651,9 +1632,7 @@ async fn chat_rejects_tool_requests_when_tool_parser_is_disabled() { spawn_mock_engine_task(handshake_address.clone(), engine_id, |dealer, _| { Box::pin(async move { assert!( - timeout(Duration::from_millis(100), recv_engine_message(dealer)) - .await - .is_err(), + timeout(Duration::from_millis(100), recv_engine_message(dealer)).await.is_err(), "chat request should fail before any engine request is sent" ); }) @@ -1667,10 +1646,7 @@ async fn chat_rejects_tool_requests_when_tool_parser_is_disabled() { ) .await .with_tool_call_parser(ParserSelection::None); - let error = match chat - .chat(sample_tool_request("chat-tool-parser-disabled")) - .await - { + let error = match chat.chat(sample_tool_request("chat-tool-parser-disabled")).await { Ok(_) => panic!("tool requests should fail when tool parsing is disabled"), Err(error) => error, }; diff --git a/src/cmd/src/cli.rs b/src/cmd/src/cli.rs index e344e9cb..15f9d984 100644 --- a/src/cmd/src/cli.rs +++ b/src/cmd/src/cli.rs @@ -75,7 +75,8 @@ impl Cli { pub enum Command { /// Run the Rust OpenAI frontend as a Python-supervised worker. Frontend(FrontendArgs), - /// Launch a managed Python headless engine, then run the Rust OpenAI frontend. + /// Launch a managed Python headless engine, then run the Rust OpenAI + /// frontend. Serve(ServeArgs), } @@ -84,10 +85,12 @@ pub enum Command { #[educe(Debug)] pub struct SharedRuntimeArgs { #[serde(rename = "model_tag")] - /// Model identifier or local model directory used for backend loading and public model ID. + /// Model identifier or local model directory used for backend loading and + /// public model ID. pub model: String, - /// Maximum time to wait for the expected engines to register on the frontend transport. + /// Maximum time to wait for the expected engines to register on the + /// frontend transport. #[arg( long = "engine-ready-timeout-secs", env = "VLLM_ENGINE_READY_TIMEOUT_S", @@ -110,11 +113,13 @@ pub struct SharedRuntimeArgs { #[arg(long = "tokenizer-mode", default_value_t)] #[serde(default, rename = "tokenizer_mode")] pub renderer: RendererSelection, - /// Override the maximum model context length. When set, the frontend uses this value - /// instead of the model's `max_position_embeddings` from `config.json`. + /// Override the maximum model context length. When set, the frontend uses + /// this value instead of the model's `max_position_embeddings` from + /// `config.json`. #[arg(long)] pub max_model_len: Option, - /// TCP port for the gRPC Generate service. When not set, no gRPC server is started. + /// TCP port for the gRPC Generate service. When not set, no gRPC server is + /// started. #[arg(long)] #[serde(default)] pub grpc_port: Option, @@ -123,19 +128,20 @@ pub struct SharedRuntimeArgs { #[serde(default)] pub shutdown_timeout: u64, - /// The file path to the chat template, or the template in single-line form for the specified - /// model. + /// The file path to the chat template, or the template in single-line form + /// for the specified model. #[arg(long)] #[serde(default)] pub chat_template: Option, /// Default keyword arguments to pass to the chat template renderer. /// - /// These will be merged with request-level chat_template_kwargs, with request values taking - /// precedence. Useful for setting default behavior for reasoning models. + /// These will be merged with request-level chat_template_kwargs, with + /// request values taking precedence. Useful for setting default + /// behavior for reasoning models. /// - /// Example: `{"enable_thinking": false}` to disable thinking mode by default for - /// Qwen3/DeepSeek models. + /// Example: `{"enable_thinking": false}` to disable thinking mode by + /// default for Qwen3/DeepSeek models. #[arg(long, value_parser = parse_json::>, value_name = "JSON")] #[serde(default)] pub default_chat_template_kwargs: Option>, @@ -144,24 +150,26 @@ pub struct SharedRuntimeArgs { /// /// * "auto" detects the format from the template /// * "string" renders content as a string. Example: `"Hello World"` - /// * "openai" renders content as a list of dictionaries, similar to OpenAI schema. Example: - /// `[{"type": "text", "text": "Hello world!"}]` + /// * "openai" renders content as a list of dictionaries, similar to OpenAI + /// schema. Example: `[{"type": "text", "text": "Hello world!"}]` #[arg(long, default_value_t)] #[serde(default)] pub chat_template_content_format: ChatTemplateContentFormatOption, - /// Log a summary line for each completed request, including prompt/output token counts - /// and finish reason. + /// Log a summary line for each completed request, including prompt/output + /// token counts and finish reason. #[arg(long)] #[serde(default)] pub enable_log_requests: bool, - /// Disable periodic logging of engine statistics (throughput, queue depth, cache usage). + /// Disable periodic logging of engine statistics (throughput, queue depth, + /// cache usage). #[arg(long)] #[serde(default)] pub disable_log_stats: bool, - /// Unsupported Python vLLM frontend arguments recognized but not yet implemented in Rust. + /// Unsupported Python vLLM frontend arguments recognized but not yet + /// implemented in Rust. #[educe(Debug(ignore))] #[command(flatten)] #[serde(default, flatten)] @@ -169,7 +177,8 @@ pub struct SharedRuntimeArgs { } impl SharedRuntimeArgs { - /// Maximum time to wait for the expected engines to register on the frontend transport. + /// Maximum time to wait for the expected engines to register on the + /// frontend transport. pub fn ready_timeout(&self) -> Duration { Duration::from_secs(self.engine_ready_timeout_secs) } @@ -181,8 +190,8 @@ impl SharedRuntimeArgs { /// Build the OpenAI-server config for the Python-bootstrap worker contract. /// - /// The resulting config binds the Python-supplied transport addresses and inherits an already - /// open HTTP listener from the supervisor process. + /// The resulting config binds the Python-supplied transport addresses and + /// inherits an already open HTTP listener from the supervisor process. fn into_bootstrapped_config( self, listen_fd: i32, @@ -220,8 +229,8 @@ impl SharedRuntimeArgs { } } - /// Build the OpenAI-server config for the managed `serve` path that still owns the startup - /// handshake and binds its own HTTP listener. + /// Build the OpenAI-server config for the managed `serve` path that still + /// owns the startup handshake and binds its own HTTP listener. fn into_managed_config( self, listener_mode: HttpListenerMode, @@ -279,17 +288,21 @@ fn parse_runtime_args_json(value: &str) -> Result { #[derive(Educe, Clone, Args, PartialEq, Eq)] #[educe(Debug)] pub struct FrontendArgs { - /// Inherited listening socket file descriptor passed by the Python supervisor. + /// Inherited listening socket file descriptor passed by the Python + /// supervisor. #[arg(long)] pub listen_fd: i32, - /// Frontend input ROUTER socket address that the Python engines will connect to. + /// Frontend input ROUTER socket address that the Python engines will + /// connect to. #[arg(long)] pub input_address: String, - /// Frontend output PULL socket address that the Python engines will push responses to. + /// Frontend output PULL socket address that the Python engines will push + /// responses to. #[arg(long)] pub output_address: String, - /// Optional Python-owned frontend-side DP coordinator socket address for external coordinator - /// mode in the bootstrapped frontend path, i.e., `stats_update_address`. + /// Optional Python-owned frontend-side DP coordinator socket address for + /// external coordinator mode in the bootstrapped frontend path, i.e., + /// `stats_update_address`. #[arg(long)] pub coordinator_address: Option, /// Total number of data-parallel engines expected for this frontend. @@ -314,12 +327,14 @@ impl FrontendArgs { } } -/// Arguments for the managed-engine mode that spawns Python on behalf of the user. +/// Arguments for the managed-engine mode that spawns Python on behalf of the +/// user. #[derive(Educe, Clone, Args, PartialEq, Eq)] #[educe(Debug)] #[command(override_usage = "vllm-rs serve [OPTIONS] [-- ...]")] pub struct ServeArgs { - /// Only launch the managed Python headless engine and do not start the Rust frontend. + /// Only launch the managed Python headless engine and do not start the Rust + /// frontend. #[arg(long)] pub headless: bool, /// Python executable used to launch the managed headless vLLM engine. @@ -334,15 +349,16 @@ pub struct ServeArgs { /// Unix domain socket path. If set, host and port arguments are ignored. #[arg(long)] pub uds: Option, - /// Host/IP used both for the managed-engine handshake endpoint and the frontend-advertised - /// input/output ZMQ socket addresses. + /// Host/IP used both for the managed-engine handshake endpoint and the + /// frontend-advertised input/output ZMQ socket addresses. #[arg( long = "data-parallel-address", visible_alias = "handshake-host", default_value = "127.0.0.1" )] pub handshake_host: String, - /// Optional TCP port for the managed-engine handshake / data-parallel RPC endpoint. + /// Optional TCP port for the managed-engine handshake / data-parallel RPC + /// endpoint. /// /// When omitted, the CLI allocates an ephemeral port automatically. #[arg( @@ -367,11 +383,12 @@ pub struct ServeArgs { #[command(flatten)] pub runtime: SharedRuntimeArgs, - /// Additional arguments forwarded to `python -m vllm.entrypoints.cli.main serve ...`. + /// Additional arguments forwarded to `python -m vllm.entrypoints.cli.main + /// serve ...`. /// - /// Arguments after an explicit `--` are forwarded verbatim. Before `--`, `vllm-rs serve` - /// automatically keeps recognized frontend options on the Rust side and forwards everything - /// else to Python. + /// Arguments after an explicit `--` are forwarded verbatim. Before `--`, + /// `vllm-rs serve` automatically keeps recognized frontend options on + /// the Rust side and forwards everything else to Python. #[arg( last = true, allow_hyphen_values = true, @@ -381,18 +398,18 @@ pub struct ServeArgs { } impl ServeArgs { - /// Build the handshake address shared by the Rust frontend and managed Python engine. + /// Build the handshake address shared by the Rust frontend and managed + /// Python engine. pub fn handshake_address(&self, handshake_port: u16) -> String { format!("tcp://{}:{}", self.handshake_host, handshake_port) } - /// Build the OpenAI-server runtime config used after the managed Python engine starts. + /// Build the OpenAI-server runtime config used after the managed Python + /// engine starts. pub fn to_frontend_config(&self, handshake_address: String) -> Config { // Prefer IPC sockets for local engine input/output. - let (local_input_address, local_output_address) = self - .frontend_local_only() - .then(frontend_ipc_addresses) - .unzip(); + let (local_input_address, local_output_address) = + self.frontend_local_only().then(frontend_ipc_addresses).unzip(); let listener_mode = match &self.uds { Some(path) => HttpListenerMode::BindUnix { path: path.clone() }, None => HttpListenerMode::BindTcp { @@ -411,7 +428,8 @@ impl ServeArgs { ) } - /// Build the managed Python-engine spawn configuration for one resolved handshake port. + /// Build the managed Python-engine spawn configuration for one resolved + /// handshake port. pub fn into_managed_engine_config(self, handshake_port: u16) -> ManagedEngineConfig { let mut python_args = self.python_args; // Manually forward some args to the Python engine. @@ -435,11 +453,11 @@ impl ServeArgs { } fn local_engine_count(&self) -> usize { - self.data_parallel_size_local - .unwrap_or(self.data_parallel_size) + self.data_parallel_size_local.unwrap_or(self.data_parallel_size) } - /// Return whether the managed Rust frontend only needs to communicate with colocated engines. + /// Return whether the managed Rust frontend only needs to communicate with + /// colocated engines. fn frontend_local_only(&self) -> bool { self.data_parallel_size_local != Some(0) && self.local_engine_count() == self.data_parallel_size diff --git a/src/cmd/src/cli/serve_validate.rs b/src/cmd/src/cli/serve_validate.rs index 549b9db5..aa146945 100644 --- a/src/cmd/src/cli/serve_validate.rs +++ b/src/cmd/src/cli/serve_validate.rs @@ -6,8 +6,8 @@ use clap::error::ErrorKind; use crate::cli::Cli; -/// Python `argparse` accepts these multi-character single-dash aliases, but `clap` cannot model -/// them directly. +/// Python `argparse` accepts these multi-character single-dash aliases, but +/// `clap` cannot model them directly. const PYTHON_MULTI_CHAR_ALIASES: &[(&str, &str)] = &[ ("-asc", "--api-server-count"), ("-pp", "--pipeline-parallel-size"), @@ -28,8 +28,8 @@ const PYTHON_MULTI_CHAR_ALIASES: &[(&str, &str)] = &[ ("-ac", "--attention-config"), ]; -/// Repartition `serve` argv so Rust frontend-owned flags stay before `--`, while everything else -/// is forwarded to Python. +/// Repartition `serve` argv so Rust frontend-owned flags stay before `--`, +/// while everything else is forwarded to Python. pub(super) fn repartition_serve_args(args: &[OsString]) -> Result, clap::Error> { if args.get(1).map(|arg| arg.as_os_str()) != Some("serve".as_ref()) { return Ok(args.to_vec()); @@ -116,11 +116,9 @@ fn normalize_python_multi_char_alias(arg: &str) -> Option { } fn find_python_multi_char_alias(arg: &str) -> Option<&'static str> { - PYTHON_MULTI_CHAR_ALIASES - .iter() - .find_map(|&(alias, canonical)| { - (arg == alias || arg.starts_with(&format!("{alias}="))).then_some(canonical) - }) + PYTHON_MULTI_CHAR_ALIASES.iter().find_map(|&(alias, canonical)| { + (arg == alias || arg.starts_with(&format!("{alias}="))).then_some(canonical) + }) } fn push_chunk( @@ -163,9 +161,8 @@ fn chunk_head_is_frontend_owned( fn collect_frontend_option_names() -> (HashSet, HashSet) { let mut command = Cli::command(); - let serve_command = command - .find_subcommand_mut("serve") - .expect("serve subcommand should exist"); + let serve_command = + command.find_subcommand_mut("serve").expect("serve subcommand should exist"); let mut long_flags = HashSet::new(); let mut short_flags = HashSet::new(); @@ -209,9 +206,8 @@ fn is_help_flag(arg: &str) -> bool { fn build_missing_model_error() -> clap::Error { let mut command = Cli::command(); - let serve_command = command - .find_subcommand_mut("serve") - .expect("serve subcommand should exist"); + let serve_command = + command.find_subcommand_mut("serve").expect("serve subcommand should exist"); serve_command.error( ErrorKind::MissingRequiredArgument, "serve requires the model to appear immediately after the subcommand", diff --git a/src/cmd/src/cli/unsupported.rs b/src/cmd/src/cli/unsupported.rs index 766f5614..7112ded7 100644 --- a/src/cmd/src/cli/unsupported.rs +++ b/src/cmd/src/cli/unsupported.rs @@ -8,11 +8,11 @@ use clap::builder::{TypedValueParser, ValueParserFactory}; use itertools::Itertools; use serde::{Deserialize, Deserializer, Serialize}; -/// Marker type for frontend-owned `serve` arguments that `vllm-rs` recognizes but does not -/// support yet. +/// Marker type for frontend-owned `serve` arguments that `vllm-rs` recognizes +/// but does not support yet. /// -/// When passed as JSON args, it can be deserialized from any value, and serializes back to the -/// original value. +/// When passed as JSON args, it can be deserialized from any value, and +/// serializes back to the original value. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(transparent)] pub struct Unsupported(pub serde_json::Value); @@ -31,10 +31,11 @@ This may lead to unexpected behavior as the Rust frontend will completely ignore } } -/// Marker type for no-op arguments that are accepted by the Rust frontend but have no effect. +/// Marker type for no-op arguments that are accepted by the Rust frontend but +/// have no effect. /// -/// When passed as JSON args, it can be deserialized from any value, but always serializes back to -/// `null`. +/// When passed as JSON args, it can be deserialized from any value, but always +/// serializes back to `null`. #[derive(Clone, Debug, PartialEq, Eq, Serialize)] pub struct Noop; @@ -79,7 +80,8 @@ impl TypedValueParser for NoopValueParser { } } -/// Frontend-owned Python `serve` arguments that `vllm-rs` recognizes but does not support yet. +/// Frontend-owned Python `serve` arguments that `vllm-rs` recognizes but does +/// not support yet. #[serde_with::skip_serializing_none] #[derive(Debug, Clone, PartialEq, Eq, Default, Args, Serialize, Deserialize)] #[command(next_help_heading = "Options not implemented in Rust frontend yet")] @@ -96,8 +98,9 @@ pub struct UnsupportedArgs { } impl UnsupportedArgs { - /// Check whether any unsupported arguments are set, and if so, return an error listing them. - /// Also warn about any no-op arguments that are set but will be ignored. + /// Check whether any unsupported arguments are set, and if so, return an + /// error listing them. Also warn about any no-op arguments that are set + /// but will be ignored. pub(crate) fn check(&self) -> Result<(), String> { let value = serde_json::to_value(self).unwrap(); let map = value.as_object().unwrap(); @@ -113,10 +116,7 @@ impl UnsupportedArgs { if !unsupported.is_empty() { unsupported.sort_unstable(); - let bullets = unsupported - .into_iter() - .map(|key| format!("- {key}")) - .join("\n"); + let bullets = unsupported.into_iter().map(|key| format!("- {key}")).join("\n"); return Err(format!( " The following arguments are not implemented in Rust frontend yet: @@ -130,44 +130,48 @@ Remove these arguments to continue." } } -/// Frontend-owned Python `vllm serve` top-level arguments that `vllm-rs` recognizes but does not -/// support yet. +/// Frontend-owned Python `vllm serve` top-level arguments that `vllm-rs` +/// recognizes but does not support yet. /// /// Source of truth in Python vLLM: /// - `vllm.entrypoints.openai.cli_args.make_arg_parser(...)` /// - `vllm.entrypoints.cli.serve.ServeSubcommand.subparser_init(...)` /// -/// These are not part of `EngineArgs`, `AsyncEngineArgs`, `BaseFrontendArgs`, or `FrontendArgs`. -/// They live on the `serve` command itself and control managed-engine / multi-process orchestration -/// rather than the shared frontend runtime config. +/// These are not part of `EngineArgs`, `AsyncEngineArgs`, `BaseFrontendArgs`, +/// or `FrontendArgs`. They live on the `serve` command itself and control +/// managed-engine / multi-process orchestration rather than the shared frontend +/// runtime config. #[serde_with::skip_serializing_none] #[derive(Debug, Clone, PartialEq, Eq, Default, Args, Serialize, Deserialize)] pub struct TopLevelUnsupportedArgs { - /// How many API server processes to run. Defaults to data_parallel_size if not specified. + /// How many API server processes to run. Defaults to data_parallel_size if + /// not specified. #[arg(long, hide = true)] pub api_server_count: Option, - /// Read CLI options from a config file. Must be a YAML with the following options: - /// https://docs.vllm.ai/en/latest/configuration/serve_args.html + /// Read CLI options from a config file. Must be a YAML with the following + /// options: https://docs.vllm.ai/en/latest/configuration/serve_args.html #[arg(long)] pub config: Option, - /// Launch a gRPC server instead of the HTTP OpenAI-compatible server. Requires: - /// pip install vllm[grpc]. + /// Launch a gRPC server instead of the HTTP OpenAI-compatible server. + /// Requires: pip install vllm[grpc]. #[arg(long, default_missing_value = "true", num_args = 0..=1)] pub grpc: Option, } -/// Frontend-owned Python engine arguments that `vllm-rs` recognizes but does not support yet. +/// Frontend-owned Python engine arguments that `vllm-rs` recognizes but does +/// not support yet. /// /// Source of truth in Python vLLM: /// - `vllm.engine.arg_utils.EngineArgs.add_cli_args(...)` /// - `vllm.engine.arg_utils.AsyncEngineArgs.add_cli_args(...)` /// -/// These arguments are declared through the Python engine-args surface, but they are still -/// frontend-owned: the API server / AsyncLLM layer reads them for tokenizer setup, request -/// validation, routing, logging, and other frontend behavior, so Rust must recognize them rather -/// than treating them as pure engine passthrough. +/// These arguments are declared through the Python engine-args surface, but +/// they are still frontend-owned: the API server / AsyncLLM layer reads them +/// for tokenizer setup, request validation, routing, logging, and other +/// frontend behavior, so Rust must recognize them rather than treating them as +/// pure engine passthrough. #[serde_with::skip_serializing_none] #[derive(Debug, Clone, PartialEq, Eq, Default, Args, Serialize, Deserialize)] pub struct EngineUnsupportedArgs { @@ -182,8 +186,8 @@ pub struct EngineUnsupportedArgs { pub hf_config_path: Option, /// Allowing API requests to read local images or videos from directories - /// specified by the server file system. This is a security risk. Should only - /// be enabled in trusted environments. + /// specified by the server file system. This is a security risk. Should + /// only be enabled in trusted environments. #[arg(long)] pub allowed_local_media_path: Option, @@ -193,15 +197,16 @@ pub struct EngineUnsupportedArgs { pub allowed_media_domains: Option, /// The specific revision to use for the tokenizer on the Hugging Face Hub. - /// It can be a branch name, a tag name, or a commit id. If unspecified, will - /// use the default version. + /// It can be a branch name, a tag name, or a commit id. If unspecified, + /// will use the default version. #[arg(long)] pub tokenizer_revision: Option, /// Maximum number of log probabilities to return when `logprobs` is - /// specified in `SamplingParams`. The default value comes the default for the - /// OpenAI Chat Completions API. -1 means no cap, i.e. all (output_length * - /// vocab_size) logprobs are allowed to be returned and it may cause OOM. + /// specified in `SamplingParams`. The default value comes the default for + /// the OpenAI Chat Completions API. -1 means no cap, i.e. all + /// (output_length * vocab_size) logprobs are allowed to be returned and + /// it may cause OOM. #[arg(long)] pub max_logprobs: Option, @@ -219,8 +224,8 @@ pub struct EngineUnsupportedArgs { /// If `True`, enables passing text embeddings as inputs via the /// `prompt_embeds` key. /// - /// WARNING: The vLLM engine may crash if incorrect shape of embeddings is passed. - /// Only enable this flag for trusted users! + /// WARNING: The vLLM engine may crash if incorrect shape of embeddings is + /// passed. Only enable this flag for trusted users! #[arg( long, visible_alias = "no-enable-prompt-embeds", @@ -232,10 +237,10 @@ pub struct EngineUnsupportedArgs { /// The model name(s) used in the API. If multiple names are provided, the /// server will respond to any of the provided names. The model name in the /// model field of a response will be the first name in this list. If not - /// specified, the model name will be the same as the `--model` argument. Noted - /// that this name(s) will also be used in `model_name` tag content of - /// prometheus metrics, if multiple names provided, metrics tag will take the - /// first one. + /// specified, the model name will be the same as the `--model` argument. + /// Noted that this name(s) will also be used in `model_name` tag + /// content of prometheus metrics, if multiple names provided, metrics + /// tag will take the first one. #[arg(long)] pub served_model_name: Option, @@ -252,10 +257,11 @@ pub struct EngineUnsupportedArgs { /// The folder path to the generation config. Defaults to `"auto"`, the /// generation config will be loaded from model path. If set to `"vllm"`, no - /// generation config is loaded, vLLM defaults will be used. If set to a folder - /// path, the generation config will be loaded from the specified folder path. - /// If `max_new_tokens` is specified in generation config, then it sets a - /// server-wide limit on the number of output tokens for all requests. + /// generation config is loaded, vLLM defaults will be used. If set to a + /// folder path, the generation config will be loaded from the specified + /// folder path. If `max_new_tokens` is specified in generation config, + /// then it sets a server-wide limit on the number of output tokens for + /// all requests. #[arg(long)] pub generation_config: Option, @@ -359,21 +365,23 @@ pub struct EngineUnsupportedArgs { #[arg(long)] pub structured_outputs_config: Option, - /// Log aggregate rather than per-engine statistics when using data parallelism. + /// Log aggregate rather than per-engine statistics when using data + /// parallelism. #[arg(long, default_missing_value = "true", num_args = 0..=1)] pub aggregate_engine_logging: Option, } -/// Frontend-owned Python OpenAI server arguments that `vllm-rs` recognizes but does not support -/// yet. +/// Frontend-owned Python OpenAI server arguments that `vllm-rs` recognizes but +/// does not support yet. /// /// Source of truth in Python vLLM: /// - `vllm.entrypoints.openai.cli_args.BaseFrontendArgs` /// - `vllm.entrypoints.openai.cli_args.FrontendArgs` /// -/// These are not engine args. They belong to the Python OpenAI-compatible frontend / API-server -/// layer itself, for example chat-template configuration, tool/frontend behavior, TLS / CORS / -/// HTTP server settings, and other northbound server knobs. +/// These are not engine args. They belong to the Python OpenAI-compatible +/// frontend / API-server layer itself, for example chat-template configuration, +/// tool/frontend behavior, TLS / CORS / HTTP server settings, and other +/// northbound server knobs. #[serde_with::skip_serializing_none] #[derive(Debug, Clone, PartialEq, Eq, Default, Args, Serialize, Deserialize)] pub struct ServerUnsupportedArgs { @@ -385,8 +393,8 @@ pub struct ServerUnsupportedArgs { pub lora_modules: Option, /// Whether to trust the chat template provided in the request. If False, - /// the server will always use the chat template specified by `--chat-template` - /// or the ones from tokenizer. + /// the server will always use the chat template specified by + /// `--chat-template` or the ones from tokenizer. #[arg( long, visible_alias = "no-trust-request-chat-template", @@ -575,8 +583,8 @@ pub struct ServerUnsupportedArgs { #[arg(long)] pub allowed_headers: Option, - /// If provided, the server will require one of these keys to be presented in - /// the header. + /// If provided, the server will require one of these keys to be presented + /// in the header. #[arg(long)] pub api_key: Option, @@ -615,8 +623,8 @@ pub struct ServerUnsupportedArgs { pub root_path: Option, /// Additional ASGI middleware to apply to the app. We accept multiple - /// --middleware arguments. The value should be an import path. If a function - /// is provided, vLLM will add it to the server using + /// --middleware arguments. The value should be an import path. If a + /// function is provided, vLLM will add it to the server using /// `@app.middleware('http')`. If a class is provided, vLLM will /// add it to the server using `app.add_middleware()`. #[arg(long)] diff --git a/src/cmd/src/logging.rs b/src/cmd/src/logging.rs index f22b83d1..5ccd99be 100644 --- a/src/cmd/src/logging.rs +++ b/src/cmd/src/logging.rs @@ -34,23 +34,19 @@ pub(crate) fn init_tracing() { let formatter = VllmEventFormatter::new(); let _ = tracing_subscriber::registry() - .with( - tracing_subscriber::fmt::layer() - .event_format(formatter) - .with_filter(filter), - ) + .with(tracing_subscriber::fmt::layer().event_format(formatter).with_filter(filter)) .try_init(); } -/// Build the CLI log filter by merging the vLLM-style default level with Rust-style target -/// overrides. +/// Build the CLI log filter by merging the vLLM-style default level with +/// Rust-style target overrides. /// /// Precedence: /// - Start from `VLLM_LOGGING_LEVEL` as the default level for all targets. /// - If `RUST_LOG` contains a global default level such as `warn`, it overrides /// `VLLM_LOGGING_LEVEL`. -/// - Any explicit target directives in `RUST_LOG`, such as `hyper=info`, override whichever default -/// level is active for those targets only. +/// - Any explicit target directives in `RUST_LOG`, such as `hyper=info`, +/// override whichever default level is active for those targets only. fn build_targets_filter(vllm_logging_level: Option<&str>, rust_log: Option<&str>) -> Targets { let mut filter = Targets::new().with_default(map_python_log_level(vllm_logging_level.unwrap_or("INFO"))); @@ -224,12 +220,12 @@ where } } -/// Shorten a source file path for log output while preserving enough context for -/// common Rust entrypoint and module filenames. +/// Shorten a source file path for log output while preserving enough context +/// for common Rust entrypoint and module filenames. /// /// - For `mod.rs`, keep the parent directory as `parent/mod.rs`. -/// - For `src/lib.rs` and `src/main.rs`, keep one additional component as `crate/src/lib.rs` or -/// `crate/src/main.rs` when available. +/// - For `src/lib.rs` and `src/main.rs`, keep one additional component as +/// `crate/src/lib.rs` or `crate/src/main.rs` when available. /// - Other files are displayed as just the basename. fn shorten_file_path(file: &str) -> &str { let mut parts = file.rsplit('/'); diff --git a/src/cmd/src/main.rs b/src/cmd/src/main.rs index 731cdd78..1a744ee2 100644 --- a/src/cmd/src/main.rs +++ b/src/cmd/src/main.rs @@ -15,9 +15,10 @@ use crate::managed_engine::{ManagedEngineHandle, allocate_handshake_port}; const TOKIO_WORKER_THREADS_ENV: &str = "TOKIO_WORKER_THREADS"; const DEFAULT_MAX_TOKIO_WORKER_THREADS: usize = 32; -/// Cap the default number of Tokio worker threads if the user did not explicitly set -/// `TOKIO_WORKER_THREADS` to avoid spawning too many threads on machines with a large number of -/// CPUs, which may lead to excessive context switching and degraded performance. +/// Cap the default number of Tokio worker threads if the user did not +/// explicitly set `TOKIO_WORKER_THREADS` to avoid spawning too many threads on +/// machines with a large number of CPUs, which may lead to excessive context +/// switching and degraded performance. fn tokio_worker_threads() -> Option { if env::var_os(TOKIO_WORKER_THREADS_ENV).is_some() { return None; @@ -54,9 +55,7 @@ fn shutdown_signal() -> CancellationToken { tokio::spawn(async move { let ctrl_c = async { - tokio::signal::ctrl_c() - .await - .expect("failed to install Ctrl-C signal handler"); + tokio::signal::ctrl_c().await.expect("failed to install Ctrl-C signal handler"); }; let sigterm = async { @@ -167,14 +166,15 @@ async fn async_main(cli: Cli) -> Result<()> { } } }; - // Regardless of the shutdown reason, broadcast shutdown signal here to ensure that all - // serving tasks are notified. + // Regardless of the shutdown reason, broadcast shutdown signal here to ensure + // that all serving tasks are notified. shutdown.cancel(); // Shutdown begins. Terminate the managed engine first. engine.shutdown(shutdown_timeout).await?; info!("managed engine shut down gracefully"); - // Wait for the API server to shut down gracefully by draining in-flight requests. + // Wait for the API server to shut down gracefully by draining in-flight + // requests. if !matches!(shutdown_reason, ShutdownReason::Server(_)) { serve_task.await.context("serve task join failed")??; } diff --git a/src/cmd/src/managed_engine.rs b/src/cmd/src/managed_engine.rs index dd8f96ef..0a506244 100644 --- a/src/cmd/src/managed_engine.rs +++ b/src/cmd/src/managed_engine.rs @@ -14,7 +14,8 @@ use tracing::info; const CHILD_POLL_INTERVAL: Duration = Duration::from_millis(200); const MIN_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5); -/// Allocate one ephemeral TCP port for the managed headless-engine handshake on the given host. +/// Allocate one ephemeral TCP port for the managed headless-engine handshake on +/// the given host. pub fn allocate_handshake_port(host: &str) -> Result { let listener = TcpListener::bind((host, 0)).context("failed to allocate handshake port")?; let port = listener @@ -78,7 +79,8 @@ pub struct ManagedEngineHandle { } impl ManagedEngineHandle { - /// Spawn one managed Python headless engine and return a handle for monitoring it. + /// Spawn one managed Python headless engine and return a handle for + /// monitoring it. pub async fn spawn(config: ManagedEngineConfig) -> Result { let command = config.to_command(); info!( @@ -88,10 +90,7 @@ impl ManagedEngineHandle { ); let mut command = Command::from(command); - command - .stdin(Stdio::null()) - .stdout(Stdio::inherit()) - .stderr(Stdio::inherit()); + command.stdin(Stdio::null()).stdout(Stdio::inherit()).stderr(Stdio::inherit()); process_group::configure(&mut command); @@ -106,9 +105,7 @@ impl ManagedEngineHandle { /// Poll whether the managed engine has exited yet. pub async fn try_wait(&self) -> Option { let mut child = self.child.lock().await; - child - .try_wait() - .expect("failed to poll the status of managed engine") + child.try_wait().expect("failed to poll the status of managed engine") } /// Wait until the managed engine exits. @@ -132,7 +129,8 @@ impl ManagedEngineHandle { return Ok(()); }; - // Enforce a minimum shutdown timeout to give the engine process enough time to clean up. + // Enforce a minimum shutdown timeout to give the engine process enough time to + // clean up. let shutdown_timeout = std::cmp::max(timeout, MIN_SHUTDOWN_TIMEOUT); // First, try to gracefully terminate. @@ -144,10 +142,7 @@ impl ManagedEngineHandle { process_group::terminate(pid)?; // Wait for the process to exit on its own. - if tokio::time::timeout(shutdown_timeout, self.wait_for_exit()) - .await - .is_ok() - { + if tokio::time::timeout(shutdown_timeout, self.wait_for_exit()).await.is_ok() { return Ok(()); } @@ -163,13 +158,13 @@ impl ManagedEngineHandle { } } -/// Process group helper functions for managing the Python subprocess and its potential children in -/// a platform-aware way. +/// Process group helper functions for managing the Python subprocess and its +/// potential children in a platform-aware way. mod process_group { use super::*; - /// Place the Python child into its own process group so `serve` can tear down - /// the whole subtree rather than just the immediate shell process. + /// Place the Python child into its own process group so `serve` can tear + /// down the whole subtree rather than just the immediate shell process. pub fn configure(command: &mut Command) { unsafe { command.pre_exec(|| { diff --git a/src/engine-core-client/examples/external_engine_logprobs.rs b/src/engine-core-client/examples/external_engine_logprobs.rs index 8d24a13a..77e9f125 100644 --- a/src/engine-core-client/examples/external_engine_logprobs.rs +++ b/src/engine-core-client/examples/external_engine_logprobs.rs @@ -152,10 +152,7 @@ async fn main() -> Result<()> { println!("requested_logprobs={}", args.logprobs); println!("requested_prompt_logprobs={}", args.prompt_logprobs); - let stream = client - .call(request) - .await - .context("failed to submit engine-core request")?; + let stream = client.call(request).await.context("failed to submit engine-core request")?; let output = timeout(output_timeout, wait_for_final_output(stream)) .await .context("timed out waiting for final output")??; @@ -178,10 +175,7 @@ async fn main() -> Result<()> { println!("new_logprobs={logprobs:#?}"); println!("new_prompt_logprobs_tensors={prompt_logprobs:#?}"); - client - .shutdown() - .await - .context("failed to shut down engine-core client")?; + client.shutdown().await.context("failed to shut down engine-core client")?; if finish_reason != Some(EngineCoreFinishReason::Length) { bail!("unexpected finish_reason: expected Length, got {finish_reason:?}"); diff --git a/src/engine-core-client/examples/external_engine_utility_call.rs b/src/engine-core-client/examples/external_engine_utility_call.rs index c3b7896b..ee2a4e57 100644 --- a/src/engine-core-client/examples/external_engine_utility_call.rs +++ b/src/engine-core-client/examples/external_engine_utility_call.rs @@ -75,10 +75,8 @@ async fn main() -> Result<()> { println!("output_address={}", client.output_address()); println!("engine_identities={:x?}", client.engine_identities()); - let initial_is_sleeping = client - .is_sleeping() - .await - .context("failed to call is_sleeping utility")?; + let initial_is_sleeping = + client.is_sleeping().await.context("failed to call is_sleeping utility")?; println!("is_sleeping(initial)={initial_is_sleeping}"); @@ -96,10 +94,7 @@ async fn main() -> Result<()> { .context("failed to call reset_prefix_cache utility")?; println!("reset_prefix_cache={reset_prefix_cache}"); - client - .reset_mm_cache() - .await - .context("failed to call reset_mm_cache utility")?; + client.reset_mm_cache().await.context("failed to call reset_mm_cache utility")?; println!("reset_mm_cache=ok"); client @@ -111,40 +106,30 @@ async fn main() -> Result<()> { if args.skip_sleep_wake { println!("sleep_wake=skipped"); } else { - client - .sleep(args.sleep_level, &args.sleep_mode) - .await - .with_context(|| { - format!( - "failed to call sleep utility with level={} mode={}", - args.sleep_level, args.sleep_mode - ) - })?; + client.sleep(args.sleep_level, &args.sleep_mode).await.with_context(|| { + format!( + "failed to call sleep utility with level={} mode={}", + args.sleep_level, args.sleep_mode + ) + })?; println!( "sleep=ok level={} mode={}", args.sleep_level, args.sleep_mode ); - let sleeping_after_sleep = client - .is_sleeping() - .await - .context("failed to call is_sleeping after sleep")?; + let sleeping_after_sleep = + client.is_sleeping().await.context("failed to call is_sleeping after sleep")?; println!("is_sleeping(after_sleep)={sleeping_after_sleep}"); if !sleeping_after_sleep { bail!("engine should report sleeping=true after sleep()"); } - client - .wake_up(None) - .await - .context("failed to call wake_up utility")?; + client.wake_up(None).await.context("failed to call wake_up utility")?; println!("wake_up=ok"); - let sleeping_after_wake = client - .is_sleeping() - .await - .context("failed to call is_sleeping after wake_up")?; + let sleeping_after_wake = + client.is_sleeping().await.context("failed to call is_sleeping after wake_up")?; println!("is_sleeping(after_wake)={sleeping_after_wake}"); if sleeping_after_wake { @@ -152,10 +137,7 @@ async fn main() -> Result<()> { } } - client - .shutdown() - .await - .context("failed to shut down engine-core client")?; + client.shutdown().await.context("failed to shut down engine-core client")?; Ok(()) } diff --git a/src/engine-core-client/src/client.rs b/src/engine-core-client/src/client.rs index 1a05f5f9..82cf981b 100644 --- a/src/engine-core-client/src/client.rs +++ b/src/engine-core-client/src/client.rs @@ -19,15 +19,18 @@ mod stream; pub use stream::{EngineCoreOutputStream, EngineCoreStreamOutput}; -/// How the frontend acquires its request/response transport with Python `EngineCoreProc`s. +/// How the frontend acquires its request/response transport with Python +/// `EngineCoreProc`s. #[derive(Debug, Clone, PartialEq, Eq)] pub enum TransportMode { - /// The Rust process owns the startup handshake and allocates or binds the frontend transport - /// addresses itself before replying to engine `HELLO` messages. + /// The Rust process owns the startup handshake and allocates or binds the + /// frontend transport addresses itself before replying to engine + /// `HELLO` messages. HandshakeOwner { /// Shared handshake endpoint that engines dial during startup. handshake_address: String, - /// Host/IP that engines should use to connect back to the frontend transport sockets. + /// Host/IP that engines should use to connect back to the frontend + /// transport sockets. advertised_host: String, /// Total number of engines expected to join this transport. engine_count: usize, @@ -39,12 +42,15 @@ pub enum TransportMode { local_output_address: Option, }, - /// The Python supervisor has already chosen the frontend transport addresses, and the Rust - /// process only needs to bind them and wait for engine registration frames. + /// The Python supervisor has already chosen the frontend transport + /// addresses, and the Rust process only needs to bind them and wait for + /// engine registration frames. Bootstrapped { - /// Input ROUTER socket address that engines will connect to for requests. + /// Input ROUTER socket address that engines will connect to for + /// requests. input_address: String, - /// Output PULL socket address that engines will connect to for responses. + /// Output PULL socket address that engines will connect to for + /// responses. output_address: String, /// Total number of engines expected to register on this transport. engine_count: usize, @@ -53,7 +59,8 @@ pub enum TransportMode { }, } -/// Which coordinator implementation should be active when one is present for a frontend client. +/// Which coordinator implementation should be active when one is present for a +/// frontend client. #[derive(Debug, Clone, PartialEq, Eq)] pub enum CoordinatorMode { /// Run the Rust in-process coordinator for managed `serve` deployments. @@ -62,14 +69,14 @@ pub enum CoordinatorMode { External { address: String }, } -/// Configuration for connecting a Rust frontend client to an already running Python -/// `EngineCoreProc`. +/// Configuration for connecting a Rust frontend client to an already running +/// Python `EngineCoreProc`. #[derive(Debug, Clone, PartialEq, Eq)] pub struct EngineCoreClientConfig { /// Frontend-to-engine transport setup. pub transport_mode: TransportMode, - /// Frontend-side coordinator behavior, or `None` when requests should flow directly to engines - /// without any coordinator involvement. + /// Frontend-side coordinator behavior, or `None` when requests should flow + /// directly to engines without any coordinator involvement. pub coordinator_mode: Option, /// Model name used for frontend-side metrics labels. pub model_name: String, @@ -78,8 +85,8 @@ pub struct EngineCoreClientConfig { } impl EngineCoreClientConfig { - /// Create a new client config with the given handshake address, expecting a single engine, and - /// default values for all other fields. + /// Create a new client config with the given handshake address, expecting a + /// single engine, and default values for all other fields. pub fn new_single(handshake_address: impl Into) -> Self { Self { transport_mode: TransportMode::HandshakeOwner { @@ -114,10 +121,11 @@ impl EngineCoreClientConfig { self } - /// Override the locally bound input/output addresses for handshake-owned transport mode. + /// Override the locally bound input/output addresses for handshake-owned + /// transport mode. /// - /// This is primarily used by tests that want deterministic IPC endpoints while still exercising - /// the handshake-owned startup path. + /// This is primarily used by tests that want deterministic IPC endpoints + /// while still exercising the handshake-owned startup path. pub fn with_local_input_output_addresses( mut self, local_input_address: Option, @@ -137,13 +145,16 @@ impl EngineCoreClientConfig { } } -/// The reason a request stream is being aborted when its output stream is dropped. +/// The reason a request stream is being aborted when its output stream is +/// dropped. #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub enum AbortCause { - /// The consumer dropped the stream before the request reached a terminal engine output. + /// The consumer dropped the stream before the request reached a terminal + /// engine output. #[default] DroppedStream, - /// The frontend matched a stop string locally and intentionally stopped consuming the stream. + /// The frontend matched a stop string locally and intentionally stopped + /// consuming the stream. StopStringMatched, } @@ -152,8 +163,8 @@ task_local::task_local! { } impl AbortCause { - /// Return the abort cause currently associated with this task, or [`AbortCause::DroppedStream`] - /// by default. + /// Return the abort cause currently associated with this task, or + /// [`AbortCause::DroppedStream`] by default. pub fn current() -> Self { ABORT_CAUSE.try_get().unwrap_or_default() } @@ -164,14 +175,16 @@ impl AbortCause { } } -/// Internal auto-abort work item sent from stream `Drop` handlers to the abort worker. +/// Internal auto-abort work item sent from stream `Drop` handlers to the abort +/// worker. #[derive(Debug, Clone)] pub(crate) struct AbortRequest { request_id: String, cause: AbortCause, } -/// Default ZMQ-based implementation that talks directly to a Python `EngineCoreProc`. +/// Default ZMQ-based implementation that talks directly to a Python +/// `EngineCoreProc`. pub struct EngineCoreClient { config: EngineCoreClientConfig, input_address: String, @@ -190,11 +203,12 @@ pub struct EngineCoreClient { } impl EngineCoreClient { - /// Connect to Python `EngineCoreProc`s using the configured transport/coordinator modes. + /// Connect to Python `EngineCoreProc`s using the configured + /// transport/coordinator modes. /// - /// In handshake-owned mode this method drives the full engine startup handshake. In - /// bootstrapped mode it binds the provided frontend sockets and waits for the expected engine - /// registration frames. + /// In handshake-owned mode this method drives the full engine startup + /// handshake. In bootstrapped mode it binds the provided frontend + /// sockets and waits for the expected engine registration frames. pub async fn connect(config: EngineCoreClientConfig) -> Result { let connected = match &config.transport_mode { TransportMode::HandshakeOwner { @@ -248,11 +262,12 @@ impl EngineCoreClient { Self::from_connected(config, connected).await } - /// Connect using handshake-owned transport mode while overriding the frontend input/output bind - /// addresses. + /// Connect using handshake-owned transport mode while overriding the + /// frontend input/output bind addresses. /// - /// This helper preserves the previous test-facing API shape. It is only valid when - /// `config.transport_mode` is `TransportMode::HandshakeOwner`. + /// This helper preserves the previous test-facing API shape. It is only + /// valid when `config.transport_mode` is + /// `TransportMode::HandshakeOwner`. // TODO: inline this pub async fn connect_with_input_output_addresses( config: EngineCoreClientConfig, @@ -264,8 +279,8 @@ impl EngineCoreClient { Self::connect(config).await } - /// Create a new client instance from the connected transport state after the startup handshake - /// completes. + /// Create a new client instance from the connected transport state after + /// the startup handshake completes. async fn from_connected( config: EngineCoreClientConfig, connected: transport::ConnectedTransport, @@ -315,12 +330,10 @@ impl EngineCoreClient { Some(coordinator_task), ) } else if let Some(address) = - dp_stats_address - .as_deref() - .or(match config.coordinator_mode.as_ref() { - Some(CoordinatorMode::External { address }) => Some(address.as_str()), - _ => None, - }) + dp_stats_address.as_deref().or(match config.coordinator_mode.as_ref() { + Some(CoordinatorMode::External { address }) => Some(address.as_str()), + _ => None, + }) { let (handle, service) = CoordinatorHandle::connect_external(address).await?; let coordinator_task = @@ -346,12 +359,14 @@ impl EngineCoreClient { }) } - /// Return the address of the input socket that the client uses to send requests to the engine. + /// Return the address of the input socket that the client uses to send + /// requests to the engine. pub fn input_address(&self) -> &str { &self.input_address } - /// Return the address of the output socket that the client listens on for engine responses. + /// Return the address of the output socket that the client listens on for + /// engine responses. pub fn output_address(&self) -> &str { &self.output_address } @@ -363,13 +378,11 @@ impl EngineCoreClient { /// Return the engine identities of all engines connected to this client. pub fn engine_identities(&self) -> Vec<&[u8]> { - self.engines - .iter() - .map(|engine| &*engine.engine_id) - .collect() + self.engines.iter().map(|engine| &*engine.engine_id).collect() } - /// Return the ready responses received from all engines on the input socket. + /// Return the ready responses received from all engines on the input + /// socket. pub fn ready_responses(&self) -> Vec<&EngineCoreReadyResponse> { self.engines .iter() @@ -377,7 +390,8 @@ impl EngineCoreClient { .collect() } - /// Return the total number of GPU blocks summed across all connected engines. + /// Return the total number of GPU blocks summed across all connected + /// engines. pub fn total_num_gpu_blocks(&self) -> u64 { self.engines .iter() @@ -388,8 +402,8 @@ impl EngineCoreClient { /// Return the minimum engine-reported `max_model_len` across all engines. /// - /// This is the auto-fitted value after KV cache profiling and may differ from - /// the originally configured value. + /// This is the auto-fitted value after KV cache profiling and may differ + /// from the originally configured value. pub fn max_model_len(&self) -> Option { self.engines .iter() @@ -398,7 +412,8 @@ impl EngineCoreClient { .min() } - /// Get the model name associated with this client used for metrics labeling. + /// Get the model name associated with this client used for metrics + /// labeling. pub fn model_name(&self) -> &str { self.inner.model_name() } @@ -416,7 +431,8 @@ impl EngineCoreClient { // Client API implementation. impl EngineCoreClient { - /// Add a new request to the engine and return a per-request raw output stream. + /// Add a new request to the engine and return a per-request raw output + /// stream. pub async fn call(&self, mut req: EngineCoreRequest) -> Result { req.client_index = self.config.client_index; req.validate()?; @@ -430,9 +446,8 @@ impl EngineCoreClient { let request_id = req.request_id.clone(); let data_parallel_rank = req.data_parallel_rank; - let (engine_id, rx) = self - .inner - .register_request(request_id.clone(), data_parallel_rank)?; + let (engine_id, rx) = + self.inner.register_request(request_id.clone(), data_parallel_rank)?; let result: Result<()> = async { if let Some(coordinator) = self.coordinator.as_ref() { @@ -449,9 +464,7 @@ impl EngineCoreClient { "registered request to engine" ); - self.inner - .send_to_engine(&engine_id, EngineCoreRequestType::Add, &req) - .await?; + self.inner.send_to_engine(&engine_id, EngineCoreRequestType::Add, &req).await?; Ok(()) } .await; @@ -480,19 +493,18 @@ impl EngineCoreClient { } for (engine_id, request_ids) in abortable { - self.inner - .do_abort_requests(&engine_id, &request_ids) - .await?; + self.inner.do_abort_requests(&engine_id, &request_ids).await?; } Ok(()) } - /// Call a typed utility method on all connected engines, returning one decoded result per - /// connected engine if all calls succeed or an error if any call fails. + /// Call a typed utility method on all connected engines, returning one + /// decoded result per connected engine if all calls succeed or an error + /// if any call fails. /// - /// Callers should pass utility arguments using Rust tuple semantics so the encoded payload - /// matches Python's `(client_index, call_id, method_name, args)` contract: - /// `()`, `(arg,)`, `(arg1, arg2)`, etc. + /// Callers should pass utility arguments using Rust tuple semantics so the + /// encoded payload matches Python's `(client_index, call_id, + /// method_name, args)` contract: `()`, `(arg,)`, `(arg1, arg2)`, etc. pub async fn call_utility(&self, method: &str, args: A) -> Result> where T: serde::de::DeserializeOwned, @@ -532,7 +544,8 @@ impl EngineCoreClient { try_join_all(futures).await } - /// Execute `collective_rpc` on all engines and flatten all engine results into one list. + /// Execute `collective_rpc` on all engines and flatten all engine results + /// into one list. pub async fn collective_rpc( &self, method: &str, @@ -572,8 +585,7 @@ impl EngineCoreClient { /// Reset the encoder cache. pub async fn reset_encoder_cache(&self) -> Result<()> { - self.call_utility::<(), _>("reset_encoder_cache", ()) - .await?; + self.call_utility::<(), _>("reset_encoder_cache", ()).await?; Ok(()) } @@ -598,7 +610,8 @@ impl EngineCoreClient { Ok(()) } - /// Wake the engine from sleep, optionally limiting the wake-up to specific tags. + /// Wake the engine from sleep, optionally limiting the wake-up to specific + /// tags. pub async fn wake_up(&self, tags: Option>) -> Result<()> { self.call_utility::<(), _>("wake_up", (tags,)).await?; Ok(()) diff --git a/src/engine-core-client/src/client/imp.rs b/src/engine-core-client/src/client/imp.rs index d92a5303..53bae79a 100644 --- a/src/engine-core-client/src/client/imp.rs +++ b/src/engine-core-client/src/client/imp.rs @@ -32,7 +32,8 @@ pub(crate) struct ClientInner { } impl ClientInner { - /// Create a new instance with the given input send half after the startup handshake completes. + /// Create a new instance with the given input send half after the startup + /// handshake completes. pub fn new( input_send: RouterSendHalf, model_name: String, @@ -47,16 +48,17 @@ impl ClientInner { } } - /// Get the model name associated with this client used for metrics labeling. + /// Get the model name associated with this client used for metrics + /// labeling. pub fn model_name(&self) -> &str { &self.model_name } - /// Register a newly added request. Return the selected engine id and the per-request - /// output channel bound to its `request_id`. + /// Register a newly added request. Return the selected engine id and the + /// per-request output channel bound to its `request_id`. /// - /// When `data_parallel_rank` is provided, the request is routed to that specific engine - /// rank, bypassing load balancing. + /// When `data_parallel_rank` is provided, the request is routed to that + /// specific engine rank, bypassing load balancing. pub fn register_request( &self, request_id: String, @@ -83,8 +85,9 @@ impl ClientInner { let _ = self.request_reg.lock().remove(request_id); } - /// Filter the given request IDs to the subset that are still tracked as active and can be - /// aborted, grouped by the engine that originally accepted them. + /// Filter the given request IDs to the subset that are still tracked as + /// active and can be aborted, grouped by the engine that originally + /// accepted them. pub fn abortable_request_ids( &self, request_ids: &[String], @@ -96,8 +99,8 @@ impl ClientInner { Ok(registry.abortable_request_ids(request_ids)) } - /// Obtain the stream sender for one output. If it indicates the request is finished, it will be - /// removed from the registry. + /// Obtain the stream sender for one output. If it indicates the request is + /// finished, it will be removed from the registry. pub fn take_sender_for_output( &self, output: &EngineCoreOutput, @@ -105,7 +108,8 @@ impl ClientInner { self.request_reg.lock().sender_for_output(output) } - /// Remove a batch of requests that have finished or aborted, returning their stream senders. + /// Remove a batch of requests that have finished or aborted, returning + /// their stream senders. pub fn finish_requests<'a>( &self, request_ids: impl IntoIterator, @@ -113,15 +117,15 @@ impl ClientInner { self.request_reg.lock().finish_many(request_ids) } - /// Apply one scheduler stats update for the given engine to the local routing state. - /// Returns `false` if the engine is unknown to the client. + /// Apply one scheduler stats update for the given engine to the local + /// routing state. Returns `false` if the engine is unknown to the + /// client. pub fn apply_scheduler_stats(&self, engine_index: u32, stats: &SchedulerStats) -> bool { - self.request_reg - .lock() - .apply_scheduler_stats(engine_index, stats) + self.request_reg.lock().apply_scheduler_stats(engine_index, stats) } - /// Close all active request streams and utility calls with the first persistent health error. + /// Close all active request streams and utility calls with the first + /// persistent health error. pub fn close_registries(&self, error: Arc) { let persistent_error = self.record_health_error(error); let request_senders = self.request_reg.lock().close(); @@ -146,8 +150,8 @@ impl ClientInner { self.health_error.load().is_none() } - /// Resolve one utility output to the waiting caller. Returns `true` if a waiting caller - /// existed. + /// Resolve one utility output to the waiting caller. Returns `true` if a + /// waiting caller existed. pub fn resolve_utility_output(&self, output: UtilityOutput) -> bool { match self.utility_reg.lock().resolve(&output.call_id) { Some(sender) => { @@ -158,8 +162,9 @@ impl ClientInner { } } - /// Send the given message to the engine. The request should be first registered via - /// `register_request()` to ensure the request stream is tracked. + /// Send the given message to the engine. The request should be first + /// registered via `register_request()` to ensure the request stream is + /// tracked. pub async fn send_to_engine( &self, engine_id: &EngineId, @@ -182,18 +187,18 @@ impl ClientInner { engine_id: &EngineId, request_ids: &[String], ) -> Result<()> { - self.send_to_engine(engine_id, EngineCoreRequestType::Abort, &request_ids) - .await + self.send_to_engine(engine_id, EngineCoreRequestType::Abort, &request_ids).await } - /// Shut down by closing all active request streams and utility calls with a sticky client - /// closed error. + /// Shut down by closing all active request streams and utility calls with a + /// sticky client closed error. pub fn shutdown(&self) { self.close_registries(Arc::new(client_closed!("engine-core client shut down"))); } - /// Remove the request from the active registry for auto-abort and return the engine that the - /// request was originally routed to, if it is still active. + /// Remove the request from the active registry for auto-abort and return + /// the engine that the request was originally routed to, if it is still + /// active. pub fn take_auto_abort_target(&self, request_id: &str) -> Option { let mut registry = self.request_reg.lock(); let (_, engine_id) = registry.remove(request_id)?; @@ -203,9 +208,9 @@ impl ClientInner { Some(engine_id) } - /// Publish the first persistent health error and return the sticky error recorded for this - /// client. Later failures do not overwrite the first one so `/health` and post-close callers - /// observe a stable cause. + /// Publish the first persistent health error and return the sticky error + /// recorded for this client. Later failures do not overwrite the first + /// one so `/health` and post-close callers observe a stable cause. fn record_health_error(&self, error: Arc) -> Arc { if let Some(existing) = self.health_error.load_full() { return existing; @@ -217,8 +222,8 @@ impl ClientInner { .expect("health error must be recorded before registries close") } - /// Assert there is a recorded health error and return a `Shared` variant wrapping it for error - /// returns when the client is already closed. + /// Assert there is a recorded health error and return a `Shared` variant + /// wrapping it for error returns when the client is already closed. fn closed_error(&self) -> Error { Error::Shared(self.health_error.load_full().expect( "closed registry must have a recorded health error before rejecting new operations", @@ -226,9 +231,9 @@ impl ClientInner { } } -/// Background loop that listens for request IDs to abort and sends abort messages to the engine. -/// This is used to implement the auto-abort behavior when a request stream is dropped without being -/// properly terminated. +/// Background loop that listens for request IDs to abort and sends abort +/// messages to the engine. This is used to implement the auto-abort behavior +/// when a request stream is dropped without being properly terminated. pub(crate) async fn run_abort_loop( inner: Arc, mut abort_rx: mpsc::UnboundedReceiver, @@ -252,9 +257,7 @@ pub(crate) async fn run_abort_loop( } } - if let Err(error) = inner - .do_abort_requests(&engine_id, slice::from_ref(&request_id)) - .await + if let Err(error) = inner.do_abort_requests(&engine_id, slice::from_ref(&request_id)).await { warn!( request_id, @@ -266,8 +269,8 @@ pub(crate) async fn run_abort_loop( } } -/// Background loop that listens for engine-core outputs and dispatches them to the corresponding -/// request streams based on their `request_id`. +/// Background loop that listens for engine-core outputs and dispatches them to +/// the corresponding request streams based on their `request_id`. pub(crate) async fn run_output_dispatcher_loop( inner: Arc, mut output_rx: mpsc::Receiver>, diff --git a/src/engine-core-client/src/client/state.rs b/src/engine-core-client/src/client/state.rs index 690461c2..61e87bf2 100644 --- a/src/engine-core-client/src/client/state.rs +++ b/src/engine-core-client/src/client/state.rs @@ -24,8 +24,8 @@ struct TrackedRequest { /// The latest real scheduler-side load snapshot observed from one engine. /// -/// These counters come from `scheduler_stats` on the normal engine output path and are the -/// preferred routing signal once available. +/// These counters come from `scheduler_stats` on the normal engine output path +/// and are the preferred routing signal once available. #[derive(Debug, Clone, Copy, PartialEq, Eq)] struct EngineLoadSnapshot { /// Requests still counted on the scheduler's waiting side. @@ -38,9 +38,9 @@ struct EngineLoadSnapshot { struct EngineRoutingState { /// Requests admitted by this frontend that have not finished yet. /// - /// This is used both as the bootstrap fallback before real scheduler stats exist and as a - /// lower bound afterwards so asynchronous scheduler snapshots cannot erase frontend admission - /// history. + /// This is used both as the bootstrap fallback before real scheduler stats + /// exist and as a lower bound afterwards so asynchronous scheduler + /// snapshots cannot erase frontend admission history. inflight: usize, /// The latest real scheduler snapshot received from this engine, if any. last_scheduler_stats: Option, @@ -49,9 +49,10 @@ struct EngineRoutingState { impl EngineRoutingState { /// Compute the routing score used to pick the least-loaded engine. /// - /// Scheduler stats can raise the load estimate above the frontend-local view, but they should - /// not lower it below requests this frontend has already admitted. Waiting requests still get - /// the same extra penalty as the original `waiting * 4 + running` score. + /// Scheduler stats can raise the load estimate above the frontend-local + /// view, but they should not lower it below requests this frontend has + /// already admitted. Waiting requests still get the same extra penalty + /// as the original `waiting * 4 + running` score. fn routing_score(&self) -> usize { const WAITING_WEIGHT: usize = 4; @@ -69,10 +70,12 @@ impl EngineRoutingState { } } -/// Internal registry for tracking active requests and their output stream senders. +/// Internal registry for tracking active requests and their output stream +/// senders. /// -/// This is used to route incoming outputs to the correct request stream, and to ensure proper -/// cleanup of senders when requests finish or the client shuts down. +/// This is used to route incoming outputs to the correct request stream, and to +/// ensure proper cleanup of senders when requests finish or the client shuts +/// down. #[derive(Debug)] pub struct RequestRegistry { closed: bool, @@ -92,12 +95,12 @@ impl RequestRegistry { } } - /// Register a newly added request. Create the per-request output channel bound to its - /// `request_id` and return the selected engine id. + /// Register a newly added request. Create the per-request output channel + /// bound to its `request_id` and return the selected engine id. /// - /// When `data_parallel_rank` is provided, the request is routed directly to the engine at - /// that rank index, bypassing load balancing. Otherwise the engine with the fewest in-flight - /// requests is chosen. + /// When `data_parallel_rank` is provided, the request is routed directly to + /// the engine at that rank index, bypassing load balancing. Otherwise + /// the engine with the fewest in-flight requests is chosen. pub fn register( &mut self, request_id: String, @@ -148,8 +151,8 @@ impl RequestRegistry { .expect("request registry must contain at least one engine")) } - /// Filter the given request IDs to the subset that are still tracked as active and can be - /// aborted, grouped by engine. + /// Filter the given request IDs to the subset that are still tracked as + /// active and can be aborted, grouped by engine. pub fn abortable_request_ids(&self, request_ids: &[String]) -> BTreeMap> { let mut by_engine = BTreeMap::new(); for request_id in request_ids { @@ -164,12 +167,11 @@ impl RequestRegistry { by_engine } - /// Obtain the stream sender for one output. If it indicates the request is finished, it will be - /// removed from the registry. + /// Obtain the stream sender for one output. If it indicates the request is + /// finished, it will be removed from the registry. pub fn sender_for_output(&mut self, output: &EngineCoreOutput) -> Option { if output.finished() { - self.remove(output.request_id.as_str()) - .map(|tracked| tracked.0) + self.remove(output.request_id.as_str()).map(|tracked| tracked.0) } else { self.requests .get(output.request_id.as_str()) @@ -177,7 +179,8 @@ impl RequestRegistry { } } - /// Remove a batch of requests that have finished or aborted, returning their stream senders. + /// Remove a batch of requests that have finished or aborted, returning + /// their stream senders. pub fn finish_many<'a>( &mut self, request_ids: impl IntoIterator, @@ -188,8 +191,9 @@ impl RequestRegistry { .collect() } - /// Apply one scheduler stats update for the given engine to the local routing state. - /// Returns `false` if the engine is unknown to the client. + /// Apply one scheduler stats update for the given engine to the local + /// routing state. Returns `false` if the engine is unknown to the + /// client. pub fn apply_scheduler_stats(&mut self, engine_index: u32, stats: &SchedulerStats) -> bool { self.apply_scheduler_counts( engine_index, @@ -213,7 +217,8 @@ impl RequestRegistry { .collect() } - /// Remove one request from the local registry. Returns the tracked entry if it exists. + /// Remove one request from the local registry. Returns the tracked entry if + /// it exists. #[must_use] pub fn remove(&mut self, request_id: &str) -> Option<(OutputSender, EngineId)> { let tracked = self.requests.remove(request_id)?; @@ -256,7 +261,8 @@ impl RequestRegistry { } } -/// Internal registry for tracking active utility calls and their waiting receivers. +/// Internal registry for tracking active utility calls and their waiting +/// receivers. #[derive(Debug)] pub struct UtilityRegistry { closed: bool, @@ -275,7 +281,8 @@ impl Default for UtilityRegistry { } impl UtilityRegistry { - /// Allocate the next utility `call_id` and register a newly added utility call. + /// Allocate the next utility `call_id` and register a newly added utility + /// call. pub fn allocate_and_register(&mut self) -> (i64, UtilityReceiver) { let call_id = self.next_call_id.fetch_add(1, Ordering::Relaxed); let (tx, rx) = oneshot::channel(); @@ -295,9 +302,7 @@ impl UtilityRegistry { } self.closed = true; - std::mem::take(&mut self.utility_calls) - .into_values() - .collect() + std::mem::take(&mut self.utility_calls).into_values().collect() } #[cfg(test)] @@ -584,9 +589,7 @@ mod tests { let (chosen, _) = registry.register("req-ok".to_string(), Some(0)).unwrap(); assert_eq!(chosen, engine_0); - let error = registry - .register("req-bad".to_string(), Some(1)) - .unwrap_err(); + let error = registry.register("req-bad".to_string(), Some(1)).unwrap_err(); assert!(matches!( error, crate::error::Error::InvalidDataParallelRank { diff --git a/src/engine-core-client/src/client/stream.rs b/src/engine-core-client/src/client/stream.rs index 9c8975d7..3cbb215b 100644 --- a/src/engine-core-client/src/client/stream.rs +++ b/src/engine-core-client/src/client/stream.rs @@ -39,9 +39,10 @@ impl Deref for EngineCoreStreamOutput { /// Stream of raw engine-core outputs for one request. /// -/// The stream yields only [`EngineCoreStreamOutput`] values whose embedded output `request_id` -/// matches the originating `add_request()` call. Normal request completion is expected to include a -/// final output object whose `finish_reason` is non-`None`. +/// The stream yields only [`EngineCoreStreamOutput`] values whose embedded +/// output `request_id` matches the originating `add_request()` call. Normal +/// request completion is expected to include a final output object whose +/// `finish_reason` is non-`None`. pub struct EngineCoreOutputStream { request_id: String, abort_tx: mpsc::UnboundedSender, @@ -106,9 +107,10 @@ impl Stream for EngineCoreOutputStream { Poll::Ready(Some(item)) } Poll::Ready(None) => { - // If we get a `None` without seeing a finished output, this is an unexpected close - // from the engine side. Mark the stream as terminated with an unexpected close - // state and send an error down the stream to notify the caller. + // If we get a `None` without seeing a finished output, this is an unexpected + // close from the engine side. Mark the stream as terminated + // with an unexpected close state and send an error down the + // stream to notify the caller. warn!(self.request_id, "request stream closed unexpectedly"); self.state = State::UnexpectedClose; @@ -129,9 +131,10 @@ impl FusedStream for EngineCoreOutputStream { impl Drop for EngineCoreOutputStream { fn drop(&mut self) { if self.is_terminated() { - // If it's terminated, it means that the request either finished cleanly, or encountered - // an error or unexpected close from the engine. In any case, the request stream is - // already considered inactive and there's no need to abort it on the engine side. + // If it's terminated, it means that the request either finished cleanly, or + // encountered an error or unexpected close from the engine. In any + // case, the request stream is already considered inactive and + // there's no need to abort it on the engine side. return; } diff --git a/src/engine-core-client/src/coordinator/bootstrap.rs b/src/engine-core-client/src/coordinator/bootstrap.rs index a2cb219f..8c6855bf 100644 --- a/src/engine-core-client/src/coordinator/bootstrap.rs +++ b/src/engine-core-client/src/coordinator/bootstrap.rs @@ -18,16 +18,11 @@ impl CoordinatorBootstrap { /// Bind the engine-facing coordinator sockets on the given host. pub(crate) async fn bind(local_host: &str) -> Result { let mut input_socket = XPubSocket::new(); - let input_address = input_socket - .bind(&format!("tcp://{local_host}:0")) - .await? - .to_string(); + let input_address = input_socket.bind(&format!("tcp://{local_host}:0")).await?.to_string(); let mut output_socket = PullSocket::new(); - let output_address = output_socket - .bind(&format!("tcp://{local_host}:0")) - .await? - .to_string(); + let output_address = + output_socket.bind(&format!("tcp://{local_host}:0")).await?.to_string(); Ok(Self { input_address, @@ -37,7 +32,8 @@ impl CoordinatorBootstrap { }) } - /// Complete the engine-facing startup gate before engines are allowed to send handshake READY. + /// Complete the engine-facing startup gate before engines are allowed to + /// send handshake READY. pub(crate) async fn wait_for_startup_gate( &mut self, engine_count: usize, @@ -57,11 +53,12 @@ async fn wait_for_engine_subscriptions( ) -> Result<()> { let mut received = 0; while received < engine_count { - let message = tokio::time::timeout(ready_timeout, input_socket.recv()) - .await - .map_err(|_| Error::HandshakeTimeout { - stage: "coordinator engine subscriptions", - timeout: ready_timeout, + let message = + tokio::time::timeout(ready_timeout, input_socket.recv()).await.map_err(|_| { + Error::HandshakeTimeout { + stage: "coordinator engine subscriptions", + timeout: ready_timeout, + } })??; if message.len() != 1 { bail_unexpected_handshake_message!( @@ -89,8 +86,6 @@ async fn wait_for_engine_subscriptions( /// Send the coordinator READY marker to all subscribed engines. async fn send_ready_to_engines(input_socket: &mut XPubSocket) -> Result<()> { - input_socket - .send(ZmqMessage::from(Bytes::from_static(b"READY"))) - .await?; + input_socket.send(ZmqMessage::from(Bytes::from_static(b"READY"))).await?; Ok(()) } diff --git a/src/engine-core-client/src/coordinator/external.rs b/src/engine-core-client/src/coordinator/external.rs index ed5934c9..eb447c48 100644 --- a/src/engine-core-client/src/coordinator/external.rs +++ b/src/engine-core-client/src/coordinator/external.rs @@ -12,8 +12,8 @@ use crate::coordinator::handle::{CoordinatorCommand, CoordinatorState}; use crate::error::{Error, Result, bail_unexpected_coordinator_output}; use crate::protocol::{OpaqueValue, decode_msgpack, encode_msgpack}; -/// Frontend-to-coordinator wakeup message sent when the first request arrives while -/// all engines are paused. +/// Frontend-to-coordinator wakeup message sent when the first request arrives +/// while all engines are paused. /// /// This matches the frontend-side msgpack tuple sent by Python /// `DPAsyncMPClient._ensure_stats_update_task` to the coordinator front socket. @@ -29,7 +29,8 @@ struct CoordinatorWakeupMessage { wave: u32, } -/// Coordinator-to-frontend state publish received on the front-side coordinator socket. +/// Coordinator-to-frontend state publish received on the front-side coordinator +/// socket. /// /// This matches the msgpack tuple periodically published by Python /// `DPCoordinatorProc.run_coordinator` to all connected frontends. @@ -52,11 +53,11 @@ struct CoordinatorStateUpdate { /// Background half of an external Python-owned coordinator connection. /// -/// This owns the command receiver and one frontend-facing XSUB socket. It mirrors -/// the subset of Python's coordinator protocol needed by the Rust bootstrapped -/// frontend: receive `(counts, wave, running)` publishes, ignore `counts`, and -/// send `(exclude_engine_index, wave)` wakeup messages when the first request -/// arrives while engines are paused. +/// This owns the command receiver and one frontend-facing XSUB socket. It +/// mirrors the subset of Python's coordinator protocol needed by the Rust +/// bootstrapped frontend: receive `(counts, wave, running)` publishes, ignore +/// `counts`, and send `(exclude_engine_index, wave)` wakeup messages when the +/// first request arrives while engines are paused. pub(crate) struct ExternalCoordinatorService { state: Arc, command_rx: mpsc::UnboundedReceiver, @@ -76,7 +77,8 @@ impl ExternalCoordinatorService { } } - /// Apply one frontend-originated command to the external coordinator state machine. + /// Apply one frontend-originated command to the external coordinator state + /// machine. async fn handle_command(&mut self, command: CoordinatorCommand) -> Result<()> { match command { CoordinatorCommand::FirstRequest { @@ -103,7 +105,8 @@ impl ExternalCoordinatorService { Ok(()) } - /// Apply one publish received from the xsub socket containing a coordinator state update. + /// Apply one publish received from the xsub socket containing a coordinator + /// state update. async fn handle_publish(&mut self, message: ZmqMessage) -> Result<()> { let frames = message.into_vec(); if frames.len() != 1 { diff --git a/src/engine-core-client/src/coordinator/handle.rs b/src/engine-core-client/src/coordinator/handle.rs index a18e2f6d..dca6f70d 100644 --- a/src/engine-core-client/src/coordinator/handle.rs +++ b/src/engine-core-client/src/coordinator/handle.rs @@ -15,8 +15,8 @@ use crate::transport::EngineId; pub(crate) struct CoordinatorStateSnapshot { /// The current DP wave, which will be stamped on outgoing requests. pub current_wave: u32, - /// Whether the engines are currently running or paused, which determines if the frontend - /// must trigger a new wave on the next request. + /// Whether the engines are currently running or paused, which determines if + /// the frontend must trigger a new wave on the next request. pub engines_running: bool, } @@ -28,8 +28,8 @@ pub(crate) type CoordinatorState = Mutex; pub(crate) enum CoordinatorCommand { /// The first request arrived while all engines were paused. /// - /// The coordinator should broadcast `START_DP_WAVE` with the current wave and the target engine - /// index as the excluded engine. + /// The coordinator should broadcast `START_DP_WAVE` with the current wave + /// and the target engine index as the excluded engine. FirstRequest { target_engine_id: EngineId, wave: u32, @@ -38,9 +38,9 @@ pub(crate) enum CoordinatorCommand { /// Frontend-facing coordinator handle used by `EngineCoreClient::call()`. /// -/// This side stays intentionally small: it can read the latest wave snapshot and -/// enqueue a `FirstRequest` transition when the request path observes the system -/// in the paused state. +/// This side stays intentionally small: it can read the latest wave snapshot +/// and enqueue a `FirstRequest` transition when the request path observes the +/// system in the paused state. #[derive(Clone)] pub(crate) struct CoordinatorHandle { state: Arc, @@ -78,8 +78,8 @@ impl CoordinatorHandle { ) } - /// Build the paired frontend handle and background service around an external - /// Python-owned frontend-side coordinator socket. + /// Build the paired frontend handle and background service around an + /// external Python-owned frontend-side coordinator socket. pub(crate) async fn connect_external( coordinator_address: &str, ) -> Result<(Self, ExternalCoordinatorService)> { diff --git a/src/engine-core-client/src/coordinator/inproc.rs b/src/engine-core-client/src/coordinator/inproc.rs index 18607740..26ab0c5a 100644 --- a/src/engine-core-client/src/coordinator/inproc.rs +++ b/src/engine-core-client/src/coordinator/inproc.rs @@ -15,10 +15,11 @@ use crate::protocol::{ encode_msgpack, }; -/// Coordinator-to-engine `START_DP_WAVE` control payload encoded on the engine-facing -/// coordinator socket. +/// Coordinator-to-engine `START_DP_WAVE` control payload encoded on the +/// engine-facing coordinator socket. /// -/// This matches the msgpack tuple broadcast by Python `DPCoordinatorProc._send_start_wave`. +/// This matches the msgpack tuple broadcast by Python +/// `DPCoordinatorProc._send_start_wave`. /// /// Original Python definition: /// @@ -33,9 +34,9 @@ struct StartDpWaveMessage { /// Background half of the in-process coordinator. /// -/// This owns the command receiver and the engine-facing coordinator input socket. -/// It is the single place where wave transitions are serialized and where -/// `START_DP_WAVE` broadcasts are emitted. +/// This owns the command receiver and the engine-facing coordinator input +/// socket. It is the single place where wave transitions are serialized and +/// where `START_DP_WAVE` broadcasts are emitted. pub(crate) struct InProcCoordinatorRunner { state: Arc, command_rx: mpsc::UnboundedReceiver, @@ -97,14 +98,16 @@ impl InProcCoordinatorRunner { Ok(()) } - /// Apply one engine-originated control output to the coordinator state machine. + /// Apply one engine-originated control output to the coordinator state + /// machine. async fn handle_outputs(&mut self, outputs: EngineCoreOutputs) -> Result<()> { match outputs.classify() { ClassifiedEngineCoreOutputs::RequestBatch(batch) if batch.outputs.is_empty() && batch.finished_requests.is_none() => { // Stats-only output for coordinator. - // Ignore since the Rust coordinator doesn't track stats for routing decisions. + // Ignore since the Rust coordinator doesn't track stats for + // routing decisions. } ClassifiedEngineCoreOutputs::DpControl { engine_index, diff --git a/src/engine-core-client/src/metrics.rs b/src/engine-core-client/src/metrics.rs index 5eacc885..8f459396 100644 --- a/src/engine-core-client/src/metrics.rs +++ b/src/engine-core-client/src/metrics.rs @@ -5,7 +5,8 @@ use crate::protocol::stats::SchedulerStats; const WAITING_REASON_CAPACITY: &str = "capacity"; const WAITING_REASON_DEFERRED: &str = "deferred"; -/// Record the scheduler-stats-backed metrics for one engine at one point in time. +/// Record the scheduler-stats-backed metrics for one engine at one point in +/// time. pub(crate) fn record_scheduler_stats( metrics: &SchedulerMetrics, model_name: impl Into, @@ -19,10 +20,7 @@ pub(crate) fn record_scheduler_stats( }; // Scheduler state gauges. - metrics - .scheduler_running - .get_or_create(&labels) - .set(stats.num_running_reqs); + metrics.scheduler_running.get_or_create(&labels).set(stats.num_running_reqs); metrics .scheduler_waiting .get_or_create(&labels) @@ -43,10 +41,7 @@ pub(crate) fn record_scheduler_stats( reason: WAITING_REASON_DEFERRED, }) .set(stats.num_skipped_waiting_reqs); - metrics - .kv_cache_usage - .get_or_create(&labels) - .set(stats.kv_cache_usage); + metrics.kv_cache_usage.get_or_create(&labels).set(stats.kv_cache_usage); // Prefix-cache counters, including the connector-backed external cache path. metrics @@ -84,11 +79,8 @@ pub(crate) fn record_scheduler_stats( .get_or_create(&labels) .inc_by(spec_decoding_stats.num_accepted_tokens); - for (position, accepted_tokens) in spec_decoding_stats - .num_accepted_tokens_per_pos - .iter() - .copied() - .enumerate() + for (position, accepted_tokens) in + spec_decoding_stats.num_accepted_tokens_per_pos.iter().copied().enumerate() { metrics .spec_decode_num_accepted_tokens_per_pos @@ -124,9 +116,8 @@ pub(crate) fn record_scheduler_stats( // Sampled KV-cache residency histograms. if !stats.kv_cache_eviction_events.is_empty() { let kv_block_lifetime_seconds = metrics.kv_block_lifetime_seconds.get_or_create(&labels); - let kv_block_idle_before_evict_seconds = metrics - .kv_block_idle_before_evict_seconds - .get_or_create(&labels); + let kv_block_idle_before_evict_seconds = + metrics.kv_block_idle_before_evict_seconds.get_or_create(&labels); let kv_block_reuse_gap_seconds = metrics.kv_block_reuse_gap_seconds.get_or_create(&labels); for event in &stats.kv_cache_eviction_events { diff --git a/src/engine-core-client/src/protocol/handshake.rs b/src/engine-core-client/src/protocol/handshake.rs index eb18e353..30b9733b 100644 --- a/src/engine-core-client/src/protocol/handshake.rs +++ b/src/engine-core-client/src/protocol/handshake.rs @@ -23,8 +23,9 @@ pub struct ReadyMessage { /// Post-initialization configuration sent from each engine on the input socket /// registration message, after the handshake completes. /// -/// Contains values that may differ from the original config (e.g. `max_model_len` -/// after KV cache auto-fitting, `num_gpu_blocks` after profiling). +/// Contains values that may differ from the original config (e.g. +/// `max_model_len` after KV cache auto-fitting, `num_gpu_blocks` after +/// profiling). /// /// Original Python definition: /// diff --git a/src/engine-core-client/src/protocol/logprobs.rs b/src/engine-core-client/src/protocol/logprobs.rs index 031a16d9..ccaec6e0 100644 --- a/src/engine-core-client/src/protocol/logprobs.rs +++ b/src/engine-core-client/src/protocol/logprobs.rs @@ -14,29 +14,32 @@ use crate::error::{Error, Result, bail_ext_value_decode, ext_value_decode}; /// One token candidate and its logprob metadata for a single sequence position. /// -/// The first entry in a [`PositionLogprobs`] is always the sampled/selected token for that -/// position. Any remaining entries follow the engine's returned top-k candidate order. +/// The first entry in a [`PositionLogprobs`] is always the sampled/selected +/// token for that position. Any remaining entries follow the engine's returned +/// top-k candidate order. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct TokenLogprob { pub token_id: u32, pub logprob: f32, - /// The sampled/selected token uses its actual vocab rank. Remaining entries use 1-based top-k - /// ranks matching the engine's returned candidate order. + /// The sampled/selected token uses its actual vocab rank. Remaining entries + /// use 1-based top-k ranks matching the engine's returned candidate + /// order. pub rank: u32, } /// Logprob payload for one sequence position. /// -/// This is the semantic Rust representation used by the public client API after the lower-level -/// ndarray/tensor wire payload has been decoded. +/// This is the semantic Rust representation used by the public client API after +/// the lower-level ndarray/tensor wire payload has been decoded. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct PositionLogprobs { pub entries: Vec, } impl PositionLogprobs { - /// Convert one decoded logprobs row into this per-position form by grouping each token/logprob - /// pair together with the sampled/selected token's actual vocab rank. + /// Convert one decoded logprobs row into this per-position form by grouping + /// each token/logprob pair together with the sampled/selected token's + /// actual vocab rank. fn from_decoded_row(token_ids: &[u32], logprobs: &[f32], sampled_rank: u32) -> Result { if token_ids.len() != logprobs.len() { bail_ext_value_decode!( @@ -68,15 +71,18 @@ impl PositionLogprobs { /// Decoded per-request logprobs payload for one engine-core output. /// -/// Unlike the Python wire payload, this public Rust type is already fully semantic: one -/// [`PositionLogprobs`] per scored position, each containing the sampled/selected token plus any -/// returned top-k alternatives for that same position. +/// Unlike the Python wire payload, this public Rust type is already fully +/// semantic: one [`PositionLogprobs`] per scored position, each containing the +/// sampled/selected token plus any returned top-k alternatives for that same +/// position. /// -/// The Python engine still sends logprobs as ndarray/tensor-shaped wire tuples. Rust resolves that -/// lower-level representation during decode and exposes only this per-position form to callers. +/// The Python engine still sends logprobs as ndarray/tensor-shaped wire tuples. +/// Rust resolves that lower-level representation during decode and exposes only +/// this per-position form to callers. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct Logprobs { - /// One decoded logprobs record per scored position in this engine-core output. + /// One decoded logprobs record per scored position in this engine-core + /// output. pub positions: Vec, } @@ -92,12 +98,14 @@ impl Logprobs { } } -/// Output field wrapper that is initially deserialized from the Python wire shape, then resolved -/// into [`Logprobs`] before the decoded message is returned to callers. +/// Output field wrapper that is initially deserialized from the Python wire +/// shape, then resolved into [`Logprobs`] before the decoded message is +/// returned to callers. #[derive(Clone, PartialEq, Debug, EnumAsInner)] pub enum MaybeWireLogprobs { - /// The logprobs are still in the wire format and need to be resolved by looking up aux frames - /// and decoding raw views. Should only be used internally during deserialization. + /// The logprobs are still in the wire format and need to be resolved by + /// looking up aux frames and decoding raw views. Should only be used + /// internally during deserialization. Wire(Box), /// The actual decoded logprobs value, Direct(Logprobs), @@ -149,8 +157,8 @@ impl Serialize for MaybeWireLogprobs { } impl MaybeWireLogprobs { - /// Resolve the wire representation into decoded logprobs by looking up aux frames and decoding - /// raw views as needed. + /// Resolve the wire representation into decoded logprobs by looking up aux + /// frames and decoding raw views as needed. fn resolve(self, frames: &[Frame], field_prefix: &str) -> Result where Frame: AsRef<[u8]>, @@ -163,8 +171,8 @@ impl MaybeWireLogprobs { } impl EngineCoreOutputs { - /// Resolve all wire-format fields in-place by looking up aux frames and decoding raw-view - /// payloads as needed. + /// Resolve all wire-format fields in-place by looking up aux frames and + /// decoding raw-view payloads as needed. fn resolve_in_place(&mut self, frames: &[Frame]) -> Result<()> where Frame: AsRef<[u8]>, @@ -177,8 +185,8 @@ impl EngineCoreOutputs { } impl EngineCoreOutput { - /// Resolve all wire-format fields in-place by looking up aux frames and decoding raw-view - /// payloads as needed. + /// Resolve all wire-format fields in-place by looking up aux frames and + /// decoding raw-view payloads as needed. fn resolve_in_place(&mut self, frames: &[Frame]) -> Result<()> where Frame: AsRef<[u8]>, @@ -196,15 +204,12 @@ impl EngineCoreOutput { impl WireLogprobs { /// Convert semantic per-position logprobs into the Python wire tuple shape. /// - /// This exists mainly so Rust-side tests can inject semantic logprobs into mocked engine-core - /// outputs without manually building ndarray raw-view tuples. + /// This exists mainly so Rust-side tests can inject semantic logprobs into + /// mocked engine-core outputs without manually building ndarray + /// raw-view tuples. fn from_direct(value: &Logprobs) -> std::result::Result { let rows = value.positions.len(); - let cols = value - .positions - .first() - .map(|position| position.entries.len()) - .unwrap_or(0); + let cols = value.positions.first().map(|position| position.entries.len()).unwrap_or(0); let mut token_ids = Vec::with_capacity(rows.saturating_mul(cols).saturating_mul(8)); let mut logprobs = Vec::with_capacity(rows.saturating_mul(cols).saturating_mul(4)); @@ -248,8 +253,9 @@ impl WireLogprobs { }) } - /// Resolve the wire-format logprobs into semantic [`Logprobs`] records by looking up aux - /// frames, decoding raw views, and grouping each row into one [`PositionLogprobs`]. + /// Resolve the wire-format logprobs into semantic [`Logprobs`] records by + /// looking up aux frames, decoding raw views, and grouping each row + /// into one [`PositionLogprobs`]. fn resolve(self, frames: &[Frame], field_prefix: &str) -> Result where Frame: AsRef<[u8]>, @@ -309,15 +315,13 @@ impl WireLogprobs { } } -/// Decode one ordinary or multipart engine-core output message into the strong typed public -/// protocol shape. +/// Decode one ordinary or multipart engine-core output message into the strong +/// typed public protocol shape. pub fn decode_engine_core_outputs(frames: &[Frame]) -> Result where Frame: AsRef<[u8]>, { - let first_frame = frames - .first() - .ok_or_else(|| ext_value_decode!("missing output frame"))?; + let first_frame = frames.first().ok_or_else(|| ext_value_decode!("missing output frame"))?; let mut outputs: EngineCoreOutputs = decode_msgpack(first_frame.as_ref())?; outputs.resolve_in_place(frames)?; diff --git a/src/engine-core-client/src/protocol/logprobs/tests.rs b/src/engine-core-client/src/protocol/logprobs/tests.rs index 84133489..7408b98f 100644 --- a/src/engine-core-client/src/protocol/logprobs/tests.rs +++ b/src/engine-core-client/src/protocol/logprobs/tests.rs @@ -185,12 +185,7 @@ fn decodes_inline_new_logprobs() { )))]; let decoded = decode_engine_core_outputs(&frames).unwrap(); - let logprobs = decoded.outputs[0] - .new_logprobs - .clone() - .unwrap() - .into_direct() - .unwrap(); + let logprobs = decoded.outputs[0].new_logprobs.clone().unwrap().into_direct().unwrap(); assert_eq!(logprobs, expected_sample_logprobs()); assert_eq!( decoded.finished_requests, @@ -221,12 +216,7 @@ fn decodes_multipart_new_logprobs() { ]; let decoded = decode_engine_core_outputs(&frames).unwrap(); - let logprobs = decoded.outputs[0] - .new_logprobs - .clone() - .unwrap() - .into_direct() - .unwrap(); + let logprobs = decoded.outputs[0].new_logprobs.clone().unwrap().into_direct().unwrap(); assert_eq!(logprobs, expected_sample_logprobs()); } @@ -263,12 +253,7 @@ fn decodes_big_endian_payloads() { None, )))]; let decoded = decode_engine_core_outputs(&frames).unwrap(); - let logprobs = decoded.outputs[0] - .new_logprobs - .clone() - .unwrap() - .into_direct() - .unwrap(); + let logprobs = decoded.outputs[0].new_logprobs.clone().unwrap().into_direct().unwrap(); assert_eq!( logprobs, Logprobs { diff --git a/src/engine-core-client/src/protocol/logprobs/wire.rs b/src/engine-core-client/src/protocol/logprobs/wire.rs index 014a6f24..e4e946b7 100644 --- a/src/engine-core-client/src/protocol/logprobs/wire.rs +++ b/src/engine-core-client/src/protocol/logprobs/wire.rs @@ -8,12 +8,12 @@ use serde_tuple::{Deserialize_tuple, Serialize_tuple}; /// const CUSTOM_TYPE_RAW_VIEW: i8 = 3; -/// Python wire representation of `LogprobsLists` / `LogprobsTensors` before aux-frame -/// references and raw-view payloads are resolved. +/// Python wire representation of `LogprobsLists` / `LogprobsTensors` before +/// aux-frame references and raw-view payloads are resolved. /// -/// This mirrors the tuple shape emitted by Python engine-core so serde can first deserialize the -/// raw wire payload before the Rust client converts it into semantic per-position logprobs -/// records. +/// This mirrors the tuple shape emitted by Python engine-core so serde can +/// first deserialize the raw wire payload before the Rust client converts it +/// into semantic per-position logprobs records. /// /// Original Python definition: /// @@ -26,20 +26,20 @@ pub struct WireLogprobs { /// Wire array with shape `[num_positions]`. /// /// Python uses the field name `sampled_token_ranks` for sample logprobs and - /// `selected_token_ranks` for prompt logprobs. Rust keeps one neutral field because both - /// payloads share the same wire representation. + /// `selected_token_ranks` for prompt logprobs. Rust keeps one neutral field + /// because both payloads share the same wire representation. pub token_ranks: WireNdArray, - /// Preserved only for wire compatibility with batch-level Python tensors. Scheduler-sliced - /// per-request outputs should emit `None` here, and the semantic Rust decoder rejects any - /// other value. + /// Preserved only for wire compatibility with batch-level Python tensors. + /// Scheduler-sliced per-request outputs should emit `None` here, and + /// the semantic Rust decoder rejects any other value. #[serde(default)] pub cu_num_generated_tokens: Option>, } /// Python ndarray/tensor wire tuple encoded as `(dtype, shape, data)`. /// -/// This matches the custom msgpack representation built by Python `serial_utils.encode_ndarray` -/// / `encode_tensor`. +/// This matches the custom msgpack representation built by Python +/// `serial_utils.encode_ndarray` / `encode_tensor`. #[derive(Debug, Clone, PartialEq, Serialize_tuple, Deserialize_tuple)] pub struct WireNdArray { pub dtype: String, @@ -49,11 +49,12 @@ pub struct WireNdArray { /// Python array payload reference inside [`WireNdArray`]. /// -/// The data can be either an inline msgpack raw-view extension or an index into the multipart -/// aux-frame list carried alongside the primary msgpack frame. +/// The data can be either an inline msgpack raw-view extension or an index into +/// the multipart aux-frame list carried alongside the primary msgpack frame. #[derive(Debug, Clone, PartialEq)] pub enum WireArrayData { - /// The index of the aux frame where the raw bytes of this array/tensor are stored. + /// The index of the aux frame where the raw bytes of this array/tensor are + /// stored. AuxIndex(usize), /// The raw bytes of this array/tensor. RawView(Vec), @@ -70,12 +71,11 @@ impl<'de> Deserialize<'de> for WireArrayData { Value::Ext(tag, _) => Err(serde::de::Error::custom(format!( "unsupported extension type code {tag}" ))), - Value::Integer(index) => index - .as_u64() - .map(|index| Self::AuxIndex(index as usize)) - .ok_or_else(|| { + Value::Integer(index) => { + index.as_u64().map(|index| Self::AuxIndex(index as usize)).ok_or_else(|| { serde::de::Error::custom("aux frame index must be a non-negative integer") - }), + }) + } other => Err(serde::de::Error::custom(format!( "expected raw-view ext or aux frame index, got {other:?}" ))), diff --git a/src/engine-core-client/src/protocol/mod.rs b/src/engine-core-client/src/protocol/mod.rs index c9bd98c7..1cd76c25 100644 --- a/src/engine-core-client/src/protocol/mod.rs +++ b/src/engine-core-client/src/protocol/mod.rs @@ -16,10 +16,10 @@ use crate::protocol::stats::{PrefillStats, SchedulerStats}; // TODO: This module currently mixes reusable frontend-facing semantic types // (for example `FinishReason`, `StopReason`, `RequestOutputKind`, and future // cleaned-up frontend sampling types) with engine-core-specific wire DTOs and -// handshake/control messages. While the Rust frontend is still evolving quickly, -// keep them co-located here for iteration speed. Once the higher-level API -// boundary stabilizes, move the truly reusable semantic types into a lower-level -// common crate and keep the engine transport/wire messages here. +// handshake/control messages. While the Rust frontend is still evolving +// quickly, keep them co-located here for iteration speed. Once the higher-level +// API boundary stabilizes, move the truly reusable semantic types into a +// lower-level common crate and keep the engine transport/wire messages here. /// Dynamic msgpack value used for schema positions that are preserved but not /// yet strongly typed in the early-stage Rust client. @@ -231,7 +231,8 @@ pub struct EngineCoreSamplingParams { /// Complete stop-token set used by engine-core for `min_tokens` masking. /// /// This mirrors Python's internal `_all_stop_token_ids` field and should - /// contain explicit `stop_token_ids` plus any frontend-derived EOS token IDs. + /// contain explicit `stop_token_ids` plus any frontend-derived EOS token + /// IDs. #[serde(rename = "_all_stop_token_ids")] pub all_stop_token_ids: BTreeSet, /// Logit biases to apply during sampling. @@ -247,14 +248,17 @@ pub struct EngineCoreSamplingParams { /// Parameters for configuring structured outputs (guided decoding). #[serde(default)] pub structured_outputs: Option, - /// Specific token IDs for which log probabilities should be returned at each position. + /// Specific token IDs for which log probabilities should be returned at + /// each position. /// - /// When set, the engine returns logprobs for exactly these tokens in addition to the - /// sampled/scored token. Mutually exclusive with the `logprobs` count field in practice. + /// When set, the engine returns logprobs for exactly these tokens in + /// addition to the sampled/scored token. Mutually exclusive with the + /// `logprobs` count field in practice. #[serde(default)] pub logprob_token_ids: Option>, - /// If `Some(true)`, the request will not attempt to read from the prefix cache; newly - /// computed blocks may still populate the cache. `None` defers to engine-core defaults. + /// If `Some(true)`, the request will not attempt to read from the prefix + /// cache; newly computed blocks may still populate the cache. `None` + /// defers to engine-core defaults. #[serde(default)] pub skip_reading_prefix_cache: Option, /// Additional request parameters for custom extensions (from `vllm_xargs`). @@ -300,10 +304,12 @@ impl EngineCoreSamplingParams { pub struct EngineCoreRequest { pub request_id: String, pub prompt_token_ids: Option>, - /// Multimodal features are preserved in the schema but not yet strongly typed. + /// Multimodal features are preserved in the schema but not yet strongly + /// typed. pub mm_features: Option, pub sampling_params: Option, - /// Pooling parameters are preserved in the schema but not yet strongly typed. + /// Pooling parameters are preserved in the schema but not yet strongly + /// typed. pub pooling_params: Option, pub arrival_time: f64, #[serde(default)] @@ -362,8 +368,8 @@ pub struct EngineCoreUtilityRequest { } impl EngineCoreUtilityRequest { - /// Create a new utility request with the given strongly typed arguments, encoding them into the - /// expected msgpack value format. + /// Create a new utility request with the given strongly typed arguments, + /// encoding them into the expected msgpack value format. pub fn new( client_index: u32, call_id: i64, @@ -402,10 +408,12 @@ impl EngineCoreUtilityRequest { pub struct EngineCoreOutput { pub request_id: String, pub new_token_ids: Vec, - /// Decoded sample logprobs for the newly generated positions in this output. + /// Decoded sample logprobs for the newly generated positions in this + /// output. #[serde(default)] pub new_logprobs: Option, - /// Decoded prompt logprobs for the scored prompt positions emitted in this output. + /// Decoded prompt logprobs for the scored prompt positions emitted in this + /// output. #[serde(default)] pub new_prompt_logprobs_tensors: Option, #[serde(default)] @@ -458,8 +466,9 @@ pub struct UtilityOutput { /// #[derive(Debug, Clone, PartialEq, Serialize_tuple, Deserialize_tuple)] pub struct UtilityResultEnvelope { - /// Recursive type information encoded on Python side, serving as the hint for deserialization. - /// We don't care it here as in Rust frontend all utility calls are strongly-typed. + /// Recursive type information encoded on Python side, serving as the hint + /// for deserialization. We don't care it here as in Rust frontend all + /// utility calls are strongly-typed. #[serde(default)] type_info: Option, /// The actual utility result. @@ -520,7 +529,8 @@ pub struct EngineCoreOutputs { pub utility_output: Option, #[serde(default)] pub finished_requests: Option>, - /// In DP mode, signals that the current wave finished and engines are paused. + /// In DP mode, signals that the current wave finished and engines are + /// paused. #[serde(default)] pub wave_complete: Option, /// In DP mode, signals that a request arrived for an old wave and the next @@ -544,7 +554,8 @@ where }) } -/// Decode a msgpack payload into a strongly typed protocol value, with enhanced error reporting. +/// Decode a msgpack payload into a strongly typed protocol value, with enhanced +/// error reporting. pub fn decode_msgpack(bytes: &[u8]) -> Result where T: for<'de> Deserialize<'de>, diff --git a/src/engine-core-client/src/test_utils.rs b/src/engine-core-client/src/test_utils.rs index ef69983e..3f12bd02 100644 --- a/src/engine-core-client/src/test_utils.rs +++ b/src/engine-core-client/src/test_utils.rs @@ -61,7 +61,8 @@ fn ready_message(status: &str) -> ReadyMessage { } } -/// Construct a default ready response payload for mock engine input registration. +/// Construct a default ready response payload for mock engine input +/// registration. fn ready_response_payload() -> Vec { rmp_serde::to_vec_named(&EngineCoreReadyResponse { max_model_len: 4096, @@ -71,11 +72,14 @@ fn ready_response_payload() -> Vec { .expect("encode ready response payload") } -/// Coordinator-side sockets connected by one mock engine when coordinator mode is enabled. +/// Coordinator-side sockets connected by one mock engine when coordinator mode +/// is enabled. pub struct MockCoordinatorConnections { - /// Subscription socket that receives coordinator broadcasts such as `START_DP_WAVE`. + /// Subscription socket that receives coordinator broadcasts such as + /// `START_DP_WAVE`. pub input_sub: SubSocket, - /// Push socket used to send coordinator-only `EngineCoreOutputs` back to the frontend. + /// Push socket used to send coordinator-only `EngineCoreOutputs` back to + /// the frontend. pub output_push: PushSocket, } @@ -87,12 +91,13 @@ pub struct MockEngineConnections { pub dealer: DealerSocket, /// Socket used to publish normal request outputs back to the frontend. pub push: PushSocket, - /// Optional coordinator sockets when the client enabled the in-process coordinator. + /// Optional coordinator sockets when the client enabled the in-process + /// coordinator. pub coordinator: Option, } -/// Complete the engine-core handshake and connect mock input/output sockets plus optional -/// coordinator sockets. +/// Complete the engine-core handshake and connect mock input/output sockets +/// plus optional coordinator sockets. pub async fn setup_mock_engine_connections( engine_handshake: String, engine_id: impl Into, @@ -125,11 +130,7 @@ pub async fn setup_mock_engine_connections( .await .expect("send HELLO ready message"); - let init_frames = handshake - .recv() - .await - .expect("receive handshake init message") - .into_vec(); + let init_frames = handshake.recv().await.expect("receive handshake init message").into_vec(); assert_eq!(init_frames.len(), 1); let init: HandshakeInitMessage = rmp_serde::from_slice(init_frames[0].as_ref()).expect("decode handshake init message"); @@ -172,11 +173,8 @@ pub async fn setup_mock_engine_connections( .await .expect("connect mock engine coordinator output socket"); - let ready = input_sub - .recv() - .await - .expect("receive coordinator READY marker") - .into_vec(); + let ready = + input_sub.recv().await.expect("receive coordinator READY marker").into_vec(); assert_eq!(ready.len(), 1); assert_eq!(ready[0].as_ref(), b"READY"); @@ -204,7 +202,8 @@ pub async fn setup_mock_engine_connections( } } -/// Connect one mock engine directly to already-bootstrapped frontend input/output sockets. +/// Connect one mock engine directly to already-bootstrapped frontend +/// input/output sockets. pub async fn setup_bootstrapped_mock_engine( input_address: String, output_address: String, @@ -225,27 +224,23 @@ pub async fn setup_bootstrapped_mock_engine( let mut input_options = SocketOptions::default(); input_options.peer_identity(peer_identity); let mut dealer = DealerSocket::with_options(input_options); - dealer - .connect(&input_address) - .await - .expect("connect mock engine input socket"); + dealer.connect(&input_address).await.expect("connect mock engine input socket"); dealer .send(ZmqMessage::from(ready_response_payload())) .await .expect("send mock engine input ready frame"); let mut push = PushSocket::new(); - push.connect(&output_address) - .await - .expect("connect mock engine output socket"); + push.connect(&output_address).await.expect("connect mock engine output socket"); (dealer, push) } /// Complete the engine-core handshake and connect mock input/output sockets. /// -/// This returns the decoded handshake init message plus the `DealerSocket` used to receive client -/// requests and the `PushSocket` used to send engine outputs back to the client. +/// This returns the decoded handshake init message plus the `DealerSocket` used +/// to receive client requests and the `PushSocket` used to send engine outputs +/// back to the client. pub async fn setup_mock_engine_with_init( engine_handshake: String, engine_id: impl Into, diff --git a/src/engine-core-client/src/tests/client.rs b/src/engine-core-client/src/tests/client.rs index 619dfcaa..4c93745e 100644 --- a/src/engine-core-client/src/tests/client.rs +++ b/src/engine-core-client/src/tests/client.rs @@ -204,9 +204,7 @@ async fn send_outputs(push: &mut PushSocket, outputs: EngineCoreOutputs) { } async fn send_output_frames(push: &mut PushSocket, frames: Vec) { - push.send(ZmqMessage::try_from(frames).unwrap()) - .await - .unwrap(); + push.send(ZmqMessage::try_from(frames).unwrap()).await.unwrap(); } async fn recv_engine_message(dealer: &mut DealerSocket) -> Vec { @@ -340,10 +338,7 @@ fn init_tracing() { TRACING.call_once(|| { let filter = EnvFilter::try_from_default_env() .unwrap_or_else(|_| EnvFilter::new("vllm_engine_core_client=debug")); - let _ = tracing_subscriber::fmt() - .with_test_writer() - .with_env_filter(filter) - .try_init(); + let _ = tracing_subscriber::fmt().with_test_writer().with_env_filter(filter).try_init(); }); } @@ -487,10 +482,8 @@ async fn coordinator_wave_control_tracks_pause_running_and_rebroadcasts() { let handshake_address = handshake_address.clone(); async move { let mut engine = setup_mock_engine_connections(handshake_address, &[0x00, 0x00]).await; - let mut coordinator = engine - .coordinator - .take() - .expect("coordinator sockets should be present"); + let mut coordinator = + engine.coordinator.take().expect("coordinator sockets should be present"); let (wave, exclude_engine) = recv_start_dp_wave(&mut coordinator.input_sub).await; assert_eq!((wave, exclude_engine), (0, 0)); @@ -568,10 +561,8 @@ async fn coordinator_wave_control_tracks_pause_running_and_rebroadcasts() { let handshake_address = handshake_address.clone(); async move { let mut engine = setup_mock_engine_connections(handshake_address, &[0x01, 0x00]).await; - let mut coordinator = engine - .coordinator - .take() - .expect("coordinator sockets should be present"); + let mut coordinator = + engine.coordinator.take().expect("coordinator sockets should be present"); let (wave, exclude_engine) = recv_start_dp_wave(&mut coordinator.input_sub).await; assert_eq!((wave, exclude_engine), (0, 0)); @@ -645,12 +636,7 @@ async fn coordinator_wave_control_tracks_pause_running_and_rebroadcasts() { .unwrap(); assert_eq!(final_1.request_id, "req-1"); assert_eq!(final_1.finish_reason, Some(EngineCoreFinishReason::Length)); - assert!( - timeout(Duration::from_secs(1), stream_1.next()) - .await - .unwrap() - .is_none() - ); + assert!(timeout(Duration::from_secs(1), stream_1.next()).await.unwrap().is_none()); let final_2 = timeout(Duration::from_secs(1), stream_2.next()) .await @@ -659,12 +645,7 @@ async fn coordinator_wave_control_tracks_pause_running_and_rebroadcasts() { .unwrap(); assert_eq!(final_2.request_id, "req-2"); assert_eq!(final_2.finish_reason, Some(EngineCoreFinishReason::Length)); - assert!( - timeout(Duration::from_secs(1), stream_2.next()) - .await - .unwrap() - .is_none() - ); + assert!(timeout(Duration::from_secs(1), stream_2.next()).await.unwrap().is_none()); tokio::time::sleep(Duration::from_millis(100)).await; @@ -676,12 +657,7 @@ async fn coordinator_wave_control_tracks_pause_running_and_rebroadcasts() { .unwrap(); assert_eq!(final_3.request_id, "req-3"); assert_eq!(final_3.finish_reason, Some(EngineCoreFinishReason::Length)); - assert!( - timeout(Duration::from_secs(1), stream_3.next()) - .await - .unwrap() - .is_none() - ); + assert!(timeout(Duration::from_secs(1), stream_3.next()).await.unwrap().is_none()); tokio::time::sleep(Duration::from_millis(100)).await; @@ -703,10 +679,8 @@ async fn coordinator_rebroadcasts_engine_start_wave_control() { let handshake_address = handshake_address.clone(); async move { let mut engine = setup_mock_engine_connections(handshake_address, &[0x00, 0x00]).await; - let mut coordinator = engine - .coordinator - .take() - .expect("coordinator sockets should be present"); + let mut coordinator = + engine.coordinator.take().expect("coordinator sockets should be present"); let (wave, exclude_engine) = recv_start_dp_wave(&mut coordinator.input_sub).await; assert_eq!((wave, exclude_engine), (4, 1)); @@ -720,10 +694,8 @@ async fn coordinator_rebroadcasts_engine_start_wave_control() { let handshake_address = handshake_address.clone(); async move { let mut engine = setup_mock_engine_connections(handshake_address, &[0x01, 0x00]).await; - let mut coordinator = engine - .coordinator - .take() - .expect("coordinator sockets should be present"); + let mut coordinator = + engine.coordinator.take().expect("coordinator sockets should be present"); send_outputs( &mut coordinator.output_push, @@ -773,10 +745,8 @@ async fn coordinator_accepts_stats_only_outputs() { let (shutdown_tx, shutdown_rx) = oneshot::channel(); let engine_task = tokio::spawn(async move { let mut engine = setup_mock_engine_connections(handshake_address, &[0x00, 0x00]).await; - let mut coordinator = engine - .coordinator - .take() - .expect("coordinator sockets should be present"); + let mut coordinator = + engine.coordinator.take().expect("coordinator sockets should be present"); let (wave, exclude_engine) = recv_start_dp_wave(&mut coordinator.input_sub).await; assert_eq!((wave, exclude_engine), (0, 0)); @@ -831,15 +801,9 @@ async fn coordinator_accepts_stats_only_outputs() { ) .await; - let mut stream = client - .call(sample_request_with_id("req-stats")) - .await - .unwrap(); - let final_output = timeout(Duration::from_secs(1), stream.next()) - .await - .unwrap() - .unwrap() - .unwrap(); + let mut stream = client.call(sample_request_with_id("req-stats")).await.unwrap(); + let final_output = + timeout(Duration::from_secs(1), stream.next()).await.unwrap().unwrap().unwrap(); assert_eq!(final_output.request_id, "req-stats"); assert_eq!( final_output.finish_reason, @@ -1054,11 +1018,7 @@ async fn duplicate_request_ids_are_rejected_without_sending_a_second_add() { let request_1: EngineCoreRequest = rmp_serde::from_slice(&add_1[1]).unwrap(); assert_eq!(request_1.request_id, "req-1"); - assert!( - timeout(Duration::from_millis(200), dealer.recv()) - .await - .is_err() - ); + assert!(timeout(Duration::from_millis(200), dealer.recv()).await.is_err()); send_outputs( push, @@ -1100,21 +1060,13 @@ async fn duplicate_request_ids_are_rejected_without_sending_a_second_add() { Error::DuplicateRequestId { request_id } if request_id == "req-1" )); - let final_output = timeout(Duration::from_secs(1), stream.next()) - .await - .unwrap() - .unwrap() - .unwrap(); + let final_output = + timeout(Duration::from_secs(1), stream.next()).await.unwrap().unwrap().unwrap(); assert_eq!( final_output.finish_reason, Some(EngineCoreFinishReason::Length) ); - assert!( - timeout(Duration::from_secs(1), stream.next()) - .await - .unwrap() - .is_none() - ); + assert!(timeout(Duration::from_secs(1), stream.next()).await.unwrap().is_none()); let _ = shutdown_tx.send(()); engine_task.await.unwrap(); client.shutdown().await.unwrap(); @@ -1144,11 +1096,7 @@ async fn finished_requests_without_final_output_is_treated_as_unexpected_close() ) .await; - assert!( - timeout(Duration::from_millis(200), dealer.recv()) - .await - .is_err() - ); + assert!(timeout(Duration::from_millis(200), dealer.recv()).await.is_err()); let _ = push; }) }, @@ -1177,12 +1125,7 @@ async fn finished_requests_without_final_output_is_treated_as_unexpected_close() error, Error::RequestStreamClosed { request_id } if request_id == "req-1" )); - assert!( - timeout(Duration::from_secs(1), stream.next()) - .await - .unwrap() - .is_none() - ); + assert!(timeout(Duration::from_secs(1), stream.next()).await.unwrap().is_none()); let _ = shutdown_tx.send(()); engine_task.await.unwrap(); @@ -1212,9 +1155,8 @@ async fn dropping_a_live_stream_triggers_abort() { ) .await; - let abort = timeout(Duration::from_secs(1), recv_engine_message(dealer)) - .await - .unwrap(); + let abort = + timeout(Duration::from_secs(1), recv_engine_message(dealer)).await.unwrap(); assert_eq!(abort[0].as_ref(), &[0x01]); let aborted_ids: Vec = rmp_serde::from_slice(&abort[1]).unwrap(); assert_eq!(aborted_ids, vec!["req-1".to_string()]); @@ -1236,11 +1178,7 @@ async fn dropping_a_live_stream_triggers_abort() { .await; let mut stream = client.call(sample_request()).await.unwrap(); - let first = timeout(Duration::from_secs(1), stream.next()) - .await - .unwrap() - .unwrap() - .unwrap(); + let first = timeout(Duration::from_secs(1), stream.next()).await.unwrap().unwrap().unwrap(); assert_eq!(first.new_token_ids, vec![99]); drop(stream); @@ -1298,10 +1236,7 @@ async fn dispatcher_failure_propagates_to_streams_and_future_calls() { assert!(is_decode_error(&error_1)); assert!(is_decode_error(&error_2)); assert!(is_decode_error( - client - .health_error() - .as_deref() - .expect("health error recorded") + client.health_error().as_deref().expect("health error recorded") )); let abort_error = client.abort(&["req-1".to_string()]).await.unwrap_err(); @@ -1396,10 +1331,8 @@ async fn call_utility_failure_message_surfaces_as_error() { let utility = recv_engine_message(dealer).await; assert_eq!(utility[0].as_ref(), &[0x03]); let payload = decode_value(&utility[1]); - let call_id = payload - .as_array() - .and_then(|array| array[1].as_i64()) - .expect("call_id"); + let call_id = + payload.as_array().and_then(|array| array[1].as_i64()).expect("call_id"); send_outputs( push, @@ -1430,10 +1363,7 @@ async fn call_utility_failure_message_surfaces_as_error() { ) .await; - let error = client - .call_utility::("is_sleeping", ()) - .await - .unwrap_err(); + let error = client.call_utility::("is_sleeping", ()).await.unwrap_err(); assert!(matches!( error, Error::UtilityCallFailed { @@ -1481,16 +1411,10 @@ async fn dispatcher_failure_propagates_to_waiting_utility_calls() { ) .await; - let error = client - .call_utility::("is_sleeping", ()) - .await - .unwrap_err(); + let error = client.call_utility::("is_sleeping", ()).await.unwrap_err(); assert!(is_decode_error(&error)); assert!(is_decode_error( - client - .health_error() - .as_deref() - .expect("health error recorded") + client.health_error().as_deref().expect("health error recorded") )); let _ = shutdown_tx.send(()); @@ -1558,9 +1482,7 @@ async fn engine_core_dead_sentinel_marks_client_unhealthy_and_sticks() { engine_id.clone(), |_dealer, push| { Box::pin(async move { - push.send(ZmqMessage::from(ENGINE_CORE_DEAD_SENTINEL.to_vec())) - .await - .unwrap(); + push.send(ZmqMessage::from(ENGINE_CORE_DEAD_SENTINEL.to_vec())).await.unwrap(); }) }, ); @@ -1592,10 +1514,7 @@ async fn engine_core_dead_sentinel_marks_client_unhealthy_and_sticks() { Some(Error::EngineCoreDead) )); - let error = client - .call_utility::("is_sleeping", ()) - .await - .unwrap_err(); + let error = client.call_utility::("is_sleeping", ()).await.unwrap_err(); assert!( is_dispatcher_closed(&error) || is_engine_core_dead(&error), "unexpected error: {error:?}" @@ -1658,10 +1577,7 @@ async fn output_loop_failure_marks_client_unhealthy_and_records_first_error() { assert!(!client.is_healthy()); assert!(is_decode_error( - client - .health_error() - .as_deref() - .expect("health error recorded") + client.health_error().as_deref().expect("health error recorded") )); let _ = shutdown_tx.send(()); @@ -1714,13 +1630,7 @@ async fn client_decodes_multipart_logprob_outputs() { output.output.finish_reason, Some(EngineCoreFinishReason::Length) ); - expect_sample_logprobs( - output - .output - .new_logprobs - .as_ref() - .expect("logprobs decoded"), - ); + expect_sample_logprobs(output.output.new_logprobs.as_ref().expect("logprobs decoded")); let _ = shutdown_tx.send(()); engine_task.await.unwrap(); @@ -1829,14 +1739,8 @@ async fn multi_engine_client_shares_transport_and_routes_by_inflight_count() { ) .await; - let init_0 = timeout(Duration::from_secs(1), init_rx_0) - .await - .unwrap() - .unwrap(); - let init_1 = timeout(Duration::from_secs(1), init_rx_1) - .await - .unwrap() - .unwrap(); + let init_0 = timeout(Duration::from_secs(1), init_rx_0).await.unwrap().unwrap(); + let init_1 = timeout(Duration::from_secs(1), init_rx_1).await.unwrap().unwrap(); assert_eq!(init_0.addresses.inputs, vec![ipc.input_endpoint()]); assert_eq!(init_1.addresses.inputs, vec![ipc.input_endpoint()]); assert_eq!(init_0.addresses.outputs, vec![ipc.output_endpoint()]); @@ -1855,17 +1759,11 @@ async fn multi_engine_client_shares_transport_and_routes_by_inflight_count() { let mut stream_1 = client.call(sample_request_with_id("req-1")).await.unwrap(); let mut stream_2 = client.call(sample_request_with_id("req-2")).await.unwrap(); assert_eq!( - timeout(Duration::from_secs(1), engine_0_seen_rx.recv()) - .await - .unwrap() - .unwrap(), + timeout(Duration::from_secs(1), engine_0_seen_rx.recv()).await.unwrap().unwrap(), "req-1" ); assert_eq!( - timeout(Duration::from_secs(1), engine_1_seen_rx) - .await - .unwrap() - .unwrap(), + timeout(Duration::from_secs(1), engine_1_seen_rx).await.unwrap().unwrap(), "req-2" ); @@ -1881,10 +1779,7 @@ async fn multi_engine_client_shares_transport_and_routes_by_inflight_count() { let mut stream_3 = client.call(sample_request_with_id("req-3")).await.unwrap(); assert_eq!( - timeout(Duration::from_secs(1), engine_0_seen_rx.recv()) - .await - .unwrap() - .unwrap(), + timeout(Duration::from_secs(1), engine_0_seen_rx.recv()).await.unwrap().unwrap(), "req-3" ); @@ -1908,24 +1803,9 @@ async fn multi_engine_client_shares_transport_and_routes_by_inflight_count() { assert_eq!(final_2.new_token_ids, vec![20]); assert_eq!(final_2.finish_reason, Some(EngineCoreFinishReason::Length)); - assert!( - timeout(Duration::from_secs(1), stream_1.next()) - .await - .unwrap() - .is_none() - ); - assert!( - timeout(Duration::from_secs(1), stream_2.next()) - .await - .unwrap() - .is_none() - ); - assert!( - timeout(Duration::from_secs(1), stream_3.next()) - .await - .unwrap() - .is_none() - ); + assert!(timeout(Duration::from_secs(1), stream_1.next()).await.unwrap().is_none()); + assert!(timeout(Duration::from_secs(1), stream_2.next()).await.unwrap().is_none()); + assert!(timeout(Duration::from_secs(1), stream_3.next()).await.unwrap().is_none()); let _ = shutdown_tx_0.send(()); let _ = shutdown_tx_1.send(()); @@ -2223,15 +2103,10 @@ fn python_msgpack_fixtures_match_rust_encoding() { let request_hex = lines.next().expect("missing request fixture line"); let outputs_hex = lines.next().expect("missing outputs fixture line"); let inline_logprobs_frames = lines.next().expect("missing inline logprobs fixture line"); - let multipart_logprobs_frames = lines - .next() - .expect("missing multipart logprobs fixture line"); - let inline_prompt_frames = lines - .next() - .expect("missing inline prompt logprobs fixture line"); - let multipart_prompt_frames = lines - .next() - .expect("missing multipart prompt logprobs fixture line"); + let multipart_logprobs_frames = lines.next().expect("missing multipart logprobs fixture line"); + let inline_prompt_frames = lines.next().expect("missing inline prompt logprobs fixture line"); + let multipart_prompt_frames = + lines.next().expect("missing multipart prompt logprobs fixture line"); let request_bytes = hex::decode(request_hex).unwrap(); let outputs_bytes = hex::decode(outputs_hex).unwrap(); @@ -2351,11 +2226,8 @@ async fn bootstrapped_connects_after_single_engine_registration() { let client = client_task.await.unwrap(); assert_eq!(client.engine_count(), 1); - let engine_ids = client - .engine_identities() - .into_iter() - .map(|id| id.to_vec()) - .collect::>(); + let engine_ids = + client.engine_identities().into_iter().map(|id| id.to_vec()).collect::>(); assert_eq!(engine_ids, vec![vec![0x00, 0x00]]); client.shutdown().await.unwrap(); @@ -2396,11 +2268,8 @@ async fn bootstrapped_connects_with_contiguous_engine_ids() { let client = client_task.await.unwrap(); assert_eq!(client.engine_count(), 2); - let engine_ids = client - .engine_identities() - .into_iter() - .map(|id| id.to_vec()) - .collect::>(); + let engine_ids = + client.engine_identities().into_iter().map(|id| id.to_vec()).collect::>(); assert_eq!(engine_ids, vec![vec![0x00, 0x00], vec![0x01, 0x00]]); client.shutdown().await.unwrap(); @@ -2553,9 +2422,7 @@ async fn bootstrapped_external_coordinator_updates_wave_ignores_counts_and_sends ) .await; - let final_output = timeout(Duration::from_secs(1), stream.next()) - .await - .unwrap(); + let final_output = timeout(Duration::from_secs(1), stream.next()).await.unwrap(); assert!(final_output.is_some()); client.shutdown().await.unwrap(); @@ -2630,9 +2497,7 @@ async fn bootstrapped_external_coordinator_running_state_suppresses_wakeup() { ) .await; - let final_output = timeout(Duration::from_secs(1), stream.next()) - .await - .unwrap(); + let final_output = timeout(Duration::from_secs(1), stream.next()).await.unwrap(); assert!(final_output.is_some()); client.shutdown().await.unwrap(); diff --git a/src/engine-core-client/src/transport.rs b/src/engine-core-client/src/transport.rs index 230e3901..bf7fb6e3 100644 --- a/src/engine-core-client/src/transport.rs +++ b/src/engine-core-client/src/transport.rs @@ -22,7 +22,8 @@ use crate::protocol::{ EngineCoreOutputs, decode_engine_core_outputs, decode_msgpack, encode_msgpack, }; -/// Dedicated single-frame sentinel emitted by Python `EngineCoreProc` when the engine dies. +/// Dedicated single-frame sentinel emitted by Python `EngineCoreProc` when the +/// engine dies. pub const ENGINE_CORE_DEAD_SENTINEL: &[u8] = b"ENGINE_CORE_DEAD"; /// Opaque routing identity of one engine on the frontend transport. @@ -47,12 +48,13 @@ impl EngineId { self.0 } - /// Parse the Python-compatible engine index encoded in the routing identity. + /// Parse the Python-compatible engine index encoded in the routing + /// identity. /// - /// Python `EngineCoreProc` currently uses a two-byte little-endian engine index - /// as its ROUTER/DEALER identity. Coordinator control messages such as - /// `START_DP_WAVE(exclude_engine_index)` need that engine-side index rather than - /// any frontend-local ordering. + /// Python `EngineCoreProc` currently uses a two-byte little-endian engine + /// index as its ROUTER/DEALER identity. Coordinator control messages + /// such as `START_DP_WAVE(exclude_engine_index)` need that engine-side + /// index rather than any frontend-local ordering. pub fn engine_index(&self) -> Option { if self.len() != 2 { return None; @@ -60,8 +62,8 @@ impl EngineId { Some(u16::from_le_bytes([self[0], self[1]]) as u32) } - /// Construct an engine id from the Python-compatible engine index encoding (two-byte - /// little-endian). + /// Construct an engine id from the Python-compatible engine index encoding + /// (two-byte little-endian). pub fn from_engine_index(value: u32) -> Self { Self(Bytes::copy_from_slice(&(value as u16).to_le_bytes())) } @@ -95,7 +97,8 @@ impl TryFrom for PeerIdentity { } } -/// Per-engine handshake result collected while bootstrapping one shared transport. +/// Per-engine handshake result collected while bootstrapping one shared +/// transport. #[derive(Clone, Debug)] pub struct ConnectedEngine { /// The identity of the connected engine. @@ -105,18 +108,19 @@ pub struct ConnectedEngine { pub ready_response: Option, } -/// Represents the connected shared transport plus all registered engines after a successful -/// multi-engine startup handshake. +/// Represents the connected shared transport plus all registered engines after +/// a successful multi-engine startup handshake. pub struct ConnectedTransport { - /// The local address of the shared input socket that all engines connect to for receiving - /// requests. + /// The local address of the shared input socket that all engines connect to + /// for receiving requests. pub input_address: String, - /// The local address of the shared output socket that all engines connect to for sending - /// responses. + /// The local address of the shared output socket that all engines connect + /// to for sending responses. pub output_address: String, /// All engines connected through the startup handshake. pub engines: Vec, - /// Optional engine-facing coordinator transport used for in-process wave coordination. + /// Optional engine-facing coordinator transport used for in-process wave + /// coordination. pub coordinator: Option, /// The sending half of the shared input socket. @@ -131,8 +135,8 @@ enum EngineStartupState { ReadyReceived, } -/// Connect to one or more engines through the startup handshake protocol, returning the shared -/// data-plane transport plus the registered engines. +/// Connect to one or more engines through the startup handshake protocol, +/// returning the shared data-plane transport plus the registered engines. pub async fn connect_handshake( handshake_address: &str, engine_count: usize, @@ -151,8 +155,8 @@ pub async fn connect_handshake( handshake_address, "waiting for engines to connect" ); - // 1. Bind shared local input/output sockets first so every engine receives the same data-plane - // addresses during handshake. + // 1. Bind shared local input/output sockets first so every engine receives the + // same data-plane addresses during handshake. debug!( local_host, ?ready_timeout, @@ -169,15 +173,16 @@ pub async fn connect_handshake( None }; - // 2. Bind the shared handshake socket once. All engines connect to this socket with their own - // identities, and startup order does not matter. + // 2. Bind the shared handshake socket once. All engines connect to this socket + // with their own identities, and startup order does not matter. let mut handshake_socket = RouterSocket::new(); handshake_socket.bind(handshake_address).await?; let mut engines = BTreeMap::new(); - // 3. Receive HELLO from every engine and send a matching INIT. When coordinator mode is - // enabled, the engines will not emit READY until the coordinator barrier below completes. + // 3. Receive HELLO from every engine and send a matching INIT. When coordinator + // mode is enabled, the engines will not emit READY until the coordinator + // barrier below completes. while engines.len() < engine_count { debug!( handshake_address, @@ -185,12 +190,12 @@ pub async fn connect_handshake( waiting_for = engine_count, "waiting for engine HELLO" ); - let message = timeout(ready_timeout, handshake_socket.recv()) - .await - .map_err(|_| Error::HandshakeTimeout { + let message = timeout(ready_timeout, handshake_socket.recv()).await.map_err(|_| { + Error::HandshakeTimeout { stage: "HELLO", timeout: ready_timeout, - })??; + } + })??; let (engine_id, handshake_message) = decode_handshake_message(message, None)?; match handshake_message.status.as_deref() { Some("HELLO") => { @@ -241,11 +246,10 @@ pub async fn connect_handshake( } } - // 4. Optional coordinator startup gate. Without coordinator there is nothing to do. + // 4. Optional coordinator startup gate. Without coordinator there is nothing to + // do. if let Some(coordinator) = coordinator.as_mut() { - coordinator - .wait_for_startup_gate(engine_count, ready_timeout) - .await?; + coordinator.wait_for_startup_gate(engine_count, ready_timeout).await?; } // 5. After the optional gate has opened, every engine may now send READY. @@ -253,19 +257,16 @@ pub async fn connect_handshake( debug!( handshake_address, connected = engines.len(), - ready = engines - .values() - .filter(|state| state.is_ready_received()) - .count(), + ready = engines.values().filter(|state| state.is_ready_received()).count(), waiting_for = engine_count, "waiting for engine READY" ); - let message = timeout(ready_timeout, handshake_socket.recv()) - .await - .map_err(|_| Error::HandshakeTimeout { + let message = timeout(ready_timeout, handshake_socket.recv()).await.map_err(|_| { + Error::HandshakeTimeout { stage: "READY", timeout: ready_timeout, - })??; + } + })??; let (engine_id, handshake_message) = decode_handshake_message(message, None)?; match handshake_message.status.as_deref() { Some("READY") => { @@ -296,9 +297,9 @@ pub async fn connect_handshake( } } - // 4. Wait for every engine to connect to the shared input socket and register itself. The - // `ready_response` is a placeholder; it is populated for each engine by - // `wait_for_input_registrations` below. + // 4. Wait for every engine to connect to the shared input socket and register + // itself. The `ready_response` is a placeholder; it is populated for each + // engine by `wait_for_input_registrations` below. let mut engines: Vec<_> = engines .into_keys() .map(|engine_id| ConnectedEngine { @@ -327,12 +328,13 @@ pub async fn connect_handshake( }) } -/// Bind to Python-supplied frontend transport addresses and wait for already-initialized engines -/// to register themselves on the input socket. +/// Bind to Python-supplied frontend transport addresses and wait for +/// already-initialized engines to register themselves on the input socket. /// -/// This path mirrors Python's externally managed `AsyncMPClient` bootstrap model: the addresses -/// are already fixed by the supervisor, and engine identities are synthesized from contiguous -/// rank order instead of being discovered through a Rust-owned handshake. +/// This path mirrors Python's externally managed `AsyncMPClient` bootstrap +/// model: the addresses are already fixed by the supervisor, and engine +/// identities are synthesized from contiguous rank order instead of being +/// discovered through a Rust-owned handshake. pub async fn connect_bootstrapped( input_address: &str, output_address: &str, @@ -420,8 +422,8 @@ fn decode_handshake_message( Ok((actual_id, handshake_message)) } -/// Send an INIT message to the engine with the local socket addresses for the engine to connect to, -/// using the handshake socket. +/// Send an INIT message to the engine with the local socket addresses for the +/// engine to connect to, using the handshake socket. async fn send_init_message( handshake_socket: &mut RouterSocket, engine_id: &EngineId, @@ -446,33 +448,33 @@ async fn send_init_message( Ok(()) } -/// Receive the input registration message from each engine and validate its identity. +/// Receive the input registration message from each engine and validate its +/// identity. /// /// Each registration contains 2 frames: `[identity, ready-payload]`. /// -/// Since vLLM commit `c8d98f81f676552c263f35bbde55e6edbe81b4e8` ("[Core] Simplify API server -/// handshake"), the payload is a msgpack-encoded [`EngineCoreReadyResponse`] carrying -/// post-initialization values such as `max_model_len`. +/// Since vLLM commit `c8d98f81f676552c263f35bbde55e6edbe81b4e8` ("[Core] +/// Simplify API server handshake"), the payload is a msgpack-encoded +/// [`EngineCoreReadyResponse`] carrying post-initialization values such as +/// `max_model_len`. /// -/// Older engines sent an empty second frame here just to establish the ROUTER/DEALER backchannel, -/// with no structured payload on the input socket. We continue to tolerate that legacy shape so -/// the frontend can still connect to slightly older local engine checkouts. +/// Older engines sent an empty second frame here just to establish the +/// ROUTER/DEALER backchannel, with no structured payload on the input socket. +/// We continue to tolerate that legacy shape so the frontend can still connect +/// to slightly older local engine checkouts. async fn wait_for_input_registrations( input_socket: &mut RouterSocket, engines: &mut [ConnectedEngine], ready_timeout: Duration, ) -> Result<()> { - let mut pending = engines - .iter() - .map(|e| e.engine_id.clone()) - .collect::>(); + let mut pending = engines.iter().map(|e| e.engine_id.clone()).collect::>(); while !pending.is_empty() { - let registration = timeout(ready_timeout, input_socket.recv()) - .await - .map_err(|_| Error::InputRegistrationTimeout { + let registration = timeout(ready_timeout, input_socket.recv()).await.map_err(|_| { + Error::InputRegistrationTimeout { timeout: ready_timeout, - })??; + } + })??; if registration.len() != 2 { bail_unexpected_handshake_message!( @@ -537,7 +539,8 @@ pub async fn send_message( Ok(()) } -/// Run the output loop to receive messages from the engine and send them to the provided channel. +/// Run the output loop to receive messages from the engine and send them to the +/// provided channel. pub async fn run_output_loop( mut output_socket: PullSocket, tx: mpsc::Sender>, @@ -546,9 +549,9 @@ pub async fn run_output_loop( let message = match output_socket.recv().await { Ok(message) => message, Err(error) => { - // If we fail to receive a message from the engine, it's likely that the engine has - // crashed or become unreachable, so we should notify the client and shut down the - // output loop. + // If we fail to receive a message from the engine, it's likely that the engine + // has crashed or become unreachable, so we should notify the + // client and shut down the output loop. error!(error = %error.as_report(), "failed to receive output message"); let _ = tx.send(Err(Error::Transport(error))).await; return; @@ -558,9 +561,7 @@ pub async fn run_output_loop( let frame_count = message.len(); trace!(frame_count, "received output message"); let frames = message.into_vec(); - let frame = frames - .first() - .expect("output message must have at least one frame"); + let frame = frames.first().expect("output message must have at least one frame"); let frame_len = frame.len(); if frame.as_ref() == ENGINE_CORE_DEAD_SENTINEL { warn!("received ENGINE_CORE_DEAD sentinel from engine"); @@ -573,16 +574,18 @@ pub async fn run_output_loop( Ok(decoded) } Err(error) => { - // If we fail to decode the message from the engine, notify the client but keep the - // output loop running to continue processing future messages from the engine. + // If we fail to decode the message from the engine, notify the client but keep + // the output loop running to continue processing future + // messages from the engine. warn!(frame_len, error = %error.as_report(), "failed to decode output message"); Err(error) } }; if tx.send(decoded).await.is_err() { - // If we fail to send the decoded message to the client, it's likely that the client has - // shut down, so we should shut down the output loop as well. + // If we fail to send the decoded message to the client, it's likely that the + // client has shut down, so we should shut down the output loop as + // well. warn!("output loop rx dropped, shutting down output loop"); return; } @@ -596,9 +599,7 @@ mod tests { #[tokio::test] async fn bind_local_sockets_resolves_zero_port_bindings() { let (input_address, _input_socket, output_address, _output_socket) = - bind_local_sockets("127.0.0.1", None, None) - .await - .expect("bind local sockets"); + bind_local_sockets("127.0.0.1", None, None).await.expect("bind local sockets"); assert!(input_address.starts_with("tcp://127.0.0.1:")); assert!(output_address.starts_with("tcp://127.0.0.1:")); diff --git a/src/llm/examples/external_engine_smoke.rs b/src/llm/examples/external_engine_smoke.rs index a78f485c..5800e2d9 100644 --- a/src/llm/examples/external_engine_smoke.rs +++ b/src/llm/examples/external_engine_smoke.rs @@ -78,9 +78,7 @@ async fn wait_for_request_completion(mut stream: GenerateOutputStream) -> Result "expected final-only stream to end after the final output" ); - let finish_reason = output - .finish_reason - .expect("final-only output must have a finish reason"); + let finish_reason = output.finish_reason.expect("final-only output must have a finish reason"); let token_ids = output.token_ids; Ok(CompletedRequest { @@ -133,15 +131,10 @@ async fn main() -> Result<()> { println!("request_id={request_id}"); println!("prompt_token_ids={PROMPT_TOKEN_IDS:?}"); - let stream = llm - .generate(request) - .await - .context("failed to submit generate request")?; + let stream = llm.generate(request).await.context("failed to submit generate request")?; let output = wait_for_timeout(stream, output_timeout).await?; - llm.shutdown() - .await - .context("failed to shut down llm client")?; + llm.shutdown().await.context("failed to shut down llm client")?; println!("token_ids={:?}", output.token_ids); println!("finish_reason={:?}", output.finish_reason); diff --git a/src/llm/src/lib.rs b/src/llm/src/lib.rs index c3857247..f50332af 100644 --- a/src/llm/src/lib.rs +++ b/src/llm/src/lib.rs @@ -20,8 +20,9 @@ use crate::request_metrics::RequestMetricsTracker; /// Thin generate-only facade over [`EngineCoreClient`]. /// -/// This mirrors the narrow public shape of Python `AsyncLLM.generate()` and `abort()`, but -/// keeps the boundary close to raw engine-core requests and outputs. +/// This mirrors the narrow public shape of Python `AsyncLLM.generate()` and +/// `abort()`, but keeps the boundary close to raw engine-core requests and +/// outputs. pub struct Llm { client: EngineCoreClient, randomize_request_id: bool, @@ -29,7 +30,8 @@ pub struct Llm { } impl Llm { - /// Create a new minimal LLM facade from an already connected engine-core client. + /// Create a new minimal LLM facade from an already connected engine-core + /// client. pub fn new(client: EngineCoreClient) -> Self { Self { client, @@ -52,18 +54,21 @@ impl Llm { self } - /// Control whether external request ids are randomized before reaching engine-core. + /// Control whether external request ids are randomized before reaching + /// engine-core. pub fn with_request_id_randomization(mut self, enabled: bool) -> Self { self.randomize_request_id = enabled; self } - /// Expose the underlying engine-core client for low-level utility/admin calls. + /// Expose the underlying engine-core client for low-level utility/admin + /// calls. pub fn engine_core_client(&self) -> &EngineCoreClient { &self.client } - /// Submit one tokenized generate request and return a per-request output stream. + /// Submit one tokenized generate request and return a per-request output + /// stream. pub async fn generate(&self, req: GenerateRequest) -> Result { let prepared = req.prepare(self.randomize_request_id)?; let prompt_token_ids = prepared.prompt_token_ids().into(); diff --git a/src/llm/src/log_stats.rs b/src/llm/src/log_stats.rs index 79b9fabd..7d3a1497 100644 --- a/src/llm/src/log_stats.rs +++ b/src/llm/src/log_stats.rs @@ -9,9 +9,9 @@ use vllm_metrics::{ const LOG_STATS_INTERVAL: Duration = Duration::from_secs(10); -/// Cached, cloned metric handles for one engine. Each clone shares the same underlying -/// `Arc` as the prometheus `Family` entry, so reads go straight to the atomic with -/// no lock. +/// Cached, cloned metric handles for one engine. Each clone shares the same +/// underlying `Arc` as the prometheus `Family` entry, so reads go +/// straight to the atomic with no lock. struct EngineMetrics { // Counters for throughput deltas. prompt_tokens_computed: U64Counter, @@ -25,7 +25,8 @@ struct EngineMetrics { kv_cache_usage: F64Gauge, } -/// Accumulated snapshot values from the last logging interval, used to compute deltas. +/// Accumulated snapshot values from the last logging interval, used to compute +/// deltas. struct CounterSnapshot { prompt_tokens: u64, generation_tokens: u64, @@ -35,9 +36,10 @@ struct CounterSnapshot { /// Periodic stats logger that mirrors Python vLLM's `LoggingStatLogger`. /// -/// Spawns a background task that logs throughput and scheduler state at a fixed interval. -/// When idle (both current and previous throughputs are zero), logs at DEBUG level. -/// When load drops to zero, emits one final INFO-level line before going quiet. +/// Spawns a background task that logs throughput and scheduler state at a fixed +/// interval. When idle (both current and previous throughputs are zero), logs +/// at DEBUG level. When load drops to zero, emits one final INFO-level line +/// before going quiet. pub(crate) struct StatsLogger { _task: AbortOnDropHandle<()>, } @@ -87,7 +89,8 @@ async fn run_stats_logger(model_name: String, engine_count: usize) { let mut interval = tokio::time::interval(LOG_STATS_INTERVAL); interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); - // The first tick fires immediately; skip it so the first log is after one full interval. + // The first tick fires immediately; skip it so the first log is after one full + // interval. interval.tick().await; let mut prev = read_counters(&engines); @@ -122,9 +125,7 @@ async fn run_stats_logger(model_name: String, engine_count: usize) { let (num_running, num_waiting, kv_cache_usage) = read_scheduler_gauges(&engines); // Compute prefix cache hit rate over this interval. - let delta_queries = curr - .prefix_cache_queries - .wrapping_sub(prev.prefix_cache_queries); + let delta_queries = curr.prefix_cache_queries.wrapping_sub(prev.prefix_cache_queries); let prefix_cache_hit_rate = if delta_queries > 0 { let delta_hits = curr.prefix_cache_hits.wrapping_sub(prev.prefix_cache_hits); delta_hits as f64 / delta_queries as f64 * 100.0 diff --git a/src/llm/src/output.rs b/src/llm/src/output.rs index 870b4b96..e17fc113 100644 --- a/src/llm/src/output.rs +++ b/src/llm/src/output.rs @@ -26,24 +26,27 @@ pub struct CollectedGenerateOutput { pub kv_transfer_params: Option, } -/// Prompt-scoped metadata emitted only once on the first [`GenerateOutput`] for one request. +/// Prompt-scoped metadata emitted only once on the first [`GenerateOutput`] for +/// one request. #[derive(Debug, Clone, PartialEq)] pub struct GeneratePromptInfo { /// Original prompt token IDs for this request. pub prompt_token_ids: Arc<[u32]>, - /// Prompt logprobs returned by engine-core for scored prompt positions, when requested. + /// Prompt logprobs returned by engine-core for scored prompt positions, + /// when requested. pub prompt_logprobs: Option, } /// The reason a request finished. /// -/// This is a higher-level abstraction over engine-core's finish and stop reasons. +/// This is a higher-level abstraction over engine-core's finish and stop +/// reasons. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, EnumAsInner)] pub enum FinishReason { /// Generation stopped for a stop string, stop token, or EOS. /// - /// The inner stop reason is present for explicit stop strings or stop tokens, and absent for - /// EOS-driven stops. + /// The inner stop reason is present for explicit stop strings or stop + /// tokens, and absent for EOS-driven stops. Stop(Option), /// `max_tokens` or `max_model_len` was reached. Length, @@ -56,12 +59,14 @@ pub enum FinishReason { } impl FinishReason { - /// Construct a stop finish reason caused by EOS rather than an explicit stop string/token. + /// Construct a stop finish reason caused by EOS rather than an explicit + /// stop string/token. pub fn stop_eos() -> Self { Self::Stop(None) } - /// Returns a human-readable string for this finish reason, used for metrics and reporting. + /// Returns a human-readable string for this finish reason, used for metrics + /// and reporting. pub fn as_str(&self) -> &'static str { match self { Self::Stop(_) => "stop", @@ -72,7 +77,8 @@ impl FinishReason { } } - /// If this is a stop finish reason, returns the inner stop reason if it exists. + /// If this is a stop finish reason, returns the inner stop reason if it + /// exists. pub fn as_stop_reason(&self) -> Option<&StopReason> { match self { Self::Stop(stop_reason) => stop_reason.as_ref(), @@ -80,7 +86,8 @@ impl FinishReason { } } - /// If this is a stop finish reason, returns the inner stop reason if it exists. + /// If this is a stop finish reason, returns the inner stop reason if it + /// exists. pub fn into_stop_reason(self) -> Option { match self { Self::Stop(stop_reason) => stop_reason, @@ -110,7 +117,8 @@ fn finish_reason_from_engine( pub struct GenerateOutput { /// Unique ID of the request that produced this output. pub request_id: String, - /// One-time prompt metadata emitted only on the first output for this request. + /// One-time prompt metadata emitted only on the first output for this + /// request. pub prompt_info: Option, /// Newly produced token IDs for this step. pub token_ids: Vec, @@ -123,20 +131,22 @@ pub struct GenerateOutput { } impl GenerateOutput { - /// Returns the prompt token IDs when this output carries [`GeneratePromptInfo`]. + /// Returns the prompt token IDs when this output carries + /// [`GeneratePromptInfo`]. /// - /// Only the first output for a request can return `Some`; all later outputs return `None`. + /// Only the first output for a request can return `Some`; all later outputs + /// return `None`. pub fn prompt_token_ids(&self) -> Option<&Arc<[u32]>> { self.prompt_info.as_ref().map(|info| &info.prompt_token_ids) } - /// Returns the prompt logprobs when this output carries [`GeneratePromptInfo`]. + /// Returns the prompt logprobs when this output carries + /// [`GeneratePromptInfo`]. /// - /// Only the first output for a request can return `Some`; all later outputs return `None`. + /// Only the first output for a request can return `Some`; all later outputs + /// return `None`. pub fn prompt_logprobs(&self) -> Option<&Logprobs> { - self.prompt_info - .as_ref() - .and_then(|info| info.prompt_logprobs.as_ref()) + self.prompt_info.as_ref().and_then(|info| info.prompt_logprobs.as_ref()) } /// Returns whether this output is terminal for the request. @@ -169,8 +179,10 @@ impl GenerateOutput { /// Stream of per-request generate outputs for one request. /// -/// - A normal termination of the stream represents a clean completion of the request. -/// - For errors, unexpected closes, or explicit aborts, the stream terminates with an error. +/// - A normal termination of the stream represents a clean completion of the +/// request. +/// - For errors, unexpected closes, or explicit aborts, the stream terminates +/// with an error. pub struct GenerateOutputStream { pending_prompt_info: Option, raw_stream: EngineCoreOutputStream, @@ -178,7 +190,8 @@ pub struct GenerateOutputStream { } impl GenerateOutputStream { - /// Create a new generate output stream by adapting one raw engine-core output stream. + /// Create a new generate output stream by adapting one raw engine-core + /// output stream. pub(crate) fn new( prompt_token_ids: Arc<[u32]>, raw_stream: EngineCoreOutputStream, @@ -224,17 +237,15 @@ impl Stream for GenerateOutputStream { if let Some(info) = &mut self.pending_prompt_info && info.prompt_logprobs.is_none() { - info.prompt_logprobs = raw - .new_prompt_logprobs_tensors - .map(|value| value.into_direct().unwrap()); + info.prompt_logprobs = + raw.new_prompt_logprobs_tensors.map(|value| value.into_direct().unwrap()); } let logprobs = raw.new_logprobs.map(|value| value.into_direct().unwrap()); let finish_reason = finish_reason_from_engine(raw.finish_reason, raw.stop_reason); if let Some(finish_reason) = finish_reason.as_ref() { - self.request_metrics - .record_finished(received_at, finish_reason.clone()); + self.request_metrics.record_finished(received_at, finish_reason.clone()); } let output = GenerateOutput { @@ -281,7 +292,8 @@ impl Drop for GenerateOutputStream { #[allow(clippy::manual_async_fn, reason = "specify `Send` bound")] #[easy_ext::ext(GenerateOutputStreamExt)] impl> + Send> T { - /// Collect the raw generate stream to completion and return the final token output. + /// Collect the raw generate stream to completion and return the final token + /// output. pub fn collect_output(self) -> impl Future> + Send { async move { let stream = self; diff --git a/src/llm/src/request.rs b/src/llm/src/request.rs index 53e7751f..8c09eeea 100644 --- a/src/llm/src/request.rs +++ b/src/llm/src/request.rs @@ -8,9 +8,10 @@ use crate::error::{Error, Result}; /// Tokenized decoder-only generate request accepted by [`crate::Llm`]. /// -/// This is the first-stage Rust subset of the inputs that eventually flow into Python -/// `AsyncLLM.generate()`. The boundary is intentionally above [`EngineCoreRequest`], but below -/// higher-level text and multimodal preprocessing. +/// This is the first-stage Rust subset of the inputs that eventually flow into +/// Python `AsyncLLM.generate()`. The boundary is intentionally above +/// [`EngineCoreRequest`], but below higher-level text and multimodal +/// preprocessing. /// /// Original Python API reference: /// diff --git a/src/llm/src/request_metrics.rs b/src/llm/src/request_metrics.rs index f62a8c26..d28b83be 100644 --- a/src/llm/src/request_metrics.rs +++ b/src/llm/src/request_metrics.rs @@ -18,8 +18,9 @@ const PROMPT_TOKEN_SOURCE_EXTERNAL_KV_TRANSFER: &str = "external_kv_transfer"; /// Request-scoped metrics state tracked across streamed engine-core updates. /// -/// This is the Rust-side counterpart of the Python frontend's request-lifecycle bookkeeping, -/// centered on `RequestStateStats` and the per-output/per-finished update flow. +/// This is the Rust-side counterpart of the Python frontend's request-lifecycle +/// bookkeeping, centered on `RequestStateStats` and the per-output/per-finished +/// update flow. /// /// Original Python definitions: /// @@ -45,7 +46,8 @@ pub(crate) struct RequestMetricsTracker { } impl RequestMetricsTracker { - /// Create the per-request tracker from the normalized `llm`-layer request context. + /// Create the per-request tracker from the normalized `llm`-layer request + /// context. pub(crate) fn new( model_name: String, arrival_time: f64, @@ -119,15 +121,15 @@ impl RequestMetricsTracker { self.last_token_ts = batch_timestamp; } - /// Emit the terminal request metrics once a finished output has been observed. + /// Emit the terminal request metrics once a finished output has been + /// observed. /// /// Original Python finished-request stats: /// pub(crate) fn record_finished(&self, received_at: f64, finish_reason: FinishReason) { let labels = engine_labels(&self.model_name, self.last_seen_engine_index); - let prefill_kv_computed_tokens = self - .prompt_len - .saturating_sub(self.latest_num_cached_tokens); + let prefill_kv_computed_tokens = + self.prompt_len.saturating_sub(self.latest_num_cached_tokens); let e2e_latency_seconds = received_at - self.arrival_time; let queue_time_seconds = diff_or_zero(self.scheduled_ts, self.queued_ts); let prefill_time_seconds = diff_or_zero(self.first_token_ts, self.scheduled_ts); @@ -159,10 +161,7 @@ impl RequestMetricsTracker { .get_or_create(&labels) .observe(max_tokens_param as f64); } - metrics() - .request_params_n - .get_or_create(&labels) - .observe(self.n_param as f64); + metrics().request_params_n.get_or_create(&labels).observe(self.n_param as f64); metrics() .request_prefill_kv_computed_tokens .get_or_create(&labels) @@ -308,9 +307,9 @@ fn diff_or_zero(end: f64, start: f64) -> f64 { /// Return the current wall-clock time in seconds since the Unix epoch. /// -/// This is used for frontend-side latency measurements such as TTFT and E2E, matching the Python -/// frontend's use of wall-clock request arrival/iteration timestamps rather than engine-core's -/// monotonic scheduler timestamps. +/// This is used for frontend-side latency measurements such as TTFT and E2E, +/// matching the Python frontend's use of wall-clock request arrival/iteration +/// timestamps rather than engine-core's monotonic scheduler timestamps. /// /// Original Python request timestamp source: /// diff --git a/src/llm/tests/generate.rs b/src/llm/tests/generate.rs index 611ff41f..f66e242b 100644 --- a/src/llm/tests/generate.rs +++ b/src/llm/tests/generate.rs @@ -217,10 +217,7 @@ fn init_tracing() { TRACING.call_once(|| { let filter = EnvFilter::try_from_default_env() .unwrap_or_else(|_| EnvFilter::new("vllm_engine_core_client=debug")); - let _ = tracing_subscriber::fmt() - .with_test_writer() - .with_env_filter(filter) - .try_init(); + let _ = tracing_subscriber::fmt().with_test_writer().with_env_filter(filter).try_init(); }); } @@ -274,10 +271,7 @@ async fn generate_streams_outputs() { ); let llm = connect_async_llm_with_ipc(handshake_address, 7, "test-model", &ipc).await; - let mut stream = llm - .generate(sample_generate_request("req-delta", 3)) - .await - .unwrap(); + let mut stream = llm.generate(sample_generate_request("req-delta", 3)).await.unwrap(); let internal_id = stream.request_id().to_string(); let first = stream.next().await.unwrap().unwrap(); @@ -363,10 +357,7 @@ async fn collect_output_aggregates_raw_tokens_logprobs_and_terminal_metadata() { ); let llm = connect_async_llm_with_ipc(handshake_address, 7, "test-model", &ipc).await; - let stream = llm - .generate(sample_generate_request("req-collect", 4)) - .await - .unwrap(); + let stream = llm.generate(sample_generate_request("req-collect", 4)).await.unwrap(); let internal_id = stream.request_id().to_string(); let collected = stream.collect_output().await.unwrap(); @@ -416,10 +407,7 @@ async fn generate_propagates_unexpected_close_errors() { ); let llm = connect_async_llm_with_ipc(handshake_address, 0, "test-model", &ipc).await; - let mut stream = llm - .generate(sample_generate_request("req-close", 1)) - .await - .unwrap(); + let mut stream = llm.generate(sample_generate_request("req-close", 1)).await.unwrap(); let internal_id = stream.request_id().to_string(); let error = stream.next().await.unwrap().unwrap_err(); @@ -462,9 +450,8 @@ async fn dropping_a_live_generate_stream_triggers_abort() { ) .await; - let abort = timeout(Duration::from_secs(1), recv_engine_message(dealer)) - .await - .unwrap(); + let abort = + timeout(Duration::from_secs(1), recv_engine_message(dealer)).await.unwrap(); assert_eq!(abort[0].as_ref(), &[0x01]); let aborted_ids: Vec = rmp_serde::from_slice(&abort[1]).unwrap(); assert_eq!(aborted_ids, vec![request.request_id]); @@ -473,10 +460,7 @@ async fn dropping_a_live_generate_stream_triggers_abort() { ); let llm = connect_async_llm_with_ipc(handshake_address, 0, "test-model", &ipc).await; - let mut stream = llm - .generate(sample_generate_request("req-drop", 4)) - .await - .unwrap(); + let mut stream = llm.generate(sample_generate_request("req-drop", 4)).await.unwrap(); let output = stream.next().await.unwrap().unwrap(); assert_eq!(output.token_ids, vec![99]); @@ -543,14 +527,8 @@ async fn duplicate_external_request_ids_are_randomized_before_reaching_engine_co ); let llm = connect_async_llm_with_ipc(handshake_address, 0, "test-model", &ipc).await; - let stream_1 = llm - .generate(sample_generate_request("req-dup", 1)) - .await - .unwrap(); - let stream_2 = llm - .generate(sample_generate_request("req-dup", 1)) - .await - .unwrap(); + let stream_1 = llm.generate(sample_generate_request("req-dup", 1)).await.unwrap(); + let stream_2 = llm.generate(sample_generate_request("req-dup", 1)).await.unwrap(); let internal_id_1 = stream_1.request_id().to_string(); let internal_id_2 = stream_2.request_id().to_string(); let collected_1 = stream_1.collect_output().await.unwrap(); @@ -736,9 +714,8 @@ async fn dropping_stream_records_abort_terminal_request_metrics() { ) .await; - let abort = timeout(Duration::from_secs(1), recv_engine_message(dealer)) - .await - .unwrap(); + let abort = + timeout(Duration::from_secs(1), recv_engine_message(dealer)).await.unwrap(); assert_eq!(abort[0].as_ref(), &[0x01]); let aborted_ids: Vec = rmp_serde::from_slice(&abort[1]).unwrap(); assert_eq!(aborted_ids, vec![request.request_id]); diff --git a/src/metrics/src/lib.rs b/src/metrics/src/lib.rs index 11f6c523..8f0db53d 100644 --- a/src/metrics/src/lib.rs +++ b/src/metrics/src/lib.rs @@ -17,9 +17,9 @@ pub use api_server::*; pub use request::*; pub use scheduler::*; -// Note: `prometheus-client` appends the `_total` suffix automatically when encoding counters, so -// all counter family registration names in this crate must use the base metric name without a -// trailing `_total`. +// Note: `prometheus-client` appends the `_total` suffix automatically when +// encoding counters, so all counter family registration names in this crate +// must use the base metric name without a trailing `_total`. pub type U64Counter = Counter; pub type U64Gauge = Gauge; pub type F64Gauge = Gauge; @@ -52,7 +52,8 @@ impl Metrics { } } - /// Render the current metrics registry into Prometheus/OpenMetrics text format. + /// Render the current metrics registry into Prometheus/OpenMetrics text + /// format. pub fn render(&self) -> Result { let mut output = String::new(); encode(&mut output, &self.registry)?; diff --git a/src/metrics/src/scheduler.rs b/src/metrics/src/scheduler.rs index 9b8a8bae..0acbdf0f 100644 --- a/src/metrics/src/scheduler.rs +++ b/src/metrics/src/scheduler.rs @@ -74,7 +74,8 @@ pub struct SchedulerMetrics { } impl SchedulerMetrics { - /// Register the scheduler-oriented metric families into the shared registry. + /// Register the scheduler-oriented metric families into the shared + /// registry. pub(crate) fn register(registry: &mut Registry) -> Self { // Scheduler state gauges. let scheduler_running = Family::default(); @@ -248,21 +249,9 @@ mod tests { engine: 0, }; - metrics - .scheduler - .estimated_flops_per_gpu - .get_or_create(&labels) - .inc(); - metrics - .scheduler - .estimated_read_bytes_per_gpu - .get_or_create(&labels) - .inc(); - metrics - .scheduler - .estimated_write_bytes_per_gpu - .get_or_create(&labels) - .inc(); + metrics.scheduler.estimated_flops_per_gpu.get_or_create(&labels).inc(); + metrics.scheduler.estimated_read_bytes_per_gpu.get_or_create(&labels).inc(); + metrics.scheduler.estimated_write_bytes_per_gpu.get_or_create(&labels).inc(); let rendered = metrics.render().unwrap(); assert!( diff --git a/src/server/examples/external_engine_openai_qwen.rs b/src/server/examples/external_engine_openai_qwen.rs index 930f7496..9034cde0 100644 --- a/src/server/examples/external_engine_openai_qwen.rs +++ b/src/server/examples/external_engine_openai_qwen.rs @@ -101,21 +101,14 @@ fn init_tracing() { fn unique_local_port() -> Result { let listener = std::net::TcpListener::bind("127.0.0.1:0") .context("failed to allocate local smoke-test port")?; - let port = listener - .local_addr() - .context("failed to read local smoke-test port")? - .port(); + let port = listener.local_addr().context("failed to read local smoke-test port")?.port(); drop(listener); Ok(port) } async fn print_models(client: &Client) -> Result<()> { let models = wait_for_models(client).await?; - let model_ids = models - .data - .into_iter() - .map(|model| model.id) - .collect::>(); + let model_ids = models.data.into_iter().map(|model| model.id).collect::>(); println!("models={model_ids:?}"); Ok(()) } @@ -143,13 +136,15 @@ async fn stream_completion( model: &str, prompt: &str, ) -> Result { - // Keep this smoke test on async-openai's standard `create_stream` path so it exercises - // the ordinary typed chat-completions client without BYOT request/response types. + // Keep this smoke test on async-openai's standard `create_stream` path so it + // exercises the ordinary typed chat-completions client without BYOT + // request/response types. // - // The current async-openai chat-completions stream delta type does not expose our - // OpenAI-compatible `reasoning_content` extension field, so this example only validates the - // assistant role chunk, visible `content` deltas, and terminal finish chunk. Reasoning - // coverage lives in our own route tests and in the `vllm-chat` smoke example. + // The current async-openai chat-completions stream delta type does not expose + // our OpenAI-compatible `reasoning_content` extension field, so this + // example only validates the assistant role chunk, visible `content` + // deltas, and terminal finish chunk. Reasoning coverage lives in our own + // route tests and in the `vllm-chat` smoke example. let request: CreateChatCompletionRequest = CreateChatCompletionRequestArgs::default() .model(model) .stream(true) diff --git a/src/server/src/config.rs b/src/server/src/config.rs index d0ddd777..065d0e1d 100644 --- a/src/server/src/config.rs +++ b/src/server/src/config.rs @@ -13,17 +13,19 @@ pub enum HttpListenerMode { BindTcp { host: String, port: u16 }, /// Bind a fresh Unix domain listener on the given filesystem path. BindUnix { path: String }, - /// Adopt an already-open listening socket inherited from a supervisor process. + /// Adopt an already-open listening socket inherited from a supervisor + /// process. InheritedFd { fd: i32 }, } -/// Which coordinator implementation should be active when one is present for a frontend client. +/// Which coordinator implementation should be active when one is present for a +/// frontend client. #[derive(Debug, Clone, PartialEq, Eq)] pub enum CoordinatorMode { /// Do not run a coordinator at all. None, - /// Run the Rust in-process coordinator for managed `serve` deployments, if there are mutliple - /// engines and the model is MoE. + /// Run the Rust in-process coordinator for managed `serve` deployments, if + /// there are mutliple engines and the model is MoE. MaybeInProc, /// Connect to an external coordinator owned by another process. External { address: String }, @@ -46,7 +48,8 @@ pub struct Config { pub reasoning_parser: ParserSelection, /// Chat renderer selection. pub renderer: RendererSelection, - /// Server-default chat template override, as a file path or inline template. + /// Server-default chat template override, as a file path or inline + /// template. pub chat_template: Option, /// Server-default keyword arguments merged into every chat-template render. pub default_chat_template_kwargs: Option>, @@ -54,16 +57,19 @@ pub struct Config { pub chat_template_content_format: ChatTemplateContentFormatOption, /// Log a summary line for each completed request. pub enable_log_requests: bool, - /// When `true`, suppress periodic stats logging (throughput, queue depth, cache usage). + /// When `true`, suppress periodic stats logging (throughput, queue depth, + /// cache usage). pub disable_log_stats: bool, - /// TCP port for the gRPC Generate service. When `None`, no gRPC server is started. + /// TCP port for the gRPC Generate service. When `None`, no gRPC server is + /// started. pub grpc_port: Option, /// Maximum time to wait for active HTTP/gRPC requests to drain on shutdown. pub shutdown_timeout: Duration, } impl Config { - /// Validate frontend configuration that can be checked before engine startup. + /// Validate frontend configuration that can be checked before engine + /// startup. pub fn validate(&self) -> Result<()> { vllm_chat::validate_parser_overrides(&self.tool_call_parser, &self.reasoning_parser)?; diff --git a/src/server/src/error.rs b/src/server/src/error.rs index 66a1a454..cc425ca0 100644 --- a/src/server/src/error.rs +++ b/src/server/src/error.rs @@ -8,7 +8,8 @@ use crate::routes::openai::utils::types::{ErrorDetail, ErrorResponse}; /// Small OpenAI-style error family used by the minimal HTTP layer. #[derive(Debug, Construct, Macro)] pub enum ApiError { - /// The request is syntactically valid OpenAI JSON but asks for unsupported behavior. + /// The request is syntactically valid OpenAI JSON but asks for unsupported + /// behavior. InvalidRequest { message: String, param: Option<&'static str>, @@ -32,7 +33,8 @@ impl ApiError { } } - /// Convert this error into the standard OpenAI-compatible JSON error payload. + /// Convert this error into the standard OpenAI-compatible JSON error + /// payload. pub fn to_error_response(&self) -> ErrorResponse { let error = match self { Self::InvalidRequest { message, param } => ErrorDetail { diff --git a/src/server/src/grpc/convert.rs b/src/server/src/grpc/convert.rs index f24762f8..883617d3 100644 --- a/src/server/src/grpc/convert.rs +++ b/src/server/src/grpc/convert.rs @@ -1,4 +1,5 @@ -//! Conversion between gRPC protobuf types and internal `vllm-text` request/response types. +//! Conversion between gRPC protobuf types and internal `vllm-text` +//! request/response types. use tonic::Status; use uuid::Uuid; @@ -16,8 +17,9 @@ use super::pb; /// Convert a gRPC `GenerateRequest` into the internal `TextRequest`. /// -/// If `req.model` is non-empty, it must match `configured_model`; otherwise the request is -/// rejected with `NotFound`. An empty string is treated as "unset" (proto3 default) and accepted. +/// If `req.model` is non-empty, it must match `configured_model`; otherwise the +/// request is rejected with `NotFound`. An empty string is treated as "unset" +/// (proto3 default) and accepted. pub fn to_text_request( req: pb::GenerateRequest, stream: bool, @@ -59,12 +61,11 @@ pub fn to_text_request( // Thread KVCacheParameters → SamplingParams fields. if let Some(kv) = kv { - // Thread kv_transfer_params through vllm_xargs, matching the HTTP route convention. + // Thread kv_transfer_params through vllm_xargs, matching the HTTP route + // convention. if let Some(kv_struct) = kv.kv_transfer_params.as_ref() { let kv_json = proto_struct_to_json(kv_struct); - let map = sampling_params - .vllm_xargs - .get_or_insert_with(Default::default); + let map = sampling_params.vllm_xargs.get_or_insert_with(Default::default); map.insert("kv_transfer_params".to_string(), kv_json); } if kv.bypass_prefix_cache { @@ -75,10 +76,7 @@ pub fn to_text_request( let decode_options = TextDecodeOptions { skip_special_tokens: true, include_stop_str_in_output: stopping.is_some_and(|s| s.include_stop_strings), - stop_strings: stopping - .map(|s| &s.stop_strings) - .filter(|ss| !ss.is_empty()) - .cloned(), + stop_strings: stopping.map(|s| &s.stop_strings).filter(|ss| !ss.is_empty()).cloned(), min_tokens: stopping.map_or(0, |s| s.min_new_tokens), }; @@ -102,18 +100,20 @@ fn build_sampling_params( stopping: Option<&pb::StoppingCriteria>, response: Option<&pb::ResponseOptions>, ) -> Result { - // Temperature is a top-level GenerateRequest field. Default to greedy (0.0) for the gRPC - // API when the caller does not specify a value. This differs from the HTTP/OpenAI API - // (which defaults to 1.0) and matches the convention of programmatic generation APIs. + // Temperature is a top-level GenerateRequest field. Default to greedy (0.0) for + // the gRPC API when the caller does not specify a value. This differs from + // the HTTP/OpenAI API (which defaults to 1.0) and matches the convention of + // programmatic generation APIs. let temperature = temperature.or(Some(0.0)); let mut params = SamplingParams { temperature, ..SamplingParams::default() }; - // RandomSampling: for every remaining sampling field the protobuf default (`0`) is - // treated as "unset" and leaves the resolved value to the lowering stage, which falls - // back to the model-provided default or a neutral/disabled value otherwise. + // RandomSampling: for every remaining sampling field the protobuf default (`0`) + // is treated as "unset" and leaves the resolved value to the lowering + // stage, which falls back to the model-provided default or a + // neutral/disabled value otherwise. if let Some(s) = sampling { // num_sequences (n > 1) is not supported yet by the TextLlm layer; the response // path also hardcodes SequenceOutput.index = 0, so accepting >1 would silently @@ -197,13 +197,15 @@ fn build_sampling_params( Ok(params) } -/// Map the proto `CandidateTokens` selector to a `(logprobs_count, logprob_token_ids)` pair. +/// Map the proto `CandidateTokens` selector to a `(logprobs_count, +/// logprob_token_ids)` pair. /// /// - `top_n(k)` → `(k, None)` — return top-k candidates by probability /// - `all` → `(-1, None)` — return the full vocabulary -/// - `token_ids(n)` → `(1, Some(vec of n token ids))` — return logprobs for specific tokens (the -/// count `n` is stored in the proto as the number of token IDs that follow, but the actual IDs -/// are carried via `logprob_token_ids` on `SamplingParams`) +/// - `token_ids(n)` → `(1, Some(vec of n token ids))` — return logprobs for +/// specific tokens (the count `n` is stored in the proto as the number of +/// token IDs that follow, but the actual IDs are carried via +/// `logprob_token_ids` on `SamplingParams`) /// - absent → `(1, None)` — just the sampled/scored token fn candidate_logprob_spec(candidates: Option<&pb::CandidateTokens>) -> (i32, Option>) { match candidates.and_then(|c| c.select.as_ref()) { @@ -260,7 +262,8 @@ fn convert_structured_output( // Response conversion // ======================================================================================== -/// Convert a `DecodedTextEvent::Start` into the prompt info portion of a gRPC response. +/// Convert a `DecodedTextEvent::Start` into the prompt info portion of a gRPC +/// response. pub fn to_prompt_info( prompt_token_ids: &[u32], prompt_logprobs: Option<&DecodedPromptLogprobs>, @@ -335,10 +338,7 @@ fn to_finish_info(finished: &Finished, token_ids: &[u32]) -> pb::FinishInfo { // echo it back as a `stop_reason`. The matched token is, by construction, the // last token of the terminal output batch (see vllm's `check_stop` in // vllm/v1/core/sched/utils.py), so we recover it from there. - None => token_ids - .last() - .copied() - .map(pb::finish_info::StopReason::EosTokenId), + None => token_ids.last().copied().map(pb::finish_info::StopReason::EosTokenId), }; (PbFinishReason::Stop as i32, sr) } @@ -352,10 +352,7 @@ fn to_finish_info(finished: &Finished, token_ids: &[u32]) -> pb::FinishInfo { num_output_tokens: finished.output_token_count as u32, finish_reason, stop_reason, - kv_transfer_params: finished - .kv_transfer_params - .as_ref() - .and_then(json_to_proto_struct), + kv_transfer_params: finished.kv_transfer_params.as_ref().and_then(json_to_proto_struct), } } @@ -365,7 +362,8 @@ fn to_finish_info(finished: &Finished, token_ids: &[u32]) -> pb::FinishInfo { /// Convert output logprobs to the flat proto representation. /// -/// Returns (logprob_values, ranks, candidate_tokens) — all parallel arrays indexed by position. +/// Returns (logprob_values, ranks, candidate_tokens) — all parallel arrays +/// indexed by position. fn output_logprobs_to_proto( lp: &DecodedLogprobs, ) -> (Vec, Vec, Vec) { @@ -377,8 +375,8 @@ fn prompt_logprobs_to_proto( plp: &DecodedPromptLogprobs, ) -> (Vec, Vec, Vec) { // The proto PromptInfo has flat parallel arrays covering all prompt positions. - // DecodedPromptLogprobs has first_token separately + scored_positions for the rest. - // The first prompt position has no scores, so we emit zeros for it. + // DecodedPromptLogprobs has first_token separately + scored_positions for the + // rest. The first prompt position has no scores, so we emit zeros for it. let (mut logprobs, mut ranks, mut candidates) = positions_to_proto(&plp.scored_positions); logprobs.insert(0, 0.0); ranks.insert(0, 0); @@ -386,7 +384,8 @@ fn prompt_logprobs_to_proto( (logprobs, ranks, candidates) } -/// Shared helper: convert a slice of decoded position logprobs to flat proto arrays. +/// Shared helper: convert a slice of decoded position logprobs to flat proto +/// arrays. fn positions_to_proto( positions: &[vllm_text::DecodedPositionLogprobs], ) -> (Vec, Vec, Vec) { @@ -423,10 +422,7 @@ fn positions_to_proto( fn proto_struct_to_json(s: &prost_types::Struct) -> serde_json::Value { serde_json::Value::Object( - s.fields - .iter() - .map(|(k, v)| (k.clone(), proto_value_to_json(v))) - .collect(), + s.fields.iter().map(|(k, v)| (k.clone(), proto_value_to_json(v))).collect(), ) } @@ -447,10 +443,7 @@ fn proto_value_to_json(v: &prost_types::Value) -> serde_json::Value { fn json_to_proto_struct(value: &serde_json::Value) -> Option { match value { serde_json::Value::Object(map) => Some(prost_types::Struct { - fields: map - .iter() - .map(|(k, v)| (k.clone(), json_to_proto_value(v))) - .collect(), + fields: map.iter().map(|(k, v)| (k.clone(), json_to_proto_value(v))).collect(), }), _ => None, } @@ -467,10 +460,7 @@ fn json_to_proto_value(v: &serde_json::Value) -> prost_types::Value { values: arr.iter().map(json_to_proto_value).collect(), }), serde_json::Value::Object(map) => Kind::StructValue(prost_types::Struct { - fields: map - .iter() - .map(|(k, v)| (k.clone(), json_to_proto_value(v))) - .collect(), + fields: map.iter().map(|(k, v)| (k.clone(), json_to_proto_value(v))).collect(), }), }; prost_types::Value { kind: Some(kind) } @@ -629,7 +619,8 @@ mod tests { #[test] fn explicit_stop_token_id_is_preserved() { let fin = finished(FinishReason::Stop(Some(StopReason::TokenId(42)))); - // Terminal token list should be ignored when an explicit stop reason is present. + // Terminal token list should be ignored when an explicit stop reason is + // present. let info = to_finish_info(&fin, &[7, 42]); assert_eq!(info.finish_reason, PbFinishReason::Stop as i32); diff --git a/src/server/src/grpc/tests.rs b/src/server/src/grpc/tests.rs index 06862029..842163f8 100644 --- a/src/server/src/grpc/tests.rs +++ b/src/server/src/grpc/tests.rs @@ -202,8 +202,9 @@ impl ChatRenderer for FakeTextBackend { } } -/// Spin up a gRPC server backed by a mock engine that serves a single request with the -/// given output specs. Returns the client, the gRPC server task, and the mock engine task. +/// Spin up a gRPC server backed by a mock engine that serves a single request +/// with the given output specs. Returns the client, the gRPC server task, and +/// the mock engine task. async fn grpc_test_server( engine_id: impl Into, output_specs: Vec<(Vec, Option)>, @@ -249,9 +250,7 @@ async fn grpc_test_server( let svc = GenerateServer::new(GenerateServiceImpl::new(state)); // Bind to an OS-assigned port. - let listener = tokio::net::TcpListener::bind("127.0.0.1:0") - .await - .expect("bind grpc listener"); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.expect("bind grpc listener"); let addr = listener.local_addr().expect("local addr"); let server_task = tokio::spawn(async move { @@ -451,10 +450,7 @@ async fn streaming_generate_yields_incremental_responses() { // First message should have prompt info. let first = &responses[0]; - let prompt_info = first - .prompt_info - .as_ref() - .expect("first response has prompt_info"); + let prompt_info = first.prompt_info.as_ref().expect("first response has prompt_info"); assert_eq!(prompt_info.num_prompt_tokens, 5); // "hello" // Collect all text deltas. @@ -471,10 +467,7 @@ async fn streaming_generate_yields_incremental_responses() { .rev() .find_map(|r| r.outputs.as_ref()) .expect("at least one output"); - let finish = last_output - .finish_info - .as_ref() - .expect("finish_info on last output"); + let finish = last_output.finish_info.as_ref().expect("finish_info on last output"); assert_eq!( finish.finish_reason, pb::finish_info::FinishReason::Stop as i32 diff --git a/src/server/src/lib.rs b/src/server/src/lib.rs index 0eb555ed..6c3abd36 100644 --- a/src/server/src/lib.rs +++ b/src/server/src/lib.rs @@ -31,7 +31,8 @@ use crate::listener::Listener; use crate::routes::build_router; use crate::state::AppState; -/// Build the shared application state for one configured model and one engine client. +/// Build the shared application state for one configured model and one engine +/// client. async fn build_state(config: &Config) -> Result> { // Load both backends from the same model metadata so they stay in sync. let loaded = load_model_backends( @@ -80,14 +81,13 @@ async fn build_state(config: &Config) -> Result> { )) } -/// Run the OpenAI-compatible HTTP server until the supplied shutdown token is cancelled. +/// Run the OpenAI-compatible HTTP server until the supplied shutdown token is +/// cancelled. /// -/// The server owns one `vllm-chat` facade, which in turn owns the lower `vllm-text` and -/// `vllm-llm` layers, and shuts them down before returning. +/// The server owns one `vllm-chat` facade, which in turn owns the lower +/// `vllm-text` and `vllm-llm` layers, and shuts them down before returning. pub async fn serve(config: Config, shutdown: CancellationToken) -> Result<()> { - config - .validate() - .context("invalid OpenAI frontend configuration")?; + config.validate().context("invalid OpenAI frontend configuration")?; // Also check shutdown during the (potentially long) startup handshake. let state = tokio::select! { @@ -101,11 +101,12 @@ pub async fn serve(config: Config, shutdown: CancellationToken) -> Result<()> { let model = state.model_id.clone(); let app = build_router(state.clone()); - // Optionally bind the gRPC Generate server on a separate port. Bind synchronously - // here so bind errors (port in use, permission denied, ...) surface before we start - // serving, rather than being deferred until shutdown. The gRPC listener follows the - // same host as the HTTP listener so that enabling --grpc-port does not accidentally - // expose the service on all interfaces when HTTP is intentionally local-only. + // Optionally bind the gRPC Generate server on a separate port. Bind + // synchronously here so bind errors (port in use, permission denied, ...) + // surface before we start serving, rather than being deferred until + // shutdown. The gRPC listener follows the same host as the HTTP listener so + // that enabling --grpc-port does not accidentally expose the service on all + // interfaces when HTTP is intentionally local-only. let grpc_setup = if let Some(grpc_port) = config.grpc_port { let grpc_host = match &config.listener_mode { HttpListenerMode::BindTcp { host, .. } => host.as_str(), @@ -134,9 +135,10 @@ pub async fn serve(config: Config, shutdown: CancellationToken) -> Result<()> { } }); - // Run HTTP and gRPC concurrently under a child token of the caller's shutdown token. - // Caller cancellation propagates into both protocols; if either protocol exits first, - // we cancel this child token so its sibling also begins a graceful drain. + // Run HTTP and gRPC concurrently under a child token of the caller's shutdown + // token. Caller cancellation propagates into both protocols; if either + // protocol exits first, we cancel this child token so its sibling also + // begins a graceful drain. let server_shutdown = shutdown.child_token(); let force_shutdown = CancellationToken::new(); let shutdown_deadline = Arc::new(OnceLock::new()); @@ -196,12 +198,10 @@ pub async fn serve(config: Config, shutdown: CancellationToken) -> Result<()> { shutdown.cancelled().await; return Ok(()); }; - let server = TonicServer::builder() - .add_service(svc) - .serve_with_incoming_shutdown( - TcpListenerStream::new(grpc_listener), - shutdown.cancelled_owned(), - ); + let server = TonicServer::builder().add_service(svc).serve_with_incoming_shutdown( + TcpListenerStream::new(grpc_listener), + shutdown.cancelled_owned(), + ); let result = tokio::select! { result = server => { diff --git a/src/server/src/listener.rs b/src/server/src/listener.rs index 178c36ee..b7b715b0 100644 --- a/src/server/src/listener.rs +++ b/src/server/src/listener.rs @@ -1,7 +1,8 @@ //! Unified HTTP listener wrapper for the Rust frontend. //! -//! This module hides the difference between TCP and Unix-domain listeners so the rest of the -//! server can bind or inherit one socket and pass it to `axum::serve(...)` through a single type. +//! This module hides the difference between TCP and Unix-domain listeners so +//! the rest of the server can bind or inherit one socket and pass it to +//! `axum::serve(...)` through a single type. use std::io::Result; use std::net::TcpListener as StdTcpListener; @@ -14,8 +15,8 @@ use tokio_util::either::Either; use crate::HttpListenerMode; -/// Runtime listener type used by the OpenAI-compatible HTTP server, which is either a TCP listener -/// or a Unix-domain listener. +/// Runtime listener type used by the OpenAI-compatible HTTP server, which is +/// either a TCP listener or a Unix-domain listener. #[derive(Debug)] pub enum Listener { Tcp(TcpListener), @@ -25,8 +26,8 @@ pub enum Listener { impl Listener { /// Bind or adopt the listener described by the frontend configuration. /// - /// For inherited sockets, the concrete listener kind is detected from the socket family of the - /// supplied file descriptor. + /// For inherited sockets, the concrete listener kind is detected from the + /// socket family of the supplied file descriptor. pub async fn bind(mode: &HttpListenerMode) -> Result { match mode { HttpListenerMode::BindTcp { host, port } => { @@ -37,7 +38,8 @@ impl Listener { } } - /// Return a log-friendly local address string for either TCP or Unix sockets. + /// Return a log-friendly local address string for either TCP or Unix + /// sockets. pub fn local_addr(&self) -> Result { match self { Self::Tcp(listener) => Ok(listener.local_addr()?.to_string()), @@ -49,14 +51,14 @@ impl Listener { } fn from_inherited_fd(fd: i32) -> Result { - // SAFETY: We trust the caller to only pass valid listener fds, and we only use this fd - // once to create a single listener. + // SAFETY: We trust the caller to only pass valid listener fds, and we only use + // this fd once to create a single listener. let owned_fd = unsafe { OwnedFd::from_raw_fd(fd) }; let socket = Socket::from(owned_fd); - // The Python supervisor pre-binds the socket to reserve the endpoint early, but Rust is - // responsible for transitioning inherited stream sockets into the listening state before - // accepting connections. + // The Python supervisor pre-binds the socket to reserve the endpoint early, but + // Rust is responsible for transitioning inherited stream sockets into + // the listening state before accepting connections. socket.listen(libc::SOMAXCONN)?; socket.set_nonblocking(true)?; @@ -110,14 +112,10 @@ mod tests { #[tokio::test(flavor = "current_thread")] async fn inherited_fd_detects_tcp_listener_without_uds_hint() { let socket = Socket::new(Domain::IPV4, Type::STREAM, None).unwrap(); - socket - .bind(&SockAddr::from(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0))) - .unwrap(); + socket.bind(&SockAddr::from(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0))).unwrap(); let fd = socket.into_raw_fd(); - let listener = Listener::bind(&HttpListenerMode::InheritedFd { fd }) - .await - .unwrap(); + let listener = Listener::bind(&HttpListenerMode::InheritedFd { fd }).await.unwrap(); assert!(matches!(listener, Listener::Tcp(_))); } @@ -129,9 +127,7 @@ mod tests { socket.bind(&SockAddr::unix(&path).unwrap()).unwrap(); let fd = socket.into_raw_fd(); - let listener = Listener::bind(&HttpListenerMode::InheritedFd { fd }) - .await - .unwrap(); + let listener = Listener::bind(&HttpListenerMode::InheritedFd { fd }).await.unwrap(); assert!(matches!(listener, Listener::Unix(_))); let _ = std::fs::remove_file(path); diff --git a/src/server/src/middleware/load.rs b/src/server/src/middleware/load.rs index 2893c94e..d03b36fb 100644 --- a/src/server/src/middleware/load.rs +++ b/src/server/src/middleware/load.rs @@ -12,8 +12,9 @@ use crate::state::AppState; /// Endpoints that will be tracked for server load. /// -/// Derived from the Python frontend's actual `@load_aware_call` coverage. This includes alias -/// paths that delegate into decorated handlers, such as `/v1/rerank` and `/v2/rerank`. +/// Derived from the Python frontend's actual `@load_aware_call` coverage. This +/// includes alias paths that delegate into decorated handlers, such as +/// `/v1/rerank` and `/v2/rerank`. const TRACKED_HANDLERS: &[&str] = &[ "/v1/responses", "/v1/responses/{response_id}", @@ -79,8 +80,9 @@ impl Drop for ServerLoadGuard { } } -/// A wrapper around response bodies that tracks server load by holding a `ServerLoadGuard`, which -/// will decrement the load when the body is fully consumed and dropped. +/// A wrapper around response bodies that tracks server load by holding a +/// `ServerLoadGuard`, which will decrement the load when the body is fully +/// consumed and dropped. struct LoadTrackedBody { inner: Body, _guard: ServerLoadGuard, diff --git a/src/server/src/middleware/metrics.rs b/src/server/src/middleware/metrics.rs index 67a73262..366d73dd 100644 --- a/src/server/src/middleware/metrics.rs +++ b/src/server/src/middleware/metrics.rs @@ -26,8 +26,8 @@ const EXCLUDED_HANDLERS: &[&str] = &[ "/is_sleeping", ]; -/// Record API-server HTTP metrics with Python-compatible (`PrometheusFastApiInstrumentator` style) -/// family names and labels. +/// Record API-server HTTP metrics with Python-compatible +/// (`PrometheusFastApiInstrumentator` style) family names and labels. pub async fn track_http_metrics(req: Request, next: Next) -> Response { let method = req.method().as_str().to_string(); let handler = req diff --git a/src/server/src/routes/cache.rs b/src/server/src/routes/cache.rs index b8bc4cce..580b91d4 100644 --- a/src/server/src/routes/cache.rs +++ b/src/server/src/routes/cache.rs @@ -16,7 +16,8 @@ pub(crate) struct ResetPrefixCacheParams { reset_external: bool, } -/// Reset the local prefix cache and optionally the connector-managed external cache. +/// Reset the local prefix cache and optionally the connector-managed external +/// cache. pub async fn reset_prefix_cache( State(state): State>, Query(params): Query, diff --git a/src/server/src/routes/http_client_tests.rs b/src/server/src/routes/http_client_tests.rs index f6449398..f5af24a9 100644 --- a/src/server/src/routes/http_client_tests.rs +++ b/src/server/src/routes/http_client_tests.rs @@ -1,5 +1,6 @@ -//! Integration tests that exercise the OpenAI-compatible HTTP API through a real TCP connection -//! using the `async-openai` client library, backed by a mock engine. +//! Integration tests that exercise the OpenAI-compatible HTTP API through a +//! real TCP connection using the `async-openai` client library, backed by a +//! mock engine. use std::future::Future; use std::pin::Pin; @@ -219,7 +220,8 @@ impl ChatRenderer for FakeChatBackend { } /// Spin up an HTTP server on a random port backed by a mock engine. -/// Returns the `async-openai` client, the HTTP server task, and the mock engine task. +/// Returns the `async-openai` client, the HTTP server task, and the mock engine +/// task. async fn http_test_server( engine_id: impl Into, output_specs: Vec<(Vec, Option)>, @@ -264,9 +266,7 @@ async fn http_test_server( let state = Arc::new(AppState::new("test-model", chat)); let app = build_router(state); - let listener = tokio::net::TcpListener::bind("127.0.0.1:0") - .await - .expect("bind http listener"); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.expect("bind http listener"); let addr = listener.local_addr().expect("local addr"); let server_task = tokio::spawn(async move { @@ -317,11 +317,7 @@ async fn non_streaming_chat_via_http_client() { .build() .expect("build request"); - let response = client - .chat() - .create(request) - .await - .expect("chat completion"); + let response = client.chat().create(request).await.expect("chat completion"); assert_eq!(response.model, "test-model"); assert_eq!(response.choices.len(), 1); @@ -355,11 +351,7 @@ async fn streaming_chat_via_http_client() { .build() .expect("build request"); - let mut stream = client - .chat() - .create_stream(request) - .await - .expect("streaming chat completion"); + let mut stream = client.chat().create_stream(request).await.expect("streaming chat completion"); let mut full_text = String::new(); let mut saw_role = false; diff --git a/src/server/src/routes/inference/generate.rs b/src/server/src/routes/inference/generate.rs index 7d7995e4..ec653ca6 100644 --- a/src/server/src/routes/inference/generate.rs +++ b/src/server/src/routes/inference/generate.rs @@ -24,7 +24,8 @@ use crate::routes::openai::utils::validated_json::ValidatedJson; use crate::state::AppState; use crate::utils::resolve_request_context; -/// Validate one token-in/token-out request and proxy it into the shared `vllm-text` stack. +/// Validate one token-in/token-out request and proxy it into the shared +/// `vllm-text` stack. pub async fn generate( State(state): State>, headers: HeaderMap, @@ -62,11 +63,7 @@ pub async fn generate( } }; - let collected = match raw_stream - .collect_output() - .instrument(request_span.clone()) - .await - { + let collected = match raw_stream.collect_output().instrument(request_span.clone()).await { Ok(collected) => collected, Err(error) => { return server_error!( diff --git a/src/server/src/routes/inference/generate/convert.rs b/src/server/src/routes/inference/generate/convert.rs index 2582f8f9..b75349b3 100644 --- a/src/server/src/routes/inference/generate/convert.rs +++ b/src/server/src/routes/inference/generate/convert.rs @@ -14,7 +14,8 @@ pub struct PreparedRequest { pub include_prompt_logprobs: bool, } -/// Validate and lower one raw generate request into the internal text-generation format. +/// Validate and lower one raw generate request into the internal +/// text-generation format. pub fn prepare_generate_request( request: GenerateRequest, configured_model: &str, diff --git a/src/server/src/routes/inference/generate/validate.rs b/src/server/src/routes/inference/generate/validate.rs index 0ce2ce02..185654a1 100644 --- a/src/server/src/routes/inference/generate/validate.rs +++ b/src/server/src/routes/inference/generate/validate.rs @@ -1,7 +1,8 @@ use super::types::GenerateRequest; use crate::error::{ApiError, bail_invalid_request}; -/// Enforce the minimal compatibility contract for the Rust token generate route. +/// Enforce the minimal compatibility contract for the Rust token generate +/// route. pub(super) fn validate_request_compat( request: &GenerateRequest, configured_model: &str, diff --git a/src/server/src/routes/openai/chat_completions.rs b/src/server/src/routes/openai/chat_completions.rs index 3afd3959..c04f155f 100644 --- a/src/server/src/routes/openai/chat_completions.rs +++ b/src/server/src/routes/openai/chat_completions.rs @@ -40,7 +40,8 @@ use crate::routes::openai::utils::validated_json::ValidatedJson; use crate::state::AppState; use crate::utils::{resolve_request_context, unix_timestamp}; -/// Validate one chat completion request and proxy it into the shared `vllm-chat` stack. +/// Validate one chat completion request and proxy it into the shared +/// `vllm-chat` stack. pub async fn chat_completions( State(state): State>, headers: HeaderMap, @@ -62,21 +63,17 @@ pub async fn chat_completions( let created = unix_timestamp(); let log_request = state.enable_log_requests; - let chat_stream = match state - .chat - .chat(prepared.chat_request) - .instrument(request_span.clone()) - .await - { - Ok(stream) => stream, - Err(error) => { - return server_error!( - "failed to submit chat request: {}", - error.to_report_string() - ) - .into_response(); - } - }; + let chat_stream = + match state.chat.chat(prepared.chat_request).instrument(request_span.clone()).await { + Ok(stream) => stream, + Err(error) => { + return server_error!( + "failed to submit chat request: {}", + error.to_report_string() + ) + .into_response(); + } + }; if stream { let chunk_stream = chat_completion_chunk_stream( @@ -241,9 +238,9 @@ async fn chat_completion_chunk_stream( ) -> Result<(), ApiError> { let mut saw_tool_calls = false; - // If the client requested logprobs or token_ids, we need to buffer chunks until we receive - // the separate `LogprobsDelta` event, so that we can emit one combined chunk with both the - // semantic delta and its per-update metadata. + // If the client requested logprobs or token_ids, we need to buffer chunks until + // we receive the separate `LogprobsDelta` event, so that we can emit one + // combined chunk with both the semantic delta and its per-update metadata. let mut pending_chunk = (requested_logprobs || return_token_ids).then(PendingChatChunk::default); @@ -291,9 +288,8 @@ async fn chat_completion_chunk_stream( .as_ref() .map(|lp| decoded_logprobs_to_openai_chat(lp, return_tokens_as_token_ids)) .transpose()?; - let openai_token_ids = return_token_ids - .then_some(token_ids) - .filter(|t| !t.is_empty()); + let openai_token_ids = + return_token_ids.then_some(token_ids).filter(|t| !t.is_empty()); if let Some(pending_chunk) = pending_chunk.as_mut() { pending_chunk.logprobs = openai_logprobs; pending_chunk.token_ids = openai_token_ids; @@ -457,7 +453,8 @@ struct PendingChatChunk { } impl PendingChatChunk { - /// Append one assistant text/reasoning block delta to the buffered OpenAI delta payload. + /// Append one assistant text/reasoning block delta to the buffered OpenAI + /// delta payload. fn push_block_delta(&mut self, kind: AssistantBlockKind, delta: String) { match kind { AssistantBlockKind::Text => append_delta_text(&mut self.delta.content, delta), @@ -470,34 +467,28 @@ impl PendingChatChunk { /// Append the OpenAI tool-call-start representation to the buffered delta. fn push_tool_call_start(&mut self, index: u32, id: String, name: String) { - self.delta - .tool_calls - .get_or_insert_with(Vec::new) - .push(ToolCallDelta { - index, - id: Some(id), - tool_type: Some("function".to_string()), - function: Some(FunctionCallDelta { - name: Some(name), - arguments: None, - }), - }); + self.delta.tool_calls.get_or_insert_with(Vec::new).push(ToolCallDelta { + index, + id: Some(id), + tool_type: Some("function".to_string()), + function: Some(FunctionCallDelta { + name: Some(name), + arguments: None, + }), + }); } /// Append one incremental tool-call arguments update to the buffered delta. fn push_tool_call_arguments(&mut self, index: u32, delta: String) { - self.delta - .tool_calls - .get_or_insert_with(Vec::new) - .push(ToolCallDelta { - index, - id: None, - tool_type: None, - function: Some(FunctionCallDelta { - name: None, - arguments: Some(delta), - }), - }); + self.delta.tool_calls.get_or_insert_with(Vec::new).push(ToolCallDelta { + index, + id: None, + tool_type: None, + function: Some(FunctionCallDelta { + name: None, + arguments: Some(delta), + }), + }); } /// Finalize the currently buffered SSE chunk, if it contains either a @@ -604,7 +595,8 @@ fn done_sse_event() -> Event { Event::default().data("[DONE]") } -/// Build the initial assistant-role SSE chunk required by the OpenAI streaming protocol. +/// Build the initial assistant-role SSE chunk required by the OpenAI streaming +/// protocol. fn start_chunk( request_id: &str, response_model: &str, @@ -760,7 +752,8 @@ fn chat_finish_reason_to_openai( } } -/// Convert one internal stop reason into the OpenAI-compatible `stop_reason` JSON shape. +/// Convert one internal stop reason into the OpenAI-compatible `stop_reason` +/// JSON shape. fn stop_reason_to_json(stop_reason: &StopReason) -> Value { serde_json::to_value(stop_reason).expect("StopReason must serialize to JSON") } diff --git a/src/server/src/routes/openai/chat_completions/convert.rs b/src/server/src/routes/openai/chat_completions/convert.rs index 1d010a9a..f73a3ec1 100644 --- a/src/server/src/routes/openai/chat_completions/convert.rs +++ b/src/server/src/routes/openai/chat_completions/convert.rs @@ -14,7 +14,8 @@ use crate::routes::openai::utils::types::{ }; use crate::utils::{ResolvedRequestContext, convert_logit_bias, merge_kv_transfer_params}; -/// Lowered chat request plus the public response metadata carried by every SSE chunk. +/// Lowered chat request plus the public response metadata carried by every SSE +/// chunk. #[derive(Debug, Clone, PartialEq)] pub struct PreparedRequest { /// Stable OpenAI-style request ID, reused as the external chat request ID. @@ -37,7 +38,8 @@ pub struct PreparedRequest { pub return_tokens_as_token_ids: bool, } -/// Validate and lower one OpenAI chat completion request into the internal chat format. +/// Validate and lower one OpenAI chat completion request into the internal chat +/// format. pub(crate) fn prepare_chat_request( request: ChatCompletionRequest, configured_model: &str, @@ -50,11 +52,7 @@ pub(crate) fn prepare_chat_request( .echo .then(|| extract_last_assistant_content(&request.messages)) .flatten(); - let messages: Vec<_> = request - .messages - .into_iter() - .map(convert_message) - .try_collect()?; + let messages: Vec<_> = request.messages.into_iter().map(convert_message).try_collect()?; let generation_prompt_mode = normalize_generation_prompt_mode( request.add_generation_prompt, request.continue_final_message, @@ -68,7 +66,8 @@ pub(crate) fn prepare_chat_request( .unwrap_or(false); let requested_logprobs = request.logprobs; - // Auto-enable prompt logprobs for non-streaming echo, matching Python vLLM's behavior. + // Auto-enable prompt logprobs for non-streaming echo, matching Python vLLM's + // behavior. let top_logprobs = request.top_logprobs.unwrap_or(0); let prompt_logprobs = request .prompt_logprobs @@ -244,7 +243,8 @@ fn convert_message(message: ChatMessage) -> Result { } } -/// Convert the given OpenAI message content value into the internal format in `vllm-chat`. +/// Convert the given OpenAI message content value into the internal format in +/// `vllm-chat`. fn convert_content(content: MessageContent) -> Result { match content { MessageContent::Text(text) => Ok(ChatContent::Text(text)), @@ -259,7 +259,8 @@ fn convert_content(content: MessageContent) -> Result { } } -/// Convert the given OpenAI assistant message content into the internal format in `vllm-chat`. +/// Convert the given OpenAI assistant message content into the internal format +/// in `vllm-chat`. fn convert_assistant_text_blocks( content: MessageContent, ) -> Result, ApiError> { @@ -288,10 +289,7 @@ fn convert_assistant_tool_calls( Ok(AssistantContentBlock::ToolCall(AssistantToolCall { id: tool_call.id, name: tool_call.function.name, - arguments: tool_call - .function - .arguments - .unwrap_or_else(|| "{}".to_string()), + arguments: tool_call.function.arguments.unwrap_or_else(|| "{}".to_string()), })) }) .collect() diff --git a/src/server/src/routes/openai/chat_completions/types.rs b/src/server/src/routes/openai/chat_completions/types.rs index a334e01e..00557ad5 100644 --- a/src/server/src/routes/openai/chat_completions/types.rs +++ b/src/server/src/routes/openai/chat_completions/types.rs @@ -16,9 +16,9 @@ use crate::routes::openai::utils::types::{ /// vLLM-compatible request type for the Chat Completions API. /// -/// Mirrors the Python vLLM `ChatCompletionRequest` class. The local copy keeps the request type -/// route-owned so we can add vLLM-only fields directly instead of layering wrapper deserializers -/// on top. +/// Mirrors the Python vLLM `ChatCompletionRequest` class. The local copy keeps +/// the request type route-owned so we can add vLLM-only fields directly instead +/// of layering wrapper deserializers on top. #[serde_with::skip_serializing_none] #[derive(Debug, Clone, Deserialize, Serialize, Validate)] #[validate(schema(function = "validate_chat_cross_parameters"))] @@ -32,8 +32,8 @@ pub struct ChatCompletionRequest { #[serde(default = "default_model")] pub model: String, - /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing - /// frequency in the text so far + /// Number between -2.0 and 2.0. Positive values penalize new tokens based + /// on their existing frequency in the text so far #[validate(range(min = -2.0, max = 2.0))] pub frequency_penalty: Option, @@ -54,7 +54,8 @@ pub struct ChatCompletionRequest { #[validate(range(min = 1))] pub max_tokens: Option, - /// An upper bound for the number of tokens that can be generated for a completion + /// An upper bound for the number of tokens that can be generated for a + /// completion #[validate(range(min = 1))] pub max_completion_tokens: Option, @@ -62,15 +63,16 @@ pub struct ChatCompletionRequest { #[validate(range(min = 1, max = 10))] pub n: Option, - /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they - /// appear in the text so far + /// Number between -2.0 and 2.0. Positive values penalize new tokens based + /// on whether they appear in the text so far #[validate(range(min = -2.0, max = 2.0))] pub presence_penalty: Option, /// An object specifying the format that the model must output pub response_format: Option, - /// If specified, our system will make a best effort to sample deterministically + /// If specified, our system will make a best effort to sample + /// deterministically pub seed: Option, /// Up to 4 sequences where the API will stop generating further tokens @@ -98,7 +100,8 @@ pub struct ChatCompletionRequest { /// Controls which (if any) tool is called by the model pub tool_choice: Option, - /// Effort level for reasoning models (none, minimal, low, medium, high, xhigh, max) + /// Effort level for reasoning models (none, minimal, low, medium, high, + /// xhigh, max) pub reasoning_effort: Option, /// Whether to enable parallel function calling during tool use @@ -169,15 +172,16 @@ pub struct ChatCompletionRequest { #[serde(default = "default_true")] pub include_reasoning: bool, - /// If true, the new message will be prepended with the last message if they belong to the same - /// role. + /// If true, the new message will be prepended with the last message if they + /// belong to the same role. #[serde(default)] pub echo: bool, /// Whether to add the generation prompt to the chat template. /// - /// When omitted, the request follows the API default behavior, which is equivalent to `true` - /// unless `continue_final_message=true` selects final assistant continuation instead. + /// When omitted, the request follows the API default behavior, which is + /// equivalent to `true` unless `continue_final_message=true` selects + /// final assistant continuation instead. pub add_generation_prompt: Option, /// Continue generating from final assistant message @@ -212,7 +216,8 @@ pub struct ChatCompletionRequest { /// External request ID used for response correlation. pub request_id: Option, - /// Tokens represented as strings of the form 'token_id:{token_id}' in logprobs + /// Tokens represented as strings of the form 'token_id:{token_id}' in + /// logprobs pub return_tokens_as_token_ids: Option, /// Include token IDs alongside generated text @@ -224,7 +229,8 @@ pub struct ChatCompletionRequest { /// KV transfer parameters for disaggregated serving pub kv_transfer_params: Option>, - /// Additional request parameters with string or numeric values for custom extensions + /// Additional request parameters with string or numeric values for custom + /// extensions pub vllm_xargs: Option>, /// Parameters for detecting repetitive N-gram patterns in output tokens @@ -349,8 +355,8 @@ pub(super) struct ChatCompletionChoice { pub token_ids: Option>, } -/// A literal type for the "assistant" role, since the API only allows that specific value in -/// responses. +/// A literal type for the "assistant" role, since the API only allows that +/// specific value in responses. #[derive(Debug, Clone, Copy, PartialEq, Eq, SerializeDisplay)] pub(super) struct AssistantRole; diff --git a/src/server/src/routes/openai/chat_completions/validate.rs b/src/server/src/routes/openai/chat_completions/validate.rs index 3f4cfb99..c5a0070c 100644 --- a/src/server/src/routes/openai/chat_completions/validate.rs +++ b/src/server/src/routes/openai/chat_completions/validate.rs @@ -89,7 +89,8 @@ pub(super) fn validate_request_compat( ); } - // ---- Reject parameters that are accepted for deserialization but not yet implemented ---- + // ---- Reject parameters that are accepted for deserialization but not yet + // implemented ---- if request.parallel_tool_calls.is_some() { bail_invalid_request!( diff --git a/src/server/src/routes/openai/completions.rs b/src/server/src/routes/openai/completions.rs index 180e142f..cb843242 100644 --- a/src/server/src/routes/openai/completions.rs +++ b/src/server/src/routes/openai/completions.rs @@ -34,7 +34,8 @@ use crate::routes::openai::utils::validated_json::ValidatedJson; use crate::state::AppState; use crate::utils::{resolve_request_context, unix_timestamp}; -/// Validate one completions request and proxy it into the shared `vllm-text` stack. +/// Validate one completions request and proxy it into the shared `vllm-text` +/// stack. pub async fn completions( State(state): State>, headers: HeaderMap, @@ -55,11 +56,7 @@ pub async fn completions( ); let created = unix_timestamp(); - let include_prompt_logprobs = prepared - .text_request - .sampling_params - .prompt_logprobs - .is_some(); + let include_prompt_logprobs = prepared.text_request.sampling_params.prompt_logprobs.is_some(); let log_request = state.enable_log_requests; let text_stream = match state @@ -150,10 +147,7 @@ async fn collect_completion( .as_stop_reason() .map(|sr| serde_json::to_value(sr).expect("StopReason must serialize to JSON")); - let prompt_char_count = echo - .as_ref() - .map(|prompt| text_len(prompt)) - .unwrap_or_default(); + let prompt_char_count = echo.as_ref().map(|prompt| text_len(prompt)).unwrap_or_default(); let prompt_logprobs = if include_prompt_logprobs { let prompt_logprobs = collected.prompt_logprobs.as_ref().ok_or_else(|| { server_error!( @@ -541,10 +535,7 @@ mod tests { .collect::>() .await; - let chunks: Vec<_> = chunks - .into_iter() - .try_collect() - .expect("stream should succeed"); + let chunks: Vec<_> = chunks.into_iter().try_collect().expect("stream should succeed"); match &chunks[0] { CompletionSseChunk::Chunk(chunk) => { @@ -554,11 +545,7 @@ mod tests { vec!["h".to_string()] ); assert_eq!( - chunk.choices[0] - .logprobs - .as_ref() - .expect("logprobs") - .text_offset, + chunk.choices[0].logprobs.as_ref().expect("logprobs").text_offset, vec![0] ); } @@ -573,11 +560,7 @@ mod tests { vec!["!".to_string()] ); assert_eq!( - chunk.choices[0] - .logprobs - .as_ref() - .expect("logprobs") - .text_offset, + chunk.choices[0].logprobs.as_ref().expect("logprobs").text_offset, vec![1] ); } diff --git a/src/server/src/routes/openai/completions/convert.rs b/src/server/src/routes/openai/completions/convert.rs index 165def3f..a2a2baf2 100644 --- a/src/server/src/routes/openai/completions/convert.rs +++ b/src/server/src/routes/openai/completions/convert.rs @@ -6,7 +6,8 @@ use crate::routes::openai::completions::validate; use crate::routes::openai::utils::structured_outputs::convert_from_response_format_value; use crate::utils::{ResolvedRequestContext, convert_logit_bias, merge_kv_transfer_params}; -/// Lowered completion request plus the public response metadata carried by every SSE chunk. +/// Lowered completion request plus the public response metadata carried by +/// every SSE chunk. #[derive(Debug, Clone, PartialEq)] pub struct PreparedRequest { /// Stable OpenAI-style request ID, reused as the external text request ID. @@ -17,7 +18,8 @@ pub struct PreparedRequest { pub include_usage: bool, /// Lowered text request for the shared `vllm-text` facade. pub text_request: TextRequest, - /// Original text prompt that should be echoed back northbound when `echo=true`. + /// Original text prompt that should be echoed back northbound when + /// `echo=true`. pub echo: Option, /// Whether to include token IDs alongside generated text. pub return_token_ids: bool, @@ -25,7 +27,8 @@ pub struct PreparedRequest { pub return_tokens_as_token_ids: bool, } -/// Validate and lower one OpenAI completions request into the internal text-generation format. +/// Validate and lower one OpenAI completions request into the internal +/// text-generation format. pub(crate) fn prepare_completion_request( request: CompletionRequest, configured_model: &str, @@ -44,20 +47,15 @@ pub(crate) fn prepare_completion_request( })?), None => None, }; - let prompt_logprobs = request - .prompt_logprobs - .or(if request.echo && !request.stream { - logprobs - } else { - None - }); + let prompt_logprobs = request.prompt_logprobs.or(if request.echo && !request.stream { + logprobs + } else { + None + }); let include_usage = (request.stream_options.as_ref()) .and_then(|options| options.include_usage) .unwrap_or(false); - let echo = request - .echo - .then(|| request.prompt.as_text().cloned()) - .flatten(); + let echo = request.echo.then(|| request.prompt.as_text().cloned()).flatten(); let structured_outputs = convert_from_response_format_value(&request.response_format, &request.structured_outputs)?; diff --git a/src/server/src/routes/openai/completions/types.rs b/src/server/src/routes/openai/completions/types.rs index f8d21369..adc8a7ba 100644 --- a/src/server/src/routes/openai/completions/types.rs +++ b/src/server/src/routes/openai/completions/types.rs @@ -9,16 +9,18 @@ use crate::routes::openai::utils::types::{ LogProbs, Normalizable, StreamOptions, StringOrArray, Usage, default_true, validate_stop, }; -/// Serde default for `CompletionRequest::max_tokens`, matching the Python vLLM / OpenAI default. +/// Serde default for `CompletionRequest::max_tokens`, matching the Python vLLM +/// / OpenAI default. fn default_completion_max_tokens() -> Option { Some(16) } /// vLLM-compatible request type for the Completions API. /// -/// Mirrors the Python vLLM `CompletionRequest` class. The local copy keeps the request type -/// route-owned so we can accept token-id prompts via [`vllm_text::Prompt`] and add vLLM-only -/// fields directly instead of layering wrapper deserializers on top. +/// Mirrors the Python vLLM `CompletionRequest` class. The local copy keeps the +/// request type route-owned so we can accept token-id prompts via +/// [`vllm_text::Prompt`] and add vLLM-only fields directly instead of layering +/// wrapper deserializers on top. #[serde_with::skip_serializing_none] #[derive(Debug, Clone, Deserialize, Serialize, Validate)] pub struct CompletionRequest { @@ -35,8 +37,8 @@ pub struct CompletionRequest { #[serde(default)] pub echo: bool, - /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing - /// frequency in the text so far + /// Number between -2.0 and 2.0. Positive values penalize new tokens based + /// on their existing frequency in the text so far pub frequency_penalty: Option, /// Modify the likelihood of specified tokens appearing in the completion @@ -45,19 +47,20 @@ pub struct CompletionRequest { /// Include the log probabilities on the logprobs most likely tokens pub logprobs: Option, - /// The maximum number of tokens to generate (defaults to 16 when absent, matching the - /// Python vLLM / OpenAI API convention) + /// The maximum number of tokens to generate (defaults to 16 when absent, + /// matching the Python vLLM / OpenAI API convention) #[serde(default = "default_completion_max_tokens")] pub max_tokens: Option, /// How many completions to generate for each prompt pub n: Option, - /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they - /// appear in the text so far + /// Number between -2.0 and 2.0. Positive values penalize new tokens based + /// on whether they appear in the text so far pub presence_penalty: Option, - /// If specified, our system will make a best effort to sample deterministically + /// If specified, our system will make a best effort to sample + /// deterministically pub seed: Option, /// Up to 4 sequences where the API will stop generating further tokens @@ -136,7 +139,8 @@ pub struct CompletionRequest { #[serde(default = "default_true")] pub add_special_tokens: bool, - /// Format specification for structured output (JSON mode, JSON schema, etc.) + /// Format specification for structured output (JSON mode, JSON schema, + /// etc.) pub response_format: Option, /// Additional kwargs for structured outputs @@ -148,7 +152,8 @@ pub struct CompletionRequest { /// External request ID used for response correlation. pub request_id: Option, - /// Tokens represented as strings of the form 'token_id:{token_id}' in logprobs + /// Tokens represented as strings of the form 'token_id:{token_id}' in + /// logprobs pub return_tokens_as_token_ids: Option, /// Include token IDs alongside generated text @@ -160,7 +165,8 @@ pub struct CompletionRequest { /// KV transfer parameters for disaggregated serving pub kv_transfer_params: Option>, - /// Additional request parameters with string or numeric values for custom extensions + /// Additional request parameters with string or numeric values for custom + /// extensions pub vllm_xargs: Option>, /// Additional fields diff --git a/src/server/src/routes/openai/completions/validate.rs b/src/server/src/routes/openai/completions/validate.rs index 3246acea..bf4911fa 100644 --- a/src/server/src/routes/openai/completions/validate.rs +++ b/src/server/src/routes/openai/completions/validate.rs @@ -8,8 +8,9 @@ pub(super) fn validate_request_compat( request: &CompletionRequest, configured_model: &str, ) -> Result<(), ApiError> { - // This path is intentionally scoped to the minimum surface needed by `vllm-bench` random - // workload compatibility, so unsupported legacy completions features fail early here. + // This path is intentionally scoped to the minimum surface needed by + // `vllm-bench` random workload compatibility, so unsupported legacy + // completions features fail early here. if request.model != configured_model { return Err(ApiError::model_not_found(request.model.clone())); } @@ -72,7 +73,8 @@ pub(super) fn validate_request_compat( ); } - // ---- Reject parameters that are accepted for deserialization but not yet implemented ---- + // ---- Reject parameters that are accepted for deserialization but not yet + // implemented ---- if request.length_penalty.is_some() { bail_invalid_request!(param = "length_penalty", "length_penalty is not supported."); diff --git a/src/server/src/routes/openai/utils/logprobs.rs b/src/server/src/routes/openai/utils/logprobs.rs index bf54be8a..90052647 100644 --- a/src/server/src/routes/openai/utils/logprobs.rs +++ b/src/server/src/routes/openai/utils/logprobs.rs @@ -9,7 +9,8 @@ use vllm_text::{ use super::types::{ChatLogProbs, ChatLogProbsContent, LogProbs, TopLogProb}; use crate::error::{ApiError, server_error}; -/// Convert decoded token-position logprobs into the OpenAI completions `logprobs` shape. +/// Convert decoded token-position logprobs into the OpenAI completions +/// `logprobs` shape. pub fn decoded_logprobs_to_openai( logprobs: &DecodedLogprobs, initial_text_offset: u32, @@ -45,10 +46,11 @@ pub fn decoded_logprobs_to_openai( }) } -/// Convert decoded prompt logprobs into the OpenAI completions `logprobs` shape. +/// Convert decoded prompt logprobs into the OpenAI completions `logprobs` +/// shape. /// -/// The first prompt token is included with `None` logprob metadata, matching Python vLLM's -/// echoed completions behavior. +/// The first prompt token is included with `None` logprob metadata, matching +/// Python vLLM's echoed completions behavior. pub fn decoded_prompt_logprobs_to_openai( prompt_logprobs: &DecodedPromptLogprobs, initial_text_offset: u32, @@ -95,7 +97,8 @@ pub fn decoded_prompt_logprobs_to_openai( }) } -/// Convert decoded prompt logprobs into the vLLM-style prompt-logprobs response shape. +/// Convert decoded prompt logprobs into the vLLM-style prompt-logprobs response +/// shape. pub fn decoded_prompt_logprobs_to_maps( prompt_logprobs: &DecodedPromptLogprobs, return_tokens_as_token_ids: bool, @@ -110,7 +113,8 @@ pub fn decoded_prompt_logprobs_to_maps( .collect() } -/// Convert decoded token-position logprobs into the OpenAI chat `logprobs` shape. +/// Convert decoded token-position logprobs into the OpenAI chat `logprobs` +/// shape. pub fn decoded_logprobs_to_openai_chat( logprobs: &DecodedLogprobs, return_tokens_as_token_ids: bool, @@ -126,7 +130,8 @@ pub fn decoded_logprobs_to_openai_chat( }) } -/// Count visible text positions using OpenAI completions' character-offset convention. +/// Count visible text positions using OpenAI completions' character-offset +/// convention. pub fn text_len(text: &str) -> u32 { u32::try_from(text.chars().count()).unwrap_or(u32::MAX) } @@ -140,10 +145,12 @@ pub fn append_openai_logprobs(mut prefix: LogProbs, suffix: LogProbs) -> LogProb prefix } -/// Build the non-stream completions `logprobs` payload from collected text output. +/// Build the non-stream completions `logprobs` payload from collected text +/// output. /// -/// When `echoed_prompt` is true, the returned payload matches Python vLLM's echoed completions -/// behavior by concatenating prompt and completion logprobs into one OpenAI `LogProbs` object. +/// When `echoed_prompt` is true, the returned payload matches Python vLLM's +/// echoed completions behavior by concatenating prompt and completion logprobs +/// into one OpenAI `LogProbs` object. pub fn collected_logprobs_to_openai( collected: &CollectedTextOutput, echoed_prompt: bool, diff --git a/src/server/src/routes/openai/utils/structured_outputs.rs b/src/server/src/routes/openai/utils/structured_outputs.rs index 79d35492..e974c836 100644 --- a/src/server/src/routes/openai/utils/structured_outputs.rs +++ b/src/server/src/routes/openai/utils/structured_outputs.rs @@ -22,8 +22,8 @@ pub struct JsonSchemaFormat { /// Supported `response_format` types for chat and completion requests. /// -/// This is our own definition (rather than the `openai-protocol` crate's) so that we can support -/// the vLLM-specific `structural_tag` variant. +/// This is our own definition (rather than the `openai-protocol` crate's) so +/// that we can support the vLLM-specific `structural_tag` variant. /// /// Original Python definitions: /// @@ -35,18 +35,21 @@ pub enum ResponseFormat { JsonSchema { json_schema: JsonSchemaFormat, }, - /// vLLM-specific structural tag format. The entire object (including the `type` field) is - /// JSON-serialized and passed as `StructuredOutputsParams.structural_tag`. + /// vLLM-specific structural tag format. The entire object (including the + /// `type` field) is JSON-serialized and passed as + /// `StructuredOutputsParams.structural_tag`. /// - /// We capture the payload as a catch-all map so both the legacy (`structures`/`triggers`) - /// and current (`format`) shapes are preserved without needing typed structs. + /// We capture the payload as a catch-all map so both the legacy + /// (`structures`/`triggers`) and current (`format`) shapes are + /// preserved without needing typed structs. StructuralTag { #[serde(flatten)] extra: serde_json::Map, }, } -/// Convert an explicit `structured_outputs` JSON blob into [`StructuredOutputsParams`]. +/// Convert an explicit `structured_outputs` JSON blob into +/// [`StructuredOutputsParams`]. fn deserialize_structured_outputs( raw: &serde_json::Value, ) -> Result { @@ -58,11 +61,11 @@ fn deserialize_structured_outputs( }) } -/// Convert a typed [`ResponseFormat`] and/or raw `structured_outputs` blob into engine-core -/// [`StructuredOutputsParams`]. +/// Convert a typed [`ResponseFormat`] and/or raw `structured_outputs` blob into +/// engine-core [`StructuredOutputsParams`]. /// -/// Mirrors the Python vLLM conversion in `ChatCompletionRequest.to_sampling_params()`: -/// +/// Mirrors the Python vLLM conversion in +/// `ChatCompletionRequest.to_sampling_params()`: pub fn convert_from_response_format( response_format: Option<&ResponseFormat>, structured_outputs: &Option, @@ -101,10 +104,11 @@ pub fn convert_from_response_format( } } -/// Convert raw `response_format` and/or `structured_outputs` JSON blobs into engine-core -/// [`StructuredOutputsParams`]. +/// Convert raw `response_format` and/or `structured_outputs` JSON blobs into +/// engine-core [`StructuredOutputsParams`]. /// -/// Used by the completions endpoint which keeps both fields as opaque `serde_json::Value`. +/// Used by the completions endpoint which keeps both fields as opaque +/// `serde_json::Value`. pub fn convert_from_response_format_value( response_format: &Option, structured_outputs: &Option, diff --git a/src/server/src/routes/openai/utils/validated_json.rs b/src/server/src/routes/openai/utils/validated_json.rs index 700282af..d07158cc 100644 --- a/src/server/src/routes/openai/utils/validated_json.rs +++ b/src/server/src/routes/openai/utils/validated_json.rs @@ -10,11 +10,13 @@ use validator::Validate; use super::types::Normalizable; use crate::error::{ApiError, invalid_request}; -/// A JSON extractor that automatically validates and normalizes the request body. +/// A JSON extractor that automatically validates and normalizes the request +/// body. /// -/// This extractor deserializes the request body and automatically calls `.validate()` -/// on types that implement the `Validate` trait. If validation fails, it returns -/// [`ApiError::InvalidRequest`] with details about the validation errors. +/// This extractor deserializes the request body and automatically calls +/// `.validate()` on types that implement the `Validate` trait. If validation +/// fails, it returns [`ApiError::InvalidRequest`] with details about the +/// validation errors. pub struct ValidatedJson(pub T); impl FromRequest for ValidatedJson diff --git a/src/server/src/routes/tests.rs b/src/server/src/routes/tests.rs index d561bca3..482be563 100644 --- a/src/server/src/routes/tests.rs +++ b/src/server/src/routes/tests.rs @@ -138,9 +138,7 @@ fn default_stream_output_specs() -> Vec<(Vec, Option Vec<&str> { - text.lines() - .filter_map(|line| line.strip_prefix("data: ")) - .collect() + text.lines().filter_map(|line| line.strip_prefix("data: ")).collect() } type TestFuture<'a> = Pin + Send + 'a>>; @@ -729,9 +727,7 @@ async fn server_load(app: &axum::Router) -> u64 { .expect("call app"); assert_eq!(response.status(), StatusCode::OK); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); let value: serde_json::Value = serde_json::from_slice(&body).expect("json body"); value["server_load"].as_u64().expect("server_load") } @@ -750,9 +746,7 @@ async fn health_status(app: &axum::Router) -> (StatusCode, Bytes) { .expect("call app"); let status = response.status(); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); (status, body) } @@ -767,18 +761,13 @@ fn metric_value(rendered: &str, metric: &str, labels: Option<&str>) -> Option().ok() } else { None } } - None => rest - .strip_prefix(' ') - .and_then(|value| value.parse::().ok()), + None => rest.strip_prefix(' ').and_then(|value| value.parse::().ok()), } }) } @@ -798,19 +787,12 @@ fn metric_delta( async fn list_models_returns_configured_model() { let mut app = test_app().await; let response = app - .call( - Request::builder() - .uri("/v1/models") - .body(Body::empty()) - .expect("build request"), - ) + .call(Request::builder().uri("/v1/models").body(Body::empty()).expect("build request")) .await .expect("call app"); assert_eq!(response.status(), StatusCode::OK); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); let json: serde_json::Value = serde_json::from_slice(&body).expect("decode json"); assert_eq!(json["data"][0]["id"], "Qwen/Qwen1.5-0.5B-Chat"); } @@ -915,9 +897,7 @@ async fn invalid_request_returns_openai_error() { .expect("call app"); assert_eq!(response.status(), StatusCode::BAD_REQUEST); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); let json: serde_json::Value = serde_json::from_slice(&body).expect("decode json"); assert_eq!(json["error"]["type"], "invalid_request_error"); } @@ -955,9 +935,7 @@ async fn non_stream_chat_returns_json_response() { .is_some_and(|value| value.starts_with("application/json")) ); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); engine_task.await.expect("mock engine task"); let json: serde_json::Value = serde_json::from_slice(&body).expect("decode json"); @@ -1044,9 +1022,7 @@ async fn non_stream_chat_includes_logprobs_and_prompt_logprobs() { .expect("call app"); assert_eq!(response.status(), StatusCode::OK); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); engine_task.await.expect("mock engine task"); let json: serde_json::Value = serde_json::from_slice(&body).expect("decode json"); @@ -1089,16 +1065,11 @@ async fn happy_path_returns_sse_stream() { assert_eq!(response.status(), StatusCode::OK); assert_eq!( - response - .headers() - .get("content-type") - .and_then(|value| value.to_str().ok()), + response.headers().get("content-type").and_then(|value| value.to_str().ok()), Some("text/event-stream") ); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); engine_task.await.expect("mock engine task"); let text = String::from_utf8(body.to_vec()).expect("utf8 body"); let after = METRICS.render().unwrap(); @@ -1251,9 +1222,7 @@ async fn load_endpoint_tracks_chat_stream_lifecycle() { assert_eq!(response.status(), StatusCode::OK); assert_eq!(server_load(&app).await, 1); - let _body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let _body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); engine_task.await.expect("mock engine task"); assert_eq!(server_load(&app).await, 0); @@ -1369,9 +1338,7 @@ async fn stream_error_is_returned_as_openai_error_sse() { assert_eq!(response.status(), StatusCode::OK); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); engine_task.await.expect("mock engine task"); let text = String::from_utf8(body.to_vec()).expect("utf8 body"); @@ -1414,9 +1381,7 @@ async fn invalid_terminal_finish_reason_is_returned_as_openai_error_sse() { assert_eq!(response.status(), StatusCode::OK); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); engine_task.await.expect("mock engine task"); let text = String::from_utf8(body.to_vec()).expect("utf8 body"); @@ -1454,9 +1419,7 @@ async fn include_usage_adds_final_usage_chunk_before_done() { assert_eq!(response.status(), StatusCode::OK); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); engine_task.await.expect("mock engine task"); let text = String::from_utf8(body.to_vec()).expect("utf8 body"); @@ -1469,10 +1432,8 @@ async fn include_usage_adds_final_usage_chunk_before_done() { .iter() .position(|payload| payload.contains("\"usage\":")) .expect("usage chunk"); - let done_index = payloads - .iter() - .position(|payload| *payload == "[DONE]") - .expect("done sentinel"); + let done_index = + payloads.iter().position(|payload| *payload == "[DONE]").expect("done sentinel"); assert!(finish_index < usage_index, "{text}"); assert!(usage_index < done_index, "{text}"); @@ -1511,9 +1472,7 @@ async fn stream_without_include_usage_keeps_existing_shape() { assert_eq!(response.status(), StatusCode::OK); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); engine_task.await.expect("mock engine task"); let text = String::from_utf8(body.to_vec()).expect("utf8 body"); @@ -1547,9 +1506,7 @@ async fn completions_invalid_request_returns_openai_error() { .expect("call app"); assert_eq!(response.status(), StatusCode::BAD_REQUEST); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); let json: serde_json::Value = serde_json::from_slice(&body).expect("decode json"); assert_eq!(json["error"]["type"], "invalid_request_error"); } @@ -1587,9 +1544,7 @@ async fn non_stream_completions_return_json_response() { .is_some_and(|value| value.starts_with("application/json")) ); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); engine_task.await.expect("mock engine task"); let json: serde_json::Value = serde_json::from_slice(&body).expect("decode json"); @@ -1626,9 +1581,7 @@ async fn non_stream_completions_echo_prepends_prompt_text() { assert_eq!(response.status(), StatusCode::OK); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); engine_task.await.expect("mock engine task"); let json: serde_json::Value = serde_json::from_slice(&body).expect("decode json"); @@ -1718,9 +1671,7 @@ async fn non_stream_completions_include_logprobs() { .expect("call app"); assert_eq!(response.status(), StatusCode::OK); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); engine_task.await.expect("mock engine task"); let json: serde_json::Value = serde_json::from_slice(&body).expect("decode json"); @@ -1820,9 +1771,7 @@ async fn non_stream_completions_include_prompt_logprobs() { .expect("call app"); assert_eq!(response.status(), StatusCode::OK); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); engine_task.await.expect("mock engine task"); let json: serde_json::Value = serde_json::from_slice(&body).expect("decode json"); @@ -2025,9 +1974,7 @@ async fn chat_completions_header_request_id_takes_precedence() { .expect("call app"); assert_eq!(response.status(), StatusCode::OK); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); engine_task.await.expect("mock engine task"); let json: serde_json::Value = serde_json::from_slice(&body).expect("decode json"); @@ -2126,9 +2073,7 @@ async fn non_stream_raw_generate_returns_token_output_envelope() { .expect("call app"); assert_eq!(response.status(), StatusCode::OK); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); engine_task.await.expect("mock engine task"); let json: serde_json::Value = serde_json::from_slice(&body).expect("decode json"); @@ -2178,9 +2123,7 @@ async fn raw_generate_rejects_streaming() { .expect("call app"); assert_eq!(response.status(), StatusCode::BAD_REQUEST); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); let json: serde_json::Value = serde_json::from_slice(&body).expect("decode json"); assert_eq!(json["error"]["param"], "stream"); } @@ -2210,9 +2153,7 @@ async fn raw_generate_rejects_empty_token_ids() { .expect("call app"); assert_eq!(response.status(), StatusCode::BAD_REQUEST); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); let json: serde_json::Value = serde_json::from_slice(&body).expect("decode json"); assert_eq!(json["error"]["param"], "token_ids"); } @@ -2271,16 +2212,11 @@ async fn completions_happy_path_returns_sse_stream() { assert_eq!(response.status(), StatusCode::OK); assert_eq!( - response - .headers() - .get("content-type") - .and_then(|value| value.to_str().ok()), + response.headers().get("content-type").and_then(|value| value.to_str().ok()), Some("text/event-stream") ); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); engine_task.await.expect("mock engine task"); let text = String::from_utf8(body.to_vec()).expect("utf8 body"); let payloads = sse_data_payloads(&text); @@ -2288,21 +2224,15 @@ async fn completions_happy_path_returns_sse_stream() { .iter() .position(|payload| payload.contains("\"usage\":")) .expect("usage chunk"); - let done_index = payloads - .iter() - .position(|payload| *payload == "[DONE]") - .expect("done sentinel"); + let done_index = + payloads.iter().position(|payload| *payload == "[DONE]").expect("done sentinel"); assert!( - payloads - .iter() - .any(|payload| payload.contains("\"text\":\"h\"")), + payloads.iter().any(|payload| payload.contains("\"text\":\"h\"")), "{text}" ); assert!( - payloads - .iter() - .any(|payload| payload.contains("\"finish_reason\":\"stop\"")), + payloads.iter().any(|payload| payload.contains("\"finish_reason\":\"stop\"")), "{text}" ); assert!(usage_index < done_index, "{text}"); @@ -2341,9 +2271,7 @@ async fn completions_echo_stream_emits_separate_prompt_chunk() { assert_eq!(response.status(), StatusCode::OK); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); engine_task.await.expect("mock engine task"); let text = String::from_utf8(body.to_vec()).expect("utf8 body"); let payloads = sse_data_payloads(&text); @@ -2358,9 +2286,7 @@ async fn completions_echo_stream_emits_separate_prompt_chunk() { assert!(hello_index < h_index, "{text}"); assert!( - payloads - .iter() - .any(|payload| payload.contains("\"text\":\"i\"")), + payloads.iter().any(|payload| payload.contains("\"text\":\"i\"")), "{text}" ); @@ -2432,10 +2358,7 @@ async fn prepared_openai_request_streams_text_events() { ) .expect("prepare request"); - let mut stream = chat - .chat(prepared.chat_request) - .await - .expect("submit chat request"); + let mut stream = chat.chat(prepared.chat_request).await.expect("submit chat request"); let mut saw_text = false; let mut saw_done = false; @@ -2499,9 +2422,7 @@ async fn reasoning_blocks_are_mapped_to_reasoning_sse_chunks() { .expect("call app"); assert_eq!(response.status(), StatusCode::OK); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); engine_task.await.expect("mock engine task"); let text = String::from_utf8(body.to_vec()).expect("utf8 body"); @@ -2561,9 +2482,7 @@ async fn tool_calls_are_mapped_to_tool_call_sse_chunks() { .expect("call app"); assert_eq!(response.status(), StatusCode::OK); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); engine_task.await.expect("mock engine task"); let text = String::from_utf8(body.to_vec()).expect("utf8 body"); @@ -2713,9 +2632,7 @@ async fn tool_call_sse_chunks_can_carry_logprobs() { .expect("call app"); assert_eq!(response.status(), StatusCode::OK); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); engine_task.finish().await; let text = String::from_utf8(body.to_vec()).expect("utf8 body"); @@ -2788,9 +2705,7 @@ async fn reset_prefix_cache_route_sends_expected_utility_call() { .expect("call app"); let status = response.status(); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); assert_eq!(status, StatusCode::OK, "{}", String::from_utf8_lossy(&body)); assert!(body.is_empty()); engine_task.await.expect("mock engine task"); @@ -2829,9 +2744,7 @@ async fn reset_mm_cache_route_sends_expected_utility_call() { .expect("call app"); let status = response.status(); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); assert_eq!(status, StatusCode::OK, "{}", String::from_utf8_lossy(&body)); assert!(body.is_empty()); engine_task.await.expect("mock engine task"); @@ -2870,9 +2783,7 @@ async fn reset_encoder_cache_route_sends_expected_utility_call() { .expect("call app"); let status = response.status(); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); assert_eq!(status, StatusCode::OK, "{}", String::from_utf8_lossy(&body)); assert!(body.is_empty()); engine_task.await.expect("mock engine task"); @@ -2945,9 +2856,7 @@ async fn collective_rpc_route_sends_expected_utility_call_and_returns_results() .expect("call app"); assert_eq!(response.status(), StatusCode::OK); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); engine_task.await.expect("mock engine task"); assert_eq!( @@ -3001,9 +2910,7 @@ async fn sleep_route_uses_python_compatible_default_query_values() { .expect("call app"); let status = response.status(); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); assert_eq!(status, StatusCode::OK, "{}", String::from_utf8_lossy(&body)); assert!(body.is_empty()); engine_task.await.expect("mock engine task"); @@ -3042,9 +2949,7 @@ async fn wake_up_route_without_tags_sends_none() { .expect("call app"); let status = response.status(); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); assert_eq!(status, StatusCode::OK, "{}", String::from_utf8_lossy(&body)); assert!(body.is_empty()); engine_task.await.expect("mock engine task"); @@ -3083,9 +2988,7 @@ async fn is_sleeping_route_returns_json_payload() { .expect("call app"); assert_eq!(response.status(), StatusCode::OK); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); engine_task.await.expect("mock engine task"); assert_eq!( @@ -3135,7 +3038,8 @@ async fn admin_routes_are_hidden_when_dev_mode_is_disabled() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[serial] async fn non_stream_completions_stop_string_excluded_from_output() { - // Engine generates "say world" but stop string "wor" truncates output to "say ". + // Engine generates "say world" but stop string "wor" truncates output to "say + // ". let output_specs = vec![ (bytes_to_token_ids(b"say"), None), ( @@ -3168,9 +3072,7 @@ async fn non_stream_completions_stop_string_excluded_from_output() { assert_eq!(response.status(), StatusCode::OK); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); engine_task.await.expect("mock engine task"); let json: serde_json::Value = serde_json::from_slice(&body).expect("decode json"); @@ -3182,7 +3084,8 @@ async fn non_stream_completions_stop_string_excluded_from_output() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[serial] async fn non_stream_completions_stop_string_included_in_output() { - // Same tokens but include_stop_str_in_output=true includes the stop string in the output. + // Same tokens but include_stop_str_in_output=true includes the stop string in + // the output. let output_specs = vec![ (bytes_to_token_ids(b"say"), None), ( @@ -3216,9 +3119,7 @@ async fn non_stream_completions_stop_string_included_in_output() { assert_eq!(response.status(), StatusCode::OK); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); engine_task.await.expect("mock engine task"); let json: serde_json::Value = serde_json::from_slice(&body).expect("decode json"); @@ -3262,9 +3163,7 @@ async fn stream_completions_stop_string_excluded_from_output() { assert_eq!(response.status(), StatusCode::OK); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); engine_task.await.expect("mock engine task"); let text = String::from_utf8(body.to_vec()).expect("utf8 body"); let payloads = sse_data_payloads(&text); @@ -3286,9 +3185,7 @@ async fn stream_completions_stop_string_excluded_from_output() { // The final chunk should have finish_reason "stop". assert!( - payloads - .iter() - .any(|p| p.contains("\"finish_reason\":\"stop\"")), + payloads.iter().any(|p| p.contains("\"finish_reason\":\"stop\"")), "{text}" ); } @@ -3329,9 +3226,7 @@ async fn stream_completions_stop_string_included_in_output() { assert_eq!(response.status(), StatusCode::OK); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); engine_task.await.expect("mock engine task"); let text = String::from_utf8(body.to_vec()).expect("utf8 body"); let payloads = sse_data_payloads(&text); @@ -3351,9 +3246,7 @@ async fn stream_completions_stop_string_included_in_output() { assert_eq!(full_text, "say wor", "full streamed text: {text}"); assert!( - payloads - .iter() - .any(|p| p.contains("\"finish_reason\":\"stop\"")), + payloads.iter().any(|p| p.contains("\"finish_reason\":\"stop\"")), "{text}" ); } @@ -3361,7 +3254,8 @@ async fn stream_completions_stop_string_included_in_output() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[serial] async fn non_stream_completions_no_stop_string_match_preserves_original_finish_reason() { - // Stop string "xyz" does not appear in "hi!" so the original finish reason is preserved. + // Stop string "xyz" does not appear in "hi!" so the original finish reason is + // preserved. let (app, engine_task) = test_app_with_engine_handle().await; let response = app @@ -3387,13 +3281,12 @@ async fn non_stream_completions_no_stop_string_match_preserves_original_finish_r assert_eq!(response.status(), StatusCode::OK); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); engine_task.await.expect("mock engine task"); let json: serde_json::Value = serde_json::from_slice(&body).expect("decode json"); - // Default output is "hi" (stop token '!' suppressed), finish_reason remains "stop" from EOS. + // Default output is "hi" (stop token '!' suppressed), finish_reason remains + // "stop" from EOS. assert_eq!(json["choices"][0]["text"], "hi"); assert_eq!(json["choices"][0]["finish_reason"], "stop"); // No text stop string matched — stop_reason should be absent. @@ -3433,9 +3326,7 @@ async fn non_stream_completions_stop_string_array_matches_first_occurrence() { assert_eq!(response.status(), StatusCode::OK); - let body = to_bytes(response.into_body(), usize::MAX) - .await - .expect("read body"); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); engine_task.await.expect("mock engine task"); let json: serde_json::Value = serde_json::from_slice(&body).expect("decode json"); diff --git a/src/server/src/state.rs b/src/server/src/state.rs index bee2a3c9..32ebfc50 100644 --- a/src/server/src/state.rs +++ b/src/server/src/state.rs @@ -37,30 +37,36 @@ impl AppState { self } - /// Return a reference to the underlying engine core client for utility calls. + /// Return a reference to the underlying engine core client for utility + /// calls. pub(crate) fn engine_core_client(&self) -> &EngineCoreClient { self.chat.engine_core_client() } - /// Return the current in-flight inference request count for the `/load` endpoint. + /// Return the current in-flight inference request count for the `/load` + /// endpoint. pub fn server_load(&self) -> u64 { self.server_load.load(Ordering::Relaxed) } - /// Increment the in-flight inference request count, called by the load tracking middleware. + /// Increment the in-flight inference request count, called by the load + /// tracking middleware. pub(crate) fn increment_server_load(&self) { self.server_load.fetch_add(1, Ordering::Relaxed); } - /// Decrement the in-flight inference request count, called by the load tracking middleware. + /// Decrement the in-flight inference request count, called by the load + /// tracking middleware. pub(crate) fn decrement_server_load(&self) { self.server_load.fetch_sub(1, Ordering::Relaxed); } - /// Wait until all request-owned references are dropped, then shut down the engine client. + /// Wait until all request-owned references are dropped, then shut down the + /// engine client. /// - /// If the deadline elapses while request/connection tasks still hold state references, skip the - /// clean engine-client shutdown and let process teardown reclaim the remaining resources. + /// If the deadline elapses while request/connection tasks still hold state + /// references, skip the clean engine-client shutdown and let process + /// teardown reclaim the remaining resources. pub async fn shutdown(mut self: Arc, deadline: Instant) -> anyhow::Result<()> { loop { match Arc::try_unwrap(self) { diff --git a/src/server/src/utils.rs b/src/server/src/utils.rs index f9cc822c..13fa0dfa 100644 --- a/src/server/src/utils.rs +++ b/src/server/src/utils.rs @@ -27,8 +27,9 @@ pub fn utility_call_error(method: &str, error: impl AsReport) -> ApiError { ApiError::server_error(format!("failed to call {method}: {}", error.as_report())) } -/// Merge `kv_transfer_params` into the `vllm_xargs` map, mirroring the Python vLLM behavior -/// where `kv_transfer_params` is injected into `extra_args` for engine-core consumption. +/// Merge `kv_transfer_params` into the `vllm_xargs` map, mirroring the Python +/// vLLM behavior where `kv_transfer_params` is injected into `extra_args` for +/// engine-core consumption. pub fn merge_kv_transfer_params( mut xargs: Option>, kv_transfer_params: Option<&HashMap>, @@ -44,8 +45,9 @@ pub fn merge_kv_transfer_params( xargs } -/// Convert OpenAI-style `logit_bias` with string token-ID keys into the internal -/// `HashMap` representation, validating that every key parses as a `u32`. +/// Convert OpenAI-style `logit_bias` with string token-ID keys into the +/// internal `HashMap` representation, validating that every key +/// parses as a `u32`. pub fn convert_logit_bias( logit_bias: Option>, ) -> Result>, ApiError> { @@ -81,9 +83,7 @@ pub fn resolve_request_context( .and_then(|s| s.trim().parse().ok()); // Extract request id from header. - let request_id_header = headers - .get("X-Request-Id") - .and_then(|value| value.to_str().ok()); + let request_id_header = headers.get("X-Request-Id").and_then(|value| value.to_str().ok()); let request_id = resolve_base_request_id(request_id_header, request_id); ResolvedRequestContext { @@ -92,17 +92,15 @@ pub fn resolve_request_context( } } -/// Resolve the base external request ID before API-specific prefixes such as `chatcmpl-`. +/// Resolve the base external request ID before API-specific prefixes such as +/// `chatcmpl-`. pub fn resolve_base_request_id( request_id_header: Option<&str>, request_id: Option<&str>, ) -> String { - request_id_header - .or(request_id) - .map(ToOwned::to_owned) - .unwrap_or_else(|| { - let mut id = Uuid::new_v4().simple().to_string(); - id.truncate(8); - id - }) + request_id_header.or(request_id).map(ToOwned::to_owned).unwrap_or_else(|| { + let mut id = Uuid::new_v4().simple().to_string(); + id.truncate(8); + id + }) } diff --git a/src/text/src/backend/hf/config.rs b/src/text/src/backend/hf/config.rs index 9199574b..5f2ecf8b 100644 --- a/src/text/src/backend/hf/config.rs +++ b/src/text/src/backend/hf/config.rs @@ -14,9 +14,10 @@ pub struct HfTokenizerConfig { #[serde(flatten)] pub special_tokens: HfSpecialTokens, pub chat_template: Option, - /// The `tokenizer_class` field from HuggingFace tokenizer configs. Some tiktoken-based models - /// (e.g. DeepSeek, Kimi K2) set this to a value containing "Tiktoken" which can be used as a - /// hint for backend selection. + /// The `tokenizer_class` field from HuggingFace tokenizer configs. Some + /// tiktoken-based models (e.g. DeepSeek, Kimi K2) set this to a value + /// containing "Tiktoken" which can be used as a hint for backend + /// selection. pub tokenizer_class: Option, } @@ -79,11 +80,13 @@ impl HfSpecialTokens { /// Minimal subset of `config.json` (the model's main HF config). /// -/// This intentionally supports only the two layouts we currently care about in the Rust frontend: +/// This intentionally supports only the two layouts we currently care about in +/// the Rust frontend: /// - pure text models that keep text metadata at the top level /// - composite models that expose a single nested `text_config` /// -/// We do not support additional entry points such as `decoder`, `generator`, or `text_encoder`. +/// We do not support additional entry points such as `decoder`, `generator`, or +/// `text_encoder`. #[derive(Debug, Default, Deserialize)] #[serde(default)] pub struct ModelConfig { @@ -166,14 +169,14 @@ impl ModelConfig { self.text_config.as_deref().unwrap_or(self) } - /// Return the effective Hugging Face `model_type` used by the Rust frontend. + /// Return the effective Hugging Face `model_type` used by the Rust + /// frontend. /// - /// This follows the same simplified text-config selection as the rest of this type: the - /// top-level config wins, otherwise a single nested `text_config` may provide the value. + /// This follows the same simplified text-config selection as the rest of + /// this type: the top-level config wins, otherwise a single nested + /// `text_config` may provide the value. pub fn model_type(&self) -> Option<&str> { - self.model_type - .as_deref() - .or_else(|| self.text_config.as_deref()?.model_type()) + self.model_type.as_deref().or_else(|| self.text_config.as_deref()?.model_type()) } /// Reject partially nested `text_config` payloads that are unlikely to be @@ -194,7 +197,8 @@ impl ModelConfig { Ok(()) } - /// Match Python's current expert-count priority on the selected text config. + /// Match Python's current expert-count priority on the selected text + /// config. /// /// The only intentional simplification here is how we pick the text config: /// Rust only looks at the top level or `text_config`, not the broader @@ -364,10 +368,6 @@ mod tests { serde_json::from_str(r#"{"text_config":{"max_position_embeddings":4096}}"#).unwrap(); let error = config.validate_text_config_selection().unwrap_err(); - assert!( - error - .to_string() - .contains("does not have `num_attention_heads`"), - ); + assert!(error.to_string().contains("does not have `num_attention_heads`"),); } } diff --git a/src/text/src/backend/hf/mod.rs b/src/text/src/backend/hf/mod.rs index de9265f7..528f3d6b 100644 --- a/src/text/src/backend/hf/mod.rs +++ b/src/text/src/backend/hf/mod.rs @@ -33,8 +33,8 @@ pub struct HfTextBackend { primary_eos_token_id: Option, /// Additional EOS ids that should flow through stop-token handling. extra_eos_token_ids: BTreeSet, - /// Generation-config for sampling defaults that may be inherited when the user does not - /// explicitly override them. + /// Generation-config for sampling defaults that may be inherited when the + /// user does not explicitly override them. generation_config: GenerationConfig, /// Model config (`config.json`). model_config: ModelConfig, @@ -84,7 +84,8 @@ impl HfTextBackend { }) } - /// Expose the resolved model files for use by the chat backend to load the chat template. + /// Expose the resolved model files for use by the chat backend to load the + /// chat template. pub fn resolved_model_files(&self) -> &ResolvedModelFiles { &self.files } diff --git a/src/text/src/backend/hf/model_files.rs b/src/text/src/backend/hf/model_files.rs index a30a1f42..2a03796b 100644 --- a/src/text/src/backend/hf/model_files.rs +++ b/src/text/src/backend/hf/model_files.rs @@ -18,9 +18,10 @@ pub enum TokenizerSource { Tiktoken(PathBuf), /// Path to `tekken.json` when present (Mistral native tokenizer format). /// - /// When set, the Tekken tokenizer should be preferred over the Hugging Face tokenizer - /// because the HuggingFace `tokenizer.json` for Mistral models has a known regex bug that - /// produces incorrect token IDs for some inputs. + /// When set, the Tekken tokenizer should be preferred over the Hugging Face + /// tokenizer because the HuggingFace `tokenizer.json` for Mistral + /// models has a known regex bug that produces incorrect token IDs for + /// some inputs. Tekken(PathBuf), } @@ -44,8 +45,9 @@ pub struct ResolvedModelFiles { } impl ResolvedModelFiles { - /// Resolve tokenizer/config files from a local model directory first when `model_id` - /// points to one, otherwise consult the local HF cache and finally the Hub. + /// Resolve tokenizer/config files from a local model directory first when + /// `model_id` points to one, otherwise consult the local HF cache and + /// finally the Hub. pub async fn new(model_id: &str) -> Result { if Path::new(model_id).is_dir() { return resolve_local_model_files(Path::new(model_id)); @@ -104,17 +106,8 @@ async fn resolve_remote_model_files(model_id: &str) -> Result Some(download_known_file(&repo, model_id, name).await?), None => None, @@ -242,11 +235,14 @@ fn resolve_local_tokenizer_source( /// Choose the tokenizer. /// /// Selection order: -/// 1. `tekken.json` — Mistral native tokenizer (preferred over HF `tokenizer.json` because the HF -/// version has a known regex bug for Mistral models). -/// 2. File extension — `.tiktoken` / `tiktoken.model` files use tiktoken from BPE data. -/// 3. `tokenizer_class` in `tokenizer_config.json` — classes containing "Tiktoken" (case- -/// insensitive) trigger tiktoken loading from a sibling BPE file. +/// 1. `tekken.json` — Mistral native tokenizer (preferred over HF +/// `tokenizer.json` because the HF version has a known regex bug for Mistral +/// models). +/// 2. File extension — `.tiktoken` / `tiktoken.model` files use tiktoken from +/// BPE data. +/// 3. `tokenizer_class` in `tokenizer_config.json` — classes containing +/// "Tiktoken" (case- insensitive) trigger tiktoken loading from a sibling +/// BPE file. /// 4. Default — `tokenizer.json` in HuggingFace format. fn resolve_tokenizer_source( tokenizer_path: PathBuf, @@ -279,9 +275,7 @@ async fn download_if_present( filename: &str, ) -> Result> { match siblings.contains(filename) { - true => download_known_file(repo, model_id, filename) - .await - .map(Some), + true => download_known_file(repo, model_id, filename).await.map(Some), false => Ok(None), } } @@ -315,10 +309,7 @@ fn find_tiktoken_sibling<'a>(siblings: &std::collections::BTreeSet<&'a str>) -> if siblings.contains("tiktoken.model") { return Some("tiktoken.model"); } - siblings - .iter() - .copied() - .find(|name| name.ends_with(".tiktoken")) + siblings.iter().copied().find(|name| name.ends_with(".tiktoken")) } /// Discover a tiktoken model file in a local directory. @@ -348,8 +339,8 @@ pub(super) fn is_tiktoken_file(path: &std::path::Path) -> bool { .is_some_and(|name| name == "tiktoken.model" || name.ends_with(".tiktoken")) } -/// Chat templates are sometimes stored as dedicated .jinja files rather than as a fixed-name config -/// entry, so we scan the cached model dir. +/// Chat templates are sometimes stored as dedicated .jinja files rather than as +/// a fixed-name config entry, so we scan the cached model dir. fn discover_chat_template_in_dir(dir: &std::path::Path) -> Option { let json_template_path = dir.join("chat_template.json"); if json_template_path.exists() { @@ -361,15 +352,11 @@ fn discover_chat_template_in_dir(dir: &std::path::Path) -> Option { return Some(jinja_path); } - std::fs::read_dir(dir) - .ok()? - .flatten() - .map(|entry| entry.path()) - .find(|path| { - path.file_name() - .and_then(|name| name.to_str()) - .is_some_and(|name| name.ends_with(".jinja")) - }) + std::fs::read_dir(dir).ok()?.flatten().map(|entry| entry.path()).find(|path| { + path.file_name() + .and_then(|name| name.to_str()) + .is_some_and(|name| name.ends_with(".jinja")) + }) } #[cfg(test)] diff --git a/src/text/src/backend/mod.rs b/src/text/src/backend/mod.rs index 8e633a2e..82d947f9 100644 --- a/src/text/src/backend/mod.rs +++ b/src/text/src/backend/mod.rs @@ -5,8 +5,8 @@ use std::sync::Arc; use crate::error::Result; use crate::tokenizer::DynTokenizer; -/// Tokenizer/model-derived hints used to enrich text-generation requests before they are lowered -/// into engine-core. +/// Tokenizer/model-derived hints used to enrich text-generation requests before +/// they are lowered into engine-core. #[derive(Debug, Clone, Default, PartialEq)] pub struct SamplingHints { pub primary_eos_token_id: Option, @@ -17,7 +17,8 @@ pub struct SamplingHints { pub default_min_p: Option, pub default_repetition_penalty: Option, pub default_max_tokens: Option, - /// Model context window size (`max_position_embeddings` from `config.json`). + /// Model context window size (`max_position_embeddings` from + /// `config.json`). pub max_model_len: Option, } @@ -34,7 +35,8 @@ pub trait TextBackend: Send + Sync { /// Return the backend model ID. fn model_id(&self) -> &str; - /// Return tokenizer/model-derived hints used to enrich southbound sampling parameters. + /// Return tokenizer/model-derived hints used to enrich southbound sampling + /// parameters. fn sampling_hints(&self) -> Result { Ok(SamplingHints::default()) } diff --git a/src/text/src/incremental.rs b/src/text/src/incremental.rs index 1a715270..d8260767 100644 --- a/src/text/src/incremental.rs +++ b/src/text/src/incremental.rs @@ -5,7 +5,8 @@ use crate::tokenizer::Tokenizer; /// Stateful incremental decoder that emits text chunks one token at a time. pub trait IncrementalDecoder: Send { - /// Push one generated token and return how many new string bytes were added. + /// Push one generated token and return how many new string bytes were + /// added. fn push_token(&mut self, token_id: u32) -> Result; /// Consume any text which is currently ready. @@ -13,7 +14,8 @@ pub trait IncrementalDecoder: Send { /// Flush any remaining buffered text that has not yet been emitted. /// - /// Called after the final generated token to force out buffered/incomplete fragments. + /// Called after the final generated token to force out buffered/incomplete + /// fragments. fn flush(&mut self, truncate_output_to: Option) -> Result<(Option, String)>; /// Return cumulative decoded text so far. @@ -72,9 +74,8 @@ impl DecodeStream<'_, T> { let max_try = SAFE_SUFFIX_MAX.min(prompt_len - 1); for suffix_len in SAFE_SUFFIX_MIN..=max_try { let start = prompt_len - suffix_len; - let decoded = self - .tokenizer - .decode(&self.ids[start..], self.skip_special_tokens)?; + let decoded = + self.tokenizer.decode(&self.ids[start..], self.skip_special_tokens)?; if !decoded.contains('\u{FFFD}') { self.prefix = decoded; self.ids.drain(..start); @@ -114,10 +115,7 @@ impl IncrementalDecoder for DecodeStream<'_, T> { } fn next_chunk(&mut self) -> Option { - let cutoff = self - .cumulative_output - .len() - .saturating_sub(self.min_bytes_to_buffer); + let cutoff = self.cumulative_output.len().saturating_sub(self.min_bytes_to_buffer); (cutoff > self.output_index).then(|| { let chunk = self.cumulative_output[self.output_index..cutoff].to_string(); self.output_index = cutoff; @@ -294,8 +292,9 @@ mod tests { } /// Backend simulating non-monotonic decode where adding a token changes how - /// earlier tokens decode (context-dependent normalization), causing prefix_len - /// to land mid-UTF-8. Reproduces the class of bug from vllm-project/vllm#17448. + /// earlier tokens decode (context-dependent normalization), causing + /// prefix_len to land mid-UTF-8. Reproduces the class of bug from + /// vllm-project/vllm#17448. #[derive(Debug)] struct NonMonotonicBackend; diff --git a/src/text/src/lib.rs b/src/text/src/lib.rs index 66d69201..08f598e6 100644 --- a/src/text/src/lib.rs +++ b/src/text/src/lib.rs @@ -39,23 +39,26 @@ trait_set! { /// Raw text facade above [`Llm`]. /// -/// This layer stays below chat semantics: prompt text or prompt token IDs flow in, decoded text -/// deltas and terminal metadata flow out. +/// This layer stays below chat semantics: prompt text or prompt token IDs flow +/// in, decoded text deltas and terminal metadata flow out. pub struct TextLlm { /// Generate-only client owned by this text facade. llm: Llm, - /// Tokenizer/model metadata backend responsible for prompt encode/decode and sampling hints. + /// Tokenizer/model metadata backend responsible for prompt encode/decode + /// and sampling hints. backend: DynTextBackend, - /// Context window size derived by the backend or from engine startup handshake, with optional - /// override from config. + /// Context window size derived by the backend or from engine startup + /// handshake, with optional override from config. max_model_len: Option, } impl TextLlm { - /// Create a new text-generation facade from a shared LLM client plus a text backend. + /// Create a new text-generation facade from a shared LLM client plus a text + /// backend. pub fn new(llm: Llm, backend: DynTextBackend) -> Self { - // Prefer the engine-reported max_model_len because it reflects the post-profiling, - // auto-fitted KV cache limit rather than static frontend metadata. + // Prefer the engine-reported max_model_len because it reflects the + // post-profiling, auto-fitted KV cache limit rather than static + // frontend metadata. let max_model_len = llm.engine_core_client().max_model_len(); Self { @@ -67,8 +70,8 @@ impl TextLlm { /// Override the maximum model context length explicitly. /// - /// This takes priority over both the engine-reported default and any tokenizer/model metadata - /// exposed by the backend. + /// This takes priority over both the engine-reported default and any + /// tokenizer/model metadata exposed by the backend. pub fn with_max_model_len(mut self, max_model_len: u32) -> Self { self.max_model_len = Some(max_model_len); self @@ -79,7 +82,8 @@ impl TextLlm { self.backend.model_id() } - /// Expose the underlying engine-core client for low-level utility/admin calls. + /// Expose the underlying engine-core client for low-level utility/admin + /// calls. pub fn engine_core_client(&self) -> &EngineCoreClient { self.llm.engine_core_client() } @@ -89,13 +93,15 @@ impl TextLlm { self.backend.tokenizer() } - /// Tokenize if needed, lower to a generate request, and return the raw token stream. + /// Tokenize if needed, lower to a generate request, and return the raw + /// token stream. pub async fn generate_raw(&self, request: TextRequest) -> Result { let (_, raw_stream) = self.generate_inner(request).await?; Ok(raw_stream) } - /// Tokenize if needed, lower to a generate request, and stream incrementally decoded text. + /// Tokenize if needed, lower to a generate request, and stream + /// incrementally decoded text. pub async fn generate(&self, request: TextRequest) -> Result { let (text_request, raw_stream) = self.generate_inner(request).await?; let tokenizer = self.backend.tokenizer(); diff --git a/src/text/src/lower.rs b/src/text/src/lower.rs index 832324c0..e5e8e246 100644 --- a/src/text/src/lower.rs +++ b/src/text/src/lower.rs @@ -11,14 +11,15 @@ use crate::tokenizer::Tokenizer; /// One text request after it has been lowered into the raw generate boundary. #[derive(Debug)] pub struct PreparedTextRequest { - /// The original high-level request, preserved for response-side metadata and decoding options. + /// The original high-level request, preserved for response-side metadata + /// and decoding options. pub text_request: TextRequest, /// The southbound request ready to be sent to `vllm-llm`. pub generate_request: GenerateRequest, } -/// Convert a high-level [`TextRequest`] into one lower-level [`GenerateRequest`] ready for the -/// `llm` crate. +/// Convert a high-level [`TextRequest`] into one lower-level +/// [`GenerateRequest`] ready for the `llm` crate. pub fn lower_text_request( request: TextRequest, prompt_token_ids: Vec, @@ -51,8 +52,8 @@ pub fn lower_text_request( }) } -/// Convert [`SamplingParams`] into [`EngineCoreSamplingParams`], enriching omitted user values with -/// tokenizer/model-derived hints when available. +/// Convert [`SamplingParams`] into [`EngineCoreSamplingParams`], enriching +/// omitted user values with tokenizer/model-derived hints when available. pub fn lower_sampling_params( sampling_params: SamplingParams, SamplingHints { @@ -93,17 +94,16 @@ pub fn lower_sampling_params( vllm_xargs, } = sampling_params; - // Mirrors the model-generation-config inheritance used by vLLM's OpenAI chat path: - // https://github.com/vllm-project/vllm/blob/bc2c0c86efb28e77677a3cfb8687e976914a313a/vllm/entrypoints/openai/chat_completion/protocol.py#L424-L450 - // If neither the caller nor the model provides a value, fall back to 1.0 — the default - // used by the Python vLLM OpenAI-compatible API (via `_DEFAULT_SAMPLING_PARAMS`). + // Mirrors the model-generation-config inheritance used by vLLM's OpenAI chat + // path: https://github.com/vllm-project/vllm/blob/bc2c0c86efb28e77677a3cfb8687e976914a313a/vllm/entrypoints/openai/chat_completion/protocol.py#L424-L450 + // If neither the caller nor the model provides a value, fall back to 1.0 — the + // default used by the Python vLLM OpenAI-compatible API (via + // `_DEFAULT_SAMPLING_PARAMS`). let temperature = temperature.or(default_temperature).unwrap_or(1.0); let top_p = top_p.or(default_top_p).unwrap_or(1.0); let top_k = top_k.or(default_top_k).unwrap_or(0); let min_p = min_p.or(default_min_p).unwrap_or(0.0); - let repetition_penalty = repetition_penalty - .or(default_repetition_penalty) - .unwrap_or(1.0); + let repetition_penalty = repetition_penalty.or(default_repetition_penalty).unwrap_or(1.0); let max_tokens = resolve_max_tokens(max_tokens, default_max_tokens, max_model_len, prompt_len)?; let min_tokens = min_tokens.unwrap_or(0); let frequency_penalty = frequency_penalty.unwrap_or(0.0); @@ -146,12 +146,13 @@ pub fn lower_sampling_params( }) } -/// Convert bad-word strings into token-ID sequences, following the Python vLLM logic in -/// `SamplingParams.update_from_tokenizer()`. +/// Convert bad-word strings into token-ID sequences, following the Python vLLM +/// logic in `SamplingParams.update_from_tokenizer()`. /// -/// Each word is encoded both with and without a leading space so that the ban applies regardless of -/// whether the word appears at the beginning or in the middle of generated text (this accounts for -/// tokenizers that use an `add_prefix_space` convention). +/// Each word is encoded both with and without a leading space so that the ban +/// applies regardless of whether the word appears at the beginning or in the +/// middle of generated text (this accounts for tokenizers that use an +/// `add_prefix_space` convention). /// /// Reference: fn tokenize_bad_words( @@ -184,12 +185,13 @@ fn tokenize_bad_words( Ok((!all_token_ids.is_empty()).then_some(all_token_ids)) } -/// Resolve the effective `max_tokens` for generation, mirroring vLLM Python's `get_max_tokens()` -/// in `vllm/entrypoints/utils.py`. +/// Resolve the effective `max_tokens` for generation, mirroring vLLM Python's +/// `get_max_tokens()` in `vllm/entrypoints/utils.py`. /// -/// Takes the minimum of all available limits (user-specified, generation-config default, and -/// `max_model_len - prompt_len`). When nothing is known, falls back to `u32::MAX` so the -/// engine-core can apply its own context-window limit. +/// Takes the minimum of all available limits (user-specified, generation-config +/// default, and `max_model_len - prompt_len`). When nothing is known, falls +/// back to `u32::MAX` so the engine-core can apply its own context-window +/// limit. pub fn resolve_max_tokens( user_max_tokens: Option, default_max_tokens: Option, @@ -219,7 +221,8 @@ fn merge_unique_token_ids( stop_token_ids: &mut Vec, extra_token_ids: impl Iterator, ) { - // Keep user-provided ordering stable while still folding in backend-derived EOS aliases. + // Keep user-provided ordering stable while still folding in backend-derived EOS + // aliases. for token_id in extra_token_ids { if !stop_token_ids.contains(&token_id) { stop_token_ids.push(token_id); @@ -236,8 +239,8 @@ mod tests { use crate::backend::{SamplingHints, TextBackend as _}; use crate::request::{Prompt, TextRequest}; - /// Stub tokenizer that returns empty token IDs — sufficient for tests that don't exercise - /// bad-words tokenization. + /// Stub tokenizer that returns empty token IDs — sufficient for tests that + /// don't exercise bad-words tokenization. struct StubTokenizer; impl Tokenizer for StubTokenizer { diff --git a/src/text/src/output/decoded.rs b/src/text/src/output/decoded.rs index 02c4ccc6..9121dae8 100644 --- a/src/text/src/output/decoded.rs +++ b/src/text/src/output/decoded.rs @@ -21,8 +21,9 @@ pub struct TextDecodeOptions { pub skip_special_tokens: bool, pub include_stop_str_in_output: bool, pub stop_strings: Option>, - /// Minimum number of tokens to generate before stop-string checking kicks in. - /// Stop strings found within the first `min_tokens` tokens are ignored. + /// Minimum number of tokens to generate before stop-string checking kicks + /// in. Stop strings found within the first `min_tokens` tokens are + /// ignored. pub min_tokens: u32, } @@ -47,26 +48,31 @@ pub struct Finished { pub kv_transfer_params: Option, } -/// Internal decoded-text event emitted before higher-level assistant adaptation. +/// Internal decoded-text event emitted before higher-level assistant +/// adaptation. #[derive(Debug, Clone, PartialEq)] pub enum DecodedTextEvent { - /// The request has reached the point where prompt-scoped decoding metadata is ready. + /// The request has reached the point where prompt-scoped decoding metadata + /// is ready. Start { /// The actual prompt token IDs for this request. prompt_token_ids: Arc<[u32]>, /// Once-only prompt logprobs metadata, when requested. /// - /// The first prompt token is carried separately because it has no left context to score - /// against; `scored_positions` covers the remaining prompt positions. + /// The first prompt token is carried separately because it has no left + /// context to score against; `scored_positions` covers the + /// remaining prompt positions. prompt_logprobs: Option, }, - /// A delta of text has been decoded, optionally alongside token-position logprobs. + /// A delta of text has been decoded, optionally alongside token-position + /// logprobs. /// /// `delta` is the newly visible decoded text fragment for this update. /// - /// `logprobs` covers the newly generated token positions from the same update, but is not - /// guaranteed to align with `delta` by character span. One update may carry token logprobs - /// but no newly visible text yet, and one visible text fragment may reflect multiple token + /// `logprobs` covers the newly generated token positions from the same + /// update, but is not guaranteed to align with `delta` by character + /// span. One update may carry token logprobs but no newly visible text + /// yet, and one visible text fragment may reflect multiple token /// positions becoming decodable together. /// /// Upper-level may further parse `delta` as reasoning or tool calls. @@ -80,7 +86,8 @@ pub enum DecodedTextEvent { }, } -/// Convert the output token stream from the `vllm_llm` layer into incrementally decoded text. +/// Convert the output token stream from the `vllm_llm` layer into incrementally +/// decoded text. #[try_stream] pub async fn decoded_text_event_stream( request_id: String, @@ -101,9 +108,8 @@ pub async fn decoded_text_event_stream( // If it's the first output, init states and yield `Start` event. if decoder.is_none() { - let prompt_token_ids = output - .prompt_token_ids() - .expect("first llm output must carry prompt token ids"); + let prompt_token_ids = + output.prompt_token_ids().expect("first llm output must carry prompt token ids"); prompt_token_count = prompt_token_ids.len(); let dec = tokenizer.create_decode_stream( @@ -152,11 +158,7 @@ pub async fn decoded_text_event_stream( let decodable_token_ids = if suppress_terminal_stop_token { // Match Python V1 token-stop detokenization by keeping the stop token // in metadata while excluding it from user-visible text. - output - .token_ids - .split_last() - .map(|(_, rest)| rest) - .unwrap_or(&[]) + output.token_ids.split_last().map(|(_, rest)| rest).unwrap_or(&[]) } else { &output.token_ids }; @@ -254,8 +256,9 @@ pub async fn decoded_text_event_stream( trace!(full_text, "request finished with terminal decoded text"); } - // Intentionally drop the stream with explicit cause, so that the engine core can - // distinguish between such normal completion vs an unexpected early drop. + // Intentionally drop the stream with explicit cause, so that the engine core + // can distinguish between such normal completion vs an unexpected + // early drop. if stop_str_matched { AbortCause::StopStringMatched.drop_as(raw_stream); } @@ -290,7 +293,8 @@ pub async fn decoded_text_event_stream( } /// If stop string matches, returns tuple -/// (index into stop string vec, byte index of first byte of stop string in output) +/// (index into stop string vec, byte index of first byte of stop string in +/// output) fn matches_stop_string(stops: &[String], output: &str, new_bytes: usize) -> Option<(usize, usize)> { // We compare byte subslices to avoid utf8 boundary problem let output = output.as_bytes(); @@ -343,7 +347,8 @@ mod tests { } } - /// Helper: run `decoded_text_event_stream` to completion and return the collected output. + /// Helper: run `decoded_text_event_stream` to completion and return the + /// collected output. async fn run_to_completion( token_ids: Vec, decode_options: TextDecodeOptions, diff --git a/src/text/src/output/logprobs.rs b/src/text/src/output/logprobs.rs index d733ff1c..d1811b1a 100644 --- a/src/text/src/output/logprobs.rs +++ b/src/text/src/output/logprobs.rs @@ -39,19 +39,21 @@ pub struct DecodedPromptLogprobs { pub first_token_id: u32, /// Best-effort decoded string for the first prompt token. /// - /// The first prompt token has no left context to score against, so it is stored separately - /// instead of appearing in `scored_positions`. + /// The first prompt token has no left context to score against, so it is + /// stored separately instead of appearing in `scored_positions`. pub first_token: String, /// Scored prompt positions after the first prompt token. /// - /// `scored_positions[i]` corresponds to the prompt token at position `i + 1`. + /// `scored_positions[i]` corresponds to the prompt token at position `i + + /// 1`. pub scored_positions: Vec, } -/// Decode generated-token logprobs from the raw `llm` token-ID shape into the text-layer -/// decoded-token representation. +/// Decode generated-token logprobs from the raw `llm` token-ID shape into the +/// text-layer decoded-token representation. /// -/// Each returned position corresponds to one generated token position from the same `llm` update. +/// Each returned position corresponds to one generated token position from the +/// same `llm` update. pub(super) fn decode_logprobs( tokenizer: &T, logprobs: &Logprobs, @@ -66,11 +68,12 @@ pub(super) fn decode_logprobs( }) } -/// Decode prompt logprobs from the raw `llm` token-ID shape into the text-layer decoded-token -/// representation. +/// Decode prompt logprobs from the raw `llm` token-ID shape into the text-layer +/// decoded-token representation. /// -/// The returned payload stores the first prompt token separately and decodes the remaining scored -/// prompt positions into `scored_positions`, matching vLLM's prompt-logprobs semantics. +/// The returned payload stores the first prompt token separately and decodes +/// the remaining scored prompt positions into `scored_positions`, matching +/// vLLM's prompt-logprobs semantics. pub(super) fn decode_prompt_logprobs( tokenizer: &T, prompt_token_ids: &[u32], @@ -95,9 +98,11 @@ pub(super) fn decode_prompt_logprobs( }) } -/// Decode one token position's raw candidate set into decoded token strings plus logprob metadata. +/// Decode one token position's raw candidate set into decoded token strings +/// plus logprob metadata. /// -/// This decodes every candidate token ID independently through the active text backend. +/// This decodes every candidate token ID independently through the active text +/// backend. fn decode_position_logprobs( tokenizer: &T, position: &PositionLogprobs, @@ -108,14 +113,14 @@ fn decode_position_logprobs( .entries .iter() .map(|entry| { - tokenizer - .decode(&[entry.token_id], skip_special_tokens) - .map(|token| DecodedTokenLogprob { + tokenizer.decode(&[entry.token_id], skip_special_tokens).map(|token| { + DecodedTokenLogprob { token_id: entry.token_id, token, logprob: entry.logprob, rank: entry.rank, - }) + } + }) }) .try_collect()?, }) @@ -137,10 +142,7 @@ mod tests { fn decode(&self, token_ids: &[u32], _skip_special_tokens: bool) -> crate::Result { Ok(String::from_utf8_lossy( - &token_ids - .iter() - .map(|token_id| *token_id as u8) - .collect::>(), + &token_ids.iter().map(|token_id| *token_id as u8).collect::>(), ) .into_owned()) } diff --git a/src/text/src/output/mod.rs b/src/text/src/output/mod.rs index 325eff61..064b820d 100644 --- a/src/text/src/output/mod.rs +++ b/src/text/src/output/mod.rs @@ -30,7 +30,8 @@ pub struct CollectedTextOutput { #[allow(clippy::manual_async_fn, reason = "specify `Send` bound")] #[easy_ext::ext(TextOutputStreamExt)] impl T { - /// Collect the stream to completion and return the final decoded text plus terminal metadata. + /// Collect the stream to completion and return the final decoded text plus + /// terminal metadata. pub fn collect_output(self) -> impl Future> + Send { async move { let stream = self; @@ -87,8 +88,8 @@ impl T { } } - // Note: this is actually unreachable, as the underlying stream always emit an error on - // unexpected close. + // Note: this is actually unreachable, as the underlying stream always emit an + // error on unexpected close. Err(Error::StreamClosedBeforeTerminalOutput { request_id: "unknown".to_string(), }) diff --git a/src/text/src/request.rs b/src/text/src/request.rs index 24f10b2f..8a6154ca 100644 --- a/src/text/src/request.rs +++ b/src/text/src/request.rs @@ -10,14 +10,16 @@ use crate::output::TextDecodeOptions; /// One raw text-generation prompt. /// -/// This supports either ordinary text that still needs tokenization or already-tokenized prompt -/// IDs that should bypass tokenizer work entirely. +/// This supports either ordinary text that still needs tokenization or +/// already-tokenized prompt IDs that should bypass tokenizer work entirely. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, EnumAsInner)] #[serde(untagged)] pub enum Prompt { - /// Untokenized prompt text that still needs tokenizer work before generation. + /// Untokenized prompt text that still needs tokenizer work before + /// generation. Text(String), - /// Pre-tokenized prompt IDs that should be forwarded southbound without re-encoding. + /// Pre-tokenized prompt IDs that should be forwarded southbound without + /// re-encoding. TokenIds(Vec), } @@ -47,7 +49,8 @@ pub struct SamplingParams { pub top_k: Option, /// Random seed used by the sampler when present. pub seed: Option, - /// Maximum number of tokens to generate. `None` means no explicit user override. + /// Maximum number of tokens to generate. `None` means no explicit user + /// override. pub max_tokens: Option, /// Minimum number of tokens to generate before EOS or stop-token handling. pub min_tokens: Option, @@ -59,15 +62,20 @@ pub struct SamplingParams { /// /// `None` disables prompt logprobs. `-1` requests the full vocabulary. pub prompt_logprobs: Option, - /// Minimum probability threshold for token sampling. `None` means no explicit user override. + /// Minimum probability threshold for token sampling. `None` means no + /// explicit user override. pub min_p: Option, - /// Frequency penalty applied by the sampler. `None` means no explicit user override. + /// Frequency penalty applied by the sampler. `None` means no explicit user + /// override. pub frequency_penalty: Option, - /// Presence penalty applied by the sampler. `None` means no explicit user override. + /// Presence penalty applied by the sampler. `None` means no explicit user + /// override. pub presence_penalty: Option, - /// Repetition penalty applied by the sampler. `None` means no explicit user override. + /// Repetition penalty applied by the sampler. `None` means no explicit user + /// override. pub repetition_penalty: Option, - /// Explicit stop token IDs provided by the caller. `None` means no explicit user override. + /// Explicit stop token IDs provided by the caller. `None` means no explicit + /// user override. pub stop_token_ids: Option>, /// If true, do not stop on the model's primary EOS token. pub ignore_eos: bool, @@ -78,16 +86,19 @@ pub struct SamplingParams { pub allowed_token_ids: Option>, /// Words to avoid during generation (tokenized to IDs during lowering). pub bad_words: Option>, - /// Specific token IDs for which log probabilities should be returned at each position. + /// Specific token IDs for which log probabilities should be returned at + /// each position. /// - /// When set, the engine returns logprobs for exactly these tokens in addition to the - /// sampled/scored token. Mutually exclusive with `logprobs` in practice. + /// When set, the engine returns logprobs for exactly these tokens in + /// addition to the sampled/scored token. Mutually exclusive with + /// `logprobs` in practice. pub logprob_token_ids: Option>, /// Parameters for configuring structured outputs (guided decoding). pub structured_outputs: Option, - /// If true, bypass reads from the prefix cache for this request (the prompt will not - /// reuse cached KV blocks from earlier requests, though newly computed blocks may still - /// populate the cache). `None` defers to engine-core defaults. + /// If true, bypass reads from the prefix cache for this request (the prompt + /// will not reuse cached KV blocks from earlier requests, though newly + /// computed blocks may still populate the cache). `None` defers to + /// engine-core defaults. pub skip_reading_prefix_cache: Option, /// Additional request parameters for custom extensions. pub vllm_xargs: Option>, @@ -122,7 +133,8 @@ impl Default for SamplingParams { } } -/// One raw text-generation request ready to be tokenized or sent directly to the engine. +/// One raw text-generation request ready to be tokenized or sent directly to +/// the engine. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct TextRequest { /// Stable caller-supplied request ID. @@ -133,10 +145,12 @@ pub struct TextRequest { pub sampling_params: SamplingParams, /// Incremental detokenization options for the response path. pub decode_options: TextDecodeOptions, - /// Whether to emit intermediate northbound deltas before the terminal result. + /// Whether to emit intermediate northbound deltas before the terminal + /// result. /// - /// If `false`, callers only observe the terminal accumulated output. If `true`, callers may - /// receive zero or more incremental decoded updates before the final terminal event. + /// If `false`, callers only observe the terminal accumulated output. If + /// `true`, callers may receive zero or more incremental decoded updates + /// before the final terminal event. pub intermediate: bool, /// Request scheduling priority (lower means earlier handling; default 0). pub priority: i32, diff --git a/src/text/src/tokenizer/byte_level_decode.rs b/src/text/src/tokenizer/byte_level_decode.rs index 5733d4db..8208e8e8 100644 --- a/src/text/src/tokenizer/byte_level_decode.rs +++ b/src/text/src/tokenizer/byte_level_decode.rs @@ -106,10 +106,8 @@ mod tests { fn decode_multibyte_euro() { // € → 0xE2 0x82 0xAC, each mapped to a specific GPT-2 char. let byte_to_char = build_byte_to_char_ref(); - let encoded: String = [0xE2u8, 0x82, 0xAC] - .iter() - .map(|&b| byte_to_char[b as usize]) - .collect(); + let encoded: String = + [0xE2u8, 0x82, 0xAC].iter().map(|&b| byte_to_char[b as usize]).collect(); assert_eq!(decode_byte_level([encoded.as_str()]), "€"); } diff --git a/src/text/src/tokenizer/hf.rs b/src/text/src/tokenizer/hf.rs index 614922c6..949ca67b 100644 --- a/src/text/src/tokenizer/hf.rs +++ b/src/text/src/tokenizer/hf.rs @@ -51,9 +51,9 @@ fn decode_fastokens_byte_level( /// Tokenizer from `tokenizer.json` in HuggingFace format. /// -/// This tries to load with `fastokens` first for better performance, then falls back to -/// HuggingFace's `tokenizers` if the former fails (e.g. due to unsupported tokenizer features or -/// file formats). +/// This tries to load with `fastokens` first for better performance, then falls +/// back to HuggingFace's `tokenizers` if the former fails (e.g. due to +/// unsupported tokenizer features or file formats). pub struct HuggingFaceTokenizer { backend: Backend, special_token_ids: Arc<[u32]>, @@ -146,11 +146,11 @@ impl Tokenizer for HuggingFaceTokenizer { })?; Ok(encoding.get_ids().to_vec()) } - Backend::Fastokens(t) | Backend::FastokensByteLevel(t) => t - .encode_with_special_tokens(text, add_special_tokens) - .map_err(|error| { + Backend::Fastokens(t) | Backend::FastokensByteLevel(t) => { + t.encode_with_special_tokens(text, add_special_tokens).map_err(|error| { Error::Tokenizer(format!("encoding failed: {}", error.as_report())) - }), + }) + } } } @@ -224,9 +224,7 @@ mod tests { tokenizer.save(&path, false).expect("save tokenizer json"); let wrapper = HuggingFaceTokenizer::new_hf(&path).expect("load hf wrapper"); - let special_id = wrapper - .token_to_id("<|im_end|>") - .expect("resolve added special token id"); + let special_id = wrapper.token_to_id("<|im_end|>").expect("resolve added special token id"); assert!(wrapper.is_special_id(special_id)); } @@ -245,9 +243,7 @@ mod tests { wrapper.backend, super::Backend::Fastokens(_) | super::Backend::FastokensByteLevel(_), )); - let special_id = wrapper - .token_to_id("<|im_end|>") - .expect("resolve added special token id"); + let special_id = wrapper.token_to_id("<|im_end|>").expect("resolve added special token id"); assert!(wrapper.is_special_id(special_id)); } diff --git a/src/text/src/tokenizer/mod.rs b/src/text/src/tokenizer/mod.rs index 260c3a2d..a5992898 100644 --- a/src/text/src/tokenizer/mod.rs +++ b/src/text/src/tokenizer/mod.rs @@ -19,8 +19,8 @@ pub trait Tokenizer: Send + Sync { /// Decode one token sequence into text. fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result; - /// Convert one token string into a token ID, returning `None` if the token is not in the - /// tokenizer vocabulary. + /// Convert one token string into a token ID, returning `None` if the token + /// is not in the tokenizer vocabulary. fn token_to_id(&self, token: &str) -> Option; /// Return whether the given token ID is special. @@ -28,10 +28,11 @@ pub trait Tokenizer: Send + Sync { false } - /// Create a stateful incremental decoder primed with the given prompt tokens. + /// Create a stateful incremental decoder primed with the given prompt + /// tokens. /// - /// The prompt tokens provide left context for the first generated token; the decoder does not - /// re-emit prompt text. + /// The prompt tokens provide left context for the first generated token; + /// the decoder does not re-emit prompt text. fn create_decode_stream( &self, prompt_token_ids: &[u32], diff --git a/src/text/src/tokenizer/tekken.rs b/src/text/src/tokenizer/tekken.rs index a8e8c2c5..e059aba4 100644 --- a/src/text/src/tokenizer/tekken.rs +++ b/src/text/src/tokenizer/tekken.rs @@ -46,8 +46,8 @@ impl Tokenizer for TekkenTokenizer { } fn token_to_id(&self, token: &str) -> Option { - // tekken-rs exposes `get_control_token` for special tokens. Try that first, then - // fall back to encoding. + // tekken-rs exposes `get_control_token` for special tokens. Try that first, + // then fall back to encoding. self.inner.get_control_token(token).ok().or_else(|| { let ids = self.inner.encode(token, false, false).ok()?; if ids.len() == 1 { Some(ids[0]) } else { None } diff --git a/src/text/src/tokenizer/tiktoken.rs b/src/text/src/tokenizer/tiktoken.rs index a39e8431..f130d747 100644 --- a/src/text/src/tokenizer/tiktoken.rs +++ b/src/text/src/tokenizer/tiktoken.rs @@ -12,32 +12,34 @@ use super::Tokenizer; use crate::Error; use crate::error::Result; -/// Default regex pattern used when loading tiktoken from a BPE file. This is the same -/// `cl100k_base` pattern that HuggingFace transformers uses as its default in -/// `TikTokenConverter`. +/// Default regex pattern used when loading tiktoken from a BPE file. This is +/// the same `cl100k_base` pattern that HuggingFace transformers uses as its +/// default in `TikTokenConverter`. const CL100K_BASE_PATTERN: &str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"; /// Kimi BPE pattern from `moonshotai/Kimi-K2-Instruct/tokenization_kimi.py`. const KIMI_PATTERN: &str = r"[\p{Han}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"; -/// Fallback number of reserved special-token slots to assume when the model's `config.json` -/// is not available (so we cannot read `vocab_size` directly). +/// Fallback number of reserved special-token slots to assume when the model's +/// `config.json` is not available (so we cannot read `vocab_size` directly). /// /// 256 is the value used by Kimi K2 / K2.5 (`tokenization_kimi.py`'s -/// `num_reserved_special_tokens`) and by Llama 3, and it appears to be the most common -/// convention among modern tiktoken-based HF tokenizers. When `config.json` *is* present we -/// honour the model's actual `vocab_size` instead of this fallback — see `Self::new`. +/// `num_reserved_special_tokens`) and by Llama 3, and it appears to be the most +/// common convention among modern tiktoken-based HF tokenizers. When +/// `config.json` *is* present we honour the model's actual `vocab_size` instead +/// of this fallback — see `Self::new`. const FALLBACK_NUM_RESERVED_SPECIAL_TOKENS: u32 = 256; /// Parsed entry from `tokenizer_config.json`'s `added_tokens_decoder`. #[derive(Debug, Clone, Deserialize)] struct AddedToken { content: String, - /// HuggingFace `added_tokens_decoder` entries can be marked `"special": true|false`. - /// Special tokens are dropped from output when `decode` is called with - /// `skip_special_tokens = true`. Defaults to `false` when the field is omitted, matching - /// HuggingFace's `AddedToken` default — so only tokens explicitly marked special are - /// stripped during normal decode (where `skip_special_tokens` itself defaults to true). + /// HuggingFace `added_tokens_decoder` entries can be marked `"special": + /// true|false`. Special tokens are dropped from output when `decode` is + /// called with `skip_special_tokens = true`. Defaults to `false` when + /// the field is omitted, matching HuggingFace's `AddedToken` default — + /// so only tokens explicitly marked special are stripped during normal + /// decode (where `skip_special_tokens` itself defaults to true). #[serde(default)] special: bool, } @@ -47,7 +49,8 @@ struct AddedToken { #[serde(default)] struct TiktokenTokenizerConfig { /// Format: - /// `{ "added_tokens_decoder": { "163584": { "content": "[BOS]", "special": true }, ... } }` + /// `{ "added_tokens_decoder": { "163584": { "content": "[BOS]", "special": + /// true }, ... } }` #[serde(default)] added_tokens_decoder: FxHashMap, } @@ -62,65 +65,77 @@ struct TiktokenModelConfig { } impl TiktokenModelConfig { - /// Read `model_type` from a model `config.json` value, falling back to a single-level nested - /// `text_config.model_type` for composite (e.g. multimodal) configs that keep text metadata - /// under a `text_config` object. + /// Read `model_type` from a model `config.json` value, falling back to a + /// single-level nested `text_config.model_type` for composite (e.g. + /// multimodal) configs that keep text metadata under a `text_config` + /// object. fn effective_model_type(&self) -> Option<&str> { self.model_type .as_deref() .or_else(|| self.text_config.as_deref()?.effective_model_type()) } - /// Read `vocab_size` from a model `config.json` value, falling back to a single-level nested - /// `text_config.vocab_size` for composite (e.g. multimodal) configs that keep text metadata - /// under a `text_config` object — matching the same shape `ModelConfig` parses. + /// Read `vocab_size` from a model `config.json` value, falling back to a + /// single-level nested `text_config.vocab_size` for composite (e.g. + /// multimodal) configs that keep text metadata under a `text_config` + /// object — matching the same shape `ModelConfig` parses. fn effective_vocab_size(&self) -> Option { - self.vocab_size - .or_else(|| self.text_config.as_deref()?.effective_vocab_size()) + self.vocab_size.or_else(|| self.text_config.as_deref()?.effective_vocab_size()) } } /// Tiktoken tokenizer from `tiktoken.model` or `*.tiktoken` BPE files. pub struct TiktokenTokenizer { inner: CoreBPE, - /// Number of regular BPE tokens. Token ids in `[0, num_base_tokens)` are BPE tokens that - /// always decode to text; ids in `[num_base_tokens, vocab_upper_bound)` live in the - /// special-token slots and are subject to `skip_special_tokens` filtering. + /// Number of regular BPE tokens. Token ids in `[0, num_base_tokens)` are + /// BPE tokens that always decode to text; ids in `[num_base_tokens, + /// vocab_upper_bound)` live in the special-token slots and are subject + /// to `skip_special_tokens` filtering. num_base_tokens: u32, - /// Exclusive upper bound on token IDs that `inner` is guaranteed to know how to decode. + /// Exclusive upper bound on token IDs that `inner` is guaranteed to know + /// how to decode. /// - /// The constructor registers every id in `[num_base_tokens, vocab_upper_bound)` with the - /// inner `CoreBPE` as a (named or `<|reserved_token_{id}|>`) special token, and the BPE - /// encoder densely covers `[0, num_base_tokens)`. So any id below this bound is in one of - /// the inner `CoreBPE`'s decoder maps and `_decode_native_and_split` will not panic on it. - /// `decode` filters out ids at or above this bound to keep that guarantee. + /// The constructor registers every id in `[num_base_tokens, + /// vocab_upper_bound)` with the inner `CoreBPE` as a (named or + /// `<|reserved_token_{id}|>`) special token, and the BPE + /// encoder densely covers `[0, num_base_tokens)`. So any id below this + /// bound is in one of the inner `CoreBPE`'s decoder maps and + /// `_decode_native_and_split` will not panic on it. `decode` filters + /// out ids at or above this bound to keep that guarantee. vocab_upper_bound: u32, - /// Ids in `[num_base_tokens, vocab_upper_bound)` whose `added_tokens_decoder` entry was - /// explicitly marked `"special": false` — i.e. tokens that should still appear in output - /// even when `skip_special_tokens = true`. For Kimi K2 / K2.5 this typically holds the - /// tool-call markers and `` / ``. Reserved-slot placeholders are not in - /// this set (they default to special and get skipped). + /// Ids in `[num_base_tokens, vocab_upper_bound)` whose + /// `added_tokens_decoder` entry was explicitly marked `"special": + /// false` — i.e. tokens that should still appear in output + /// even when `skip_special_tokens = true`. For Kimi K2 / K2.5 this + /// typically holds the tool-call markers and `` / ``. + /// Reserved-slot placeholders are not in this set (they default to + /// special and get skipped). non_special_added_ids: FxHashSet, - /// Reverse map for special / added token strings populated from the reserved range. This lets - /// `token_to_id` answer special-token lookups directly without round-tripping through - /// `tiktoken-rs`'s encoder, which can panic for unknown special-looking strings. + /// Reverse map for special / added token strings populated from the + /// reserved range. This lets `token_to_id` answer special-token lookups + /// directly without round-tripping through `tiktoken-rs`'s encoder, + /// which can panic for unknown special-looking strings. special_token_ids_by_text: FxHashMap, - /// Set of out-of-vocab token IDs we have already warned about. The reserved-slot population - /// in the constructor should keep this empty under normal operation; it only fills up if a - /// model emits ids at or above `vocab_upper_bound` (e.g. an engine sampling bug). We dedupe - /// so streaming decode (which calls `decode` repeatedly on the same prefix) does not spam. + /// Set of out-of-vocab token IDs we have already warned about. The + /// reserved-slot population in the constructor should keep this empty + /// under normal operation; it only fills up if a model emits ids at or + /// above `vocab_upper_bound` (e.g. an engine sampling bug). We dedupe + /// so streaming decode (which calls `decode` repeatedly on the same prefix) + /// does not spam. warned_unknown_ids: Mutex>, } impl TiktokenTokenizer { - /// Load a tiktoken tokenizer from a `.tiktoken` / `tiktoken.model` BPE file. + /// Load a tiktoken tokenizer from a `.tiktoken` / `tiktoken.model` BPE + /// file. /// - /// The BPE file format is one ` ` pair per line, the same format - /// used by OpenAI's tiktoken and by HuggingFace model repos that ship tiktoken files (e.g. - /// DeepSeek, Kimi K2). + /// The BPE file format is one ` ` pair per line, + /// the same format used by OpenAI's tiktoken and by HuggingFace model + /// repos that ship tiktoken files (e.g. DeepSeek, Kimi K2). /// - /// Special / added tokens are read from `tokenizer_config.json` in the same directory when - /// present. The `cl100k_base` regex pattern is used as a reasonable default. + /// Special / added tokens are read from `tokenizer_config.json` in the same + /// directory when present. The `cl100k_base` regex pattern is used as a + /// reasonable default. pub fn new(path: &Path) -> Result { info!(path = %path.display(), "loading tokenizer with tiktoken (BPE file)"); @@ -145,9 +160,8 @@ impl TiktokenTokenizer { let rank_str = parts .next() .ok_or_else(|| Error::Tokenizer("missing rank in tiktoken file".to_string()))?; - let token_bytes = base64::engine::general_purpose::STANDARD - .decode(token_b64) - .map_err(|error| { + let token_bytes = + base64::engine::general_purpose::STANDARD.decode(token_b64).map_err(|error| { Error::Tokenizer(format!("invalid base64 in tiktoken file: {error}")) })?; let rank: u32 = rank_str.parse().map_err(|error| { @@ -158,8 +172,8 @@ impl TiktokenTokenizer { let parent_dir = path.parent(); - // Read added/special tokens (id → {name, special}) from tokenizer_config.json in the - // same dir. + // Read added/special tokens (id → {name, special}) from tokenizer_config.json + // in the same dir. let added_tokens_by_id = parent_dir .map(|dir| dir.join("tokenizer_config.json")) .filter(|p| p.exists()) @@ -170,8 +184,8 @@ impl TiktokenTokenizer { .map(|config: TiktokenTokenizerConfig| config.added_tokens_decoder) .unwrap_or_default(); - // Read `config.json` once so both `vocab_size` and model-specific tokenizer behavior can - // be derived from the same source of truth. + // Read `config.json` once so both `vocab_size` and model-specific tokenizer + // behavior can be derived from the same source of truth. let model_config: Option = parent_dir .map(|dir| dir.join("config.json")) .filter(|p| p.exists()) @@ -181,30 +195,33 @@ impl TiktokenTokenizer { }); let vocab_size_from_config = model_config.as_ref().and_then(|c| c.effective_vocab_size()); - // Build the full special-tokens encoder by populating the reserved range that follows - // the BPE vocabulary. The Python reference does this in `tokenization_kimi.py`: + // Build the full special-tokens encoder by populating the reserved range that + // follows the BPE vocabulary. The Python reference does this in + // `tokenization_kimi.py`: // - // for i in range(num_base_tokens, num_base_tokens + num_reserved_special_tokens): - // name = added_tokens_decoder.get(i, f"<|reserved_token_{i}|>") + // for i in range(num_base_tokens, num_base_tokens + + // num_reserved_special_tokens): name = + // added_tokens_decoder.get(i, f"<|reserved_token_{i}|>") // - // The same idea generalises to any tiktoken-based HF model: any id that the model is - // allowed to sample but is not listed in `added_tokens_decoder` is a "reserved" slot - // that should still decode to *something* rather than panic. Without this step, the - // model could emit a reserved id (e.g. id 163589 for Kimi K2.5) and decoding would + // The same idea generalises to any tiktoken-based HF model: any id that the + // model is allowed to sample but is not listed in + // `added_tokens_decoder` is a "reserved" slot that should still decode + // to *something* rather than panic. Without this step, the model could + // emit a reserved id (e.g. id 163589 for Kimi K2.5) and decoding would // panic in `CoreBPE::_decode_native_and_split`. // // We size the reserved range using whichever upper bound is largest: - // 1. `vocab_size` from config.json if present (the accurate, per-model answer), + // 1. `vocab_size` from config.json if present (the accurate, per-model + // answer), // 2. otherwise `num_base_tokens + 256` (the Kimi/Llama 3 default convention), - // 3. extended further to cover any explicit `added_tokens_decoder` id beyond either. + // 3. extended further to cover any explicit `added_tokens_decoder` id beyond + // either. // - // Note: `*.tiktoken` ranks are token ids, and they are not guaranteed to be contiguous. - // We therefore define the base-vocab boundary as `max_rank + 1`, not `encoder.len()`. - let num_base_tokens = encoder - .values() - .copied() - .max() - .map_or(0, |max_rank| max_rank.saturating_add(1)); + // Note: `*.tiktoken` ranks are token ids, and they are not guaranteed to be + // contiguous. We therefore define the base-vocab boundary as `max_rank + // + 1`, not `encoder.len()`. + let num_base_tokens = + encoder.values().copied().max().map_or(0, |max_rank| max_rank.saturating_add(1)); let max_added_id = added_tokens_by_id.keys().copied().max().unwrap_or(0); let reserved_end = vocab_size_from_config .unwrap_or_else(|| num_base_tokens.saturating_add(FALLBACK_NUM_RESERVED_SPECIAL_TOKENS)) @@ -230,9 +247,7 @@ impl TiktokenTokenizer { special_tokens_encoder.insert(name, id); } - let pattern = model_config - .as_ref() - .map_or(CL100K_BASE_PATTERN, detect_bpe_pattern); + let pattern = model_config.as_ref().map_or(CL100K_BASE_PATTERN, detect_bpe_pattern); let special_token_ids_by_text = special_tokens_encoder.clone(); let bpe = CoreBPE::new(encoder, special_tokens_encoder, pattern).map_err(|error| { Error::Tokenizer(format!( @@ -251,8 +266,9 @@ impl TiktokenTokenizer { }) } - /// Log a warning the first time an unknown token id is seen during decode, deduped across - /// calls so streaming decode does not spam the log for the same id. + /// Log a warning the first time an unknown token id is seen during decode, + /// deduped across calls so streaming decode does not spam the log for + /// the same id. fn warn_unknown_id(&self, token_id: u32) { let newly_inserted = self .warned_unknown_ids @@ -280,23 +296,25 @@ impl Tokenizer for TiktokenTokenizer { fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result { // Filter passes: // - // 1. Ids at or above `vocab_upper_bound` are dropped (with a warn-once log) — without this, - // `_decode_native_and_split` would panic on ids missing from both of `CoreBPE`'s - // internal decoder maps. The constructor registers every id in `[num_base_tokens, - // vocab_upper_bound)` as a special token (named or `<|reserved_token_{id}|>` - // placeholder, matching `tokenization_kimi.py`), so any in-range id is safe and this - // branch only fires for genuinely out-of-vocab ids — e.g. an engine sampling bug - // emitting an id above the model's stated vocab_size. + // 1. Ids at or above `vocab_upper_bound` are dropped (with a warn-once log) — + // without this, `_decode_native_and_split` would panic on ids missing from + // both of `CoreBPE`'s internal decoder maps. The constructor registers every + // id in `[num_base_tokens, vocab_upper_bound)` as a special token (named or + // `<|reserved_token_{id}|>` placeholder, matching `tokenization_kimi.py`), + // so any in-range id is safe and this branch only fires for genuinely + // out-of-vocab ids — e.g. an engine sampling bug emitting an id above the + // model's stated vocab_size. // - // 2. When `skip_special_tokens = true`, ids in `[num_base_tokens, vocab_upper_bound)` are - // dropped *unless* they were marked `"special": false` in `added_tokens_decoder`. This - // matches HuggingFace's tokenizer semantics: tool-call markers and `` / - // `` (which Kimi K2 / K2.5 declare as non-special) stay in the output, while - // BOS/EOS/header tokens and reserved-slot placeholders are stripped. + // 2. When `skip_special_tokens = true`, ids in `[num_base_tokens, + // vocab_upper_bound)` are dropped *unless* they were marked `"special": + // false` in `added_tokens_decoder`. This matches HuggingFace's tokenizer + // semantics: tool-call markers and `` / `` (which Kimi K2 / + // K2.5 declare as non-special) stay in the output, while BOS/EOS/header + // tokens and reserved-slot placeholders are stripped. // - // Lossy UTF-8 decoding (instead of strict `String::from_utf8`) is used so partial - // multi-byte sequences become `\u{FFFD}`, which `DecodeStream` relies on to detect - // incomplete characters during streaming. + // Lossy UTF-8 decoding (instead of strict `String::from_utf8`) is used so + // partial multi-byte sequences become `\u{FFFD}`, which `DecodeStream` + // relies on to detect incomplete characters during streaming. let safe_ids: Vec = token_ids .iter() .copied() @@ -310,11 +328,7 @@ impl Tokenizer for TiktokenTokenizer { || self.non_special_added_ids.contains(&id) }) .collect(); - let bytes: Vec = self - .inner - ._decode_native_and_split(safe_ids) - .flatten() - .collect(); + let bytes: Vec = self.inner._decode_native_and_split(safe_ids).flatten().collect(); Ok(String::from_utf8_lossy(&bytes).into_owned()) } @@ -323,9 +337,10 @@ impl Tokenizer for TiktokenTokenizer { return Some(token_id); } - // Fall back to ordinary encoding for regular vocabulary items. This deliberately avoids - // `encode_with_special_tokens`: older `tiktoken-rs` versions can panic if the input text - // merely *looks* like a special token but is not registered in `special_tokens_encoder`. + // Fall back to ordinary encoding for regular vocabulary items. This + // deliberately avoids `encode_with_special_tokens`: older `tiktoken-rs` + // versions can panic if the input text merely *looks* like a special + // token but is not registered in `special_tokens_encoder`. let ids = self.inner.encode_ordinary(token); if ids.len() == 1 { Some(ids[0]) } else { None } } @@ -339,9 +354,10 @@ impl Tokenizer for TiktokenTokenizer { /// Select the BPE regex pattern for a tiktoken model based on `config.json`. /// -/// Most tiktoken models use the `cl100k_base` regex. Kimi models ship a custom regex in their -/// Python tokenizer implementation; we mirror the explicit `model_type` switch used by Dynamo -/// instead of heuristically parsing Python source files. +/// Most tiktoken models use the `cl100k_base` regex. Kimi models ship a custom +/// regex in their Python tokenizer implementation; we mirror the explicit +/// `model_type` switch used by Dynamo instead of heuristically parsing Python +/// source files. fn detect_bpe_pattern(config: &TiktokenModelConfig) -> &'static str { let model_type = config.effective_model_type(); @@ -372,9 +388,10 @@ mod tests { }; } - /// Write a minimal `*.tiktoken` BPE file (one token per byte 0..=255) into `dir` and - /// return its path. The single-byte vocab is enough to exercise the multi-byte / streaming - /// UTF-8 paths without depending on any pretrained tokenizer asset. + /// Write a minimal `*.tiktoken` BPE file (one token per byte 0..=255) into + /// `dir` and return its path. The single-byte vocab is enough to + /// exercise the multi-byte / streaming UTF-8 paths without depending on + /// any pretrained tokenizer asset. fn write_synthetic_bpe_file(dir: &std::path::Path) -> PathBuf { let mut content = String::new(); for byte in 0u8..=255 { @@ -386,11 +403,12 @@ mod tests { path } - /// Write a synthetic `*.tiktoken` file whose base-vocab ranks are sparse/non-contiguous. + /// Write a synthetic `*.tiktoken` file whose base-vocab ranks are + /// sparse/non-contiguous. /// - /// This reproduces the important edge case for `num_base_tokens`: it must be derived from - /// `max_rank + 1`, not `encoder.len()`, otherwise high-rank base tokens get misclassified as - /// reserved/special ids. + /// This reproduces the important edge case for `num_base_tokens`: it must + /// be derived from `max_rank + 1`, not `encoder.len()`, otherwise + /// high-rank base tokens get misclassified as reserved/special ids. fn write_sparse_rank_bpe_file(dir: &std::path::Path) -> PathBuf { let mut content = String::new(); for byte in 0u8..=255 { @@ -406,8 +424,9 @@ mod tests { path } - /// Build a `TiktokenTokenizer` from the synthetic BPE file with no sibling config files, - /// so the constructor takes the `FALLBACK_NUM_RESERVED_SPECIAL_TOKENS` (256) path. + /// Build a `TiktokenTokenizer` from the synthetic BPE file with no sibling + /// config files, so the constructor takes the + /// `FALLBACK_NUM_RESERVED_SPECIAL_TOKENS` (256) path. fn tiktoken_backend() -> (TiktokenTokenizer, TempDir) { let dir = tempfile::tempdir().expect("create temp dir"); let path = write_synthetic_bpe_file(dir.path()); @@ -415,9 +434,10 @@ mod tests { (backend, dir) } - /// Verify that tiktoken decode uses lossy UTF-8 (producing `\u{FFFD}`) rather than - /// returning an error for incomplete multi-byte sequences. This is critical for streaming - /// decode — `DecodeStream` relies on `\u{FFFD}` to detect incomplete characters. + /// Verify that tiktoken decode uses lossy UTF-8 (producing `\u{FFFD}`) + /// rather than returning an error for incomplete multi-byte sequences. + /// This is critical for streaming decode — `DecodeStream` relies on + /// `\u{FFFD}` to detect incomplete characters. #[test] fn tiktoken_decode_incomplete_utf8_produces_replacement_char() { let (backend, _dir) = tiktoken_backend(); @@ -434,16 +454,18 @@ mod tests { } } - /// When `config.json` exposes a `vocab_size`, the reserved-token range must be sized to it - /// rather than to the 256-slot fallback. This is the general (non-Kimi-specific) path: any - /// tiktoken model whose own `config.json` says e.g. `vocab_size = 280` should populate + /// When `config.json` exposes a `vocab_size`, the reserved-token range must + /// be sized to it rather than to the 256-slot fallback. This is the + /// general (non-Kimi-specific) path: any tiktoken model whose own + /// `config.json` says e.g. `vocab_size = 280` should populate /// reserved slots for `[num_base_tokens, 280)` and nothing beyond. #[test] fn tiktoken_reserved_range_uses_vocab_size_from_config_json() { let dir = tempfile::tempdir().expect("create temp dir"); let bpe_path = write_synthetic_bpe_file(dir.path()); - // num_base_tokens = 256, vocab_size = 280 → reserved range = [256, 280) (24 slots, - // smaller than the 256 fallback so we can prove the config value is honoured). + // num_base_tokens = 256, vocab_size = 280 → reserved range = [256, 280) (24 + // slots, smaller than the 256 fallback so we can prove the config value + // is honoured). fs::write(dir.path().join("config.json"), r#"{"vocab_size": 280}"#) .expect("write config.json"); let backend = TiktokenTokenizer::new(&bpe_path).expect("load tiktoken backend"); @@ -457,9 +479,9 @@ mod tests { vec![in_range_id] ); - // Outside the configured range: not registered as a reserved slot — falls through to - // the warn-and-skip backstop. The point is that we *don't* over-populate beyond what - // the model actually exposes. + // Outside the configured range: not registered as a reserved slot — falls + // through to the warn-and-skip backstop. The point is that we *don't* + // over-populate beyond what the model actually exposes. let out_of_range_id: u32 = 290; let out_of_range_placeholder = format!("<|reserved_token_{out_of_range_id}|>"); assert_eq!(backend.decode(&[out_of_range_id], false).unwrap(), ""); @@ -469,9 +491,11 @@ mod tests { /// Sparse/non-contiguous BPE ranks must still count as base-vocab ids. /// /// Regression shape: - /// - base vocabulary contains ids 0..=255 and also a normal BPE token at id 1000 - /// - if `num_base_tokens` were computed as `encoder.len()` (257), id 1000 would be - /// misclassified as special/reserved and disappear under `skip_special_tokens = true` + /// - base vocabulary contains ids 0..=255 and also a normal BPE token at id + /// 1000 + /// - if `num_base_tokens` were computed as `encoder.len()` (257), id 1000 + /// would be misclassified as special/reserved and disappear under + /// `skip_special_tokens = true` #[test] fn tiktoken_sparse_base_ranks_are_not_misclassified_as_special() { let dir = tempfile::tempdir().expect("create temp dir"); @@ -491,12 +515,14 @@ mod tests { /// * keep regular BPE token text unchanged, /// * drop ids whose `added_tokens_decoder` entry says `"special": true`, /// * drop reserved-slot placeholder ids (which default to special), - /// * keep ids whose `added_tokens_decoder` entry says `"special": false` — this is how Kimi K2 - /// / K2.5 marks tool-call markers and `` / ``. + /// * keep ids whose `added_tokens_decoder` entry says `"special": false` — + /// this is how Kimi K2 / K2.5 marks tool-call markers and `` / + /// ``. /// - /// Synthetic backend has `num_base_tokens = 256`. We write a `tokenizer_config.json` that - /// names ids 257 (special) and 258 (non-special), and a `config.json` with `vocab_size` - /// covering both. Id 259 stays a default reserved placeholder (special). + /// Synthetic backend has `num_base_tokens = 256`. We write a + /// `tokenizer_config.json` that names ids 257 (special) and 258 + /// (non-special), and a `config.json` with `vocab_size` covering both. + /// Id 259 stays a default reserved placeholder (special). #[test] fn tiktoken_skip_special_tokens_filters_special_but_keeps_non_special_added_tokens() { let dir = tempfile::tempdir().expect("create temp dir"); @@ -515,7 +541,8 @@ mod tests { .expect("write config.json"); let backend = TiktokenTokenizer::new(&bpe_path).expect("load tiktoken backend"); - // Resolve the BPE ids for "Hi" so we can interleave them with special-token ids. + // Resolve the BPE ids for "Hi" so we can interleave them with special-token + // ids. let h = backend.encode("H", false).unwrap()[0]; let i = backend.encode("i", false).unwrap()[0]; @@ -532,13 +559,14 @@ mod tests { "H<|im_end|>i<|tool_call_begin|><|reserved_token_259|>" ); - // skip_special_tokens = true: special token (257) and reserved placeholder (259) are - // dropped; the non-special added token (258) survives. + // skip_special_tokens = true: special token (257) and reserved placeholder + // (259) are dropped; the non-special added token (258) survives. let stripped = backend.decode(&ids, true).unwrap(); assert_eq!(stripped, "Hi<|tool_call_begin|>"); } - /// `vocab_size` may live under `text_config` for composite (e.g. multimodal) configs. + /// `vocab_size` may live under `text_config` for composite (e.g. + /// multimodal) configs. #[test] fn tiktoken_reserved_range_reads_text_config_vocab_size() { let dir = tempfile::tempdir().expect("create temp dir"); @@ -617,20 +645,20 @@ mod tests { ); assert_eq!(added_tokens.get(&257).map(|t| t.special), Some(false)); assert_eq!( - added_tokens - .get(&258) - .map(|t| (t.content.as_str(), t.special)), + added_tokens.get(&258).map(|t| (t.content.as_str(), t.special)), Some(("", true)) ); } - /// Reserved token ids in `[num_base_tokens, num_base_tokens + 256)` must decode to their - /// placeholder name (matching `tokenization_kimi.py`'s `<|reserved_token_{i}|>` format), - /// even when the source `tokenizer_config.json` does not list them in `added_tokens_decoder`. + /// Reserved token ids in `[num_base_tokens, num_base_tokens + 256)` must + /// decode to their placeholder name (matching `tokenization_kimi.py`'s + /// `<|reserved_token_{i}|>` format), even when the source + /// `tokenizer_config.json` does not list them in `added_tokens_decoder`. /// - /// In our synthetic backend `num_base_tokens = 256` (256 single-byte BPE tokens), so the - /// reserved range is `[256, 512)`. Picking id 300 — well inside that range and absent - /// from any `added_tokens_decoder` — should round-trip both ways. + /// In our synthetic backend `num_base_tokens = 256` (256 single-byte BPE + /// tokens), so the reserved range is `[256, 512)`. Picking id 300 — + /// well inside that range and absent from any `added_tokens_decoder` — + /// should round-trip both ways. #[test] fn tiktoken_reserved_token_round_trip() { let (backend, _dir) = tiktoken_backend(); @@ -649,13 +677,15 @@ mod tests { assert_eq!(backend.token_to_id(&placeholder), Some(reserved_id)); } - /// Decoding a token id that is beyond even the reserved range must not panic — it falls - /// through to the warn-and-skip backstop instead of crashing the worker thread. + /// Decoding a token id that is beyond even the reserved range must not + /// panic — it falls through to the warn-and-skip backstop instead of + /// crashing the worker thread. #[test] fn tiktoken_decode_unknown_token_id_does_not_panic() { let (backend, _dir) = tiktoken_backend(); - // ID well above num_base_tokens (256) + reserved (256) = 512 — guaranteed unknown. + // ID well above num_base_tokens (256) + reserved (256) = 512 — guaranteed + // unknown. let unknown_id: u32 = 999_999; let result = backend.decode(&[unknown_id], false); assert_eq!(result.unwrap(), ""); @@ -667,8 +697,9 @@ mod tests { assert_eq!(result, "Hi"); } - /// Streaming decode of CJK text through tiktoken should produce the original text without - /// errors, even though individual tokens may represent partial UTF-8 byte sequences. + /// Streaming decode of CJK text through tiktoken should produce the + /// original text without errors, even though individual tokens may + /// represent partial UTF-8 byte sequences. #[test] fn tiktoken_streaming_decode_multibyte() { let (backend, _dir) = tiktoken_backend(); @@ -692,7 +723,8 @@ mod tests { assert_eq!(full_text, text); } - /// Mixed ASCII and multi-byte text should stream correctly through tiktoken. + /// Mixed ASCII and multi-byte text should stream correctly through + /// tiktoken. #[test] fn tiktoken_streaming_decode_mixed_ascii_and_multibyte() { let (backend, _dir) = tiktoken_backend(); From 8b00aac66b295f9f3503dfda403c22949dbc9ca0 Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Tue, 5 May 2026 19:52:36 +0800 Subject: [PATCH 2/2] minor Signed-off-by: Bugen Zhao --- rustfmt.unstable.toml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/rustfmt.unstable.toml b/rustfmt.unstable.toml index 8585f04f..a9514a8c 100644 --- a/rustfmt.unstable.toml +++ b/rustfmt.unstable.toml @@ -2,9 +2,13 @@ # Apply manually with: # cargo +nightly fmt -- --config-path rustfmt.unstable.toml -unstable_features = true style_edition = "2024" chain_width = 80 +use_field_init_shorthand = true + +# Unstable features go here. +unstable_features = true + format_code_in_doc_comments = true format_macro_matchers = true normalize_comments = true @@ -13,4 +17,3 @@ imports_granularity = "Module" group_imports = "StdExternalCrate" reorder_impl_items = true wrap_comments = true -use_field_init_shorthand = true