Skip to content
9 changes: 2 additions & 7 deletions crates/grpc_client/src/mlx_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,8 @@ impl MlxEngineClient {
// - response_format — same as constrained decoding
//
// Servicer limitations (fixable without mlx-lm changes):
// - TODO(mlx): String stop sequences — mlx-lm supports this via
// tokenizer.encode() → SequenceStateMachine. Fix by converting stop
// strings to token IDs in the preparation stage (which already has the
// Rust tokenizer) and passing them as stop_token_ids in the proto.
//
// - String stop sequences: supported in chat and completion pipelines.
// Messages and Generate pipelines still reject string stops (see reject_stop_strings).
Comment on lines +245 to +246
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Update the comment to mention all three pipelines that still reject string stops.

The comment states that "Messages and Generate pipelines still reject string stops," but the Responses pipeline (line 422) also retains the reject_stop_strings check. For completeness, the comment should list all three.

📝 Proposed fix to make the documentation complete
-    //   - String stop sequences: supported in chat and completion pipelines.
-    //     Messages and Generate pipelines still reject string stops (see reject_stop_strings).
+    //   - String stop sequences: supported in chat and completion pipelines.
+    //     Messages, Generate, and Responses pipelines still reject string stops (see reject_stop_strings).
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// - String stop sequences: supported in chat and completion pipelines.
// Messages and Generate pipelines still reject string stops (see reject_stop_strings).
// - String stop sequences: supported in chat and completion pipelines.
// Messages, Generate, and Responses pipelines still reject string stops (see reject_stop_strings).
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@crates/grpc_client/src/mlx_engine.rs` around lines 245 - 246, Update the
inline comment that currently says "Messages and Generate pipelines still reject
string stops" to list all three pipelines that reject string stops (Messages,
Generate, and Responses). Locate the comment near the top of mlx_engine.rs (the
block describing stop-sequence support) and amend the sentence to explicitly
include "Responses" alongside "Messages" and "Generate", and ensure it
references the existing reject_stop_strings check used in the Responses
handling.

// Track upstream: https://github.com/ml-explore/mlx-lm

fn reject_constraint(constraint: Option<&(String, String)>) -> Result<(), String> {
Expand Down Expand Up @@ -309,7 +306,6 @@ impl MlxEngineClient {
) -> Result<proto::GenerateRequest, String> {
Self::reject_constraint(constraint.as_ref())?;
Self::reject_n(body.n)?;
Self::reject_stop_strings(body.stop.as_ref().is_some_and(|s| !s.is_empty()))?;
Self::reject_response_format(body.response_format.is_some())?;

let sampling_params = Self::build_sampling_params_from_chat(body);
Comment thread
zach-li-sudo marked this conversation as resolved.
Expand All @@ -335,7 +331,6 @@ impl MlxEngineClient {
token_ids: Vec<u32>,
) -> Result<proto::GenerateRequest, String> {
Self::reject_n(body.n)?;
Self::reject_stop_strings(body.stop.as_ref().is_some_and(|s| !s.is_empty()))?;
Self::reject_if_any_constraint(
body.json_schema.as_ref(),
body.regex.as_ref(),
Expand Down
4 changes: 3 additions & 1 deletion crates/protocols/src/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,11 +266,13 @@ pub struct CompletionStreamResponse {
pub usage: Option<Usage>,
}

#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
#[derive(Default, Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
pub struct CompletionStreamChoice {
pub text: String,
pub index: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<LogProbs>,
pub finish_reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub matched_stop: Option<Value>,
}
13 changes: 12 additions & 1 deletion crates/tokenizer/src/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub struct MockTokenizer {
vocab: HashMap<String, u32>,
reverse_vocab: HashMap<u32, String>,
special_tokens: SpecialTokens,
fail_encode: bool,
}

impl Default for MockTokenizer {
Expand Down Expand Up @@ -62,19 +63,29 @@ impl MockTokenizer {
vocab,
reverse_vocab,
special_tokens,
fail_encode: false,
}
}

pub fn failing() -> Self {
Self {
fail_encode: true,
..Self::new()
}
}
}

impl Encoder for MockTokenizer {
fn encode(&self, input: &str, _add_special_tokens: bool) -> Result<Encoding> {
if self.fail_encode {
return Err(anyhow::anyhow!("test encode error"));
}
// Simple word-based tokenization using the vocab
// Split by whitespace and look up each word (decoder adds spaces back)
let tokens: Vec<u32> = input
.split_whitespace()
.filter_map(|word| self.vocab.get(word).copied())
.collect();

Ok(Encoding::Plain(tokens))
}

Expand Down
46 changes: 43 additions & 3 deletions model_gateway/src/routers/grpc/common/stages/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

use std::sync::Arc;

use axum::response::Response;
use llm_tokenizer::traits::Tokenizer;
use openai_protocol::common::StringOrArray;
use rand::Rng;
use smg_grpc_client::{
mlx_proto,
Expand All @@ -11,9 +14,13 @@ use smg_grpc_client::{
use tracing::{debug, warn};

use crate::{
routers::grpc::{
context::{RequestType, WorkerSelection},
proto_wrapper::ProtoGenerateRequest,
routers::{
error,
grpc::{
context::{RequestType, WorkerSelection},
proto_wrapper::ProtoGenerateRequest,
utils::resolve_mlx_stop_ids,
},
},
worker::{
sampling_defaults::SamplingDefaults, RuntimeType, Worker, DEFAULT_BOOTSTRAP_PORT,
Expand Down Expand Up @@ -263,3 +270,36 @@ fn inject_sglang_bootstrap_metadata(
hostname, bootstrap_port, room_id
);
}

/// Convert string stop sequences to token IDs and append them to the MLX proto request.
///
/// The MLX proto only supports stop_token_ids; string stop sequences from the
/// CompletionRequest must be tokenized here before the request is dispatched.
/// No-op if the request has no string stop sequences.
#[expect(
clippy::result_large_err,
reason = "Response is the standard error type in the pipeline stage pattern"
)]
pub(crate) fn apply_mlx_stop_sequences(
proto_request: &mut ProtoGenerateRequest,
stop: Option<&StringOrArray>,
tokenizer: Option<&dyn Tokenizer>,
) -> Result<(), Response> {
let Some(stop) = stop else {
return Ok(());
};

if let ProtoGenerateRequest::Mlx(req) = proto_request {
let token_ids = resolve_mlx_stop_ids(stop, tokenizer)?;
let sampling = req.sampling_params.as_mut().ok_or_else(|| {
error::internal_error(
"mlx_sampling_params_missing",
"MLX GenerateRequest has no sampling_params; cannot inject stop IDs",
)
})?;
sampling.stop_token_ids.extend(token_ids);
}

// Non-MLX backends handle string stop sequences natively; no-op for them.
Ok(())
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}
44 changes: 41 additions & 3 deletions model_gateway/src/routers/grpc/proto_wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
use std::collections::HashMap;

use futures_util::StreamExt;
use llm_tokenizer::traits::Tokenizer;
use openai_protocol::common::StringOrArray;
use smg_grpc_client::{
mlx_engine::AbortOnDropStream as MlxStream,
mlx_proto::{self as mlx},
Expand All @@ -17,6 +19,8 @@ use smg_grpc_client::{
vllm_proto::{self as vllm, generate_complete::MatchedStop as VllmMatchedStop},
};

use crate::routers::grpc::utils::resolve_mlx_matched_stop_json;

// =====================
// Multimodal Data
// =====================
Expand Down Expand Up @@ -738,6 +742,14 @@ impl ProtoGenerateComplete {
matches!(self, Self::Mlx(_))
}

/// Return the raw matched stop token ID for MLX responses; None for all other backends.
fn mlx_matched_stop_token_id(&self) -> Option<u32> {
match self {
Self::Mlx(c) => c.matched_stop_token_id,
_ => None,
}
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

/// Get token IDs from either backend (output_ids in proto)
pub fn token_ids(&self) -> &[u32] {
match self {
Expand Down Expand Up @@ -795,6 +807,10 @@ impl ProtoGenerateComplete {
/// - MatchedTokenId → Number
/// - MatchedStopStr → String
/// - None → None
#[expect(
clippy::unreachable,
reason = "MLX must use matched_stop_json_with_context"
)]
pub fn matched_stop_json(&self) -> Option<serde_json::Value> {
macro_rules! convert {
($oneof:expr, $token_id:path, $stop_str:path) => {
Expand All @@ -820,9 +836,31 @@ impl ProtoGenerateComplete {
TrtllmMatchedStop::MatchedTokenId,
TrtllmMatchedStop::MatchedStopStr
),
Self::Mlx(c) => c
.matched_stop_token_id
.map(|id| serde_json::Value::Number(id.into())),
// MLX requires request context to resolve the token ID; use matched_stop_json_with_context.
Self::Mlx(_) => unreachable!("matched_stop_json called for MLX backend"),
Comment thread
zach-li-sudo marked this conversation as resolved.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep MLX matched-stop lookup non-panicking for Harmony

Changing matched_stop_json() to unreachable!() for Self::Mlx now makes active Harmony MLX paths crash at runtime, because those flows still call the old method (model_gateway/src/routers/grpc/harmony/processor.rs:62 and .../harmony/streaming.rs:299) instead of matched_stop_json_with_context. Any Harmony request routed to MLX that reaches a Complete frame will panic rather than returning a response, so this should remain non-panicking until all Harmony call sites are migrated.

Useful? React with 👍 / 👎.

}
}

/// Resolve the matched stop for any backend, using request context for MLX.
///
/// MLX only stores a token ID; this maps it back to the user-facing string or integer
/// (see `chat_utils::resolve_mlx_matched_stop_json`). All other backends return
/// `matched_stop_json()` directly.
pub fn matched_stop_json_with_context(
&self,
stop: Option<&StringOrArray>,
stop_token_ids: Option<&Vec<u32>>,
tokenizer: &dyn Tokenizer,
) -> Option<serde_json::Value> {
if self.is_mlx() {
resolve_mlx_matched_stop_json(
self.mlx_matched_stop_token_id(),
stop,
stop_token_ids,
tokenizer,
)
} else {
self.matched_stop_json()
}
}

Expand Down
14 changes: 11 additions & 3 deletions model_gateway/src/routers/grpc/regular/processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,11 @@ impl ResponseProcessor {
finish_reason_str
};

let matched_stop = complete.matched_stop_json();
let matched_stop = complete.matched_stop_json_with_context(
original_request.stop.as_ref(),
original_request.stop_token_ids.as_ref(),
tokenizer.as_ref(),
);

// Step 4: Convert output logprobs if present
let logprobs = complete.output_logprobs().map(|ref proto_logprobs| {
Expand Down Expand Up @@ -760,7 +764,7 @@ impl ResponseProcessor {
execution_result: ExecutionResult,
completion_req: Arc<CompletionRequest>,
dispatch: DispatchMetadata,
_tokenizer: Arc<dyn Tokenizer>,
tokenizer: Arc<dyn Tokenizer>,
stop_decoder: &mut StopSequenceDecoder,
prompt_text: &str,
) -> Result<CompletionResponse, axum::response::Response> {
Expand Down Expand Up @@ -822,7 +826,11 @@ impl ResponseProcessor {
}
};

let matched_stop = complete.matched_stop_json();
let matched_stop = complete.matched_stop_json_with_context(
completion_req.stop.as_ref(),
completion_req.stop_token_ids.as_ref(),
tokenizer.as_ref(),
);

let suffix_len = completion_req.suffix.as_ref().map_or(0, |s| s.len());
let echo_len = if completion_req.echo {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ impl PipelineStage for ChatRequestBuildingStage {
}
}

helpers::apply_mlx_stop_sequences(
&mut proto_request,
chat_request.stop.as_ref(),
ctx.state.tokenizer.as_deref(),
)?;

ctx.state.proto_request = Some(ProtoRequest::Generate(proto_request));
Ok(None)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,11 @@ impl PipelineStage for CompletionRequestBuildingStage {
helpers::maybe_inject_pd_metadata(&mut proto_request, workers);
}
}

helpers::apply_mlx_stop_sequences(
&mut proto_request,
completion_request.stop.as_ref(),
ctx.state.tokenizer.as_deref(),
)?;
ctx.state.proto_request = Some(ProtoRequest::Generate(proto_request));
Ok(None)
}
Expand Down
34 changes: 21 additions & 13 deletions model_gateway/src/routers/grpc/regular/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,14 @@ impl StreamingProcessor {
cached_tokens.insert(index, complete.cached_tokens());
finish_reasons.insert(index, complete.finish_reason().to_string());

matched_stops.insert(index, complete.matched_stop_json());
matched_stops.insert(
index,
complete.matched_stop_json_with_context(
original_request.stop.as_ref(),
original_request.stop_token_ids.as_ref(),
tokenizer.as_ref(),
),
);

// Don't break - continue reading all Complete messages for n>1
}
Expand Down Expand Up @@ -2405,8 +2412,7 @@ impl StreamingProcessor {
choices: vec![CompletionStreamChoice {
text: std::mem::take(&mut chunk_text),
index,
logprobs: None,
finish_reason: None,
..Default::default()
}],
model: model.clone(),
system_fingerprint: system_fingerprint.map(String::from),
Expand All @@ -2432,8 +2438,7 @@ impl StreamingProcessor {
choices: vec![CompletionStreamChoice {
text: sfx.to_string(),
index,
logprobs: None,
finish_reason: None,
..Default::default()
}],
model: model.clone(),
system_fingerprint: system_fingerprint.map(String::from),
Expand All @@ -2449,10 +2454,9 @@ impl StreamingProcessor {
object: "text_completion".to_string(),
created,
choices: vec![CompletionStreamChoice {
text: String::new(),
index,
logprobs: None,
finish_reason: Some("stop".to_string()),
..Default::default()
}],
Comment thread
zach-li-sudo marked this conversation as resolved.
model: model.clone(),
system_fingerprint: system_fingerprint.map(String::from),
Expand Down Expand Up @@ -2485,8 +2489,7 @@ impl StreamingProcessor {
choices: vec![CompletionStreamChoice {
text: prompt_text.to_string(),
index,
logprobs: None,
finish_reason: None,
..Default::default()
}],
model: model.clone(),
system_fingerprint: system_fingerprint.map(String::from),
Expand All @@ -2508,8 +2511,7 @@ impl StreamingProcessor {
choices: vec![CompletionStreamChoice {
text,
index,
logprobs: None,
finish_reason: None,
..Default::default()
}],
model: model.clone(),
system_fingerprint: system_fingerprint.map(String::from),
Expand All @@ -2530,8 +2532,7 @@ impl StreamingProcessor {
choices: vec![CompletionStreamChoice {
text: sfx.to_string(),
index,
logprobs: None,
finish_reason: None,
..Default::default()
}],
model: model.clone(),
system_fingerprint: system_fingerprint.map(String::from),
Expand Down Expand Up @@ -2568,6 +2569,12 @@ impl StreamingProcessor {
}
};

let matched_stop = complete.matched_stop_json_with_context(
completion_request.stop.as_ref(),
completion_request.stop_token_ids.as_ref(),
tokenizer.as_ref(),
);

let final_chunk = CompletionStreamResponse {
id: request_id.clone(),
object: "text_completion".to_string(),
Expand All @@ -2577,6 +2584,7 @@ impl StreamingProcessor {
index,
logprobs: None,
finish_reason,
matched_stop,
}],
model: model.clone(),
system_fingerprint: system_fingerprint.map(String::from),
Expand Down
Loading