diff --git a/crates/grpc_client/src/mlx_engine.rs b/crates/grpc_client/src/mlx_engine.rs index fab8f44fe..b581f73b9 100644 --- a/crates/grpc_client/src/mlx_engine.rs +++ b/crates/grpc_client/src/mlx_engine.rs @@ -242,11 +242,8 @@ impl MlxEngineClient { // - response_format — same as constrained decoding // // Servicer limitations (fixable without mlx-lm changes): - // - TODO(mlx): String stop sequences — mlx-lm supports this via - // tokenizer.encode() → SequenceStateMachine. Fix by converting stop - // strings to token IDs in the preparation stage (which already has the - // Rust tokenizer) and passing them as stop_token_ids in the proto. - // + // - String stop sequences: supported in chat and completion pipelines. + // Messages and Generate pipelines still reject string stops (see reject_stop_strings). // Track upstream: https://github.com/ml-explore/mlx-lm fn reject_constraint(constraint: Option<&(String, String)>) -> Result<(), String> { @@ -309,7 +306,6 @@ impl MlxEngineClient { ) -> Result { Self::reject_constraint(constraint.as_ref())?; Self::reject_n(body.n)?; - Self::reject_stop_strings(body.stop.as_ref().is_some_and(|s| !s.is_empty()))?; Self::reject_response_format(body.response_format.is_some())?; let sampling_params = Self::build_sampling_params_from_chat(body); @@ -335,7 +331,6 @@ impl MlxEngineClient { token_ids: Vec, ) -> Result { Self::reject_n(body.n)?; - Self::reject_stop_strings(body.stop.as_ref().is_some_and(|s| !s.is_empty()))?; Self::reject_if_any_constraint( body.json_schema.as_ref(), body.regex.as_ref(), diff --git a/crates/protocols/src/completion.rs b/crates/protocols/src/completion.rs index 393596fe9..82f91020d 100644 --- a/crates/protocols/src/completion.rs +++ b/crates/protocols/src/completion.rs @@ -266,11 +266,13 @@ pub struct CompletionStreamResponse { pub usage: Option, } -#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)] +#[derive(Default, Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)] pub struct CompletionStreamChoice { pub text: String, pub index: u32, #[serde(skip_serializing_if = "Option::is_none")] pub logprobs: Option, pub finish_reason: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub matched_stop: Option, } diff --git a/crates/tokenizer/src/mock.rs b/crates/tokenizer/src/mock.rs index f057aabe2..7b0e2a0a0 100644 --- a/crates/tokenizer/src/mock.rs +++ b/crates/tokenizer/src/mock.rs @@ -11,6 +11,7 @@ pub struct MockTokenizer { vocab: HashMap, reverse_vocab: HashMap, special_tokens: SpecialTokens, + fail_encode: bool, } impl Default for MockTokenizer { @@ -62,19 +63,29 @@ impl MockTokenizer { vocab, reverse_vocab, special_tokens, + fail_encode: false, + } + } + + pub fn failing() -> Self { + Self { + fail_encode: true, + ..Self::new() } } } impl Encoder for MockTokenizer { fn encode(&self, input: &str, _add_special_tokens: bool) -> Result { + if self.fail_encode { + return Err(anyhow::anyhow!("test encode error")); + } // Simple word-based tokenization using the vocab // Split by whitespace and look up each word (decoder adds spaces back) let tokens: Vec = input .split_whitespace() .filter_map(|word| self.vocab.get(word).copied()) .collect(); - Ok(Encoding::Plain(tokens)) } diff --git a/model_gateway/src/routers/grpc/common/stages/helpers.rs b/model_gateway/src/routers/grpc/common/stages/helpers.rs index 65ef6dded..3cacf0ff6 100644 --- a/model_gateway/src/routers/grpc/common/stages/helpers.rs +++ b/model_gateway/src/routers/grpc/common/stages/helpers.rs @@ -2,6 +2,9 @@ use std::sync::Arc; +use axum::response::Response; +use llm_tokenizer::traits::Tokenizer; +use openai_protocol::common::StringOrArray; use rand::Rng; use smg_grpc_client::{ mlx_proto, @@ -11,9 +14,13 @@ use smg_grpc_client::{ use tracing::{debug, warn}; use crate::{ - routers::grpc::{ - context::{RequestType, WorkerSelection}, - proto_wrapper::ProtoGenerateRequest, + routers::{ + error, + grpc::{ + context::{RequestType, WorkerSelection}, + proto_wrapper::ProtoGenerateRequest, + utils::resolve_mlx_stop_ids, + }, }, worker::{ sampling_defaults::SamplingDefaults, RuntimeType, Worker, DEFAULT_BOOTSTRAP_PORT, @@ -263,3 +270,36 @@ fn inject_sglang_bootstrap_metadata( hostname, bootstrap_port, room_id ); } + +/// Convert string stop sequences to token IDs and append them to the MLX proto request. +/// +/// The MLX proto only supports stop_token_ids; string stop sequences from the +/// CompletionRequest must be tokenized here before the request is dispatched. +/// No-op if the request has no string stop sequences. +#[expect( + clippy::result_large_err, + reason = "Response is the standard error type in the pipeline stage pattern" +)] +pub(crate) fn apply_mlx_stop_sequences( + proto_request: &mut ProtoGenerateRequest, + stop: Option<&StringOrArray>, + tokenizer: Option<&dyn Tokenizer>, +) -> Result<(), Response> { + let Some(stop) = stop else { + return Ok(()); + }; + + if let ProtoGenerateRequest::Mlx(req) = proto_request { + let token_ids = resolve_mlx_stop_ids(stop, tokenizer)?; + let sampling = req.sampling_params.as_mut().ok_or_else(|| { + error::internal_error( + "mlx_sampling_params_missing", + "MLX GenerateRequest has no sampling_params; cannot inject stop IDs", + ) + })?; + sampling.stop_token_ids.extend(token_ids); + } + + // Non-MLX backends handle string stop sequences natively; no-op for them. + Ok(()) +} diff --git a/model_gateway/src/routers/grpc/proto_wrapper.rs b/model_gateway/src/routers/grpc/proto_wrapper.rs index 971ff388b..279a280a9 100644 --- a/model_gateway/src/routers/grpc/proto_wrapper.rs +++ b/model_gateway/src/routers/grpc/proto_wrapper.rs @@ -6,6 +6,8 @@ use std::collections::HashMap; use futures_util::StreamExt; +use llm_tokenizer::traits::Tokenizer; +use openai_protocol::common::StringOrArray; use smg_grpc_client::{ mlx_engine::AbortOnDropStream as MlxStream, mlx_proto::{self as mlx}, @@ -17,6 +19,8 @@ use smg_grpc_client::{ vllm_proto::{self as vllm, generate_complete::MatchedStop as VllmMatchedStop}, }; +use crate::routers::grpc::utils::resolve_mlx_matched_stop_json; + // ===================== // Multimodal Data // ===================== @@ -738,6 +742,14 @@ impl ProtoGenerateComplete { matches!(self, Self::Mlx(_)) } + /// Return the raw matched stop token ID for MLX responses; None for all other backends. + fn mlx_matched_stop_token_id(&self) -> Option { + match self { + Self::Mlx(c) => c.matched_stop_token_id, + _ => None, + } + } + /// Get token IDs from either backend (output_ids in proto) pub fn token_ids(&self) -> &[u32] { match self { @@ -795,6 +807,10 @@ impl ProtoGenerateComplete { /// - MatchedTokenId → Number /// - MatchedStopStr → String /// - None → None + #[expect( + clippy::unreachable, + reason = "MLX must use matched_stop_json_with_context" + )] pub fn matched_stop_json(&self) -> Option { macro_rules! convert { ($oneof:expr, $token_id:path, $stop_str:path) => { @@ -820,9 +836,31 @@ impl ProtoGenerateComplete { TrtllmMatchedStop::MatchedTokenId, TrtllmMatchedStop::MatchedStopStr ), - Self::Mlx(c) => c - .matched_stop_token_id - .map(|id| serde_json::Value::Number(id.into())), + // MLX requires request context to resolve the token ID; use matched_stop_json_with_context. + Self::Mlx(_) => unreachable!("matched_stop_json called for MLX backend"), + } + } + + /// Resolve the matched stop for any backend, using request context for MLX. + /// + /// MLX only stores a token ID; this maps it back to the user-facing string or integer + /// (see `chat_utils::resolve_mlx_matched_stop_json`). All other backends return + /// `matched_stop_json()` directly. + pub fn matched_stop_json_with_context( + &self, + stop: Option<&StringOrArray>, + stop_token_ids: Option<&Vec>, + tokenizer: &dyn Tokenizer, + ) -> Option { + if self.is_mlx() { + resolve_mlx_matched_stop_json( + self.mlx_matched_stop_token_id(), + stop, + stop_token_ids, + tokenizer, + ) + } else { + self.matched_stop_json() } } diff --git a/model_gateway/src/routers/grpc/regular/processor.rs b/model_gateway/src/routers/grpc/regular/processor.rs index 582b401ad..d6e504f90 100644 --- a/model_gateway/src/routers/grpc/regular/processor.rs +++ b/model_gateway/src/routers/grpc/regular/processor.rs @@ -184,7 +184,11 @@ impl ResponseProcessor { finish_reason_str }; - let matched_stop = complete.matched_stop_json(); + let matched_stop = complete.matched_stop_json_with_context( + original_request.stop.as_ref(), + original_request.stop_token_ids.as_ref(), + tokenizer.as_ref(), + ); // Step 4: Convert output logprobs if present let logprobs = complete.output_logprobs().map(|ref proto_logprobs| { @@ -760,7 +764,7 @@ impl ResponseProcessor { execution_result: ExecutionResult, completion_req: Arc, dispatch: DispatchMetadata, - _tokenizer: Arc, + tokenizer: Arc, stop_decoder: &mut StopSequenceDecoder, prompt_text: &str, ) -> Result { @@ -822,7 +826,11 @@ impl ResponseProcessor { } }; - let matched_stop = complete.matched_stop_json(); + let matched_stop = complete.matched_stop_json_with_context( + completion_req.stop.as_ref(), + completion_req.stop_token_ids.as_ref(), + tokenizer.as_ref(), + ); let suffix_len = completion_req.suffix.as_ref().map_or(0, |s| s.len()); let echo_len = if completion_req.echo { diff --git a/model_gateway/src/routers/grpc/regular/stages/chat/request_building.rs b/model_gateway/src/routers/grpc/regular/stages/chat/request_building.rs index d8be6b682..cac790142 100644 --- a/model_gateway/src/routers/grpc/regular/stages/chat/request_building.rs +++ b/model_gateway/src/routers/grpc/regular/stages/chat/request_building.rs @@ -114,6 +114,12 @@ impl PipelineStage for ChatRequestBuildingStage { } } + helpers::apply_mlx_stop_sequences( + &mut proto_request, + chat_request.stop.as_ref(), + ctx.state.tokenizer.as_deref(), + )?; + ctx.state.proto_request = Some(ProtoRequest::Generate(proto_request)); Ok(None) } diff --git a/model_gateway/src/routers/grpc/regular/stages/completion/request_building.rs b/model_gateway/src/routers/grpc/regular/stages/completion/request_building.rs index d5f7db249..f829d7b2e 100644 --- a/model_gateway/src/routers/grpc/regular/stages/completion/request_building.rs +++ b/model_gateway/src/routers/grpc/regular/stages/completion/request_building.rs @@ -93,7 +93,11 @@ impl PipelineStage for CompletionRequestBuildingStage { helpers::maybe_inject_pd_metadata(&mut proto_request, workers); } } - + helpers::apply_mlx_stop_sequences( + &mut proto_request, + completion_request.stop.as_ref(), + ctx.state.tokenizer.as_deref(), + )?; ctx.state.proto_request = Some(ProtoRequest::Generate(proto_request)); Ok(None) } diff --git a/model_gateway/src/routers/grpc/regular/streaming.rs b/model_gateway/src/routers/grpc/regular/streaming.rs index 031c93ab0..b9dbbab73 100644 --- a/model_gateway/src/routers/grpc/regular/streaming.rs +++ b/model_gateway/src/routers/grpc/regular/streaming.rs @@ -492,7 +492,14 @@ impl StreamingProcessor { cached_tokens.insert(index, complete.cached_tokens()); finish_reasons.insert(index, complete.finish_reason().to_string()); - matched_stops.insert(index, complete.matched_stop_json()); + matched_stops.insert( + index, + complete.matched_stop_json_with_context( + original_request.stop.as_ref(), + original_request.stop_token_ids.as_ref(), + tokenizer.as_ref(), + ), + ); // Don't break - continue reading all Complete messages for n>1 } @@ -2405,8 +2412,7 @@ impl StreamingProcessor { choices: vec![CompletionStreamChoice { text: std::mem::take(&mut chunk_text), index, - logprobs: None, - finish_reason: None, + ..Default::default() }], model: model.clone(), system_fingerprint: system_fingerprint.map(String::from), @@ -2432,8 +2438,7 @@ impl StreamingProcessor { choices: vec![CompletionStreamChoice { text: sfx.to_string(), index, - logprobs: None, - finish_reason: None, + ..Default::default() }], model: model.clone(), system_fingerprint: system_fingerprint.map(String::from), @@ -2449,10 +2454,9 @@ impl StreamingProcessor { object: "text_completion".to_string(), created, choices: vec![CompletionStreamChoice { - text: String::new(), index, - logprobs: None, finish_reason: Some("stop".to_string()), + ..Default::default() }], model: model.clone(), system_fingerprint: system_fingerprint.map(String::from), @@ -2485,8 +2489,7 @@ impl StreamingProcessor { choices: vec![CompletionStreamChoice { text: prompt_text.to_string(), index, - logprobs: None, - finish_reason: None, + ..Default::default() }], model: model.clone(), system_fingerprint: system_fingerprint.map(String::from), @@ -2508,8 +2511,7 @@ impl StreamingProcessor { choices: vec![CompletionStreamChoice { text, index, - logprobs: None, - finish_reason: None, + ..Default::default() }], model: model.clone(), system_fingerprint: system_fingerprint.map(String::from), @@ -2530,8 +2532,7 @@ impl StreamingProcessor { choices: vec![CompletionStreamChoice { text: sfx.to_string(), index, - logprobs: None, - finish_reason: None, + ..Default::default() }], model: model.clone(), system_fingerprint: system_fingerprint.map(String::from), @@ -2568,6 +2569,12 @@ impl StreamingProcessor { } }; + let matched_stop = complete.matched_stop_json_with_context( + completion_request.stop.as_ref(), + completion_request.stop_token_ids.as_ref(), + tokenizer.as_ref(), + ); + let final_chunk = CompletionStreamResponse { id: request_id.clone(), object: "text_completion".to_string(), @@ -2577,6 +2584,7 @@ impl StreamingProcessor { index, logprobs: None, finish_reason, + matched_stop, }], model: model.clone(), system_fingerprint: system_fingerprint.map(String::from), diff --git a/model_gateway/src/routers/grpc/utils/chat_utils.rs b/model_gateway/src/routers/grpc/utils/chat_utils.rs index b75b8abf6..504260ccb 100644 --- a/model_gateway/src/routers/grpc/utils/chat_utils.rs +++ b/model_gateway/src/routers/grpc/utils/chat_utils.rs @@ -419,6 +419,100 @@ pub fn create_stop_decoder( builder.build() } +/// Tokenizes stop strings into token IDs for the MLX backend. +/// +/// Returns `Err` for any string that cannot be honored as a stop condition: +/// - encodes to more than one token +/// - encodes to zero tokens (not in vocabulary) +/// - tokenizer returns an error +/// +/// The caller should surface all errors as HTTP 400. +pub(crate) fn stop_strings_to_token_ids<'a>( + stop: impl IntoIterator, + tokenizer: &dyn Tokenizer, +) -> Result, String> { + let mut ids = Vec::new(); + for s in stop { + match tokenizer.encode(s, false) { + Ok(enc) => match enc.token_ids() { + [id] => ids.push(*id), + tokens if !tokens.is_empty() => { + return Err(format!( + "stop string {s:?} encodes to {} tokens; \ + MLX backend only supports single-token stop strings", + tokens.len() + )); + } + _ => { + return Err(format!( + "stop string {s:?} produced no tokens; \ + it may not be present in the model vocabulary" + )); + } + }, + Err(e) => { + return Err(format!("failed to tokenize stop string {s:?}: {e}")); + } + } + } + Ok(ids) +} + +/// Resolve the `matched_stop` JSON value for an MLX response. +/// +/// MLX only returns a token ID; this reverses the mapping back to the user-facing form: +/// - If the token ID was tokenized from a user stop string → return the string. +/// - If the token ID was an explicit user stop_token_id → return the integer. +/// - Otherwise (EOS or other internal stop) → return None. +pub(crate) fn resolve_mlx_matched_stop_json( + matched_token_id: Option, + stop: Option<&StringOrArray>, + stop_token_ids: Option<&Vec>, + tokenizer: &dyn Tokenizer, +) -> Option { + let id = matched_token_id?; + + // Check stop strings first: find the string that tokenizes to this single token. + if let Some(stop_strings) = stop { + for s in stop_strings.iter() { + if let Ok(enc) = tokenizer.encode(s, false) { + if enc.token_ids() == [id] { + return Some(Value::String(s.to_string())); + } + } + } + } + + // Check explicit stop_token_ids provided by the user. + if stop_token_ids.is_some_and(|ids| ids.contains(&id)) { + return Some(Value::Number(id.into())); + } + + // EOS or other internal stop condition — don't surface to the caller. + None +} + +/// For MLX: tokenize string stop sequences and merge with existing token IDs. +/// Returns an HTTP error response if the tokenizer is missing or a stop string encodes +/// to more than one token (propagate with `?` from a pipeline stage). +#[expect( + clippy::result_large_err, + reason = "Response is the standard error type in the pipeline stage pattern" +)] +pub(crate) fn resolve_mlx_stop_ids( + stop_strings: &StringOrArray, + tokenizer: Option<&dyn Tokenizer>, +) -> Result, Response> { + let tok = tokenizer.ok_or_else(|| { + error::bad_request( + "tokenizer_unavailable", + "MLX backend requires a tokenizer to convert string stop sequences", + ) + })?; + stop_strings_to_token_ids(stop_strings.iter(), tok) + .map_err(|e| error::bad_request("unsupported_stop_string", e)) +} + /// Parse tool calls from JSON schema constrained response pub(crate) fn parse_json_schema_response( processed_text: &str, @@ -582,7 +676,8 @@ pub(crate) fn parse_finish_reason( #[cfg(test)] mod tests { - use llm_tokenizer::chat_template::ChatTemplateContentFormat; + use axum::http::StatusCode; + use llm_tokenizer::{chat_template::ChatTemplateContentFormat, MockTokenizer}; use openai_protocol::{ chat::{ChatMessage, MessageContent}, common::{ContentPart, ImageUrl}, @@ -591,6 +686,15 @@ mod tests { use super::*; + type StopTokenCase<'a> = (&'a [&'a str], Option<&'a [u32]>, &'a str); + type MatchedStopCase<'a> = ( + Option, + Option<&'a str>, + &'a [u32], + Option, + &'a str, + ); + #[test] fn test_transform_messages_string_format() { let messages = vec![ChatMessage::User { @@ -780,4 +884,103 @@ mod tests { assert_eq!(content_array[0]["type"], "text"); assert_eq!(content_array[1], json!({"type": "image"})); } + #[test] + fn test_stop_strings_to_token_ids() { + // MockTokenizer vocab: "Hello"→1, "world"→2, "test"→3, "<|im_end|>"→1002. + // expected = None means the call should return Err. + let cases: &[StopTokenCase<'_>] = &[ + (&["Hello"], Some(&[1]), "single token regular"), + (&["world"], Some(&[2]), "single token another regular"), + (&["<|im_end|>"], Some(&[1002]), "single token special"), + (&["Hello world"], None, "multi token returns err"), + (&["zzzunknown"], None, "unknown vocab returns err"), + ( + &["Hello", "Hello world"], + None, + "array with multi token err", + ), + (&["Hello", "test"], Some(&[1, 3]), "array all single token"), + (&[], Some(&[]), "empty array"), + ]; + let tok = MockTokenizer::new(); + for &(inputs, expected, name) in cases { + let result = stop_strings_to_token_ids(inputs.iter().copied(), &tok); + match expected { + Some(ids) => assert_eq!(result.unwrap(), ids, "{name}"), + None => assert!(result.is_err(), "{name}"), + } + } + } + + #[test] + fn test_stop_encode_error_returns_err() { + let tok = MockTokenizer::failing(); + let result = stop_strings_to_token_ids(["Hello", "test"].iter().copied(), &tok); + assert!(result.is_err()); + } + + #[test] + fn test_resolve_mlx_stop_ids_zero_token_is_400() { + let tok = MockTokenizer::new(); + let stop = StringOrArray::String("zzzunknown".to_string()); + let resp = resolve_mlx_stop_ids(&stop, Some(&tok)).unwrap_err(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + } + + #[test] + fn test_resolve_mlx_stop_ids_tokenizer_error_is_400() { + let tok = MockTokenizer::failing(); + let stop = StringOrArray::String("Hello".to_string()); + let resp = resolve_mlx_stop_ids(&stop, Some(&tok)).unwrap_err(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + } + + #[test] + fn test_resolve_mlx_stop_ids_missing_tokenizer_is_400() { + let stop = StringOrArray::String("Hello".to_string()); + let resp = resolve_mlx_stop_ids(&stop, None).unwrap_err(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + } + + #[test] + fn test_resolve_mlx_matched_stop() { + // MockTokenizer vocab: "Hello"→1, "world"→2, "test"→3, "<|im_end|>"→1002. + // stop_ids=&[] is treated as None (no user stop_token_ids supplied). + let cases: &[MatchedStopCase<'_>] = &[ + (None, None, &[], None, "no id returns none"), + ( + Some(1), + Some("Hello"), + &[], + Some(Value::String("Hello".to_string())), + "string match", + ), + ( + Some(42), + None, + &[42], + Some(Value::Number(42u32.into())), + "token id match", + ), + ( + Some(1), + Some("Hello"), + &[1], + Some(Value::String("Hello".to_string())), + "string wins over token id", + ), + (Some(999), None, &[], None, "eos returns none"), + ]; + let tok = MockTokenizer::new(); + for (id, stop_str, stop_ids, expected, name) in cases { + let stop = stop_str.map(|s| StringOrArray::String(s.to_string())); + let ids: Vec = stop_ids.to_vec(); + let ids_opt = if ids.is_empty() { None } else { Some(&ids) }; + assert_eq!( + resolve_mlx_matched_stop_json(*id, stop.as_ref(), ids_opt, &tok), + *expected, + "{name}", + ); + } + } } diff --git a/model_gateway/src/routers/grpc/utils/mod.rs b/model_gateway/src/routers/grpc/utils/mod.rs index 2415494ae..a4cc89b7f 100644 --- a/model_gateway/src/routers/grpc/utils/mod.rs +++ b/model_gateway/src/routers/grpc/utils/mod.rs @@ -12,7 +12,7 @@ pub use chat_utils::{create_stop_decoder, process_chat_messages}; pub(crate) use chat_utils::{ filter_chat_request_by_tool_choice, filter_tools_by_tool_choice, generate_tool_call_id, get_history_tool_calls_count, parse_finish_reason, parse_json_schema_response, - resolve_tokenizer, send_error_sse, + resolve_mlx_matched_stop_json, resolve_mlx_stop_ids, resolve_tokenizer, send_error_sse, }; pub(crate) use logprobs::{ convert_generate_input_logprobs, convert_generate_output_logprobs, convert_proto_logprobs,