diff --git a/src/cmd/src/cli.rs b/src/cmd/src/cli.rs index 15f9d98..0458425 100644 --- a/src/cmd/src/cli.rs +++ b/src/cmd/src/cli.rs @@ -168,6 +168,17 @@ pub struct SharedRuntimeArgs { #[serde(default)] pub disable_log_stats: bool, + /// The model name(s) used in the API. If multiple names are provided, the + /// server will respond to any of the provided names. The model name in the + /// model field of a response will be the first name in this list. If not + /// specified, the model name will be the same as the `--model` argument. + /// Noted that this name(s) will also be used in `model_name` tag + /// content of prometheus metrics, if multiple names provided, metrics + /// tag will take the first one. + #[arg(long, num_args = 0..)] + #[serde(default)] + pub served_model_name: Vec, + /// Unsupported Python vLLM frontend arguments recognized but not yet /// implemented in Rust. #[educe(Debug(ignore))] @@ -215,6 +226,7 @@ impl SharedRuntimeArgs { None => CoordinatorMode::None, }, model: self.model, + served_model_name: self.served_model_name, listener_mode: HttpListenerMode::InheritedFd { fd: listen_fd }, tool_call_parser: self.tool_call_parser, reasoning_parser: self.reasoning_parser, @@ -254,6 +266,7 @@ impl SharedRuntimeArgs { }, coordinator_mode: CoordinatorMode::MaybeInProc, model: self.model, + served_model_name: self.served_model_name, listener_mode, tool_call_parser: self.tool_call_parser, reasoning_parser: self.reasoning_parser, diff --git a/src/cmd/src/cli/tests.rs b/src/cmd/src/cli/tests.rs index 8de2adb..9d9e1dc 100644 --- a/src/cmd/src/cli/tests.rs +++ b/src/cmd/src/cli/tests.rs @@ -49,6 +49,7 @@ fn serve_args_forward_python_flags_with_separator() { chat_template_content_format: Auto, enable_log_requests: false, disable_log_stats: false, + served_model_name: [], }, python_args: [ "--dtype", @@ -210,6 +211,7 @@ fn frontend_args_accept_json() { chat_template_content_format: Auto, enable_log_requests: false, disable_log_stats: false, + served_model_name: [], }, }, ), @@ -596,6 +598,7 @@ fn serve_args_accept_handshake_aliases() { chat_template_content_format: Auto, enable_log_requests: false, disable_log_stats: false, + served_model_name: [], }, python_args: [], }, @@ -690,6 +693,7 @@ fn serve_frontend_config_uses_dp_address_as_advertised_host() { }, coordinator_mode: MaybeInProc, model: "Qwen/Qwen3-0.6B", + served_model_name: [], listener_mode: BindTcp { host: "127.0.0.1", port: 8000, @@ -751,6 +755,7 @@ fn serve_frontend_config_keeps_tcp_transport_for_non_local_only_topology() { }, coordinator_mode: MaybeInProc, model: "Qwen/Qwen3-0.6B", + served_model_name: [], listener_mode: BindTcp { host: "127.0.0.1", port: 8000, @@ -828,6 +833,7 @@ fn frontend_config_uses_external_coordinator_when_coordinator_address_is_present address: "tcp://127.0.0.1:7000", }, model: "Qwen/Qwen3-0.6B", + served_model_name: [], listener_mode: InheritedFd { fd: 3, }, diff --git a/src/cmd/src/cli/unsupported.rs b/src/cmd/src/cli/unsupported.rs index 7112ded..eeaa083 100644 --- a/src/cmd/src/cli/unsupported.rs +++ b/src/cmd/src/cli/unsupported.rs @@ -234,16 +234,6 @@ pub struct EngineUnsupportedArgs { )] pub enable_prompt_embeds: Option, - /// The model name(s) used in the API. If multiple names are provided, the - /// server will respond to any of the provided names. The model name in the - /// model field of a response will be the first name in this list. If not - /// specified, the model name will be the same as the `--model` argument. - /// Noted that this name(s) will also be used in `model_name` tag - /// content of prometheus metrics, if multiple names provided, metrics - /// tag will take the first one. - #[arg(long)] - pub served_model_name: Option, - /// The token to use as HTTP bearer authorization for remote files. If /// `True`, will use the token generated when running `hf auth login` /// (stored in `~/.cache/huggingface/token`). diff --git a/src/server/examples/external_engine_openai_qwen.rs b/src/server/examples/external_engine_openai_qwen.rs index 9034cde..6ef2e1a 100644 --- a/src/server/examples/external_engine_openai_qwen.rs +++ b/src/server/examples/external_engine_openai_qwen.rs @@ -56,6 +56,7 @@ async fn main() -> Result<()> { }, coordinator_mode: CoordinatorMode::MaybeInProc, model: args.model, + served_model_name: vec![], listener_mode: HttpListenerMode::BindTcp { host: "127.0.0.1".to_string(), port, diff --git a/src/server/src/config.rs b/src/server/src/config.rs index 065d0e1..5f49c6b 100644 --- a/src/server/src/config.rs +++ b/src/server/src/config.rs @@ -38,8 +38,12 @@ pub struct Config { pub transport_mode: TransportMode, /// Requested frontend-side coordinator behavior. pub coordinator_mode: CoordinatorMode, - /// Backend model identifier and exposed OpenAI model ID. + /// Backend model identifier used for engine-core loading. pub model: String, + /// Model name(s) exposed to clients via the OpenAI API. When non-empty, + /// the first entry is used as the primary ID in responses and all entries + /// are accepted in requests. When empty, falls back to `model`. + pub served_model_name: Vec, /// HTTP listener setup. pub listener_mode: HttpListenerMode, /// Tool-call parser selection. diff --git a/src/server/src/grpc/convert.rs b/src/server/src/grpc/convert.rs index 883617d..91b7c7d 100644 --- a/src/server/src/grpc/convert.rs +++ b/src/server/src/grpc/convert.rs @@ -17,15 +17,15 @@ use super::pb; /// Convert a gRPC `GenerateRequest` into the internal `TextRequest`. /// -/// If `req.model` is non-empty, it must match `configured_model`; otherwise the -/// request is rejected with `NotFound`. An empty string is treated as "unset" -/// (proto3 default) and accepted. +/// If `req.model` is non-empty, it must match one of `served_model_names`; +/// otherwise the request is rejected with `NotFound`. An empty string is +/// treated as "unset" (proto3 default) and accepted. pub fn to_text_request( req: pb::GenerateRequest, stream: bool, - configured_model: &str, + served_model_names: &[String], ) -> Result { - if !req.model.is_empty() && req.model != configured_model { + if !req.model.is_empty() && !served_model_names.iter().any(|n| n == &req.model) { return Err(Status::not_found(format!( "model `{}` not found", req.model @@ -521,13 +521,14 @@ mod tests { temperature: Some(0.7), ..base_request() }; - let text = to_text_request(req, false, "test-model").expect("convert ok"); + let text = to_text_request(req, false, &["test-model".to_string()]).expect("convert ok"); assert_eq!(text.sampling_params.temperature, Some(0.7)); } #[test] fn unset_temperature_defaults_to_greedy() { - let text = to_text_request(base_request(), false, "test-model").expect("convert ok"); + let text = to_text_request(base_request(), false, &["test-model".to_string()]) + .expect("convert ok"); // The gRPC API defaults to greedy (0.0) when temperature is not specified. assert_eq!(text.sampling_params.temperature, Some(0.0)); } @@ -541,7 +542,7 @@ mod tests { }), ..base_request() }; - let text = to_text_request(req, false, "test-model").expect("convert ok"); + let text = to_text_request(req, false, &["test-model".to_string()]).expect("convert ok"); assert_eq!(text.sampling_params.seed, None); } @@ -554,7 +555,7 @@ mod tests { }), ..base_request() }; - let text = to_text_request(req, false, "test-model").expect("convert ok"); + let text = to_text_request(req, false, &["test-model".to_string()]).expect("convert ok"); assert_eq!(text.sampling_params.seed, Some(0)); } @@ -567,7 +568,7 @@ mod tests { }), ..base_request() }; - let text = to_text_request(req, false, "test-model").expect("convert ok"); + let text = to_text_request(req, false, &["test-model".to_string()]).expect("convert ok"); assert_eq!(text.sampling_params.skip_reading_prefix_cache, Some(true)); } @@ -580,7 +581,7 @@ mod tests { }), ..base_request() }; - let text = to_text_request(req, false, "test-model").expect("convert ok"); + let text = to_text_request(req, false, &["test-model".to_string()]).expect("convert ok"); assert_eq!(text.sampling_params.skip_reading_prefix_cache, None); // Prompt conversion still succeeds and reaches the expected variant. assert!(matches!(text.prompt, Prompt::Text(s) if s == "hi")); diff --git a/src/server/src/grpc/mod.rs b/src/server/src/grpc/mod.rs index d70cef3..2f648aa 100644 --- a/src/server/src/grpc/mod.rs +++ b/src/server/src/grpc/mod.rs @@ -49,7 +49,8 @@ impl pb::generate_server::Generate for GenerateServiceImpl { ) -> Result, Status> { let proto_req = request.into_inner(); let response_opts = ResponseOpts::from_proto(proto_req.response.as_ref()); - let text_request = convert::to_text_request(proto_req, false, &self.state.model_id)?; + let text_request = + convert::to_text_request(proto_req, false, self.state.served_model_names())?; let request_id = text_request.request_id.clone(); info!(%request_id, "grpc generate (unary)"); @@ -97,7 +98,8 @@ impl pb::generate_server::Generate for GenerateServiceImpl { ) -> Result, Status> { let proto_req = request.into_inner(); let response_opts = ResponseOpts::from_proto(proto_req.response.as_ref()); - let text_request = convert::to_text_request(proto_req, true, &self.state.model_id)?; + let text_request = + convert::to_text_request(proto_req, true, self.state.served_model_names())?; let request_id = text_request.request_id.clone(); info!(%request_id, "grpc generate (stream)"); diff --git a/src/server/src/grpc/tests.rs b/src/server/src/grpc/tests.rs index 842163f..7e5fd7a 100644 --- a/src/server/src/grpc/tests.rs +++ b/src/server/src/grpc/tests.rs @@ -246,7 +246,7 @@ async fn grpc_test_server( test_llm(client), Arc::new(FakeTextBackend) as Arc, ); - let state = Arc::new(AppState::new("test-model", chat)); + let state = Arc::new(AppState::new(vec!["test-model".to_string()], chat)); let svc = GenerateServer::new(GenerateServiceImpl::new(state)); // Bind to an OS-assigned port. diff --git a/src/server/src/lib.rs b/src/server/src/lib.rs index 6c3abd3..2b68428 100644 --- a/src/server/src/lib.rs +++ b/src/server/src/lib.rs @@ -76,8 +76,16 @@ async fn build_state(config: &Config) -> Result> { .with_tool_call_parser(config.tool_call_parser.clone()) .with_reasoning_parser(config.reasoning_parser.clone()); + // If no served names are specified, fall back to the backend model path so + // that the API always has at least one valid model ID. + let served_model_names = if config.served_model_name.is_empty() { + vec![config.model.clone()] + } else { + config.served_model_name.clone() + }; + Ok(Arc::new( - AppState::new(config.model.clone(), chat).with_log_requests(config.enable_log_requests), + AppState::new(served_model_names, chat).with_log_requests(config.enable_log_requests), )) } @@ -98,7 +106,7 @@ pub async fn serve(config: Config, shutdown: CancellationToken) -> Result<()> { .await .context("failed to bind listener for OpenAI server")?; let bind_address = listener.local_addr()?; - let model = state.model_id.clone(); + let model = state.primary_model_name().to_owned(); let app = build_router(state.clone()); // Optionally bind the gRPC Generate server on a separate port. Bind diff --git a/src/server/src/routes/http_client_tests.rs b/src/server/src/routes/http_client_tests.rs index f5af24a..66fee92 100644 --- a/src/server/src/routes/http_client_tests.rs +++ b/src/server/src/routes/http_client_tests.rs @@ -263,7 +263,7 @@ async fn http_test_server( test_llm(client), Arc::new(FakeChatBackend) as Arc, ); - let state = Arc::new(AppState::new("test-model", chat)); + let state = Arc::new(AppState::new(vec!["test-model".to_string()], chat)); let app = build_router(state); let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.expect("bind http listener"); diff --git a/src/server/src/routes/inference/generate.rs b/src/server/src/routes/inference/generate.rs index ec653ca..9289e65 100644 --- a/src/server/src/routes/inference/generate.rs +++ b/src/server/src/routes/inference/generate.rs @@ -32,7 +32,8 @@ pub async fn generate( ValidatedJson(body): ValidatedJson, ) -> Response { let request_context = resolve_request_context(&headers, body.request_id.as_deref()); - let prepared = match prepare_generate_request(body, &state.model_id, request_context) { + let prepared = match prepare_generate_request(body, state.served_model_names(), request_context) + { Ok(prepared) => prepared, Err(error) => return error.into_response(), }; diff --git a/src/server/src/routes/inference/generate/convert.rs b/src/server/src/routes/inference/generate/convert.rs index b75349b..7c16c02 100644 --- a/src/server/src/routes/inference/generate/convert.rs +++ b/src/server/src/routes/inference/generate/convert.rs @@ -18,10 +18,10 @@ pub struct PreparedRequest { /// text-generation format. pub fn prepare_generate_request( request: GenerateRequest, - configured_model: &str, + served_model_names: &[String], ctx: ResolvedRequestContext, ) -> Result { - validate::validate_request_compat(&request, configured_model)?; + validate::validate_request_compat(&request, served_model_names)?; let include_logprobs = request.sampling_params.logprobs.is_some(); let include_prompt_logprobs = request.sampling_params.prompt_logprobs.is_some(); @@ -81,7 +81,7 @@ mod tests { let prepared = prepare_generate_request( request, - "Qwen/Qwen1.5-0.5B-Chat", + &["Qwen/Qwen1.5-0.5B-Chat".to_string()], ResolvedRequestContext::default(), ) .expect("prepare"); diff --git a/src/server/src/routes/inference/generate/validate.rs b/src/server/src/routes/inference/generate/validate.rs index 185654a..74a5bbb 100644 --- a/src/server/src/routes/inference/generate/validate.rs +++ b/src/server/src/routes/inference/generate/validate.rs @@ -5,10 +5,10 @@ use crate::error::{ApiError, bail_invalid_request}; /// route. pub(super) fn validate_request_compat( request: &GenerateRequest, - configured_model: &str, + served_model_names: &[String], ) -> Result<(), ApiError> { if let Some(model) = request.model.as_ref() - && model != configured_model + && !served_model_names.iter().any(|n| n == model) { return Err(ApiError::model_not_found(model.clone())); } @@ -60,13 +60,17 @@ mod tests { .expect("parse request") } + fn served(names: &[&str]) -> Vec { + names.iter().map(|s| s.to_string()).collect() + } + #[test] fn validate_request_compat_rejects_streaming() { let request = GenerateRequest { stream: true, ..base_request() }; - assert!(validate_request_compat(&request, "Qwen/Qwen1.5-0.5B-Chat").is_err()); + assert!(validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])).is_err()); } #[test] @@ -75,6 +79,6 @@ mod tests { token_ids: Vec::new(), ..base_request() }; - assert!(validate_request_compat(&request, "Qwen/Qwen1.5-0.5B-Chat").is_err()); + assert!(validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])).is_err()); } } diff --git a/src/server/src/routes/openai/chat_completions.rs b/src/server/src/routes/openai/chat_completions.rs index c04f155..c0894bb 100644 --- a/src/server/src/routes/openai/chat_completions.rs +++ b/src/server/src/routes/openai/chat_completions.rs @@ -50,7 +50,7 @@ pub async fn chat_completions( let stream = body.stream; let request_context = resolve_request_context(&headers, body.request_id.as_deref()); - let prepared = match prepare_chat_request(body, &state.model_id, request_context) { + let prepared = match prepare_chat_request(body, state.served_model_names(), request_context) { Ok(prepared) => prepared, Err(error) => return error.into_response(), }; diff --git a/src/server/src/routes/openai/chat_completions/convert.rs b/src/server/src/routes/openai/chat_completions/convert.rs index f73a3ec..939ee1e 100644 --- a/src/server/src/routes/openai/chat_completions/convert.rs +++ b/src/server/src/routes/openai/chat_completions/convert.rs @@ -40,12 +40,15 @@ pub struct PreparedRequest { /// Validate and lower one OpenAI chat completion request into the internal chat /// format. +/// +/// `served_model_names` must be non-empty; the first entry is used as the +/// `model` field in responses. pub(crate) fn prepare_chat_request( request: ChatCompletionRequest, - configured_model: &str, + served_model_names: &[String], ctx: ResolvedRequestContext, ) -> Result { - validate::validate_request_compat(&request, configured_model)?; + validate::validate_request_compat(&request, served_model_names)?; let request_id = format!("chatcmpl-{}", ctx.request_id); let echo = request @@ -132,7 +135,7 @@ pub(crate) fn prepare_chat_request( Ok(PreparedRequest { request_id, - response_model: configured_model.to_string(), + response_model: served_model_names.first().cloned().unwrap_or_default(), include_usage, requested_logprobs, include_prompt_logprobs, @@ -354,6 +357,10 @@ mod tests { resolve_request_context(headers, request_id) } + fn served(names: &[&str]) -> Vec { + names.iter().map(|s| s.to_string()).collect() + } + fn base_request() -> ChatCompletionRequest { ChatCompletionRequest { model: "Qwen/Qwen1.5-0.5B-Chat".to_string(), @@ -384,7 +391,7 @@ mod tests { let prepared = prepare_chat_request( request, - "Qwen/Qwen1.5-0.5B-Chat", + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), ResolvedRequestContext::default(), ) .expect("request is valid"); @@ -423,7 +430,7 @@ mod tests { fn prepare_chat_request_keeps_optional_sampling_fields_unset() { let prepared = prepare_chat_request( base_request(), - "Qwen/Qwen1.5-0.5B-Chat", + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), ResolvedRequestContext::default(), ) .expect("request is valid"); @@ -467,7 +474,7 @@ mod tests { let prepared = prepare_chat_request( request, - "Qwen/Qwen1.5-0.5B-Chat", + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), ResolvedRequestContext::default(), ) .expect("request is valid"); @@ -506,7 +513,7 @@ mod tests { let prepared = prepare_chat_request( request, - "Qwen/Qwen1.5-0.5B-Chat", + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), ResolvedRequestContext::default(), ) .expect("request is valid"); @@ -546,7 +553,7 @@ mod tests { assert!( prepare_chat_request( request, - "Qwen/Qwen1.5-0.5B-Chat", + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), ResolvedRequestContext::default(), ) .is_err() @@ -572,7 +579,7 @@ mod tests { assert!( prepare_chat_request( request, - "Qwen/Qwen1.5-0.5B-Chat", + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), ResolvedRequestContext::default(), ) .is_err() @@ -599,7 +606,7 @@ mod tests { let prepared = prepare_chat_request( request, - "Qwen/Qwen1.5-0.5B-Chat", + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), ResolvedRequestContext::default(), ) .expect("request is valid"); @@ -635,7 +642,7 @@ mod tests { let prepared = prepare_chat_request( request, - "Qwen/Qwen1.5-0.5B-Chat", + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), ResolvedRequestContext::default(), ) .expect("request is valid"); @@ -692,7 +699,7 @@ mod tests { let prepared = prepare_chat_request( request, - "Qwen/Qwen1.5-0.5B-Chat", + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), ResolvedRequestContext::default(), ) .expect("request is valid"); @@ -735,7 +742,7 @@ mod tests { let prepared = prepare_chat_request( request, - "Qwen/Qwen1.5-0.5B-Chat", + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), ResolvedRequestContext::default(), ) .expect("request is valid"); @@ -760,7 +767,7 @@ mod tests { let prepared = prepare_chat_request( request, - "Qwen/Qwen1.5-0.5B-Chat", + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), ResolvedRequestContext::default(), ) .expect("request is valid"); @@ -776,7 +783,7 @@ mod tests { headers.insert("X-data-parallel-rank", "7".parse().unwrap()); let prepared = prepare_chat_request( base_request(), - "Qwen/Qwen1.5-0.5B-Chat", + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), request_context(&headers, None), ) .expect("request is valid"); @@ -787,7 +794,7 @@ mod tests { fn prepare_chat_request_leaves_data_parallel_rank_none_when_absent() { let prepared = prepare_chat_request( base_request(), - "Qwen/Qwen1.5-0.5B-Chat", + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), ResolvedRequestContext::default(), ) .expect("request is valid"); @@ -801,7 +808,7 @@ mod tests { let prepared = prepare_chat_request( request, - "Qwen/Qwen1.5-0.5B-Chat", + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), ResolvedRequestContext::default(), ) .expect("request is valid"); @@ -820,7 +827,7 @@ mod tests { let error = prepare_chat_request( request, - "Qwen/Qwen1.5-0.5B-Chat", + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), ResolvedRequestContext::default(), ) .unwrap_err(); @@ -842,7 +849,7 @@ mod tests { let prepared = prepare_chat_request( request, - "Qwen/Qwen1.5-0.5B-Chat", + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), ResolvedRequestContext::default(), ) .expect("request is valid"); @@ -860,7 +867,7 @@ mod tests { let error = prepare_chat_request( request, - "Qwen/Qwen1.5-0.5B-Chat", + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), ResolvedRequestContext::default(), ) .unwrap_err(); @@ -883,7 +890,7 @@ mod tests { let prepared = prepare_chat_request( request, - "Qwen/Qwen1.5-0.5B-Chat", + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), ResolvedRequestContext::default(), ) .expect("request is valid"); diff --git a/src/server/src/routes/openai/chat_completions/validate.rs b/src/server/src/routes/openai/chat_completions/validate.rs index c5a0070..fbd10ee 100644 --- a/src/server/src/routes/openai/chat_completions/validate.rs +++ b/src/server/src/routes/openai/chat_completions/validate.rs @@ -5,9 +5,9 @@ use crate::routes::openai::utils::types::{ChatMessage, Tool, ToolChoice, ToolCho /// Enforce the minimal compatibility contract for the Rust OpenAI server. pub(super) fn validate_request_compat( request: &ChatCompletionRequest, - configured_model: &str, + served_model_names: &[String], ) -> Result<(), ApiError> { - if request.model != configured_model { + if !served_model_names.iter().any(|n| n == &request.model) { return Err(ApiError::model_not_found(request.model.clone())); } @@ -190,6 +190,10 @@ mod tests { ToolChoiceValue, ToolReference, }; + fn served(names: &[&str]) -> Vec { + names.iter().map(|s| s.to_string()).collect() + } + fn base_request() -> ChatCompletionRequest { ChatCompletionRequest { model: "Qwen/Qwen1.5-0.5B-Chat".to_string(), @@ -209,7 +213,7 @@ mod tests { ..base_request() }; - validate_request_compat(&request, "Qwen/Qwen1.5-0.5B-Chat") + validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])) .expect("stop strings should be accepted"); } @@ -223,7 +227,7 @@ mod tests { seed: Some(7), ..base_request() }; - validate_request_compat(&request, "Qwen/Qwen1.5-0.5B-Chat") + validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])) .expect("sampling fields should be accepted"); let request = ChatCompletionRequest { @@ -238,7 +242,7 @@ mod tests { }]), ..base_request() }; - validate_request_compat(&request, "Qwen/Qwen1.5-0.5B-Chat") + validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])) .expect("function tools should be accepted"); let request = ChatCompletionRequest { @@ -257,7 +261,7 @@ mod tests { }], ..base_request() }; - validate_request_compat(&request, "Qwen/Qwen1.5-0.5B-Chat") + validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])) .expect("developer function tools should be accepted"); } @@ -280,7 +284,7 @@ mod tests { ..base_request() }; - assert!(validate_request_compat(&request, "Qwen/Qwen1.5-0.5B-Chat").is_err()); + assert!(validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])).is_err()); } #[test] @@ -289,7 +293,7 @@ mod tests { logprobs: true, ..base_request() }; - validate_request_compat(&request, "Qwen/Qwen1.5-0.5B-Chat") + validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])) .expect("logprobs should be accepted"); } @@ -304,7 +308,7 @@ mod tests { ..base_request() }; - validate_request_compat(&request, "Qwen/Qwen1.5-0.5B-Chat") + validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])) .expect("reasoning_effort should be accepted"); } @@ -314,7 +318,7 @@ mod tests { top_logprobs: Some(0), ..base_request() }; - assert!(validate_request_compat(&request, "Qwen/Qwen1.5-0.5B-Chat").is_err()); + assert!(validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])).is_err()); } #[test] @@ -323,13 +327,13 @@ mod tests { prompt_logprobs: Some(1), ..base_request() }; - assert!(validate_request_compat(&request, "Qwen/Qwen1.5-0.5B-Chat").is_err()); + assert!(validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])).is_err()); let request = ChatCompletionRequest { prompt_logprobs: Some(-1), ..base_request() }; - assert!(validate_request_compat(&request, "Qwen/Qwen1.5-0.5B-Chat").is_err()); + assert!(validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])).is_err()); } #[test] @@ -339,7 +343,7 @@ mod tests { prompt_logprobs: Some(-2), ..base_request() }; - assert!(validate_request_compat(&request, "Qwen/Qwen1.5-0.5B-Chat").is_err()); + assert!(validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])).is_err()); } #[test] @@ -348,14 +352,14 @@ mod tests { response_format: Some(ResponseFormat::Text), ..base_request() }; - validate_request_compat(&request, "Qwen/Qwen1.5-0.5B-Chat") + validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])) .expect("response_format=text should be accepted"); let request = ChatCompletionRequest { response_format: Some(ResponseFormat::JsonObject), ..base_request() }; - validate_request_compat(&request, "Qwen/Qwen1.5-0.5B-Chat") + validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])) .expect("response_format=json_object should be accepted"); } @@ -366,7 +370,7 @@ mod tests { ..base_request() }; - validate_request_compat(&request, "Qwen/Qwen1.5-0.5B-Chat") + validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])) .expect("tool_choice=none is ok"); } @@ -376,7 +380,7 @@ mod tests { tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Required)), ..base_request() }; - assert!(validate_request_compat(&required, "Qwen/Qwen1.5-0.5B-Chat").is_err()); + assert!(validate_request_compat(&required, &served(&["Qwen/Qwen1.5-0.5B-Chat"])).is_err()); let named = ChatCompletionRequest { tool_choice: Some(ToolChoice::Function { @@ -387,7 +391,7 @@ mod tests { }), ..base_request() }; - assert!(validate_request_compat(&named, "Qwen/Qwen1.5-0.5B-Chat").is_err()); + assert!(validate_request_compat(&named, &served(&["Qwen/Qwen1.5-0.5B-Chat"])).is_err()); let allowed_tools = ChatCompletionRequest { tool_choice: Some(ToolChoice::AllowedTools { @@ -399,6 +403,8 @@ mod tests { }), ..base_request() }; - assert!(validate_request_compat(&allowed_tools, "Qwen/Qwen1.5-0.5B-Chat").is_err()); + assert!( + validate_request_compat(&allowed_tools, &served(&["Qwen/Qwen1.5-0.5B-Chat"])).is_err() + ); } } diff --git a/src/server/src/routes/openai/completions.rs b/src/server/src/routes/openai/completions.rs index cb84324..33813e6 100644 --- a/src/server/src/routes/openai/completions.rs +++ b/src/server/src/routes/openai/completions.rs @@ -45,10 +45,11 @@ pub async fn completions( let logprobs = body.logprobs; let request_context = resolve_request_context(&headers, body.request_id.as_deref()); - let prepared = match prepare_completion_request(body, &state.model_id, request_context) { - Ok(prepared) => prepared, - Err(error) => return error.into_response(), - }; + let prepared = + match prepare_completion_request(body, state.served_model_names(), request_context) { + Ok(prepared) => prepared, + Err(error) => return error.into_response(), + }; let request_span = tracing::info_span!( "completions", request_id = %prepared.request_id, diff --git a/src/server/src/routes/openai/completions/convert.rs b/src/server/src/routes/openai/completions/convert.rs index a2a2baf..7c77009 100644 --- a/src/server/src/routes/openai/completions/convert.rs +++ b/src/server/src/routes/openai/completions/convert.rs @@ -29,12 +29,15 @@ pub struct PreparedRequest { /// Validate and lower one OpenAI completions request into the internal /// text-generation format. +/// +/// `served_model_names` must be non-empty; the first entry is used as the +/// `model` field in responses. pub(crate) fn prepare_completion_request( request: CompletionRequest, - configured_model: &str, + served_model_names: &[String], ctx: ResolvedRequestContext, ) -> Result { - validate::validate_request_compat(&request, configured_model)?; + validate::validate_request_compat(&request, served_model_names)?; let request_id = format!("cmpl-{}", ctx.request_id); @@ -104,7 +107,7 @@ pub(crate) fn prepare_completion_request( Ok(PreparedRequest { request_id, - response_model: configured_model.to_string(), + response_model: served_model_names.first().cloned().unwrap_or_default(), include_usage, text_request, echo, @@ -127,6 +130,10 @@ mod tests { resolve_request_context(headers, request_id) } + fn served(names: &[&str]) -> Vec { + names.iter().map(|s| s.to_string()).collect() + } + fn base_request_json() -> serde_json::Value { json!({ "model": "Qwen/Qwen1.5-0.5B-Chat", @@ -182,7 +189,7 @@ mod tests { let prepared = prepare_completion_request( request, - "Qwen/Qwen1.5-0.5B-Chat", + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), ResolvedRequestContext::default(), ) .expect("prepare"); @@ -226,7 +233,7 @@ mod tests { let prepared = prepare_completion_request( request, - "Qwen/Qwen1.5-0.5B-Chat", + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), ResolvedRequestContext::default(), ) .expect("prepare"); @@ -248,7 +255,7 @@ mod tests { let prepared = prepare_completion_request( request, - "Qwen/Qwen1.5-0.5B-Chat", + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), ResolvedRequestContext::default(), ) .expect("prepare"); @@ -273,7 +280,7 @@ mod tests { assert!( prepare_completion_request( request, - "Qwen/Qwen1.5-0.5B-Chat", + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), ResolvedRequestContext::default(), ) .is_err() @@ -293,7 +300,7 @@ mod tests { let prepared = prepare_completion_request( request, - "Qwen/Qwen1.5-0.5B-Chat", + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), ResolvedRequestContext::default(), ) .expect("prepare"); @@ -317,7 +324,7 @@ mod tests { headers.insert("X-data-parallel-rank", "3".parse().unwrap()); let prepared = prepare_completion_request( request, - "Qwen/Qwen1.5-0.5B-Chat", + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), request_context(&headers, None), ) .expect("prepare"); @@ -335,7 +342,7 @@ mod tests { let prepared = prepare_completion_request( request, - "Qwen/Qwen1.5-0.5B-Chat", + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), ResolvedRequestContext::default(), ) .expect("prepare"); diff --git a/src/server/src/routes/openai/completions/validate.rs b/src/server/src/routes/openai/completions/validate.rs index bf4911f..a536092 100644 --- a/src/server/src/routes/openai/completions/validate.rs +++ b/src/server/src/routes/openai/completions/validate.rs @@ -6,12 +6,12 @@ use crate::error::{ApiError, bail_invalid_request}; /// Enforce the minimal compatibility contract for the Rust OpenAI server. pub(super) fn validate_request_compat( request: &CompletionRequest, - configured_model: &str, + served_model_names: &[String], ) -> Result<(), ApiError> { // This path is intentionally scoped to the minimum surface needed by // `vllm-bench` random workload compatibility, so unsupported legacy // completions features fail early here. - if request.model != configured_model { + if !served_model_names.iter().any(|n| n == &request.model) { return Err(ApiError::model_not_found(request.model.clone())); } @@ -120,13 +120,37 @@ mod tests { .expect("parse request") } + fn served_names(names: &[&str]) -> Vec { + names.iter().map(|s| s.to_string()).collect() + } + #[test] fn validate_request_compat_accepts_logprobs() { let request = CompletionRequest { logprobs: Some(1), ..base_request() }; - assert!(validate_request_compat(&request, "Qwen/Qwen1.5-0.5B-Chat").is_ok()); + assert!( + validate_request_compat(&request, &served_names(&["Qwen/Qwen1.5-0.5B-Chat"])).is_ok() + ); + } + + #[test] + fn validate_request_compat_accepts_any_served_name() { + let request = base_request(); + assert!( + validate_request_compat( + &request, + &served_names(&["other-alias", "Qwen/Qwen1.5-0.5B-Chat"]) + ) + .is_ok() + ); + } + + #[test] + fn validate_request_compat_rejects_unknown_model() { + let request = base_request(); + assert!(validate_request_compat(&request, &served_names(&["other-model"])).is_err()); } #[test] @@ -135,7 +159,9 @@ mod tests { prompt_logprobs: Some(1), ..base_request() }; - assert!(validate_request_compat(&request, "Qwen/Qwen1.5-0.5B-Chat").is_err()); + assert!( + validate_request_compat(&request, &served_names(&["Qwen/Qwen1.5-0.5B-Chat"])).is_err() + ); } #[test] @@ -145,6 +171,8 @@ mod tests { prompt_logprobs: Some(-1), ..base_request() }; - assert!(validate_request_compat(&request, "Qwen/Qwen1.5-0.5B-Chat").is_ok()); + assert!( + validate_request_compat(&request, &served_names(&["Qwen/Qwen1.5-0.5B-Chat"])).is_ok() + ); } } diff --git a/src/server/src/routes/openai/models.rs b/src/server/src/routes/openai/models.rs index 9544dae..42e3098 100644 --- a/src/server/src/routes/openai/models.rs +++ b/src/server/src/routes/openai/models.rs @@ -6,15 +6,19 @@ use axum::extract::State; use crate::routes::openai::utils::types::{ListModelsResponse, ModelObject}; use crate::state::AppState; -/// Return the single configured model in OpenAI `list models` format. +/// Return all configured served model names in OpenAI `list models` format. pub async fn list_models(State(state): State>) -> Json { Json(ListModelsResponse { object: "list".to_string(), - data: vec![ModelObject { - id: state.model_id.clone(), - object: "model".to_string(), - created: 0, - owned_by: "vllm-frontend-rs".to_string(), - }], + data: state + .served_model_names() + .iter() + .map(|name| ModelObject { + id: name.clone(), + object: "model".to_string(), + created: 0, + owned_by: "vllm-frontend-rs".to_string(), + }) + .collect(), }) } diff --git a/src/server/src/routes/tests.rs b/src/server/src/routes/tests.rs index 482be56..d284ab3 100644 --- a/src/server/src/routes/tests.rs +++ b/src/server/src/routes/tests.rs @@ -614,7 +614,10 @@ async fn test_app() -> axum::Router { Arc::new(FakeChatBackend::new()), ) .await; - build_router(Arc::new(AppState::new("Qwen/Qwen1.5-0.5B-Chat", chat))) + build_router(Arc::new(AppState::new( + vec!["Qwen/Qwen1.5-0.5B-Chat".to_string()], + chat, + ))) } async fn test_health_app_with_engine_script( @@ -642,7 +645,10 @@ where .expect("connect client"); let chat = ChatLlm::from_shared_backend(test_llm(client), Arc::new(FakeChatBackend::new())); - let state = Arc::new(AppState::new("Qwen/Qwen1.5-0.5B-Chat", chat)); + let state = Arc::new(AppState::new( + vec!["Qwen/Qwen1.5-0.5B-Chat".to_string()], + chat, + )); (build_router(state.clone()), state, engine_task) } @@ -671,7 +677,10 @@ where let chat = ChatLlm::from_shared_backend(test_llm(client), Arc::new(FakeChatBackend::new())); ( build_router_with_dev_mode( - Arc::new(AppState::new("Qwen/Qwen1.5-0.5B-Chat", chat)), + Arc::new(AppState::new( + vec!["Qwen/Qwen1.5-0.5B-Chat".to_string()], + chat, + )), true, ), engine_task, @@ -692,7 +701,10 @@ async fn test_app_with_stream_output_specs( ) .await; ( - build_router(Arc::new(AppState::new("Qwen/Qwen1.5-0.5B-Chat", chat))), + build_router(Arc::new(AppState::new( + vec!["Qwen/Qwen1.5-0.5B-Chat".to_string()], + chat, + ))), engine_task, ) } @@ -704,7 +716,10 @@ async fn test_app_with_backend_and_stream_output_specs( let (chat, engine_task) = test_models_with_engine_outputs_and_backend(b"engine-openai", output_specs, backend).await; ( - build_router(Arc::new(AppState::new("Qwen/Qwen1.5-0.5B-Chat", chat))), + build_router(Arc::new(AppState::new( + vec!["Qwen/Qwen1.5-0.5B-Chat".to_string()], + chat, + ))), engine_task, ) } @@ -998,7 +1013,10 @@ async fn non_stream_chat_includes_logprobs_and_prompt_logprobs() { .await .expect("connect client"); let chat = ChatLlm::from_shared_backend(test_llm(client), Arc::new(FakeChatBackend::new())); - let mut app = build_router(Arc::new(AppState::new("Qwen/Qwen1.5-0.5B-Chat", chat))); + let mut app = build_router(Arc::new(AppState::new( + vec!["Qwen/Qwen1.5-0.5B-Chat".to_string()], + chat, + ))); let response = app .call( @@ -1648,7 +1666,10 @@ async fn non_stream_completions_include_logprobs() { .await .expect("connect client"); let chat = ChatLlm::from_shared_backend(test_llm(client), Arc::new(FakeChatBackend::new())); - let mut app = build_router(Arc::new(AppState::new("Qwen/Qwen1.5-0.5B-Chat", chat))); + let mut app = build_router(Arc::new(AppState::new( + vec!["Qwen/Qwen1.5-0.5B-Chat".to_string()], + chat, + ))); let response = app .call( @@ -1747,7 +1768,10 @@ async fn non_stream_completions_include_prompt_logprobs() { .await .expect("connect client"); let chat = ChatLlm::from_shared_backend(test_llm(client), Arc::new(FakeChatBackend::new())); - let mut app = build_router(Arc::new(AppState::new("Qwen/Qwen1.5-0.5B-Chat", chat))); + let mut app = build_router(Arc::new(AppState::new( + vec!["Qwen/Qwen1.5-0.5B-Chat".to_string()], + chat, + ))); let response = app .call( @@ -1830,7 +1854,10 @@ async fn non_stream_chat_completions_still_succeed() { .await .expect("connect client"); let chat = ChatLlm::from_shared_backend(test_llm(client), Arc::new(FakeChatBackend::new())); - let mut app = build_router(Arc::new(AppState::new("Qwen/Qwen1.5-0.5B-Chat", chat))); + let mut app = build_router(Arc::new(AppState::new( + vec!["Qwen/Qwen1.5-0.5B-Chat".to_string()], + chat, + ))); let response = app .call( @@ -1887,7 +1914,10 @@ async fn non_stream_completions_still_succeed() { .await .expect("connect client"); let chat = ChatLlm::from_shared_backend(test_llm(client), Arc::new(FakeChatBackend::new())); - let mut app = build_router(Arc::new(AppState::new("Qwen/Qwen1.5-0.5B-Chat", chat))); + let mut app = build_router(Arc::new(AppState::new( + vec!["Qwen/Qwen1.5-0.5B-Chat".to_string()], + chat, + ))); let response = app .call( @@ -1951,7 +1981,10 @@ async fn chat_completions_header_request_id_takes_precedence() { .await .expect("connect client"); let chat = ChatLlm::from_shared_backend(Llm::new(client), Arc::new(FakeChatBackend::new())); - let mut app = build_router(Arc::new(AppState::new("Qwen/Qwen1.5-0.5B-Chat", chat))); + let mut app = build_router(Arc::new(AppState::new( + vec!["Qwen/Qwen1.5-0.5B-Chat".to_string()], + chat, + ))); let response = app .call( @@ -2045,7 +2078,10 @@ async fn non_stream_raw_generate_returns_token_output_envelope() { .await .expect("connect client"); let chat = ChatLlm::from_shared_backend(Llm::new(client), Arc::new(FakeChatBackend::new())); - let mut app = build_router(Arc::new(AppState::new("Qwen/Qwen1.5-0.5B-Chat", chat))); + let mut app = build_router(Arc::new(AppState::new( + vec!["Qwen/Qwen1.5-0.5B-Chat".to_string()], + chat, + ))); let response = app .call( @@ -2353,7 +2389,7 @@ async fn prepared_openai_request_streams_text_events() { "messages": [{"role": "user", "content": "hello"}] })) .expect("decode request"), - "Qwen/Qwen1.5-0.5B-Chat", + &["Qwen/Qwen1.5-0.5B-Chat".to_string()], crate::utils::ResolvedRequestContext::default(), ) .expect("prepare request"); @@ -2597,7 +2633,10 @@ async fn tool_call_sse_chunks_can_carry_logprobs() { test_llm(client), Arc::new(FakeChatBackend::with_model_id("Qwen/Qwen3-0.6B")), ); - let app = build_router(Arc::new(AppState::new("Qwen/Qwen1.5-0.5B-Chat", chat))); + let app = build_router(Arc::new(AppState::new( + vec!["Qwen/Qwen1.5-0.5B-Chat".to_string()], + chat, + ))); let response = app .clone() @@ -3002,7 +3041,10 @@ async fn is_sleeping_route_returns_json_payload() { async fn admin_routes_are_hidden_when_dev_mode_is_disabled() { let (chat, engine_task) = test_chat_with_engine_handle().await; let app = build_router_with_dev_mode( - Arc::new(AppState::new("Qwen/Qwen1.5-0.5B-Chat", chat)), + Arc::new(AppState::new( + vec!["Qwen/Qwen1.5-0.5B-Chat".to_string()], + chat, + )), false, ); diff --git a/src/server/src/state.rs b/src/server/src/state.rs index 32ebfc5..04d37f1 100644 --- a/src/server/src/state.rs +++ b/src/server/src/state.rs @@ -10,8 +10,9 @@ const SHUTDOWN_REFCOUNT_POLL_INTERVAL: Duration = Duration::from_millis(100); /// Shared router state for the minimal single-model OpenAI server. pub struct AppState { - /// Public model ID returned by `/v1/models` and validated on chat requests. - pub model_id: String, + /// All public model IDs served by this frontend. The first entry is the + /// primary ID used in responses; all entries are valid in requests. + served_model_names: Vec, /// Shared chat facade used by all requests. pub chat: ChatLlm, /// Whether to log a summary line for each completed request. @@ -22,9 +23,20 @@ pub struct AppState { impl AppState { /// Construct one application state instance. - pub fn new(model_id: impl Into, chat: ChatLlm) -> Self { + /// + /// `served_model_names` must be non-empty; the first entry is the primary + /// model ID returned in API responses. + /// + /// # Panics + /// + /// Panics if `served_model_names` is empty. + pub fn new(served_model_names: Vec, chat: ChatLlm) -> Self { + assert!( + !served_model_names.is_empty(), + "served_model_names must not be empty" + ); Self { - model_id: model_id.into(), + served_model_names, chat, enable_log_requests: false, server_load: AtomicU64::new(0), @@ -37,6 +49,17 @@ impl AppState { self } + /// The primary model name echoed back in API responses (the first served + /// name). + pub fn primary_model_name(&self) -> &str { + self.served_model_names.first().map(String::as_str).unwrap_or_default() + } + + /// All model names served by this frontend. + pub fn served_model_names(&self) -> &[String] { + &self.served_model_names + } + /// Return a reference to the underlying engine core client for utility /// calls. pub(crate) fn engine_core_client(&self) -> &EngineCoreClient {