Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/cmd/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,17 @@ pub struct SharedRuntimeArgs {
#[serde(default)]
pub disable_log_stats: bool,

/// The model name(s) used in the API. If multiple names are provided, the
/// server will respond to any of the provided names. The model name in the
/// model field of a response will be the first name in this list. If not
/// specified, the model name will be the same as the `--model` argument.
/// Noted that this name(s) will also be used in `model_name` tag
/// content of prometheus metrics, if multiple names provided, metrics
/// tag will take the first one.
#[arg(long, num_args = 0..)]
Comment thread
ericcurtin marked this conversation as resolved.
#[serde(default)]
pub served_model_name: Vec<String>,

/// Unsupported Python vLLM frontend arguments recognized but not yet
/// implemented in Rust.
#[educe(Debug(ignore))]
Expand Down Expand Up @@ -215,6 +226,7 @@ impl SharedRuntimeArgs {
None => CoordinatorMode::None,
},
model: self.model,
served_model_name: self.served_model_name,
listener_mode: HttpListenerMode::InheritedFd { fd: listen_fd },
tool_call_parser: self.tool_call_parser,
reasoning_parser: self.reasoning_parser,
Expand Down Expand Up @@ -254,6 +266,7 @@ impl SharedRuntimeArgs {
},
coordinator_mode: CoordinatorMode::MaybeInProc,
model: self.model,
served_model_name: self.served_model_name,
listener_mode,
tool_call_parser: self.tool_call_parser,
reasoning_parser: self.reasoning_parser,
Expand Down
6 changes: 6 additions & 0 deletions src/cmd/src/cli/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ fn serve_args_forward_python_flags_with_separator() {
chat_template_content_format: Auto,
enable_log_requests: false,
disable_log_stats: false,
served_model_name: [],
},
python_args: [
"--dtype",
Expand Down Expand Up @@ -210,6 +211,7 @@ fn frontend_args_accept_json() {
chat_template_content_format: Auto,
enable_log_requests: false,
disable_log_stats: false,
served_model_name: [],
},
},
),
Expand Down Expand Up @@ -596,6 +598,7 @@ fn serve_args_accept_handshake_aliases() {
chat_template_content_format: Auto,
enable_log_requests: false,
disable_log_stats: false,
served_model_name: [],
},
python_args: [],
},
Expand Down Expand Up @@ -690,6 +693,7 @@ fn serve_frontend_config_uses_dp_address_as_advertised_host() {
},
coordinator_mode: MaybeInProc,
model: "Qwen/Qwen3-0.6B",
served_model_name: [],
listener_mode: BindTcp {
host: "127.0.0.1",
port: 8000,
Expand Down Expand Up @@ -751,6 +755,7 @@ fn serve_frontend_config_keeps_tcp_transport_for_non_local_only_topology() {
},
coordinator_mode: MaybeInProc,
model: "Qwen/Qwen3-0.6B",
served_model_name: [],
listener_mode: BindTcp {
host: "127.0.0.1",
port: 8000,
Expand Down Expand Up @@ -828,6 +833,7 @@ fn frontend_config_uses_external_coordinator_when_coordinator_address_is_present
address: "tcp://127.0.0.1:7000",
},
model: "Qwen/Qwen3-0.6B",
served_model_name: [],
listener_mode: InheritedFd {
fd: 3,
},
Expand Down
10 changes: 0 additions & 10 deletions src/cmd/src/cli/unsupported.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,16 +234,6 @@ pub struct EngineUnsupportedArgs {
)]
pub enable_prompt_embeds: Option<Unsupported>,

/// The model name(s) used in the API. If multiple names are provided, the
/// server will respond to any of the provided names. The model name in the
/// model field of a response will be the first name in this list. If not
/// specified, the model name will be the same as the `--model` argument.
/// Noted that this name(s) will also be used in `model_name` tag
/// content of prometheus metrics, if multiple names provided, metrics
/// tag will take the first one.
#[arg(long)]
pub served_model_name: Option<Unsupported>,

/// The token to use as HTTP bearer authorization for remote files. If
/// `True`, will use the token generated when running `hf auth login`
/// (stored in `~/.cache/huggingface/token`).
Expand Down
1 change: 1 addition & 0 deletions src/server/examples/external_engine_openai_qwen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ async fn main() -> Result<()> {
},
coordinator_mode: CoordinatorMode::MaybeInProc,
model: args.model,
served_model_name: vec![],
listener_mode: HttpListenerMode::BindTcp {
host: "127.0.0.1".to_string(),
port,
Expand Down
6 changes: 5 additions & 1 deletion src/server/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,12 @@ pub struct Config {
pub transport_mode: TransportMode,
/// Requested frontend-side coordinator behavior.
pub coordinator_mode: CoordinatorMode,
/// Backend model identifier and exposed OpenAI model ID.
/// Backend model identifier used for engine-core loading.
pub model: String,
/// Model name(s) exposed to clients via the OpenAI API. When non-empty,
/// the first entry is used as the primary ID in responses and all entries
/// are accepted in requests. When empty, falls back to `model`.
pub served_model_name: Vec<String>,
/// HTTP listener setup.
pub listener_mode: HttpListenerMode,
/// Tool-call parser selection.
Expand Down
23 changes: 12 additions & 11 deletions src/server/src/grpc/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ use super::pb;

/// Convert a gRPC `GenerateRequest` into the internal `TextRequest`.
///
/// If `req.model` is non-empty, it must match `configured_model`; otherwise the
/// request is rejected with `NotFound`. An empty string is treated as "unset"
/// (proto3 default) and accepted.
/// If `req.model` is non-empty, it must match one of `served_model_names`;
/// otherwise the request is rejected with `NotFound`. An empty string is
/// treated as "unset" (proto3 default) and accepted.
pub fn to_text_request(
req: pb::GenerateRequest,
stream: bool,
configured_model: &str,
served_model_names: &[String],
) -> Result<TextRequest, Status> {
if !req.model.is_empty() && req.model != configured_model {
if !req.model.is_empty() && !served_model_names.iter().any(|n| n == &req.model) {
return Err(Status::not_found(format!(
"model `{}` not found",
req.model
Expand Down Expand Up @@ -521,13 +521,14 @@ mod tests {
temperature: Some(0.7),
..base_request()
};
let text = to_text_request(req, false, "test-model").expect("convert ok");
let text = to_text_request(req, false, &["test-model".to_string()]).expect("convert ok");
assert_eq!(text.sampling_params.temperature, Some(0.7));
}

#[test]
fn unset_temperature_defaults_to_greedy() {
let text = to_text_request(base_request(), false, "test-model").expect("convert ok");
let text = to_text_request(base_request(), false, &["test-model".to_string()])
.expect("convert ok");
// The gRPC API defaults to greedy (0.0) when temperature is not specified.
assert_eq!(text.sampling_params.temperature, Some(0.0));
}
Expand All @@ -541,7 +542,7 @@ mod tests {
}),
..base_request()
};
let text = to_text_request(req, false, "test-model").expect("convert ok");
let text = to_text_request(req, false, &["test-model".to_string()]).expect("convert ok");
assert_eq!(text.sampling_params.seed, None);
}

Expand All @@ -554,7 +555,7 @@ mod tests {
}),
..base_request()
};
let text = to_text_request(req, false, "test-model").expect("convert ok");
let text = to_text_request(req, false, &["test-model".to_string()]).expect("convert ok");
assert_eq!(text.sampling_params.seed, Some(0));
}

Expand All @@ -567,7 +568,7 @@ mod tests {
}),
..base_request()
};
let text = to_text_request(req, false, "test-model").expect("convert ok");
let text = to_text_request(req, false, &["test-model".to_string()]).expect("convert ok");
assert_eq!(text.sampling_params.skip_reading_prefix_cache, Some(true));
}

Expand All @@ -580,7 +581,7 @@ mod tests {
}),
..base_request()
};
let text = to_text_request(req, false, "test-model").expect("convert ok");
let text = to_text_request(req, false, &["test-model".to_string()]).expect("convert ok");
assert_eq!(text.sampling_params.skip_reading_prefix_cache, None);
// Prompt conversion still succeeds and reaches the expected variant.
assert!(matches!(text.prompt, Prompt::Text(s) if s == "hi"));
Expand Down
6 changes: 4 additions & 2 deletions src/server/src/grpc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ impl pb::generate_server::Generate for GenerateServiceImpl {
) -> Result<Response<pb::GenerateResponse>, Status> {
let proto_req = request.into_inner();
let response_opts = ResponseOpts::from_proto(proto_req.response.as_ref());
let text_request = convert::to_text_request(proto_req, false, &self.state.model_id)?;
let text_request =
convert::to_text_request(proto_req, false, self.state.served_model_names())?;

let request_id = text_request.request_id.clone();
info!(%request_id, "grpc generate (unary)");
Expand Down Expand Up @@ -97,7 +98,8 @@ impl pb::generate_server::Generate for GenerateServiceImpl {
) -> Result<Response<Self::GenerateStreamStream>, Status> {
let proto_req = request.into_inner();
let response_opts = ResponseOpts::from_proto(proto_req.response.as_ref());
let text_request = convert::to_text_request(proto_req, true, &self.state.model_id)?;
let text_request =
convert::to_text_request(proto_req, true, self.state.served_model_names())?;

let request_id = text_request.request_id.clone();
info!(%request_id, "grpc generate (stream)");
Expand Down
2 changes: 1 addition & 1 deletion src/server/src/grpc/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ async fn grpc_test_server(
test_llm(client),
Arc::new(FakeTextBackend) as Arc<dyn ChatTextBackend>,
);
let state = Arc::new(AppState::new("test-model", chat));
let state = Arc::new(AppState::new(vec!["test-model".to_string()], chat));
let svc = GenerateServer::new(GenerateServiceImpl::new(state));

// Bind to an OS-assigned port.
Expand Down
12 changes: 10 additions & 2 deletions src/server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,16 @@ async fn build_state(config: &Config) -> Result<Arc<AppState>> {
.with_tool_call_parser(config.tool_call_parser.clone())
.with_reasoning_parser(config.reasoning_parser.clone());

// If no served names are specified, fall back to the backend model path so
// that the API always has at least one valid model ID.
let served_model_names = if config.served_model_name.is_empty() {
vec![config.model.clone()]
} else {
config.served_model_name.clone()
Comment thread
ericcurtin marked this conversation as resolved.
};

Ok(Arc::new(
AppState::new(config.model.clone(), chat).with_log_requests(config.enable_log_requests),
AppState::new(served_model_names, chat).with_log_requests(config.enable_log_requests),
))
}

Expand All @@ -98,7 +106,7 @@ pub async fn serve(config: Config, shutdown: CancellationToken) -> Result<()> {
.await
.context("failed to bind listener for OpenAI server")?;
let bind_address = listener.local_addr()?;
let model = state.model_id.clone();
let model = state.primary_model_name().to_owned();
let app = build_router(state.clone());

// Optionally bind the gRPC Generate server on a separate port. Bind
Expand Down
2 changes: 1 addition & 1 deletion src/server/src/routes/http_client_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ async fn http_test_server(
test_llm(client),
Arc::new(FakeChatBackend) as Arc<dyn ChatTextBackend>,
);
let state = Arc::new(AppState::new("test-model", chat));
let state = Arc::new(AppState::new(vec!["test-model".to_string()], chat));
let app = build_router(state);

let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.expect("bind http listener");
Expand Down
3 changes: 2 additions & 1 deletion src/server/src/routes/inference/generate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ pub async fn generate(
ValidatedJson(body): ValidatedJson<GenerateRequest>,
) -> Response {
let request_context = resolve_request_context(&headers, body.request_id.as_deref());
let prepared = match prepare_generate_request(body, &state.model_id, request_context) {
let prepared = match prepare_generate_request(body, state.served_model_names(), request_context)
{
Ok(prepared) => prepared,
Err(error) => return error.into_response(),
};
Expand Down
6 changes: 3 additions & 3 deletions src/server/src/routes/inference/generate/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ pub struct PreparedRequest {
/// text-generation format.
pub fn prepare_generate_request(
request: GenerateRequest,
configured_model: &str,
served_model_names: &[String],
ctx: ResolvedRequestContext,
) -> Result<PreparedRequest, ApiError> {
validate::validate_request_compat(&request, configured_model)?;
validate::validate_request_compat(&request, served_model_names)?;

let include_logprobs = request.sampling_params.logprobs.is_some();
let include_prompt_logprobs = request.sampling_params.prompt_logprobs.is_some();
Expand Down Expand Up @@ -81,7 +81,7 @@ mod tests {

let prepared = prepare_generate_request(
request,
"Qwen/Qwen1.5-0.5B-Chat",
&["Qwen/Qwen1.5-0.5B-Chat".to_string()],
ResolvedRequestContext::default(),
)
.expect("prepare");
Expand Down
12 changes: 8 additions & 4 deletions src/server/src/routes/inference/generate/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ use crate::error::{ApiError, bail_invalid_request};
/// route.
pub(super) fn validate_request_compat(
request: &GenerateRequest,
configured_model: &str,
served_model_names: &[String],
) -> Result<(), ApiError> {
if let Some(model) = request.model.as_ref()
&& model != configured_model
&& !served_model_names.iter().any(|n| n == model)
{
return Err(ApiError::model_not_found(model.clone()));
}
Expand Down Expand Up @@ -60,13 +60,17 @@ mod tests {
.expect("parse request")
}

fn served(names: &[&str]) -> Vec<String> {
names.iter().map(|s| s.to_string()).collect()
}

#[test]
fn validate_request_compat_rejects_streaming() {
let request = GenerateRequest {
stream: true,
..base_request()
};
assert!(validate_request_compat(&request, "Qwen/Qwen1.5-0.5B-Chat").is_err());
assert!(validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])).is_err());
}

#[test]
Expand All @@ -75,6 +79,6 @@ mod tests {
token_ids: Vec::new(),
..base_request()
};
assert!(validate_request_compat(&request, "Qwen/Qwen1.5-0.5B-Chat").is_err());
assert!(validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])).is_err());
}
}
2 changes: 1 addition & 1 deletion src/server/src/routes/openai/chat_completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ pub async fn chat_completions(
let stream = body.stream;
let request_context = resolve_request_context(&headers, body.request_id.as_deref());

let prepared = match prepare_chat_request(body, &state.model_id, request_context) {
let prepared = match prepare_chat_request(body, state.served_model_names(), request_context) {
Ok(prepared) => prepared,
Err(error) => return error.into_response(),
};
Expand Down
Loading
Loading