diff --git a/crates/protocols/src/worker.rs b/crates/protocols/src/worker.rs index 03d6d7547..e91892494 100644 --- a/crates/protocols/src/worker.rs +++ b/crates/protocols/src/worker.rs @@ -267,6 +267,9 @@ pub enum ProviderType { /// Google Gemini — special logprobs handling. #[serde(alias = "gemini", alias = "google")] Gemini, + /// Amazon Bedrock Runtime. + #[serde(alias = "bedrock", alias = "aws")] + Bedrock, /// Custom provider with string identifier. #[serde(untagged)] Custom(String), @@ -280,6 +283,7 @@ impl ProviderType { Self::XAI => "xai", Self::Anthropic => "anthropic", Self::Gemini => "gemini", + Self::Bedrock => "bedrock", Self::Custom(s) => s.as_str(), } } @@ -297,6 +301,8 @@ impl ProviderType { Some(Self::Anthropic) } else if host.ends_with("googleapis.com") { Some(Self::Gemini) + } else if host.contains("bedrock-runtime.") && host.ends_with(".amazonaws.com") { + Some(Self::Bedrock) } else { None } @@ -310,6 +316,7 @@ impl ProviderType { Self::XAI => Some("XAI_ADMIN_KEY"), Self::Anthropic => Some("ANTHROPIC_ADMIN_KEY"), Self::Gemini => Some("GEMINI_ADMIN_KEY"), + Self::Bedrock => Some("AWS_BEDROCK_API_KEY"), Self::Custom(_) => None, } } @@ -329,6 +336,12 @@ impl ProviderType { Some(Self::Gemini) } else if model_lower.starts_with("claude") { Some(Self::Anthropic) + } else if model_lower.starts_with("anthropic.claude") + || model_lower.starts_with("amazon.") + || model_lower.starts_with("meta.") + || model_lower.starts_with("mistral.") + { + Some(Self::Bedrock) } else if model_lower.starts_with("gpt") || model_lower.starts_with("o1") || model_lower.starts_with("o3") diff --git a/docs/getting-started/external-providers.md b/docs/getting-started/external-providers.md index bf424f146..d4064176f 100644 --- a/docs/getting-started/external-providers.md +++ b/docs/getting-started/external-providers.md @@ -4,7 +4,7 @@ title: External Providers # External Providers -SMG can route requests to external LLM provider APIs (OpenAI, Anthropic, xAI, Google Gemini), acting as a unified gateway. This enables provider-agnostic applications, load balancing across providers, and centralized observability. +SMG can route requests to external LLM provider APIs (OpenAI, Anthropic, xAI, Google Gemini, and AWS Bedrock), acting as a unified gateway. This enables provider-agnostic applications, load balancing across providers, and centralized observability.
@@ -27,6 +27,24 @@ SMG auto-detects the provider from the model name in each request and applies th | Anthropic | `claude-*` models | `x-api-key` (plus `anthropic-version`) | | xAI | `grok-*` models | `Authorization: Bearer` | | Google Gemini | `gemini-*` models | `x-goog-api-key` | +| AWS Bedrock | `anthropic.claude*`, `amazon.*`, `meta.*`, `mistral.*` | AWS SigV4 (`Authorization`, `X-Amz-Date`, etc.) | + +### Bedrock notes + +- Bedrock runs as a first-class routing mode (`routing.mode.type: bedrock`) and is HTTP-only. +- SMG signs Bedrock requests with SigV4 using this credential resolution order: + 1) environment variables (`AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, optional `AWS_SESSION_TOKEN`) + 2) shared credentials profile (`AWS_PROFILE` / `~/.aws/credentials`) + 3) container/instance metadata (ECS/EC2 IMDS). +- Configure Bedrock region/service in `router_config.bedrock`: + +```yaml +bedrock: + region: us-east-1 + service: bedrock + model_map: + claude-opus-4-5: us.anthropic.claude-opus-4-5-20251101-v1:0 +``` --- diff --git a/model_gateway/Cargo.toml b/model_gateway/Cargo.toml index c3a9ee816..c6a746ae1 100644 --- a/model_gateway/Cargo.toml +++ b/model_gateway/Cargo.toml @@ -116,7 +116,9 @@ openmetrics-parser = "0.4.4" arc-swap = "1.7.1" bitflags.workspace = true once_cell = "1.21.3" -sha2 = "0.10" +sha2 = "0.11" +crc32fast = "1.4" +crc32c = "0.6" base64 = "0.22" image = { version = "0.25.4", default-features = false } tokio-tungstenite = { workspace = true } @@ -125,6 +127,8 @@ wasmtime = { workspace = true } tempfile = "3.8" multer = { workspace = true } str0m = { workspace = true } +hmac = "0.13" +hex = "0.4.3" [build-dependencies] chrono = { version = "0.4", features = ["clock"] } diff --git a/model_gateway/src/config/builder.rs b/model_gateway/src/config/builder.rs index 209daefdc..912b6f8d7 100644 --- a/model_gateway/src/config/builder.rs +++ b/model_gateway/src/config/builder.rs @@ -96,6 +96,11 @@ impl RouterConfigBuilder { self } + pub fn bedrock_mode(mut self, worker_urls: Vec) -> Self { + self.config.mode = RoutingMode::Bedrock { worker_urls }; + self + } + pub fn mode(mut self, mode: RoutingMode) -> Self { self.config.mode = mode; self diff --git a/model_gateway/src/config/types.rs b/model_gateway/src/config/types.rs index 1a675bd4b..073480d78 100644 --- a/model_gateway/src/config/types.rs +++ b/model_gateway/src/config/types.rs @@ -127,6 +127,8 @@ pub struct RouterConfig { pub disable_circuit_breaker: bool, pub health_check: HealthCheckConfig, #[serde(default)] + pub bedrock: BedrockConfig, + #[serde(default)] pub enable_igw: bool, /// Can be a HuggingFace model ID or local path pub model_path: Option, @@ -215,6 +217,55 @@ pub struct TokenizerCacheConfig { pub l1_max_memory: usize, } +/// AWS Bedrock routing configuration +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(default)] +pub struct BedrockConfig { + /// AWS region used for request signing (e.g. us-east-1). + pub region: Option, + /// AWS service identifier for SigV4 (defaults to `bedrock`). + /// + /// Note: Bedrock uses `bedrock` as the SigV4 signing service for BOTH the + /// control plane (`bedrock..amazonaws.com`) and the runtime data + /// plane (`bedrock-runtime..amazonaws.com`). The hostname differs + /// but the credential scope service does not. See AWS SDK for Go + /// `bedrockruntime/service.go` and AWS SDK for JS v3 `client-bedrock-runtime`, + /// both of which set `SigningName = "bedrock"`. + pub service: String, + /// Optional model-id remapping from incoming model -> Bedrock model ID. + pub model_map: HashMap, +} + +impl Default for BedrockConfig { + fn default() -> Self { + Self { + region: None, + service: "bedrock".to_string(), + model_map: HashMap::new(), + } + } +} + +impl BedrockConfig { + /// AWS region for SigV4: non-empty `region`, else `AWS_REGION`, else `AWS_DEFAULT_REGION`. + pub(crate) fn resolved_signing_region(&self) -> Option { + if let Some(r) = &self.region { + let trimmed = r.trim(); + if !trimmed.is_empty() { + return Some(trimmed.to_string()); + } + } + std::env::var("AWS_REGION") + .ok() + .filter(|s| !s.trim().is_empty()) + .or_else(|| { + std::env::var("AWS_DEFAULT_REGION") + .ok() + .filter(|s| !s.trim().is_empty()) + }) + } +} + fn default_load_monitor_interval_secs() -> u64 { 10 } @@ -284,6 +335,8 @@ pub enum RoutingMode { Anthropic { worker_urls: Vec }, #[serde(rename = "gemini")] Gemini { worker_urls: Vec }, + #[serde(rename = "bedrock")] + Bedrock { worker_urls: Vec }, } impl RoutingMode { @@ -302,6 +355,7 @@ impl RoutingMode { RoutingMode::OpenAI { worker_urls } => worker_urls.len(), RoutingMode::Anthropic { worker_urls } => worker_urls.len(), RoutingMode::Gemini { worker_urls } => worker_urls.len(), + RoutingMode::Bedrock { worker_urls } => worker_urls.len(), } } @@ -656,6 +710,7 @@ impl Default for RouterConfig { disable_retries: false, disable_circuit_breaker: false, health_check: HealthCheckConfig::default(), + bedrock: BedrockConfig::default(), enable_igw: false, connection_mode: ConnectionMode::Http, model_path: None, @@ -705,6 +760,7 @@ impl RouterConfig { RoutingMode::OpenAI { .. } => "openai", RoutingMode::Anthropic { .. } => "anthropic", RoutingMode::Gemini { .. } => "gemini", + RoutingMode::Bedrock { .. } => "bedrock", } } diff --git a/model_gateway/src/config/validation.rs b/model_gateway/src/config/validation.rs index 7dfe99d00..2189f41aa 100644 --- a/model_gateway/src/config/validation.rs +++ b/model_gateway/src/config/validation.rs @@ -43,6 +43,7 @@ impl ConfigValidator { } Self::validate_tokenizer_cache(&config.tokenizer_cache)?; + Self::validate_bedrock(&config.mode, &config.bedrock)?; Self::validate_skills(config)?; Self::validate_background(&config.background)?; @@ -99,6 +100,31 @@ impl ConfigValidator { Self::validate_skills_config(skills) } + fn validate_bedrock(mode: &RoutingMode, cfg: &BedrockConfig) -> ConfigResult<()> { + if cfg.service.trim().is_empty() { + return Err(ConfigError::InvalidValue { + field: "bedrock.service".to_string(), + value: cfg.service.clone(), + reason: "Must not be empty".to_string(), + }); + } + if let Some(region) = &cfg.region { + if region.trim().is_empty() { + return Err(ConfigError::InvalidValue { + field: "bedrock.region".to_string(), + value: region.clone(), + reason: "Must not be empty when set".to_string(), + }); + } + } + if matches!(mode, RoutingMode::Bedrock { .. }) && cfg.resolved_signing_region().is_none() { + return Err(ConfigError::ValidationFailed { + reason: "Bedrock routing requires a non-empty bedrock.region, or AWS_REGION / AWS_DEFAULT_REGION (used for SigV4 signing)".to_string(), + }); + } + Ok(()) + } + fn validate_skills_config(skills: &SkillsConfig) -> ConfigResult<()> { if skills.blob_store.path.trim().is_empty() { return Err(ConfigError::InvalidValue { @@ -349,6 +375,20 @@ impl ConfigValidator { Self::validate_urls(worker_urls)?; } } + RoutingMode::Bedrock { worker_urls } => { + // Allow empty URLs to support dynamic worker addition + if !worker_urls.is_empty() { + Self::validate_urls(worker_urls)?; + if worker_urls.iter().any(|u| u.starts_with("grpc://")) { + return Err(ConfigError::InvalidValue { + field: "worker_url".to_string(), + value: "grpc://...".to_string(), + reason: "Bedrock mode only supports http:// or https:// URLs" + .to_string(), + }); + } + } + } } Ok(()) } @@ -593,6 +633,11 @@ impl ConfigValidator { reason: "Gemini mode does not support service discovery".to_string(), }); } + RoutingMode::Bedrock { .. } => { + return Err(ConfigError::ValidationFailed { + reason: "Bedrock mode does not support service discovery".to_string(), + }); + } } Ok(()) @@ -885,6 +930,8 @@ fn validate_mebibyte_limit(field: &str, value_mb: usize) -> ConfigResult<()> { #[cfg(test)] mod tests { + use serial_test::serial; + use super::*; use crate::worker::ConnectionMode; @@ -1217,6 +1264,60 @@ mod tests { // Should pass validation even with empty URLs assert!(ConfigValidator::validate(&config).is_ok()); + + // Test that empty URLs are allowed in Bedrock mode + let mut config = RouterConfig::new( + RoutingMode::Bedrock { + worker_urls: vec![], + }, + PolicyConfig::Random, + ); + config.bedrock.region = Some("us-east-1".to_string()); + assert!(ConfigValidator::validate(&config).is_ok()); + } + + #[test] + #[serial] + fn bedrock_routing_rejects_missing_signing_region() { + struct EnvRestore { + saved: Vec<(&'static str, Option)>, + } + impl EnvRestore { + fn new(keys: &[&'static str]) -> Self { + let saved = keys + .iter() + .map(|k| (*k, std::env::var_os(k))) + .collect::>(); + Self { saved } + } + } + impl Drop for EnvRestore { + fn drop(&mut self) { + for (k, v) in &self.saved { + match v { + Some(val) => std::env::set_var(k, val), + None => std::env::remove_var(k), + } + } + } + } + + let _guard = EnvRestore::new(&["AWS_REGION", "AWS_DEFAULT_REGION"]); + std::env::remove_var("AWS_REGION"); + std::env::remove_var("AWS_DEFAULT_REGION"); + + let mut config = RouterConfig::new( + RoutingMode::Bedrock { + worker_urls: vec![], + }, + PolicyConfig::Random, + ); + config.bedrock.region = None; + let result = ConfigValidator::validate(&config); + assert!( + result.is_err(), + "expected validation error when Bedrock mode has no region and no AWS_* env" + ); } #[test] diff --git a/model_gateway/src/main.rs b/model_gateway/src/main.rs index 41a949d0f..e6f254350 100644 --- a/model_gateway/src/main.rs +++ b/model_gateway/src/main.rs @@ -65,6 +65,8 @@ pub enum Backend { Anthropic, #[value(name = "gemini")] Gemini, + #[value(name = "bedrock")] + Bedrock, } impl std::fmt::Display for Backend { @@ -76,6 +78,7 @@ impl std::fmt::Display for Backend { Backend::Openai => "openai", Backend::Anthropic => "anthropic", Backend::Gemini => "gemini", + Backend::Bedrock => "bedrock", }; write!(f, "{s}") } @@ -1097,6 +1100,10 @@ impl CliArgs { RoutingMode::Gemini { worker_urls: self.worker_urls.clone(), } + } else if matches!(self.backend, Some(Backend::Bedrock)) { + RoutingMode::Bedrock { + worker_urls: self.worker_urls.clone(), + } } else if self.pd_disaggregation { RoutingMode::PrefillDecode { prefill_urls, @@ -1164,6 +1171,9 @@ impl CliArgs { RoutingMode::Gemini { worker_urls } => { all_urls.extend(worker_urls.clone()); } + RoutingMode::Bedrock { worker_urls } => { + all_urls.extend(worker_urls.clone()); + } } let connection_mode = Self::determine_connection_mode(&all_urls); @@ -1429,6 +1439,8 @@ fn main() -> Result<(), Box> { "OpenAI Backend".to_string() } else if matches!(cli_args.backend, Some(Backend::Anthropic)) { "Anthropic Backend".to_string() + } else if matches!(cli_args.backend, Some(Backend::Bedrock)) { + "Bedrock Backend".to_string() } else if cli_args.pd_disaggregation { "PD Disaggregated".to_string() } else if let Some(backend) = &cli_args.backend { diff --git a/model_gateway/src/routers/bedrock/context.rs b/model_gateway/src/routers/bedrock/context.rs new file mode 100644 index 000000000..833666b58 --- /dev/null +++ b/model_gateway/src/routers/bedrock/context.rs @@ -0,0 +1,27 @@ +use std::sync::Arc; + +use super::signing::AwsSigner; +use crate::{config::types::BedrockConfig, worker::WorkerRegistry}; + +pub(crate) struct RouterContext { + pub worker_registry: Arc, + pub http_client: reqwest::Client, + pub bedrock: BedrockConfig, + pub signer: AwsSigner, +} + +impl RouterContext { + pub fn new( + worker_registry: Arc, + http_client: reqwest::Client, + bedrock: BedrockConfig, + signer: AwsSigner, + ) -> Self { + Self { + worker_registry, + http_client, + bedrock, + signer, + } + } +} diff --git a/model_gateway/src/routers/bedrock/converse_stream.rs b/model_gateway/src/routers/bedrock/converse_stream.rs new file mode 100644 index 000000000..0129749fe --- /dev/null +++ b/model_gateway/src/routers/bedrock/converse_stream.rs @@ -0,0 +1,295 @@ +//! Map Bedrock `ConverseStream` AWS event-stream frames to OpenAI-style SSE lines. + +use std::time::{SystemTime, UNIX_EPOCH}; + +use bytes::{Bytes, BytesMut}; +use futures_util::StreamExt; +use serde_json::{json, Value}; +use tokio::sync::mpsc; + +use super::event_stream::{pop_next_event, DecodeError, StreamEvent}; + +fn now_epoch() -> i64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_secs() as i64) + .unwrap_or(0) +} + +fn map_bedrock_stop_to_openai(stop: &str) -> &'static str { + match stop { + "max_tokens" => "length", + "tool_use" => "tool_calls", + "stop_sequence" => "stop", + "end_turn" => "stop", + "content_filtered" => "content_filter", + "guardrail_intervened" => "content_filter", + _ => "stop", + } +} + +fn openai_chunk( + id: &str, + created: i64, + model: &str, + delta: Value, + finish_reason: Value, + usage: Option, +) -> String { + let mut obj = json!({ + "id": id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": [{ + "index": 0, + "delta": delta, + "finish_reason": finish_reason, + }] + }); + if let Some(u) = usage { + if let Some(map) = obj.as_object_mut() { + map.insert("usage".to_string(), u); + } + } + format!("data: {obj}\n\n") +} + +fn extract_delta_text(payload: &Value) -> Option { + payload + .get("contentBlockDelta") + .and_then(|c| c.get("delta")) + .and_then(|d| d.get("text")) + .and_then(|t| t.as_str()) + .map(str::to_owned) + .or_else(|| { + payload + .get("delta") + .and_then(|d| d.get("text")) + .and_then(|t| t.as_str()) + .map(str::to_owned) + }) +} + +fn extract_message_stop(payload: &Value) -> Option<&str> { + payload + .get("messageStop") + .and_then(|m| m.get("stopReason")) + .and_then(Value::as_str) + .or_else(|| payload.get("stopReason").and_then(Value::as_str)) +} + +fn extract_metadata_usage(payload: &Value) -> Option { + let usage = payload + .get("metadata") + .and_then(|m| m.get("usage")) + .or_else(|| payload.get("usage"))?; + let num = |k: &str| -> u64 { + usage + .get(k) + .and_then(|v| { + v.as_u64() + .or_else(|| v.as_i64().and_then(|i| u64::try_from(i).ok())) + }) + .unwrap_or(0) + }; + let input = num("inputTokens"); + let output = num("outputTokens"); + let total = num("totalTokens").max(input.saturating_add(output)); + Some(json!({ + "prompt_tokens": input, + "completion_tokens": output, + "total_tokens": total, + })) +} + +fn extract_assistant_role(payload: &Value) -> bool { + payload + .get("messageStart") + .and_then(|m| m.get("role")) + .and_then(Value::as_str) + == Some("assistant") + || payload.get("role").and_then(Value::as_str) == Some("assistant") +} + +fn aws_error_sse(payload: &[u8]) -> String { + let msg = serde_json::from_slice::(payload) + .ok() + .and_then(|v| { + v.get("message") + .and_then(Value::as_str) + .map(str::to_owned) + .or_else(|| v.get("__type").and_then(Value::as_str).map(str::to_owned)) + }) + .unwrap_or_else(|| String::from_utf8_lossy(payload).into_owned()); + let err = json!({"message": msg, "type": "bedrock_stream_error"}); + format!("data: {err}\n\n") +} + +fn send_sse(tx: &mpsc::UnboundedSender>, line: String) { + let _ = tx.send(Ok(Bytes::from(line))); +} + +/// Drain Bedrock `application/vnd.amazon.eventstream` bytes and emit OpenAI-style `data: …` lines. +pub(crate) async fn forward_bedrock_converse_stream_as_sse( + mut upstream: S, + model_id: String, + tx: mpsc::UnboundedSender>, +) where + S: futures_util::Stream> + Unpin, +{ + let created = now_epoch(); + let id = format!("chatcmpl-bedrock-{created}"); + let mut buf = BytesMut::new(); + let mut sent_role = false; + let mut pending_finish: Option<&'static str> = None; + + while let Some(chunk) = upstream.next().await { + let chunk = match chunk { + Ok(c) => c, + Err(e) => { + let _ = tx.send(Err(format!("Bedrock stream read failed: {e}"))); + return; + } + }; + buf.extend_from_slice(&chunk); + + loop { + match pop_next_event(&mut buf) { + Ok(ev) => { + if let Err(e) = handle_stream_event( + &ev, + &model_id, + &id, + created, + &mut sent_role, + &mut pending_finish, + &tx, + ) { + let _ = tx.send(Err(e)); + return; + } + } + Err(DecodeError::Truncated) => break, + Err(DecodeError::Invalid(reason)) => { + let _ = tx.send(Err(format!("Bedrock event stream decode error: {reason}"))); + return; + } + } + } + } + + // If the upstream ended while a partial frame remained, surface this as an + // error rather than emitting `[DONE]`. Silently completing would let clients + // treat a truncated response as a successful one and corrupt downstream + // tool-call / conversation state. + if !buf.is_empty() { + let _ = tx.send(Err(format!( + "Bedrock stream ended with {} unread byte(s); upstream terminated mid-frame", + buf.len() + ))); + return; + } + + if let Some(fr) = pending_finish.take() { + send_sse( + &tx, + openai_chunk(&id, created, &model_id, json!({}), json!(fr), None), + ); + } + + send_sse(&tx, "data: [DONE]\n\n".to_string()); +} + +fn handle_stream_event( + ev: &StreamEvent, + model_id: &str, + id: &str, + created: i64, + sent_role: &mut bool, + pending_finish: &mut Option<&'static str>, + tx: &mpsc::UnboundedSender>, +) -> Result<(), String> { + let payload: Value = serde_json::from_slice(&ev.payload) + .map_err(|e| format!("invalid Bedrock stream JSON: {e}"))?; + + if ev.event_type.ends_with("Exception") || ev.event_type.ends_with("Fault") { + send_sse(tx, aws_error_sse(&ev.payload)); + return Err("bedrock upstream exception".to_string()); + } + + if payload.get("__type").is_some() && payload.get("message").and_then(|m| m.as_str()).is_some() + { + send_sse(tx, aws_error_sse(&ev.payload)); + return Err("bedrock upstream error".to_string()); + } + + if extract_assistant_role(&payload) && !*sent_role { + *sent_role = true; + send_sse( + tx, + openai_chunk( + id, + created, + model_id, + json!({"role": "assistant"}), + Value::Null, + None, + ), + ); + } + + if let Some(text) = extract_delta_text(&payload) { + if !text.is_empty() { + send_sse( + tx, + openai_chunk( + id, + created, + model_id, + json!({"content": text}), + Value::Null, + None, + ), + ); + } + } + + if let Some(stop) = extract_message_stop(&payload) { + *pending_finish = Some(map_bedrock_stop_to_openai(stop)); + } + + if let Some(usage) = extract_metadata_usage(&payload) { + let finish = pending_finish + .take() + .map(|s| json!(s)) + .unwrap_or(Value::Null); + send_sse( + tx, + openai_chunk(id, created, model_id, json!({}), finish, Some(usage)), + ); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn maps_stop_reasons() { + assert_eq!(map_bedrock_stop_to_openai("end_turn"), "stop"); + assert_eq!(map_bedrock_stop_to_openai("max_tokens"), "length"); + assert_eq!(map_bedrock_stop_to_openai("tool_use"), "tool_calls"); + } + + #[test] + fn extracts_nested_delta_text() { + let v: Value = serde_json::from_str( + r#"{"contentBlockDelta":{"contentBlockIndex":0,"delta":{"text":"Hello"}}}"#, + ) + .unwrap(); + assert_eq!(extract_delta_text(&v).as_deref(), Some("Hello")); + } +} diff --git a/model_gateway/src/routers/bedrock/errors.rs b/model_gateway/src/routers/bedrock/errors.rs new file mode 100644 index 000000000..f391ed86c --- /dev/null +++ b/model_gateway/src/routers/bedrock/errors.rs @@ -0,0 +1,51 @@ +use axum::{ + body::Body, + http::{header::CONTENT_TYPE, StatusCode}, + response::Response, +}; +use serde::Deserialize; + +use crate::routers::error; + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct BedrockErrorPayload { + #[serde(default)] + message: Option, +} + +pub(crate) fn map_upstream_error(status: StatusCode, body: &[u8]) -> Response { + let parsed = serde_json::from_slice::(body).ok(); + let message = parsed + .and_then(|p| p.message) + .unwrap_or_else(|| String::from_utf8_lossy(body).to_string()); + let code = status + .canonical_reason() + .unwrap_or("bedrock_error") + .to_lowercase() + .replace(' ', "_"); + + error::create_error(status, code, message) +} + +pub(crate) fn map_send_error(err: impl std::fmt::Display) -> Response { + error::service_unavailable("bedrock_upstream_unreachable", format!("{err}")) +} + +pub(crate) fn map_bad_mapping_error(err: impl std::fmt::Display) -> Response { + error::bad_gateway("bedrock_response_mapping_failed", format!("{err}")) +} + +pub(crate) fn map_signing_error(err: impl std::fmt::Display) -> Response { + error::internal_error("bedrock_signing_error", format!("{err}")) +} + +pub(crate) fn unsupported_endpoint() -> Response { + Response::builder() + .status(StatusCode::NOT_IMPLEMENTED) + .header(CONTENT_TYPE, "application/json") + .body(Body::from( + "{\"error\":{\"code\":\"not_supported\",\"message\":\"Endpoint not yet supported for bedrock router\"}}", + )) + .unwrap_or_else(|_| error::not_implemented("not_supported", "Unsupported endpoint")) +} diff --git a/model_gateway/src/routers/bedrock/event_stream.rs b/model_gateway/src/routers/bedrock/event_stream.rs new file mode 100644 index 000000000..3d4be2711 --- /dev/null +++ b/model_gateway/src/routers/bedrock/event_stream.rs @@ -0,0 +1,157 @@ +//! Minimal decoder for AWS binary event streams (`application/vnd.amazon.eventstream`) +//! used by Bedrock `ConverseStream`. + +use bytes::{Buf, BytesMut}; + +/// One decoded Bedrock stream event (payload is the raw JSON union body). +#[derive(Debug, Clone)] +pub(crate) struct StreamEvent { + pub event_type: String, + pub payload: Vec, +} + +#[derive(Debug)] +pub(crate) enum DecodeError { + Truncated, + Invalid(&'static str), +} + +/// Try to pull the next complete message from `buf`. On success, advances `buf`. +pub(crate) fn pop_next_event(buf: &mut BytesMut) -> Result { + if buf.len() < 12 { + return Err(DecodeError::Truncated); + } + + let total_len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize; + let headers_len = u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]) as usize; + let prelude_crc = u32::from_be_bytes([buf[8], buf[9], buf[10], buf[11]]); + + if total_len < 16 { + return Err(DecodeError::Invalid("message too short")); + } + if buf.len() < total_len { + return Err(DecodeError::Truncated); + } + + let prelude_crc_calc = crc32fast::hash(&buf[0..8]); + if prelude_crc_calc != prelude_crc { + return Err(DecodeError::Invalid("prelude crc mismatch")); + } + + let headers_end = 12 + headers_len; + if headers_end + 4 > total_len { + return Err(DecodeError::Invalid("header length inconsistent")); + } + + let payload_end = total_len - 4; + let msg_crc_expected = u32::from_be_bytes([ + buf[payload_end], + buf[payload_end + 1], + buf[payload_end + 2], + buf[payload_end + 3], + ]); + let msg_crc_calc = crc32fast::hash(&buf[0..payload_end]); + if msg_crc_calc != msg_crc_expected { + return Err(DecodeError::Invalid("message crc mismatch")); + } + + let headers_bytes = &buf[12..headers_end]; + let payload = buf[headers_end..payload_end].to_vec(); + + let event_type = parse_event_type_header(headers_bytes).unwrap_or_default(); + + buf.advance(total_len); + + Ok(StreamEvent { + event_type, + payload, + }) +} + +fn parse_event_type_header(headers: &[u8]) -> Option { + let mut pos = 0; + let mut event_type: Option = None; + while pos < headers.len() { + let name_len = *headers.get(pos)? as usize; + pos += 1; + let name = headers.get(pos..pos + name_len)?; + pos += name_len; + let value_type = *headers.get(pos)?; + pos += 1; + let vlen = header_value_byte_len(value_type, headers.get(pos..)?)?; + let value_bytes = headers.get(pos..pos + vlen)?; + pos += vlen; + if name == b":event-type" || name == b"event-type" { + if value_type == 7 && value_bytes.len() >= 2 { + let slen = u16::from_be_bytes([value_bytes[0], value_bytes[1]]) as usize; + if value_bytes.len() >= 2 + slen { + let s = value_bytes.get(2..2 + slen)?; + event_type = std::str::from_utf8(s).ok().map(str::to_owned); + } + } else { + event_type = Some(String::from_utf8_lossy(value_bytes).into_owned()); + } + } + } + event_type +} + +/// Byte length of the header **value** for AWS event-stream header value types. +fn header_value_byte_len(value_type: u8, rest: &[u8]) -> Option { + match value_type { + 0 | 1 => Some(0), + 2 => Some(1), + 3 => Some(2), + 4 => Some(4), + 5 => Some(8), + 6 => { + let len = u16::from_be_bytes([*rest.first()?, *rest.get(1)?]) as usize; + Some(2 + len) + } + 7 => { + let len = u16::from_be_bytes([*rest.first()?, *rest.get(1)?]) as usize; + Some(2 + len) + } + _ => None, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn decodes_single_event_with_string_header_and_json_payload() { + let name = b":event-type"; + assert_eq!(name.len(), 11); + let value_type = 7u8; + let ev = b"contentBlockDelta"; + let mut value_bytes = Vec::new(); + value_bytes.extend_from_slice(&(ev.len() as u16).to_be_bytes()); + value_bytes.extend_from_slice(ev); + + let mut headers = Vec::new(); + headers.push(name.len() as u8); + headers.extend_from_slice(name); + headers.push(value_type); + headers.extend_from_slice(&value_bytes); + + let headers_len = headers.len(); + let payload = br#"{"contentBlockDelta":{"contentBlockIndex":0,"delta":{"text":"Hi"}}}"#; + let total_len = 12 + headers_len + payload.len() + 4; + let mut msg = Vec::new(); + msg.extend_from_slice(&(total_len as u32).to_be_bytes()); + msg.extend_from_slice(&(headers_len as u32).to_be_bytes()); + let prelude_crc = crc32fast::hash(&msg[0..8]); + msg.extend_from_slice(&prelude_crc.to_be_bytes()); + msg.extend_from_slice(&headers); + msg.extend_from_slice(payload); + let msg_crc = crc32fast::hash(&msg[..msg.len()]); + msg.extend_from_slice(&msg_crc.to_be_bytes()); + + let mut buf = BytesMut::from(&msg[..]); + let ev = pop_next_event(&mut buf).expect("decode"); + assert_eq!(ev.event_type, "contentBlockDelta"); + assert!(buf.is_empty()); + } +} diff --git a/model_gateway/src/routers/bedrock/mod.rs b/model_gateway/src/routers/bedrock/mod.rs new file mode 100644 index 000000000..c8c11de86 --- /dev/null +++ b/model_gateway/src/routers/bedrock/mod.rs @@ -0,0 +1,13 @@ +mod context; +mod converse_stream; +mod errors; +mod event_stream; +mod request_map; +mod response_map; +mod router; +mod signing; + +pub use router::BedrockRouter; + +#[cfg(test)] +mod tests; diff --git a/model_gateway/src/routers/bedrock/request_map.rs b/model_gateway/src/routers/bedrock/request_map.rs new file mode 100644 index 000000000..3de68a63e --- /dev/null +++ b/model_gateway/src/routers/bedrock/request_map.rs @@ -0,0 +1,182 @@ +use openai_protocol::chat::{ChatCompletionRequest, ChatMessage, MessageContent}; +use serde::Serialize; + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct BedrockConverseRequest { + #[serde(skip_serializing_if = "Vec::is_empty")] + pub system: Vec, + pub messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub inference_config: Option, +} + +#[derive(Debug, Serialize)] +pub(crate) struct BedrockMessage { + pub role: String, + pub content: Vec, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct ContentBlock { + pub text: String, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct InferenceConfig { + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop_sequences: Option>, +} + +pub(crate) fn map_chat_request(request: &ChatCompletionRequest) -> BedrockConverseRequest { + let mut system = Vec::new(); + let mut messages = Vec::new(); + + for msg in &request.messages { + match msg { + ChatMessage::System { content, .. } | ChatMessage::Developer { content, .. } => { + push_text_blocks(&mut system, content); + } + ChatMessage::User { content, .. } => { + let content = content_blocks_from_content(content); + if !content.is_empty() { + messages.push(BedrockMessage { + role: "user".to_string(), + content, + }); + } + } + ChatMessage::Assistant { + content, + tool_calls, + .. + } => { + let mut blocks = Vec::new(); + if let Some(content) = content { + blocks.extend(content_blocks_from_content(content)); + } + if let Some(calls) = tool_calls { + let tools_text = calls + .iter() + .map(|c| { + format!( + "tool_call {} {}", + c.function.name, + c.function.arguments.clone().unwrap_or_default() + ) + }) + .collect::>() + .join("\n"); + if !tools_text.is_empty() { + blocks.push(ContentBlock { text: tools_text }); + } + } + if !blocks.is_empty() { + messages.push(BedrockMessage { + role: "assistant".to_string(), + content: blocks, + }); + } + } + ChatMessage::Tool { + content, + tool_call_id, + } => { + let mut text = content.to_simple_string(); + if !tool_call_id.is_empty() { + text = format!("tool_result {tool_call_id}: {text}"); + } + if !text.is_empty() { + messages.push(BedrockMessage { + role: "user".to_string(), + content: vec![ContentBlock { text }], + }); + } + } + ChatMessage::Function { content, name } => { + if !content.is_empty() || !name.is_empty() { + messages.push(BedrockMessage { + role: "user".to_string(), + content: vec![ContentBlock { + text: format!("function {name}: {content}"), + }], + }); + } + } + } + } + + if messages.is_empty() { + messages.push(BedrockMessage { + role: "user".to_string(), + content: vec![ContentBlock { + text: "[no input]".to_string(), + }], + }); + } + + let inference_config = Some(InferenceConfig { + max_tokens: max_tokens(request), + temperature: request.temperature, + top_p: request.top_p, + stop_sequences: stop_sequences_from_value(request.stop.as_ref()), + }); + + BedrockConverseRequest { + system, + messages, + inference_config, + } +} + +#[expect(deprecated)] +fn max_tokens(request: &ChatCompletionRequest) -> Option { + request.max_completion_tokens.or(request.max_tokens) +} + +fn push_text_blocks(dst: &mut Vec, content: &MessageContent) { + dst.extend(content_blocks_from_content(content)); +} + +fn content_blocks_from_content(content: &MessageContent) -> Vec { + let text = content.to_simple_string(); + let trimmed = text.trim(); + if trimmed.is_empty() { + Vec::new() + } else { + vec![ContentBlock { + text: trimmed.to_string(), + }] + } +} + +fn stop_sequences_from_value( + v: Option<&openai_protocol::common::StringOrArray>, +) -> Option> { + match v { + Some(openai_protocol::common::StringOrArray::String(s)) if !s.is_empty() => { + Some(vec![s.clone()]) + } + Some(openai_protocol::common::StringOrArray::Array(a)) => { + let list = a + .iter() + .filter(|s| !s.is_empty()) + .cloned() + .collect::>(); + if list.is_empty() { + None + } else { + Some(list) + } + } + _ => None, + } +} diff --git a/model_gateway/src/routers/bedrock/response_map.rs b/model_gateway/src/routers/bedrock/response_map.rs new file mode 100644 index 000000000..41025037a --- /dev/null +++ b/model_gateway/src/routers/bedrock/response_map.rs @@ -0,0 +1,140 @@ +use std::time::{SystemTime, UNIX_EPOCH}; + +use serde::Deserialize; +use serde_json::{json, Value}; + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct ConverseResponse { + #[serde(default)] + output: Option, + #[serde(default)] + usage: Option, + #[serde(default)] + stop_reason: Option, +} + +#[derive(Debug, Deserialize)] +struct ConverseOutput { + message: ConverseMessage, +} + +#[derive(Debug, Deserialize)] +struct ConverseMessage { + #[serde(default)] + content: Vec, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct ConverseContent { + #[serde(default)] + text: Option, + #[serde(default)] + tool_use: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct ToolUse { + tool_use_id: String, + name: String, + #[serde(default)] + input: Value, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct Usage { + #[serde(default)] + input_tokens: Option, + #[serde(default)] + output_tokens: Option, + #[serde(default)] + total_tokens: Option, +} + +pub(crate) fn map_non_stream_response(raw: &[u8], model: &str) -> Result { + let parsed: ConverseResponse = serde_json::from_slice(raw)?; + let created = now_epoch(); + + let mut text = String::new(); + let mut tool_calls = Vec::new(); + if let Some(output) = parsed.output.as_ref() { + for block in &output.message.content { + if let Some(t) = block.text.as_deref() { + text.push_str(t); + } + if let Some(tu) = block.tool_use.as_ref() { + // OpenAI expects `function.arguments` as a JSON-encoded string. + let arguments = + serde_json::to_string(&tu.input).unwrap_or_else(|_| "{}".to_string()); + tool_calls.push(json!({ + "id": tu.tool_use_id, + "type": "function", + "function": { + "name": tu.name, + "arguments": arguments, + }, + })); + } + } + } + + let usage = parsed.usage.unwrap_or(Usage { + input_tokens: Some(0), + output_tokens: Some(0), + total_tokens: Some(0), + }); + let prompt_tokens = usage.input_tokens.unwrap_or(0); + let completion_tokens = usage.output_tokens.unwrap_or(0); + let total_tokens = usage + .total_tokens + .unwrap_or_else(|| prompt_tokens.saturating_add(completion_tokens)); + + let mut message = json!({ + "role": "assistant", + "content": if text.is_empty() && !tool_calls.is_empty() { + Value::Null + } else { + Value::String(text) + }, + }); + if !tool_calls.is_empty() { + message["tool_calls"] = Value::Array(tool_calls); + } + + Ok(json!({ + "id": format!("chatcmpl-bedrock-{}", created), + "object": "chat.completion", + "created": created, + "model": model, + "choices": [{ + "index": 0, + "message": message, + "finish_reason": map_stop_reason(parsed.stop_reason.as_deref().unwrap_or("end_turn")), + }], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + } + })) +} + +fn map_stop_reason(stop: &str) -> &'static str { + match stop { + "max_tokens" => "length", + "tool_use" => "tool_calls", + "stop_sequence" | "end_turn" => "stop", + "content_filtered" | "guardrail_intervened" => "content_filter", + _ => "stop", + } +} + +fn now_epoch() -> i64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_secs() as i64) + .unwrap_or(0) +} diff --git a/model_gateway/src/routers/bedrock/router.rs b/model_gateway/src/routers/bedrock/router.rs new file mode 100644 index 000000000..aff0f34bf --- /dev/null +++ b/model_gateway/src/routers/bedrock/router.rs @@ -0,0 +1,276 @@ +use std::{any::Any, sync::Arc, time::Duration}; + +use async_trait::async_trait; +use axum::{ + body::Body, + http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode}, + response::Response, +}; +use openai_protocol::chat::ChatCompletionRequest; +use reqwest::Url; +use serde_json::to_vec; +use tokio::sync::mpsc; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tracing::error; + +use super::{ + context::RouterContext, converse_stream::forward_bedrock_converse_stream_as_sse, errors, + request_map::map_chat_request, response_map::map_non_stream_response, signing::AwsSigner, +}; +use crate::{ + app_context::AppContext, + config::types::RetryConfig, + middleware::TenantRequestMeta, + observability::metrics::{bool_to_static_str, metrics_labels, Metrics}, + routers::{ + common::{ + retry::{is_retryable_status, RetryExecutor}, + worker_selection::{SelectWorkerRequest, WorkerSelector}, + }, + RouterTrait, + }, + worker::ProviderType, +}; + +pub struct BedrockRouter { + ctx: RouterContext, + retry_config: RetryConfig, + request_timeout: Duration, +} + +impl BedrockRouter { + pub fn new(context: Arc) -> Self { + let bedrock = context.router_config.bedrock.clone(); + let region = bedrock + .resolved_signing_region() + .unwrap_or_else(|| "us-east-1".to_string()); + let service = if bedrock.service.is_empty() { + "bedrock".to_string() + } else { + bedrock.service.clone() + }; + let signer = AwsSigner::new(region, service); + let retry_config = context.router_config.effective_retry_config(); + let request_timeout = Duration::from_secs(context.router_config.request_timeout_secs); + + Self { + ctx: RouterContext::new( + context.worker_registry.clone(), + context.client.clone(), + bedrock, + signer, + ), + retry_config, + request_timeout, + } + } + + fn resolve_model(&self, incoming_model: &str) -> String { + self.ctx + .bedrock + .model_map + .get(incoming_model) + .cloned() + .unwrap_or_else(|| incoming_model.to_string()) + } +} + +impl std::fmt::Debug for BedrockRouter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BedrockRouter").finish() + } +} + +#[async_trait] +impl RouterTrait for BedrockRouter { + fn as_any(&self) -> &dyn Any { + self + } + + async fn route_chat( + &self, + headers: Option<&HeaderMap>, + _tenant_meta: &TenantRequestMeta, + body: &ChatCompletionRequest, + model_id: &str, + ) -> Response { + let stream = body.stream; + Metrics::record_router_request( + metrics_labels::ROUTER_OPENAI, + metrics_labels::BACKEND_EXTERNAL, + metrics_labels::CONNECTION_HTTP, + model_id, + metrics_labels::ENDPOINT_CHAT, + bool_to_static_str(stream), + ); + + let bedrock_model = self.resolve_model(model_id); + let selector = WorkerSelector::new(&self.ctx.worker_registry, &self.ctx.http_client); + let worker = match selector + .select_worker(&SelectWorkerRequest { + model_id: &bedrock_model, + headers, + provider: Some(ProviderType::Bedrock), + ..Default::default() + }) + .await + { + Ok(w) => w, + Err(resp) => return resp, + }; + let payload = map_chat_request(body); + let payload_bytes = match to_vec(&payload) { + Ok(p) => Arc::new(p), + Err(e) => return errors::map_bad_mapping_error(e), + }; + + let stream_path = if stream { + "converse-stream" + } else { + "converse" + }; + let endpoint = format!("{}/model/{}/{}", worker.url(), bedrock_model, stream_path); + let parsed_url = match Url::parse(&endpoint) { + Ok(u) => u, + Err(e) => return errors::map_bad_mapping_error(e), + }; + + let client = self.ctx.http_client.clone(); + let signer = self.ctx.signer.clone(); + let worker_for_retry = Arc::clone(&worker); + let openai_model_id = model_id.to_string(); + let request_timeout = self.request_timeout; + + RetryExecutor::execute_response_with_retry( + &self.retry_config, + |_attempt| { + let client = client.clone(); + let body = Arc::clone(&payload_bytes); + let url = parsed_url.clone(); + let endpoint = endpoint.clone(); + let worker = Arc::clone(&worker_for_retry); + let signer = signer.clone(); + let openai_model_id = openai_model_id.clone(); + async move { + let signed = match signer.sign("POST", &url, &body).await { + Ok(s) => s, + Err(e) => return errors::map_signing_error(e), + }; + + let mut req = client + .post(&endpoint) + .header("Authorization", signed.authorization) + .header("X-Amz-Date", signed.amz_date) + .header("X-Amz-Content-Sha256", signed.payload_hash) + .header(CONTENT_TYPE, HeaderValue::from_static("application/json")) + .body((*body).clone()); + if stream { + req = req.header( + "Accept", + HeaderValue::from_static("application/vnd.amazon.eventstream"), + ); + } else { + req = req.timeout(request_timeout); + } + if let Some(token) = signed.security_token { + req = req.header("X-Amz-Security-Token", token); + } + + let resp = match req.send().await { + Ok(r) => r, + Err(e) => { + worker.record_outcome(503); + return errors::map_send_error(e); + } + }; + let status = StatusCode::from_u16(resp.status().as_u16()) + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + worker.record_outcome(status.as_u16()); + + if !status.is_success() { + let bytes = match resp.bytes().await { + Ok(b) => b, + Err(e) => return errors::map_send_error(e), + }; + return errors::map_upstream_error(status, &bytes); + } + + if stream { + let byte_stream = resp.bytes_stream(); + let (tx, rx) = mpsc::unbounded_channel(); + #[expect( + clippy::disallowed_methods, + reason = "fire-and-forget Bedrock stream translation; same pattern as openai chat streaming" + )] + tokio::spawn(async move { + forward_bedrock_converse_stream_as_sse( + byte_stream, + openai_model_id, + tx, + ) + .await; + }); + let mut response = + Response::new(Body::from_stream(UnboundedReceiverStream::new(rx))); + *response.status_mut() = StatusCode::OK; + response.headers_mut().insert( + CONTENT_TYPE, + HeaderValue::from_static("text/event-stream"), + ); + response + } else { + let bytes = match resp.bytes().await { + Ok(b) => b, + Err(e) => return errors::map_send_error(e), + }; + match map_non_stream_response(&bytes, &openai_model_id) { + Ok(mapped) => match to_vec(&mapped) { + Ok(serialized) => { + let mut response = Response::new(Body::from(serialized)); + *response.status_mut() = StatusCode::OK; + response.headers_mut().insert( + CONTENT_TYPE, + HeaderValue::from_static("application/json"), + ); + response + } + Err(e) => errors::map_bad_mapping_error(e), + }, + Err(e) => errors::map_bad_mapping_error(e), + } + } + } + }, + |res, _| is_retryable_status(res.status()), + |delay, attempt| { + Metrics::record_worker_retry( + metrics_labels::BACKEND_EXTERNAL, + metrics_labels::ENDPOINT_CHAT, + ); + Metrics::record_worker_retry_backoff(attempt, delay); + }, + || { + Metrics::record_worker_retries_exhausted( + metrics_labels::BACKEND_EXTERNAL, + metrics_labels::ENDPOINT_CHAT, + ); + }, + ) + .await + } + + async fn route_responses( + &self, + _headers: Option<&HeaderMap>, + _tenant_meta: &TenantRequestMeta, + _body: &openai_protocol::responses::ResponsesRequest, + _model_id: &str, + ) -> Response { + error!("Responses endpoint is not implemented in Bedrock router yet"); + errors::unsupported_endpoint() + } + + fn router_type(&self) -> &'static str { + "bedrock" + } +} diff --git a/model_gateway/src/routers/bedrock/signing.rs b/model_gateway/src/routers/bedrock/signing.rs new file mode 100644 index 000000000..50360a462 --- /dev/null +++ b/model_gateway/src/routers/bedrock/signing.rs @@ -0,0 +1,536 @@ +use std::{sync::Arc, time::Duration}; + +use chrono::DateTime; +use hmac::{Hmac, KeyInit, Mac}; +use reqwest::Url; +use sha2::{Digest, Sha256}; +use tokio::sync::RwLock; + +type HmacSha256 = Hmac; +const AWS_METADATA_HTTP_TIMEOUT_SECS: u64 = 5; +/// Refresh temporary credentials this many seconds before their expiry. +const CREDENTIAL_REFRESH_WINDOW_SECS: i64 = 300; + +#[derive(Clone)] +pub(crate) struct AwsSigner { + region: String, + service: String, + credentials: Arc>>, + http_client: reqwest::Client, +} + +pub(crate) struct SignedHeaders { + pub authorization: String, + pub amz_date: String, + pub payload_hash: String, + pub security_token: Option, +} + +#[derive(Clone, Debug)] +struct AwsCredentials { + access_key_id: String, + secret_access_key: String, + session_token: Option, +} + +#[derive(Clone, Debug)] +struct CachedCredentials { + creds: AwsCredentials, + /// None for static credentials (env/profile) that never expire. + expires_at: Option>, +} + +impl CachedCredentials { + fn needs_refresh(&self) -> bool { + match self.expires_at { + Some(expires_at) => { + let refresh_at = + expires_at - chrono::Duration::seconds(CREDENTIAL_REFRESH_WINDOW_SECS); + chrono::Utc::now() >= refresh_at + } + None => false, + } + } +} + +impl AwsSigner { + pub fn new(region: String, service: String) -> Self { + Self { + region, + service, + credentials: Arc::new(RwLock::new(None)), + http_client: build_metadata_http_client(), + } + } + + pub async fn sign( + &self, + method: &str, + url: &Url, + body: &[u8], + ) -> Result { + let creds = self.get_or_refresh_credentials().await?; + + let now = chrono::Utc::now(); + let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string(); + let date_stamp = now.format("%Y%m%d").to_string(); + let payload_hash = hex_sha256(body); + let host = url + .host_str() + .ok_or_else(|| "Bedrock URL missing host".to_string())?; + + let canonical_uri = canonical_uri(url.path()); + let canonical_query = url.query().unwrap_or(""); + + let mut canonical_headers = + format!("host:{host}\nx-amz-content-sha256:{payload_hash}\nx-amz-date:{amz_date}\n"); + let mut signed_headers = String::from("host;x-amz-content-sha256;x-amz-date"); + + if let Some(token) = &creds.session_token { + canonical_headers.push_str(&format!("x-amz-security-token:{token}\n")); + signed_headers.push_str(";x-amz-security-token"); + } + + let canonical_request = format!( + "{method}\n{canonical_uri}\n{canonical_query}\n{canonical_headers}\n{signed_headers}\n{payload_hash}" + ); + let hashed_request = hex_sha256(canonical_request.as_bytes()); + let scope = format!("{date_stamp}/{}/{}/aws4_request", self.region, self.service); + let string_to_sign = format!("AWS4-HMAC-SHA256\n{amz_date}\n{scope}\n{hashed_request}"); + + let k_date = hmac_sha256( + format!("AWS4{}", creds.secret_access_key).as_bytes(), + date_stamp.as_bytes(), + )?; + let k_region = hmac_sha256(&k_date, self.region.as_bytes())?; + let k_service = hmac_sha256(&k_region, self.service.as_bytes())?; + let k_signing = hmac_sha256(&k_service, b"aws4_request")?; + let signature = hex::encode(hmac_sha256(&k_signing, string_to_sign.as_bytes())?); + + let authorization = format!( + "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}", + creds.access_key_id, scope, signed_headers, signature + ); + + Ok(SignedHeaders { + authorization, + amz_date, + payload_hash, + security_token: creds.session_token.clone(), + }) + } + + async fn get_or_refresh_credentials(&self) -> Result { + { + let guard = self.credentials.read().await; + if let Some(cached) = guard.as_ref() { + if !cached.needs_refresh() { + return Ok(cached.creds.clone()); + } + } + } + // Acquire write lock and double-check before refreshing. + let mut guard = self.credentials.write().await; + if let Some(cached) = guard.as_ref() { + if !cached.needs_refresh() { + return Ok(cached.creds.clone()); + } + } + let cached = self.load_credentials().await?; + let creds = cached.creds.clone(); + *guard = Some(cached); + Ok(creds) + } + + async fn load_credentials(&self) -> Result { + if let Some(c) = load_env_credentials() { + return Ok(CachedCredentials { + creds: c, + expires_at: None, + }); + } + if let Some(c) = load_profile_credentials() { + return Ok(CachedCredentials { + creds: c, + expires_at: None, + }); + } + if let Some(c) = self.load_web_identity_credentials().await? { + return Ok(c); + } + if let Some(c) = self.load_ecs_credentials().await? { + return Ok(c); + } + if let Some(c) = self.load_imds_credentials().await? { + return Ok(c); + } + Err( + "Unable to resolve AWS credentials from env, profile, web identity, ECS, or IMDS" + .to_string(), + ) + } + + async fn load_web_identity_credentials(&self) -> Result, String> { + let token_file = match std::env::var("AWS_WEB_IDENTITY_TOKEN_FILE").ok() { + Some(f) if !f.is_empty() => f, + _ => return Ok(None), + }; + let role_arn = match std::env::var("AWS_ROLE_ARN").ok() { + Some(r) if !r.is_empty() => r, + _ => return Ok(None), + }; + let session_name = std::env::var("AWS_ROLE_SESSION_NAME") + .unwrap_or_else(|_| "smg-bedrock-session".to_string()); + + let token = std::fs::read_to_string(&token_file) + .map_err(|e| format!("Failed to read web identity token file {token_file}: {e}"))?; + let token = token.trim().to_string(); + + let region = self.region.clone(); + let sts_url = format!("https://sts.{region}.amazonaws.com/"); + let body = format!( + "Action=AssumeRoleWithWebIdentity&Version=2011-06-15&RoleArn={}&RoleSessionName={}&WebIdentityToken={}", + urlencoding_simple(&role_arn), + urlencoding_simple(&session_name), + urlencoding_simple(&token), + ); + + let resp = self + .http_client + .post(&sts_url) + .header("Content-Type", "application/x-www-form-urlencoded") + .body(body) + .send() + .await + .map_err(|e| format!("STS AssumeRoleWithWebIdentity request failed: {e}"))?; + + if !resp.status().is_success() { + return Ok(None); + } + + let xml = resp + .text() + .await + .map_err(|e| format!("STS response read failed: {e}"))?; + + Ok(extract_sts_credentials(&xml)) + } + + async fn load_ecs_credentials(&self) -> Result, String> { + let relative = std::env::var("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI").ok(); + let full = std::env::var("AWS_CONTAINER_CREDENTIALS_FULL_URI").ok(); + let (uri, auth_token) = match (full, relative) { + (Some(full), _) => (full, read_container_authorization_token()?), + (None, Some(rel)) => (format!("http://169.254.170.2{rel}"), None), + (None, None) => return Ok(None), + }; + + let mut req = self.http_client.get(&uri); + if let Some(token) = auth_token { + req = req.header("Authorization", token); + } + let resp = req + .send() + .await + .map_err(|e| format!("ECS credential request failed: {e}"))?; + if !resp.status().is_success() { + return Ok(None); + } + let value: serde_json::Value = resp + .json() + .await + .map_err(|e| format!("Invalid ECS credential response: {e}"))?; + Ok(extract_credential_fields(&value)) + } + + async fn load_imds_credentials(&self) -> Result, String> { + let disabled = std::env::var("AWS_EC2_METADATA_DISABLED") + .ok() + .is_some_and(|v| v.eq_ignore_ascii_case("true")); + if disabled { + return Ok(None); + } + + let token_resp = self + .http_client + .put("http://169.254.169.254/latest/api/token") + .header("X-aws-ec2-metadata-token-ttl-seconds", "21600") + .send() + .await + .map_err(|e| format!("IMDS token request failed: {e}"))?; + if !token_resp.status().is_success() { + return Ok(None); + } + let token = token_resp + .text() + .await + .map_err(|e| format!("IMDS token read failed: {e}"))?; + + let role_resp = self + .http_client + .get("http://169.254.169.254/latest/meta-data/iam/security-credentials/") + .header("X-aws-ec2-metadata-token", &token) + .send() + .await + .map_err(|e| format!("IMDS role request failed: {e}"))?; + if !role_resp.status().is_success() { + return Ok(None); + } + let role_name = role_resp + .text() + .await + .map_err(|e| format!("IMDS role read failed: {e}"))? + .lines() + .next() + .map(str::trim) + .filter(|s| !s.is_empty()) + .map(ToOwned::to_owned); + let Some(role_name) = role_name else { + return Ok(None); + }; + + let creds_resp = self + .http_client + .get(format!( + "http://169.254.169.254/latest/meta-data/iam/security-credentials/{role_name}" + )) + .header("X-aws-ec2-metadata-token", &token) + .send() + .await + .map_err(|e| format!("IMDS credentials request failed: {e}"))?; + if !creds_resp.status().is_success() { + return Ok(None); + } + let value: serde_json::Value = creds_resp + .json() + .await + .map_err(|e| format!("Invalid IMDS credential response: {e}"))?; + Ok(extract_credential_fields(&value)) + } +} + +fn build_metadata_http_client() -> reqwest::Client { + reqwest::Client::builder() + .timeout(Duration::from_secs(AWS_METADATA_HTTP_TIMEOUT_SECS)) + .connect_timeout(Duration::from_secs(AWS_METADATA_HTTP_TIMEOUT_SECS)) + .build() + .unwrap_or_else(|_| reqwest::Client::new()) +} + +fn hmac_sha256(key: &[u8], data: &[u8]) -> Result, String> { + let mut mac = + HmacSha256::new_from_slice(key).map_err(|e| format!("Invalid HMAC key length: {e}"))?; + mac.update(data); + Ok(mac.finalize().into_bytes().to_vec()) +} + +fn hex_sha256(data: &[u8]) -> String { + let mut hasher = Sha256::new(); + hasher.update(data); + hex::encode(hasher.finalize()) +} + +fn canonical_uri(path: &str) -> String { + let path = if path.is_empty() { "/" } else { path }; + let mut out = String::with_capacity(path.len()); + for &byte in path.as_bytes() { + match byte { + b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' | b'/' => { + out.push(byte as char); + } + _ => { + out.push('%'); + out.push_str(&format!("{byte:02X}")); + } + } + } + out +} + +/// Resolve the ECS container authorization token for use with +/// `AWS_CONTAINER_CREDENTIALS_FULL_URI`. Prefers the inline +/// `AWS_CONTAINER_AUTHORIZATION_TOKEN`, falling back to reading the file +/// referenced by `AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE`. Returns `Ok(None)` +/// when neither is set. +fn read_container_authorization_token() -> Result, String> { + if let Some(token) = std::env::var("AWS_CONTAINER_AUTHORIZATION_TOKEN") + .ok() + .filter(|s| !s.is_empty()) + { + return Ok(Some(token)); + } + if let Some(path) = std::env::var("AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE") + .ok() + .filter(|s| !s.is_empty()) + { + let token = std::fs::read_to_string(&path).map_err(|e| { + format!("Failed to read AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE {path}: {e}") + })?; + let token = token.trim().to_string(); + if !token.is_empty() { + return Ok(Some(token)); + } + } + Ok(None) +} + +fn load_env_credentials() -> Option { + // Treat empty env vars as absent so we fall back to web identity / ECS / IMDS + // rather than signing with invalid blank credentials. + let access_key_id = std::env::var("AWS_ACCESS_KEY_ID") + .ok() + .filter(|s| !s.is_empty())?; + let secret_access_key = std::env::var("AWS_SECRET_ACCESS_KEY") + .ok() + .filter(|s| !s.is_empty())?; + let session_token = std::env::var("AWS_SESSION_TOKEN") + .ok() + .filter(|s| !s.is_empty()); + Some(AwsCredentials { + access_key_id, + secret_access_key, + session_token, + }) +} + +fn load_profile_credentials() -> Option { + let profile = std::env::var("AWS_PROFILE").unwrap_or_else(|_| "default".to_string()); + let path = std::env::var("AWS_SHARED_CREDENTIALS_FILE") + .ok() + .or_else(|| { + std::env::var("HOME") + .ok() + .map(|home| format!("{home}/.aws/credentials")) + })?; + let content = std::fs::read_to_string(path).ok()?; + + let mut current_section = String::new(); + let mut access_key_id = None; + let mut secret_access_key = None; + let mut session_token = None; + + for line in content.lines() { + let trimmed = line.trim(); + if trimmed.is_empty() || trimmed.starts_with('#') || trimmed.starts_with(';') { + continue; + } + if trimmed.starts_with('[') && trimmed.ends_with(']') { + current_section = trimmed[1..trimmed.len() - 1].trim().to_string(); + continue; + } + if current_section != profile { + continue; + } + let Some((k, v)) = trimmed.split_once('=') else { + continue; + }; + match k.trim() { + "aws_access_key_id" => access_key_id = Some(v.trim().to_string()), + "aws_secret_access_key" => secret_access_key = Some(v.trim().to_string()), + "aws_session_token" => session_token = Some(v.trim().to_string()), + _ => {} + } + } + + Some(AwsCredentials { + access_key_id: access_key_id?, + secret_access_key: secret_access_key?, + session_token, + }) +} + +/// Minimal percent-encoding for form values (encodes everything except unreserved chars). +fn urlencoding_simple(s: &str) -> String { + let mut out = String::with_capacity(s.len()); + for byte in s.bytes() { + match byte { + b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => { + out.push(byte as char); + } + _ => { + out.push('%'); + out.push_str(&format!("{byte:02X}")); + } + } + } + out +} + +/// Extract credentials from STS AssumeRoleWithWebIdentity XML response. +fn extract_sts_credentials(xml: &str) -> Option { + let access_key_id = xml_text_between(xml, "", "")?; + let secret_access_key = xml_text_between(xml, "", "")?; + let session_token = xml_text_between(xml, "", ""); + let expires_at = xml_text_between(xml, "", "") + .and_then(|s| DateTime::parse_from_rfc3339(&s).ok()) + .map(|dt| dt.with_timezone(&chrono::Utc)); + + Some(CachedCredentials { + creds: AwsCredentials { + access_key_id, + secret_access_key, + session_token, + }, + expires_at, + }) +} + +fn xml_text_between(xml: &str, open: &str, close: &str) -> Option { + let start = xml.find(open)? + open.len(); + let end = xml[start..].find(close)? + start; + Some(xml[start..end].trim().to_string()) +} + +fn extract_credential_fields(value: &serde_json::Value) -> Option { + let access_key_id = value + .get("AccessKeyId") + .or_else(|| value.get("accessKeyId")) + .and_then(serde_json::Value::as_str)? + .to_string(); + let secret_access_key = value + .get("SecretAccessKey") + .or_else(|| value.get("secretAccessKey")) + .and_then(serde_json::Value::as_str)? + .to_string(); + let session_token = value + .get("Token") + .or_else(|| value.get("sessionToken")) + .and_then(serde_json::Value::as_str) + .map(ToOwned::to_owned); + let expires_at = value + .get("Expiration") + .or_else(|| value.get("expiration")) + .and_then(serde_json::Value::as_str) + .and_then(|s| DateTime::parse_from_rfc3339(s).ok()) + .map(|dt| dt.with_timezone(&chrono::Utc)); + + Some(CachedCredentials { + creds: AwsCredentials { + access_key_id, + secret_access_key, + session_token, + }, + expires_at, + }) +} + +#[cfg(test)] +mod tests { + use super::canonical_uri; + + #[test] + fn canonical_uri_encodes_model_id_colon_once() { + assert_eq!( + canonical_uri("/model/us.anthropic.claude-opus-4-5-20251101-v1:0/converse"), + "/model/us.anthropic.claude-opus-4-5-20251101-v1%3A0/converse" + ); + } + + #[test] + fn canonical_uri_encodes_converse_stream_path() { + assert_eq!( + canonical_uri("/model/us.anthropic.claude-opus-4-5-20251101-v1:0/converse-stream"), + "/model/us.anthropic.claude-opus-4-5-20251101-v1%3A0/converse-stream" + ); + } +} diff --git a/model_gateway/src/routers/bedrock/tests.rs b/model_gateway/src/routers/bedrock/tests.rs new file mode 100644 index 000000000..a26e6940c --- /dev/null +++ b/model_gateway/src/routers/bedrock/tests.rs @@ -0,0 +1,99 @@ +use openai_protocol::chat::{ChatCompletionRequest, ChatMessage, MessageContent}; +use reqwest::Url; +use serial_test::serial; + +use super::{ + request_map::map_chat_request, response_map::map_non_stream_response, signing::AwsSigner, +}; + +#[test] +fn maps_chat_request_to_bedrock_shape() { + let req = ChatCompletionRequest { + model: "us.anthropic.claude-opus-4-5-20251101-v1:0".to_string(), + messages: vec![ + ChatMessage::System { + content: MessageContent::Text("You are concise".to_string()), + name: None, + }, + ChatMessage::User { + content: MessageContent::Text("Hello".to_string()), + name: None, + }, + ], + ..Default::default() + }; + + let mapped = map_chat_request(&req); + assert_eq!(mapped.system.len(), 1); + assert_eq!(mapped.messages.len(), 1); + assert_eq!(mapped.messages[0].role, "user"); +} + +#[test] +fn maps_converse_response_to_openai() { + let raw = br#"{ + "output": {"message": {"content": [{"text": "Hi there"}]}}, + "usage": {"inputTokens": 12, "outputTokens": 3, "totalTokens": 15}, + "stopReason": "end_turn" + }"#; + let mapped = + map_non_stream_response(raw, "us.anthropic.claude-opus-4-5-20251101-v1:0").expect("maps"); + assert_eq!( + mapped["choices"][0]["message"]["content"].as_str(), + Some("Hi there") + ); + assert_eq!(mapped["usage"]["total_tokens"].as_u64(), Some(15)); +} + +#[tokio::test] +#[serial] +async fn signer_builds_authorization_header_with_env_credentials() { + struct EnvRestore { + saved: Vec<(&'static str, Option)>, + } + impl EnvRestore { + fn new(keys: &[&'static str]) -> Self { + let saved = keys + .iter() + .map(|k| (*k, std::env::var_os(k))) + .collect::>(); + Self { saved } + } + } + impl Drop for EnvRestore { + fn drop(&mut self) { + for (k, v) in &self.saved { + match v { + Some(val) => std::env::set_var(k, val), + None => std::env::remove_var(k), + } + } + } + } + + let _guard = EnvRestore::new(&[ + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_REGION", + "AWS_EC2_METADATA_DISABLED", + ]); + std::env::set_var("AWS_ACCESS_KEY_ID", "AKIAEXAMPLE"); + std::env::set_var("AWS_SECRET_ACCESS_KEY", "secret-example"); + std::env::set_var("AWS_REGION", "us-east-1"); + std::env::set_var("AWS_EC2_METADATA_DISABLED", "true"); + + let signer = AwsSigner::new("us-east-1".to_string(), "bedrock".to_string()); + let url = Url::parse("https://bedrock-runtime.us-east-1.amazonaws.com/model/test/converse") + .expect("url"); + let signed = signer + .sign("POST", &url, br#"{"a":1}"#) + .await + .expect("signed"); + + assert!(signed + .authorization + .starts_with("AWS4-HMAC-SHA256 Credential=")); + assert!(signed.authorization.contains("SignedHeaders=")); + assert!(!signed.amz_date.is_empty()); + assert_eq!(signed.payload_hash.len(), 64); +} diff --git a/model_gateway/src/routers/factory.rs b/model_gateway/src/routers/factory.rs index e4be743d4..2508a694e 100644 --- a/model_gateway/src/routers/factory.rs +++ b/model_gateway/src/routers/factory.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use super::{ anthropic::AnthropicRouter, + bedrock::BedrockRouter, gemini::GeminiRouter, grpc::{pd_router::GrpcPDRouter, router::GrpcRouter}, http::{pd_router::PDRouter, router::Router}, @@ -39,6 +40,7 @@ pub mod router_ids { pub const HTTP_OPENAI: RouterId = RouterId::new("http-openai"); pub const HTTP_ANTHROPIC: RouterId = RouterId::new("http-anthropic"); pub const HTTP_GEMINI: RouterId = RouterId::new("http-gemini"); + pub const HTTP_BEDROCK: RouterId = RouterId::new("http-bedrock"); pub const GRPC_REGULAR: RouterId = RouterId::new("grpc-regular"); pub const GRPC_PD: RouterId = RouterId::new("grpc-pd"); } @@ -71,6 +73,9 @@ impl RouterFactory { RoutingMode::Gemini { .. } => { Err("Gemini mode requires HTTP connection_mode".to_string()) } + RoutingMode::Bedrock { .. } => { + Err("Bedrock mode requires HTTP connection_mode".to_string()) + } }, ConnectionMode::Http => match &ctx.router_config.mode { RoutingMode::Regular { .. } => Self::create_regular_router(ctx).await, @@ -90,6 +95,7 @@ impl RouterFactory { RoutingMode::OpenAI { .. } => Self::create_openai_router(ctx).await, RoutingMode::Anthropic { .. } => Self::create_anthropic_router(ctx).await, RoutingMode::Gemini { .. } => Self::create_gemini_router(ctx).await, + RoutingMode::Bedrock { .. } => Self::create_bedrock_router(ctx).await, }, } } @@ -201,6 +207,18 @@ impl RouterFactory { Ok(Box::new(router)) } + /// Create a Bedrock router + #[expect( + clippy::unused_async, + reason = "async for API consistency with other create_* factory methods" + )] + pub async fn create_bedrock_router( + ctx: &Arc, + ) -> Result, String> { + let router = BedrockRouter::new(ctx.clone()); + Ok(Box::new(router)) + } + /// Create all routers for IGW (multi-router) mode. /// /// Returns a list of (router_id, label, creation_result) tuples. @@ -245,6 +263,11 @@ impl RouterFactory { "Gemini", Self::create_gemini_router(ctx).await, ), + ( + router_ids::HTTP_BEDROCK, + "Bedrock", + Self::create_bedrock_router(ctx).await, + ), ] } } diff --git a/model_gateway/src/routers/mod.rs b/model_gateway/src/routers/mod.rs index 3bc036b7c..08201728d 100644 --- a/model_gateway/src/routers/mod.rs +++ b/model_gateway/src/routers/mod.rs @@ -29,6 +29,7 @@ use openai_protocol::{ use crate::middleware::TenantRequestMeta; pub mod anthropic; +pub mod bedrock; pub mod common; pub mod conversations; pub mod error; diff --git a/model_gateway/src/routers/openai/provider/bedrock.rs b/model_gateway/src/routers/openai/provider/bedrock.rs new file mode 100644 index 000000000..e39027f24 --- /dev/null +++ b/model_gateway/src/routers/openai/provider/bedrock.rs @@ -0,0 +1,10 @@ +use super::Provider; +use crate::worker::ProviderType; + +pub struct BedrockProvider; + +impl Provider for BedrockProvider { + fn provider_type(&self) -> ProviderType { + ProviderType::Bedrock + } +} diff --git a/model_gateway/src/routers/openai/provider/mod.rs b/model_gateway/src/routers/openai/provider/mod.rs index 74e208066..21538876c 100644 --- a/model_gateway/src/routers/openai/provider/mod.rs +++ b/model_gateway/src/routers/openai/provider/mod.rs @@ -1,6 +1,7 @@ //! Provider abstractions for vendor-specific API transformations. mod anthropic; +mod bedrock; mod gemini; mod openai; mod provider_trait; @@ -12,6 +13,7 @@ mod types; mod xai; pub use anthropic::AnthropicProvider; +pub use bedrock::BedrockProvider; pub use gemini::GeminiProvider; pub use openai::OpenAIProvider; pub use provider_trait::Provider; diff --git a/model_gateway/src/routers/openai/provider/registry.rs b/model_gateway/src/routers/openai/provider/registry.rs index a0deac56d..fb23c7d50 100644 --- a/model_gateway/src/routers/openai/provider/registry.rs +++ b/model_gateway/src/routers/openai/provider/registry.rs @@ -1,7 +1,8 @@ use std::{collections::HashMap, sync::Arc}; use super::{ - AnthropicProvider, GeminiProvider, OpenAIProvider, Provider, SGLangProvider, XAIProvider, + AnthropicProvider, BedrockProvider, GeminiProvider, OpenAIProvider, Provider, SGLangProvider, + XAIProvider, }; use crate::worker::ProviderType; @@ -36,6 +37,10 @@ impl ProviderRegistry { ProviderType::Anthropic, Arc::new(AnthropicProvider) as Arc, ); + providers.insert( + ProviderType::Bedrock, + Arc::new(BedrockProvider) as Arc, + ); Self { providers, diff --git a/model_gateway/src/routers/router_manager.rs b/model_gateway/src/routers/router_manager.rs index 4218e7e4a..0b6d5f7ea 100644 --- a/model_gateway/src/routers/router_manager.rs +++ b/model_gateway/src/routers/router_manager.rs @@ -157,11 +157,13 @@ impl RouterManager { (ConnectionMode::Http, RoutingMode::PrefillDecode { .. }) => router_ids::HTTP_PD, (ConnectionMode::Http, RoutingMode::OpenAI { .. }) => router_ids::HTTP_OPENAI, (ConnectionMode::Http, RoutingMode::Anthropic { .. }) => router_ids::HTTP_ANTHROPIC, + (ConnectionMode::Http, RoutingMode::Bedrock { .. }) => router_ids::HTTP_BEDROCK, (ConnectionMode::Grpc, RoutingMode::Regular { .. }) => router_ids::GRPC_REGULAR, (ConnectionMode::Grpc, RoutingMode::PrefillDecode { .. }) => router_ids::GRPC_PD, (ConnectionMode::Http, RoutingMode::Gemini { .. }) => router_ids::HTTP_GEMINI, (ConnectionMode::Grpc, RoutingMode::OpenAI { .. }) => router_ids::GRPC_REGULAR, (ConnectionMode::Grpc, RoutingMode::Anthropic { .. }) => router_ids::GRPC_REGULAR, + (ConnectionMode::Grpc, RoutingMode::Bedrock { .. }) => router_ids::GRPC_REGULAR, (ConnectionMode::Grpc, RoutingMode::Gemini { .. }) => router_ids::GRPC_REGULAR, } } @@ -239,6 +241,7 @@ impl RouterManager { let router_id = match w.provider_for_model(model) { Some(ProviderType::Gemini) => &router_ids::HTTP_GEMINI, Some(ProviderType::Anthropic) => &router_ids::HTTP_ANTHROPIC, + Some(ProviderType::Bedrock) => &router_ids::HTTP_BEDROCK, _ => &router_ids::HTTP_OPENAI, }; return self.routers.get(router_id).map(|r| r.clone()); diff --git a/model_gateway/src/routers/skills/handlers.rs b/model_gateway/src/routers/skills/handlers.rs index 7ad2de307..e65cd6d27 100644 --- a/model_gateway/src/routers/skills/handlers.rs +++ b/model_gateway/src/routers/skills/handlers.rs @@ -805,7 +805,7 @@ where hasher.update(updated_at.as_bytes()); hasher.update([0xff]); } - Ok(format!("W/\"{:x}\"", hasher.finalize())) + Ok(format!("W/\"{}\"", hex::encode(hasher.finalize()))) } fn build_resource_etag(value: &T, weak: bool) -> Result { @@ -817,9 +817,9 @@ fn build_resource_etag(value: &T, weak: bool) -> Result>) -> Response { RoutingMode::OpenAI { .. } => !healthy_workers.is_empty(), RoutingMode::Anthropic { .. } => !healthy_workers.is_empty(), RoutingMode::Gemini { .. } => !healthy_workers.is_empty(), + RoutingMode::Bedrock { .. } => !healthy_workers.is_empty(), } }; diff --git a/model_gateway/src/workflow/job_queue.rs b/model_gateway/src/workflow/job_queue.rs index 6d15253b0..6aed80170 100644 --- a/model_gateway/src/workflow/job_queue.rs +++ b/model_gateway/src/workflow/job_queue.rs @@ -494,7 +494,8 @@ impl JobQueue { } RoutingMode::OpenAI { worker_urls } | RoutingMode::Anthropic { worker_urls } - | RoutingMode::Gemini { worker_urls } => { + | RoutingMode::Gemini { worker_urls } + | RoutingMode::Bedrock { worker_urls } => { let provider_name = router_config.mode_type(); return submit_external_worker_jobs( worker_urls, diff --git a/model_gateway/src/workflow/steps/external/discover_models.rs b/model_gateway/src/workflow/steps/external/discover_models.rs index b57ed2c40..92a54cb61 100644 --- a/model_gateway/src/workflow/steps/external/discover_models.rs +++ b/model_gateway/src/workflow/steps/external/discover_models.rs @@ -177,6 +177,15 @@ fn infer_provider_from_id(id: &str) -> Option { return Some(ProviderType::Gemini); } + // Bedrock-hosted model IDs (common prefixes) + if id_lower.starts_with("anthropic.claude") + || id_lower.starts_with("amazon.") + || id_lower.starts_with("meta.") + || id_lower.starts_with("mistral.") + { + return Some(ProviderType::Bedrock); + } + None } @@ -273,6 +282,15 @@ impl StepExecutor for DiscoverModelsStep { let config = &context.data.config; let provider = ProviderType::from_url(&config.url); + // Bedrock Runtime has no /v1/models endpoint and requires SigV4 auth — skip discovery. + if provider.as_ref() == Some(&ProviderType::Bedrock) { + info!( + "Skipping model discovery for Bedrock worker {} — Bedrock Runtime does not expose /v1/models", + config.url + ); + return Ok(StepResult::Skip); + } + // Resolve discovery API key: env var admin key > config.api_key > None (wildcard) let discovery_key = resolve_discovery_api_key(provider.as_ref(), &config.url, config.api_key.as_deref());