Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/grpc_client/src/mlx_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ impl MlxEngineClient {

let http_endpoint = if let Some(addr) = endpoint.strip_prefix("grpc://") {
format!("http://{addr}")
} else if let Some(addr) = endpoint.strip_prefix("grpcs://") {
format!("https://{addr}")
} else {
endpoint.to_string()
};
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Expand Down
4 changes: 3 additions & 1 deletion crates/grpc_client/src/sglang_scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,11 @@ impl SglangSchedulerClient {
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
debug!("Connecting to SGLang scheduler at {}", endpoint);

// Convert grpc:// to http:// for tonic
// Convert gRPC schemes to tonic-compatible HTTP(S) endpoints.
let http_endpoint = if let Some(addr) = endpoint.strip_prefix("grpc://") {
format!("http://{addr}")
} else if let Some(addr) = endpoint.strip_prefix("grpcs://") {
format!("https://{addr}")
} else {
endpoint.to_string()
};
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Expand Down
4 changes: 3 additions & 1 deletion crates/grpc_client/src/trtllm_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,11 @@ impl TrtllmServiceClient {
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
debug!("Connecting to TensorRT-LLM gRPC server at {}", endpoint);

// Convert grpc:// to http:// for tonic
// Convert gRPC schemes to tonic-compatible HTTP(S) endpoints.
let http_endpoint = if let Some(addr) = endpoint.strip_prefix("grpc://") {
format!("http://{addr}")
} else if let Some(addr) = endpoint.strip_prefix("grpcs://") {
format!("https://{addr}")
} else {
endpoint.to_string()
};
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Expand Down
4 changes: 3 additions & 1 deletion crates/grpc_client/src/vllm_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,11 @@ impl VllmEngineClient {
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
debug!("Connecting to vLLM gRPC server at {}", endpoint);

// Convert grpc:// to http:// for tonic
// Convert gRPC schemes to tonic-compatible HTTP(S) endpoints.
let http_endpoint = if let Some(addr) = endpoint.strip_prefix("grpc://") {
format!("http://{addr}")
} else if let Some(addr) = endpoint.strip_prefix("grpcs://") {
format!("https://{addr}")
} else {
endpoint.to_string()
};
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Expand Down
28 changes: 24 additions & 4 deletions model_gateway/src/config/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -835,14 +835,18 @@ impl ConfigValidator {
});
}

if !url.starts_with("http://")
&& !url.starts_with("https://")
&& !url.starts_with("grpc://")
// Perform case-insensitive scheme check
let url_lower = url.to_ascii_lowercase();
if !url_lower.starts_with("http://")
&& !url_lower.starts_with("https://")
&& !url_lower.starts_with("grpc://")
&& !url_lower.starts_with("grpcs://")
{
return Err(ConfigError::InvalidValue {
field: "worker_url".to_string(),
value: url.clone(),
reason: "URL must start with http://, https://, or grpc://".to_string(),
reason: "URL must start with http://, https://, grpc://, or grpcs://"
.to_string(),
});
}
Comment thread
heymrbox marked this conversation as resolved.

Expand Down Expand Up @@ -1235,6 +1239,22 @@ mod tests {
assert!(result.is_ok());
}

#[test]
fn test_validate_grpcs_worker_url() {
let mut config = RouterConfig::new(
RoutingMode::Regular {
worker_urls: vec!["grpcs://worker:50051".to_string()],
},
PolicyConfig::Random,
);

config.connection_mode = ConnectionMode::Grpc;
config.model_path = Some("meta-llama/Llama-3-8B".to_string());

let result = ConfigValidator::validate(&config);
assert!(result.is_ok());
}

#[test]
fn test_validate_grpc_with_tokenizer_path() {
let mut config = RouterConfig::new(
Expand Down
43 changes: 42 additions & 1 deletion model_gateway/src/workflow/steps/local/create_worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,11 @@ fn infer_non_generation_type(labels: &HashMap<String, String>) -> ModelType {
}

fn normalize_url(url: &str, connection_mode: ConnectionMode) -> String {
if url.starts_with("http://") || url.starts_with("https://") || url.starts_with("grpc://") {
if url.starts_with("http://")
|| url.starts_with("https://")
|| url.starts_with("grpc://")
|| url.starts_with("grpcs://")
Comment thread
heymrbox marked this conversation as resolved.
{
url.to_string()
} else {
match connection_mode {
Expand All @@ -307,3 +311,40 @@ fn normalize_url(url: &str, connection_mode: ConnectionMode) -> String {
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn normalize_url_preserves_existing_schemes() {
assert_eq!(
normalize_url("http://localhost:30000", ConnectionMode::Http),
"http://localhost:30000"
);
assert_eq!(
normalize_url("https://localhost:30000", ConnectionMode::Http),
"https://localhost:30000"
);
assert_eq!(
normalize_url("grpc://localhost:30001", ConnectionMode::Grpc),
"grpc://localhost:30001"
);
assert_eq!(
normalize_url("grpcs://localhost:30001", ConnectionMode::Grpc),
"grpcs://localhost:30001"
);
}

#[test]
fn normalize_url_adds_scheme_for_bare_urls() {
assert_eq!(
normalize_url("localhost:30000", ConnectionMode::Http),
"http://localhost:30000"
);
assert_eq!(
normalize_url("localhost:30001", ConnectionMode::Grpc),
"grpc://localhost:30001"
);
}
}
74 changes: 73 additions & 1 deletion model_gateway/src/workflow/steps/local/detect_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,20 @@ use crate::{
},
};

fn explicit_connection_mode(url: &str) -> Option<ConnectionMode> {
if url.starts_with("grpc://") || url.starts_with("grpcs://") {
Some(ConnectionMode::Grpc)
} else if url.starts_with("http://") || url.starts_with("https://") {
Some(ConnectionMode::Http)
} else {
None
}
Comment thread
heymrbox marked this conversation as resolved.
}
Comment thread
heymrbox marked this conversation as resolved.
Comment thread
coderabbitai[bot] marked this conversation as resolved.

/// Step 1: Detect connection mode (HTTP vs gRPC).
///
/// Probes both protocols in parallel. HTTP takes priority if both succeed.
/// Explicit URL schemes are honored. For bare host:port URLs, probes both
/// protocols in parallel and HTTP takes priority if both succeed.
/// Does NOT detect backend runtime — that's handled by DetectBackendStep.
pub struct DetectConnectionModeStep;

Expand Down Expand Up @@ -51,6 +62,33 @@ impl StepExecutor<WorkerWorkflowData> for DetectConnectionModeStep {
.unwrap_or(app_context.router_config.health_check.timeout_secs);
let client = &app_context.client;

if let Some(connection_mode) = explicit_connection_mode(&url) {
let result = match connection_mode {
ConnectionMode::Http => try_http_reachable(&url, timeout, client).await,
ConnectionMode::Grpc => try_grpc_reachable(&url, timeout).await,
};
Comment thread
heymrbox marked this conversation as resolved.

match result {
Ok(()) => {
debug!(
"{} explicitly configured as {}",
config.url, connection_mode
);
context.data.connection_mode = Some(connection_mode);
return Ok(StepResult::Success);
}
Err(err) => {
return Err(WorkflowError::StepFailed {
step_id: StepId::new("detect_connection_mode"),
message: format!(
"{connection_mode} health check failed for explicitly configured worker URL {}: {}",
config.url, err
),
});
}
}
}

let (http_result, grpc_result) = tokio::join!(
try_http_reachable(&url, timeout, client),
try_grpc_reachable(&url, timeout)
Expand Down Expand Up @@ -84,3 +122,37 @@ impl StepExecutor<WorkerWorkflowData> for DetectConnectionModeStep {
true
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn explicit_connection_mode_honors_grpc_scheme() {
assert_eq!(
explicit_connection_mode("grpc://localhost:30001"),
Some(ConnectionMode::Grpc)
);
assert_eq!(
explicit_connection_mode("grpcs://localhost:30001"),
Some(ConnectionMode::Grpc)
);
}

#[test]
fn explicit_connection_mode_honors_http_schemes() {
assert_eq!(
explicit_connection_mode("http://localhost:30000"),
Some(ConnectionMode::Http)
);
assert_eq!(
explicit_connection_mode("https://example.com"),
Some(ConnectionMode::Http)
);
}

#[test]
fn explicit_connection_mode_leaves_bare_urls_for_probe_detection() {
assert_eq!(explicit_connection_mode("localhost:30000"), None);
}
}
117 changes: 99 additions & 18 deletions model_gateway/src/workflow/steps/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,42 +6,78 @@ use reqwest::Client;

use crate::routers::grpc::client::GrpcClient;

/// Strip protocol prefix (http://, https://, grpc://) from URL.
fn strip_scheme<'a>(url: &'a str, scheme: &str) -> Option<&'a str> {
url.get(..scheme.len())
.filter(|prefix| prefix.eq_ignore_ascii_case(scheme))
.map(|_| &url[scheme.len()..])
}

fn url_scheme(url: &str) -> Option<String> {
url.split_once("://")
.map(|(scheme, _)| scheme.to_ascii_lowercase())
}

/// Strip protocol prefix (http://, https://, grpc://, grpcs://) from URL.
pub(crate) fn strip_protocol(url: &str) -> String {
url.trim_start_matches("http://")
.trim_start_matches("https://")
.trim_start_matches("grpc://")
.to_string()
for scheme in ["http://", "https://", "grpc://", "grpcs://"] {
if let Some(rest) = strip_scheme(url, scheme) {
return rest.to_string();
}
}
url.to_string()
}

/// Ensure URL has an HTTP(S) scheme — handles bare `host:port` and `grpc://` inputs.
/// Ensure URL has an HTTP(S) scheme — handles bare `host:port` and gRPC inputs.
pub(crate) fn http_base_url(url: &str) -> String {
if url.starts_with("http://") || url.starts_with("https://") {
if strip_scheme(url, "http://").is_some() || strip_scheme(url, "https://").is_some() {
url.trim_end_matches('/').to_string()
} else {
format!("http://{}", strip_protocol(url).trim_end_matches('/'))
}
}

/// Ensure URL has a gRPC scheme — handles bare `host:port` and `http://` inputs.
/// Ensure URL has a gRPC scheme — handles bare `host:port` and HTTP(S) inputs.
pub(crate) fn grpc_base_url(url: &str) -> String {
if url.starts_with("grpc://") {
if strip_scheme(url, "grpc://").is_some() || strip_scheme(url, "grpcs://").is_some() {
url.trim_end_matches('/').to_string()
} else {
format!("grpc://{}", strip_protocol(url).trim_end_matches('/'))
}
}

fn http_health_url(url: &str) -> Result<String, String> {
match url_scheme(url).as_deref() {
Some("http") | Some("https") => Ok(format!("{}/health", url.trim_end_matches('/'))),
Some("grpc") | Some("grpcs") => Err(format!(
"HTTP health check does not accept gRPC URL scheme: {url}"
)),
Some(scheme) => Err(format!(
"HTTP health check does not accept URL scheme '{scheme}': {url}"
)),
None => Ok(format!("http://{}/health", url.trim_end_matches('/'))),
}
}

fn grpc_reachable_url(url: &str) -> Result<String, String> {
match url_scheme(url).as_deref() {
Some("grpc") | Some("grpcs") => Ok(url.trim_end_matches('/').to_string()),
Some("http") | Some("https") => Err(format!(
"gRPC health check does not accept HTTP URL scheme: {url}"
)),
Some(scheme) => Err(format!(
"gRPC health check does not accept URL scheme '{scheme}': {url}"
)),
None => Ok(format!("grpc://{}", url.trim_end_matches('/'))),
}
}

/// Try HTTP health check (2xx response required).
pub(crate) async fn try_http_reachable(
url: &str,
timeout_secs: u64,
client: &Client,
) -> Result<(), String> {
let is_https = url.starts_with("https://");
let protocol = if is_https { "https" } else { "http" };
let clean_url = strip_protocol(url);
let health_url = format!("{protocol}://{clean_url}/health");
let health_url = http_health_url(url)?;

client
.get(&health_url)
Expand Down Expand Up @@ -82,11 +118,7 @@ pub(crate) async fn do_grpc_health_check(
/// We don't care which runtime it is here — that's `DetectBackendStep`'s job.
/// We just need to know: does this endpoint speak gRPC at all?
pub(crate) async fn try_grpc_reachable(url: &str, timeout_secs: u64) -> Result<(), String> {
let grpc_url = if url.starts_with("grpc://") {
url.to_string()
} else {
format!("grpc://{}", strip_protocol(url))
};
let grpc_url = grpc_reachable_url(url)?;

let (sglang, vllm, trtllm, mlx) = tokio::join!(
do_grpc_health_check(&grpc_url, timeout_secs, "sglang"),
Expand All @@ -102,3 +134,52 @@ pub(crate) async fn try_grpc_reachable(url: &str, timeout_secs: u64) -> Result<(
)),
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn http_health_url_accepts_http_https_and_bare_urls() {
assert_eq!(
http_health_url("http://localhost:30000").unwrap(),
"http://localhost:30000/health"
);
assert_eq!(
http_health_url("https://example.com/").unwrap(),
"https://example.com/health"
);
assert_eq!(
http_health_url("localhost:30000").unwrap(),
"http://localhost:30000/health"
);
}

#[test]
fn http_health_url_rejects_grpc_schemes() {
assert!(http_health_url("grpc://localhost:30001").is_err());
assert!(http_health_url("grpcs://localhost:30001").is_err());
}

#[test]
fn grpc_reachable_url_accepts_grpc_grpcs_and_bare_urls() {
assert_eq!(
grpc_reachable_url("grpc://localhost:30001").unwrap(),
"grpc://localhost:30001"
);
assert_eq!(
grpc_reachable_url("grpcs://localhost:30001/").unwrap(),
"grpcs://localhost:30001"
);
assert_eq!(
grpc_reachable_url("localhost:30001").unwrap(),
"grpc://localhost:30001"
);
}

#[test]
fn grpc_reachable_url_rejects_http_schemes() {
assert!(grpc_reachable_url("http://localhost:30000").is_err());
assert!(grpc_reachable_url("https://localhost:30000").is_err());
}
}
Loading