diff --git a/benches/benchmarks.rs b/benches/benchmarks.rs index 93933c1..e83b216 100644 --- a/benches/benchmarks.rs +++ b/benches/benchmarks.rs @@ -1,15 +1,9 @@ -use criterion::{black_box, criterion_group, criterion_main, Criterion, BenchmarkId}; -use key_cycle_proxy::{ - config::{ApiKeyInfo, UpstreamConfig}, - proxy::{KeyPool, ProxyEngine, ProxyHandler, UpstreamClient}, - routes::create_router, - types::OpenAIRequest, -}; +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; +use key_cycle_proxy::{config::ApiKeyInfo, proxy::KeyPool, types::OpenAIRequest}; use secrecy::SecretString; use serde_json::json; use std::{sync::Arc, time::Duration}; use tokio::runtime::Runtime; -use wiremock::{matchers::method, Mock, MockServer, ResponseTemplate}; fn create_test_keys(count: usize) -> Vec { (0..count) @@ -25,48 +19,41 @@ fn create_test_keys(count: usize) -> Vec { fn bench_key_selection(c: &mut Criterion) { let mut group = c.benchmark_group("key_selection"); - + for key_count in [5, 10, 25, 50, 100].iter() { let keys = create_test_keys(*key_count); let pool = KeyPool::new(keys, "round_robin"); - + group.bench_with_input( BenchmarkId::new("round_robin", key_count), key_count, - |b, _| { - b.iter(|| { - black_box(pool.get_key_for_model("gpt-3.5-turbo")) - }) - }, + |b, _| b.iter(|| black_box(pool.get_key_for_model("gpt-3.5-turbo"))), ); - + let keys = create_test_keys(*key_count); let pool = KeyPool::new(keys, "least_latency"); - + group.bench_with_input( BenchmarkId::new("least_latency", key_count), key_count, - |b, _| { - b.iter(|| { - black_box(pool.get_key_for_model("gpt-3.5-turbo")) - }) - }, + |b, _| b.iter(|| black_box(pool.get_key_for_model("gpt-3.5-turbo"))), ); } - + group.finish(); } fn bench_json_parsing(c: &mut Criterion) { let mut group = c.benchmark_group("json_parsing"); - + let simple_request = json!({ "model": "gpt-3.5-turbo", "messages": [ {"role": "user", "content": "Hello!"} ] - }).to_string(); - + }) + .to_string(); + let complex_request = json!({ "model": "gpt-4-turbo", "messages": [ @@ -84,21 +71,21 @@ fn bench_json_parsing(c: &mut Criterion) { "stream": false, "user": "benchmark-user-123" }).to_string(); - + group.bench_function("simple_request", |b| { b.iter(|| { let request: OpenAIRequest = serde_json::from_str(black_box(&simple_request)).unwrap(); black_box(request) }) }); - + group.bench_function("complex_request", |b| { b.iter(|| { let request: OpenAIRequest = serde_json::from_str(black_box(&complex_request)).unwrap(); black_box(request) }) }); - + group.finish(); } @@ -106,10 +93,10 @@ fn bench_concurrent_key_access(c: &mut Criterion) { let rt = Runtime::new().unwrap(); let mut group = c.benchmark_group("concurrent_access"); group.sample_size(50); // Reduce sample size for concurrent tests - + let keys = create_test_keys(10); let pool = Arc::new(KeyPool::new(keys, "round_robin")); - + group.bench_function("sequential_access", |b| { b.iter(|| { for _ in 0..100 { @@ -117,25 +104,31 @@ fn bench_concurrent_key_access(c: &mut Criterion) { } }) }); - + group.bench_function("concurrent_access", |b| { - b.to_async(&rt).iter(|| async { - let mut handles = vec![]; - - for _ in 0..10 { - let pool_clone = pool.clone(); - let handle = tokio::spawn(async move { + b.iter_custom(|iters| { + let start = std::time::Instant::now(); + rt.block_on(async { + for _ in 0..iters { + let mut handles = vec![]; + for _ in 0..10 { - black_box(pool_clone.get_key_for_model("gpt-3.5-turbo")); + let pool_clone = pool.clone(); + let handle = tokio::spawn(async move { + for _ in 0..10 { + black_box(pool_clone.get_key_for_model("gpt-3.5-turbo")); + } + }); + handles.push(handle); } - }); - handles.push(handle); - } - - futures::future::join_all(handles).await + + futures::future::join_all(handles).await; + } + }); + start.elapsed() }) }); - + group.finish(); } @@ -145,4 +138,4 @@ criterion_group!( bench_json_parsing, bench_concurrent_key_access ); -criterion_main!(benches); \ No newline at end of file +criterion_main!(benches); diff --git a/src/config.rs b/src/config.rs index 3ff94c6..f8738d8 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,14 +1,19 @@ -use serde::{Deserialize, Serialize}; +use anyhow::{Context, Result}; use secrecy::SecretString; +use serde::{Deserialize, Serialize}; use std::time::Duration; -use anyhow::{Context, Result}; -#[derive(Debug, Clone, Deserialize, Serialize)] +#[derive(Debug, Clone, Deserialize, Serialize, Default)] pub struct Config { + #[serde(default)] pub server: ServerConfig, + #[serde(default)] pub upstream: UpstreamConfig, + #[serde(default)] pub keys: KeysConfig, + #[serde(default)] pub rate_limit: RateLimitConfig, + #[serde(default)] pub observability: ObservabilityConfig, } @@ -65,6 +70,7 @@ pub struct ObservabilityConfig { } #[derive(Debug, Clone, Deserialize)] +#[allow(dead_code)] pub struct ApiKeyInfo { #[serde(skip_serializing)] pub key: SecretString, @@ -95,18 +101,6 @@ pub struct LegacyApiKeyInfo { pub models: Vec, } -impl Default for Config { - fn default() -> Self { - Self { - server: ServerConfig::default(), - upstream: UpstreamConfig::default(), - keys: KeysConfig::default(), - rate_limit: RateLimitConfig::default(), - observability: ObservabilityConfig::default(), - } - } -} - impl Default for ServerConfig { fn default() -> Self { Self { @@ -159,36 +153,67 @@ impl Default for ObservabilityConfig { } // Default value functions -fn default_bind_addr() -> String { "0.0.0.0:8080".to_string() } -fn default_request_body_limit() -> usize { 262_144 } -fn default_graceful_shutdown_seconds() -> u64 { 10 } -fn default_base_url() -> String { "https://api.openai.com/v1".to_string() } -fn default_connect_timeout() -> u64 { 800 } -fn default_request_timeout() -> u64 { 60_000 } -fn default_retry_initial_backoff() -> u64 { 50 } -fn default_retry_max_backoff() -> u64 { 2000 } -fn default_max_retries() -> u32 { 3 } -fn default_rotation_strategy() -> String { "round_robin_health_weighted".to_string() } -fn default_unhealthy_penalty() -> u32 { 5 } -fn default_per_key_rps() -> u32 { 3 } -fn default_global_rps() -> u32 { 50 } -fn default_burst() -> u32 { 10 } -fn default_metrics_bind() -> String { "0.0.0.0:9090".to_string() } -fn default_tracing_level() -> String { "info".to_string() } +fn default_bind_addr() -> String { + "0.0.0.0:8080".to_string() +} +fn default_request_body_limit() -> usize { + 262_144 +} +fn default_graceful_shutdown_seconds() -> u64 { + 10 +} +fn default_base_url() -> String { + "https://api.openai.com/v1".to_string() +} +fn default_connect_timeout() -> u64 { + 800 +} +fn default_request_timeout() -> u64 { + 60_000 +} +fn default_retry_initial_backoff() -> u64 { + 50 +} +fn default_retry_max_backoff() -> u64 { + 2000 +} +fn default_max_retries() -> u32 { + 3 +} +fn default_rotation_strategy() -> String { + "round_robin_health_weighted".to_string() +} +fn default_unhealthy_penalty() -> u32 { + 5 +} +fn default_per_key_rps() -> u32 { + 3 +} +fn default_global_rps() -> u32 { + 50 +} +fn default_burst() -> u32 { + 10 +} +fn default_metrics_bind() -> String { + "0.0.0.0:9090".to_string() +} +fn default_tracing_level() -> String { + "info".to_string() +} pub fn load_config() -> Result<(Config, Vec)> { // Load main config let mut config = Config::default(); - + // Try to load from config file if it exists if let Ok(config_str) = std::fs::read_to_string("config.toml") { - config = toml::from_str(&config_str) - .context("Failed to parse config.toml")?; + config = toml::from_str(&config_str).context("Failed to parse config.toml")?; } - + // Load API keys from environment or legacy config.json let api_keys = load_api_keys()?; - + Ok((config, api_keys)) } @@ -196,29 +221,36 @@ fn load_api_keys() -> Result> { // First try environment variable if let Ok(keys_env) = std::env::var("OPENAI_KEYS") { let keys: Vec<&str> = keys_env.split(',').collect(); - return Ok(keys.into_iter().map(|key| ApiKeyInfo { - key: SecretString::new(key.trim().to_string()), - url: default_base_url(), - models: vec!["others".to_string()], - latency: None, - health_score: 1.0, - }).collect()); + return Ok(keys + .into_iter() + .map(|key| ApiKeyInfo { + key: SecretString::new(key.trim().to_string()), + url: default_base_url(), + models: vec!["others".to_string()], + latency: None, + health_score: 1.0, + }) + .collect()); } - + // Fallback to legacy config.json if let Ok(config_content) = std::fs::read_to_string("config.json") { - let legacy_config: LegacyConfig = serde_json::from_str(&config_content) - .context("Failed to parse config.json")?; - - return Ok(legacy_config.api_keys.into_iter().map(|key_info| ApiKeyInfo { - key: SecretString::new(key_info.key), - url: key_info.url, - models: key_info.models, - latency: None, - health_score: 1.0, - }).collect()); + let legacy_config: LegacyConfig = + serde_json::from_str(&config_content).context("Failed to parse config.json")?; + + return Ok(legacy_config + .api_keys + .into_iter() + .map(|key_info| ApiKeyInfo { + key: SecretString::new(key_info.key), + url: key_info.url, + models: key_info.models, + latency: None, + health_score: 1.0, + }) + .collect()); } - + anyhow::bail!("No API keys found. Set OPENAI_KEYS environment variable or create config.json"); } @@ -226,15 +258,16 @@ impl UpstreamConfig { pub fn connect_timeout(&self) -> Duration { Duration::from_millis(self.connect_timeout_ms) } - + pub fn request_timeout(&self) -> Duration { Duration::from_millis(self.request_timeout_ms) } - + + #[allow(dead_code)] pub fn retry_initial_backoff(&self) -> Duration { Duration::from_millis(self.retry_initial_backoff_ms) } - + pub fn retry_max_backoff(&self) -> Duration { Duration::from_millis(self.retry_max_backoff_ms) } @@ -244,4 +277,4 @@ impl ServerConfig { pub fn graceful_shutdown_duration(&self) -> Duration { Duration::from_secs(self.graceful_shutdown_seconds) } -} \ No newline at end of file +} diff --git a/src/lib.rs b/src/lib.rs index bebf25d..75c575c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,4 +2,4 @@ pub mod config; pub mod proxy; pub mod routes; pub mod types; -pub mod util; \ No newline at end of file +pub mod util; diff --git a/src/main.rs b/src/main.rs index a0babb7..99f48e8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -40,7 +40,7 @@ async fn main() -> Result<()> { // Load configuration let (config, api_keys) = load_config().context("Failed to load configuration")?; - + if api_keys.is_empty() { anyhow::bail!("No API keys configured. Please set OPENAI_KEYS environment variable or create config.json"); } @@ -52,8 +52,8 @@ async fn main() -> Result<()> { // Initialize components let key_pool = Arc::new(KeyPool::new(api_keys, &config.keys.rotation_strategy)); - let upstream_client = UpstreamClient::new(config.upstream.clone()) - .context("Failed to create upstream client")?; + let upstream_client = + UpstreamClient::new(config.upstream.clone()).context("Failed to create upstream client")?; let engine = Arc::new(ProxyEngine::new( key_pool.clone(), upstream_client, @@ -145,7 +145,10 @@ async fn shutdown_signal(grace_period: Duration) { // Give the server some time to finish ongoing requests if grace_period > Duration::ZERO { - info!("Waiting {}s for ongoing requests to complete...", grace_period.as_secs()); + info!( + "Waiting {}s for ongoing requests to complete...", + grace_period.as_secs() + ); tokio::time::sleep(grace_period).await; } -} \ No newline at end of file +} diff --git a/src/proxy/engine.rs b/src/proxy/engine.rs index 135e153..56f248b 100644 --- a/src/proxy/engine.rs +++ b/src/proxy/engine.rs @@ -4,7 +4,10 @@ use crate::proxy::{ upstream::{should_rotate_key, UpstreamClient}, }; use crate::types::OpenAIRequest; -use crate::util::{convert_axum_headers_to_reqwest, convert_axum_method_to_reqwest, convert_reqwest_headers_to_axum}; +use crate::util::{ + convert_axum_headers_to_reqwest, convert_axum_method_to_reqwest, + convert_reqwest_headers_to_axum, +}; use axum::body::Body; use axum::http::{HeaderMap, Method, StatusCode}; use axum::response::Response; @@ -21,11 +24,7 @@ pub struct ProxyEngine { } impl ProxyEngine { - pub fn new( - key_pool: Arc, - upstream_client: UpstreamClient, - max_retries: u32, - ) -> Self { + pub fn new(key_pool: Arc, upstream_client: UpstreamClient, max_retries: u32) -> Self { Self { key_pool, upstream_client, @@ -55,15 +54,28 @@ impl ProxyEngine { // Attempt the request with retries let mut attempt_count = 0; let mut last_error = None; + let mut use_next_key = false; while attempt_count <= self.max_retries { // Get appropriate API key for the model - let key_info = match self.key_pool.get_key_for_model(&model) { - Some(key) => key, - None => { - return Err(ProxyError::NoKeyAvailable { - model: model.clone(), - }) + // On first attempt, use model-specific key; on retries, rotate through all keys + let key_info = if use_next_key { + match self.key_pool.get_next_key() { + Some(key) => key, + None => { + return Err(ProxyError::NoKeyAvailable { + model: model.clone(), + }) + } + } + } else { + match self.key_pool.get_key_for_model(&model) { + Some(key) => key, + None => { + return Err(ProxyError::NoKeyAvailable { + model: model.clone(), + }) + } } }; @@ -96,14 +108,10 @@ impl ProxyEngine { status ); attempt_count += 1; - - // Get next key for retry - if let Some(next_key) = self.key_pool.get_next_key() { - info!("Forwarding to {} with next API key", next_key.url); - } - + use_next_key = true; // Switch to using get_next_key for retries + last_error = Some(ProxyError::UpstreamFailed { - source: reqwest::Error::from(response.error_for_status_ref().unwrap_err()), + source: response.error_for_status_ref().unwrap_err(), }); continue; } @@ -115,6 +123,7 @@ impl ProxyEngine { error!("Error sending request to upstream: {}", e); last_error = Some(e); attempt_count += 1; + use_next_key = true; // Switch to using get_next_key for retries } } } @@ -146,9 +155,7 @@ impl ProxyEngine { // Handle streaming response body let body_stream = response.bytes_stream(); - let body = Body::from_stream(body_stream.map(|chunk| { - chunk.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)) - })); + let body = Body::from_stream(body_stream.map(|chunk| chunk.map_err(std::io::Error::other))); builder .body(body) @@ -191,4 +198,4 @@ mod tests { let result = engine.extract_model_from_body(&invalid_body); assert!(result.is_err()); } -} \ No newline at end of file +} diff --git a/src/proxy/error.rs b/src/proxy/error.rs index 9cd378b..282dcbe 100644 --- a/src/proxy/error.rs +++ b/src/proxy/error.rs @@ -1,47 +1,48 @@ +use crate::types::ErrorResponse; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; use axum::Json; -use crate::types::ErrorResponse; use thiserror::Error; #[derive(Error, Debug)] +#[allow(dead_code)] pub enum ProxyError { #[error("No API key available for model '{model}'")] NoKeyAvailable { model: String }, - + #[error("No API key found")] NoKeyFound, - + #[error("Invalid API key")] InvalidApiKey, - + #[error("Upstream request failed: {source}")] UpstreamFailed { #[from] source: reqwest::Error, }, - + #[error("Request timeout")] Timeout, - + #[error("Rate limit exceeded")] RateLimited, - + #[error("Invalid JSON payload: {source}")] InvalidJson { #[from] source: serde_json::Error, }, - + #[error("Request body too large")] PayloadTooLarge, - + #[error("Method not allowed")] MethodNotAllowed, - + #[error("All retries exhausted")] AllRetriesExhausted, - + #[error("Internal server error: {message}")] Internal { message: String }, } @@ -52,7 +53,7 @@ impl ProxyError { message: message.into(), } } - + pub fn status_code(&self) -> StatusCode { match self { ProxyError::NoKeyAvailable { .. } => StatusCode::INTERNAL_SERVER_ERROR, @@ -62,8 +63,6 @@ impl ProxyError { // Map specific reqwest errors to appropriate status codes if source.is_timeout() { StatusCode::GATEWAY_TIMEOUT - } else if source.is_connect() { - StatusCode::BAD_GATEWAY } else { StatusCode::BAD_GATEWAY } @@ -83,11 +82,11 @@ impl IntoResponse for ProxyError { fn into_response(self) -> Response { let status = self.status_code(); let error_response = ErrorResponse::new(self.to_string()); - + tracing::error!("Proxy error: {} (status: {})", self, status); - + (status, Json(error_response)).into_response() } } -pub type ProxyResult = Result; \ No newline at end of file +pub type ProxyResult = Result; diff --git a/src/proxy/handler.rs b/src/proxy/handler.rs index 34c8c4f..4d1b07f 100644 --- a/src/proxy/handler.rs +++ b/src/proxy/handler.rs @@ -46,6 +46,7 @@ impl ProxyHandler { } /// Handle requests with path parameters (for /v1/* routes) + #[allow(dead_code)] pub async fn handle_v1_request( State(handler): State>, Path(path): Path, @@ -76,6 +77,7 @@ impl ProxyHandler { } /// Extract the request body as bytes +#[allow(dead_code)] pub async fn extract_body(request: Request) -> Result { match axum::body::to_bytes(request.into_body(), usize::MAX).await { Ok(bytes) => Ok(bytes), @@ -89,27 +91,6 @@ pub async fn extract_body(request: Request) -> Result { #[cfg(test)] mod tests { use super::*; - use crate::config::{ApiKeyInfo, UpstreamConfig}; - use crate::proxy::{KeyPool, UpstreamClient}; - use axum::http::Method; - use secrecy::SecretString; - - fn create_test_handler() -> ProxyHandler { - let key = ApiKeyInfo { - key: SecretString::new("test-key".to_string()), - url: "https://api.test.com".to_string(), - models: vec!["gpt-3.5-turbo".to_string()], - latency: None, - health_score: 1.0, - }; - - let key_pool = Arc::new(KeyPool::new(vec![key], "round_robin")); - let upstream_config = UpstreamConfig::default(); - let upstream_client = UpstreamClient::new(upstream_config).unwrap(); - let engine = Arc::new(ProxyEngine::new(key_pool, upstream_client, 3)); - - ProxyHandler::new(engine) - } #[tokio::test] async fn test_health_check() { @@ -117,4 +98,4 @@ mod tests { assert!(result.is_ok()); assert_eq!(result.unwrap(), "OK"); } -} \ No newline at end of file +} diff --git a/src/proxy/key_pool.rs b/src/proxy/key_pool.rs index 554334a..f8e10b2 100644 --- a/src/proxy/key_pool.rs +++ b/src/proxy/key_pool.rs @@ -78,11 +78,13 @@ impl KeyPool { /// Update latency measurement for a key pub fn update_latency(&self, key_index: usize, latency: Duration) { - self.latency_cache.insert(key_index, (latency, Instant::now())); + self.latency_cache + .insert(key_index, (latency, Instant::now())); debug!("Updated latency for key {}: {:?}", key_index, latency); } /// Get all keys for health checking + #[allow(dead_code)] pub fn get_all_keys(&self) -> &[Arc] { &self.keys } @@ -90,19 +92,19 @@ impl KeyPool { /// Measure latency for all keys by making HEAD requests pub async fn update_all_latencies(&self) { info!("Starting latency measurements for {} keys", self.keys.len()); - + let client = reqwest::Client::new(); let mut tasks = vec![]; for (index, key_info) in self.keys.iter().enumerate() { let client = client.clone(); let url = key_info.url.clone(); - + let task = tokio::spawn(async move { let latency = measure_key_latency(&client, &url).await; (index, latency) }); - + tasks.push(task); } @@ -125,13 +127,16 @@ impl KeyPool { if keys.is_empty() { return None; } - + let current = self.current_index.fetch_add(1, Ordering::SeqCst); let index = current % keys.len(); Some(keys[index].1.clone()) } - fn health_weighted_selection(&self, keys: &[(usize, &Arc)]) -> Option> { + fn health_weighted_selection( + &self, + keys: &[(usize, &Arc)], + ) -> Option> { if keys.is_empty() { return None; } @@ -141,7 +146,10 @@ impl KeyPool { self.round_robin_selection(keys) } - fn least_latency_selection(&self, keys: &[(usize, &Arc)]) -> Option> { + fn least_latency_selection( + &self, + keys: &[(usize, &Arc)], + ) -> Option> { if keys.is_empty() { return None; } @@ -179,11 +187,11 @@ impl KeyPool { async fn measure_key_latency(client: &reqwest::Client, url: &str) -> Duration { let start = Instant::now(); - + // Make a HEAD request with timeout let request_timeout = Duration::from_secs(5); let result = timeout(request_timeout, client.head(url).send()).await; - + match result { Ok(Ok(_)) => { let duration = start.elapsed(); @@ -229,7 +237,7 @@ mod tests { // Test round-robin for gpt-3.5-turbo (should cycle between keys 1 and 2) let key1 = pool.get_key_for_model("gpt-3.5-turbo").unwrap(); let key2 = pool.get_key_for_model("gpt-3.5-turbo").unwrap(); - + // Keys should be different (round-robin) assert_ne!(key1.url, key2.url); } @@ -269,4 +277,4 @@ mod tests { let result = pool.get_key_for_model("claude-1"); assert!(result.is_none()); } -} \ No newline at end of file +} diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index 5173b32..f669a1d 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -5,7 +5,8 @@ pub mod key_pool; pub mod upstream; pub use engine::ProxyEngine; +#[allow(unused_imports)] pub use error::ProxyError; pub use handler::ProxyHandler; pub use key_pool::KeyPool; -pub use upstream::UpstreamClient; \ No newline at end of file +pub use upstream::UpstreamClient; diff --git a/src/proxy/upstream.rs b/src/proxy/upstream.rs index 2a39aa7..8d9719d 100644 --- a/src/proxy/upstream.rs +++ b/src/proxy/upstream.rs @@ -34,7 +34,7 @@ impl UpstreamClient { headers: Option, ) -> ProxyResult { let url = format!("{}{}", key_info.url, path); - + debug!( "Making {} request to {} with API key (redacted)", method, url @@ -65,7 +65,9 @@ impl UpstreamClient { match timeout(self.config.request_timeout(), request.send()).await { Ok(Ok(response)) => { // Check for retryable HTTP status codes - if self.should_retry_status(response.status()) && attempt < self.config.max_retries { + if self.should_retry_status(response.status()) + && attempt < self.config.max_retries + { warn!( "Received retryable status {} from upstream {}, attempt {}/{}", response.status(), @@ -73,21 +75,30 @@ impl UpstreamClient { attempt + 1, self.config.max_retries + 1 ); - + // Wait before retry - let wait_time = Duration::from_millis(self.config.retry_initial_backoff_ms * (2_u64.pow(attempt))); + let wait_time = Duration::from_millis( + self.config.retry_initial_backoff_ms * (2_u64.pow(attempt)), + ); tokio::time::sleep(wait_time.min(self.config.retry_max_backoff())).await; continue; } - + return Ok(response); } Ok(Err(e)) => { if self.should_retry_error(&e) && attempt < self.config.max_retries { - warn!("Request error: {}, retrying attempt {}/{}", e, attempt + 1, self.config.max_retries + 1); - + warn!( + "Request error: {}, retrying attempt {}/{}", + e, + attempt + 1, + self.config.max_retries + 1 + ); + // Wait before retry - let wait_time = Duration::from_millis(self.config.retry_initial_backoff_ms * (2_u64.pow(attempt))); + let wait_time = Duration::from_millis( + self.config.retry_initial_backoff_ms * (2_u64.pow(attempt)), + ); tokio::time::sleep(wait_time.min(self.config.retry_max_backoff())).await; continue; } else { @@ -96,10 +107,16 @@ impl UpstreamClient { } Err(_) => { if attempt < self.config.max_retries { - warn!("Request timeout, retrying attempt {}/{}", attempt + 1, self.config.max_retries + 1); - + warn!( + "Request timeout, retrying attempt {}/{}", + attempt + 1, + self.config.max_retries + 1 + ); + // Wait before retry - let wait_time = Duration::from_millis(self.config.retry_initial_backoff_ms * (2_u64.pow(attempt))); + let wait_time = Duration::from_millis( + self.config.retry_initial_backoff_ms * (2_u64.pow(attempt)), + ); tokio::time::sleep(wait_time.min(self.config.retry_max_backoff())).await; continue; } else { @@ -137,7 +154,7 @@ impl UpstreamClient { 418 | // I'm a teapot (some APIs use this for rate limiting) 502 | // Bad Gateway 503 | // Service Unavailable - 504 // Gateway Timeout + 504 // Gateway Timeout ) } } @@ -149,7 +166,7 @@ pub fn should_rotate_key(status: reqwest::StatusCode) -> bool { 429 | // Too Many Requests 418 | // I'm a teapot 502 | // Bad Gateway - 400 // Bad Request (potentially invalid key) + 400 // Bad Request (potentially invalid key) ) } @@ -157,7 +174,6 @@ pub fn should_rotate_key(status: reqwest::StatusCode) -> bool { mod tests { use super::*; use reqwest::StatusCode; - use secrecy::SecretString; fn create_test_config() -> UpstreamConfig { UpstreamConfig { @@ -170,16 +186,6 @@ mod tests { } } - fn create_test_key() -> Arc { - Arc::new(ApiKeyInfo { - key: SecretString::new("test-key".to_string()), - url: "https://api.test.com".to_string(), - models: vec!["gpt-3.5-turbo".to_string()], - latency: None, - health_score: 1.0, - }) - } - #[test] fn test_should_retry_status() { let config = create_test_config(); @@ -208,4 +214,4 @@ mod tests { assert!(!should_rotate_key(StatusCode::CREATED)); assert!(!should_rotate_key(StatusCode::NOT_FOUND)); } -} \ No newline at end of file +} diff --git a/src/routes.rs b/src/routes.rs index 11dfee3..a3c2a37 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -8,27 +8,24 @@ use axum::{ Router, }; use std::sync::Arc; +use std::time::Duration; use tower_http::cors::{Any, CorsLayer}; use tower_http::limit::RequestBodyLimitLayer; use tower_http::timeout::TimeoutLayer; use tower_http::trace::TraceLayer; -use std::time::Duration; pub fn create_router( - handler: Arc, + handler: Arc, body_limit: usize, request_timeout: Duration, ) -> Router { Router::new() // Health check endpoint .route("/health", get(ProxyHandler::health_check)) - // Catch-all for OpenAI API requests (maintaining compatibility with existing paths) .route("/*path", any(proxy_request_handler)) - // Add the handler as application state .with_state(handler) - // Add middleware layers .layer( CorsLayer::new() @@ -73,8 +70,8 @@ mod tests { use super::*; use crate::config::{ApiKeyInfo, UpstreamConfig}; use crate::proxy::{KeyPool, ProxyEngine, UpstreamClient}; - use axum::http::{Request, StatusCode}; use axum::body::Body; + use axum::http::{Request, StatusCode}; use secrecy::SecretString; use tower::ServiceExt; @@ -123,4 +120,4 @@ mod tests { // Should return method not allowed since we only accept POST assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED); } -} \ No newline at end of file +} diff --git a/src/types/mod.rs b/src/types/mod.rs index 2be1bd6..ce008b7 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -1,3 +1,3 @@ pub mod openai; -pub use openai::*; \ No newline at end of file +pub use openai::*; diff --git a/src/types/openai.rs b/src/types/openai.rs index 3aa53ae..bca1884 100644 --- a/src/types/openai.rs +++ b/src/types/openai.rs @@ -10,11 +10,13 @@ pub struct OpenAIRequest { } #[derive(Debug, Clone, Deserialize, Serialize)] +#[allow(dead_code)] pub struct OpenAIError { pub error: OpenAIErrorDetails, } #[derive(Debug, Clone, Deserialize, Serialize)] +#[allow(dead_code)] pub struct OpenAIErrorDetails { pub message: String, #[serde(rename = "type")] @@ -34,4 +36,4 @@ impl ErrorResponse { error: message.into(), } } -} \ No newline at end of file +} diff --git a/src/util/mod.rs b/src/util/mod.rs index 722d318..2aa7e3d 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -1,5 +1,3 @@ -use std::time::Duration; - pub fn convert_axum_method_to_reqwest(method: &axum::http::Method) -> reqwest::Method { match *method { axum::http::Method::GET => reqwest::Method::GET, @@ -18,7 +16,7 @@ pub fn convert_axum_headers_to_reqwest( headers: &axum::http::HeaderMap, ) -> reqwest::header::HeaderMap { let mut reqwest_headers = reqwest::header::HeaderMap::new(); - + for (name, value) in headers { if let (Ok(req_name), Ok(req_value)) = ( reqwest::header::HeaderName::from_bytes(name.as_str().as_bytes()), @@ -27,7 +25,7 @@ pub fn convert_axum_headers_to_reqwest( reqwest_headers.insert(req_name, req_value); } } - + reqwest_headers } @@ -35,7 +33,7 @@ pub fn convert_reqwest_headers_to_axum( headers: &reqwest::header::HeaderMap, ) -> axum::http::HeaderMap { let mut axum_headers = axum::http::HeaderMap::new(); - + for (name, value) in headers { if let (Ok(axum_name), Ok(axum_value)) = ( axum::http::HeaderName::from_bytes(name.as_str().as_bytes()), @@ -44,6 +42,6 @@ pub fn convert_reqwest_headers_to_axum( axum_headers.insert(axum_name, axum_value); } } - + axum_headers -} \ No newline at end of file +} diff --git a/tests/api_integration_tests.rs b/tests/api_integration_tests.rs index 25b4c75..9e33785 100644 --- a/tests/api_integration_tests.rs +++ b/tests/api_integration_tests.rs @@ -3,7 +3,6 @@ use axum::{ http::{Request, StatusCode}, Router, }; -use bytes::Bytes; use key_cycle_proxy::{ config::{ApiKeyInfo, UpstreamConfig}, proxy::{KeyPool, ProxyEngine, ProxyHandler, UpstreamClient}, @@ -12,10 +11,9 @@ use key_cycle_proxy::{ use secrecy::SecretString; use serde_json::json; use std::{sync::Arc, time::Duration}; -use tokio::net::TcpListener; use tower::ServiceExt; use wiremock::{ - matchers::{header, method, path, query_param}, + matchers::{header, method, path}, Mock, MockServer, ResponseTemplate, }; @@ -155,22 +153,20 @@ async fn test_api_key_rotation_on_rate_limit() { Mock::given(method("POST")) .and(path("/v1/chat/completions")) .and(header("authorization", "Bearer sk-test-key-2")) - .respond_with( - ResponseTemplate::new(200).set_body_json(json!({ - "id": "chatcmpl-fallback123", - "object": "chat.completion", - "created": 1234567890, - "model": "gpt-3.5-turbo", - "choices": [{ - "index": 0, - "message": { - "role": "assistant", - "content": "Fallback response from second key" - }, - "finish_reason": "stop" - }] - })), - ) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "id": "chatcmpl-fallback123", + "object": "chat.completion", + "created": 1234567890, + "model": "gpt-3.5-turbo", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "Fallback response from second key" + }, + "finish_reason": "stop" + }] + }))) .mount(&mock_server_2) .await; @@ -215,22 +211,20 @@ async fn test_api_model_routing() { Mock::given(method("POST")) .and(path("/v1/chat/completions")) .and(header("authorization", "Bearer sk-test-key-2")) - .respond_with( - ResponseTemplate::new(200).set_body_json(json!({ - "id": "chatcmpl-claude123", - "object": "chat.completion", - "created": 1234567890, - "model": "claude-2", - "choices": [{ - "index": 0, - "message": { - "role": "assistant", - "content": "Response from Claude model" - }, - "finish_reason": "stop" - }] - })), - ) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "id": "chatcmpl-claude123", + "object": "chat.completion", + "created": 1234567890, + "model": "claude-2", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "Response from Claude model" + }, + "finish_reason": "stop" + }] + }))) .mount(&mock_server_2) .await; @@ -267,9 +261,30 @@ async fn test_api_model_routing() { #[tokio::test] async fn test_api_error_handling_no_available_keys() { - let (app, _mock_server_1, _mock_server_2) = create_test_app_with_mocks().await; + let (app, _mock_server_1, mock_server_2) = create_test_app_with_mocks().await; - // Test with a model that has no matching keys (no setup for a specific model) + // Setup mock response for the "others" key (key 2) to handle any model + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .and(header("authorization", "Bearer sk-test-key-2")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "id": "chatcmpl-others123", + "object": "chat.completion", + "created": 1234567890, + "model": "nonexistent-model", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "Response from others fallback" + }, + "finish_reason": "stop" + }] + }))) + .mount(&mock_server_2) + .await; + + // Test with a model that has no matching keys (but "others" should catch it) let request = Request::builder() .method("POST") .uri("/v1/chat/completions") @@ -278,7 +293,7 @@ async fn test_api_error_handling_no_available_keys() { json!({ "model": "nonexistent-model", "messages": [ - {"role": "user", "content": "This should fail"} + {"role": "user", "content": "This should be handled by others fallback"} ] }) .to_string(), @@ -287,10 +302,8 @@ async fn test_api_error_handling_no_available_keys() { let response = app.oneshot(request).await.unwrap(); - // Should return error when no keys match and no "others" fallback - // Note: In our setup, we have an "others" key, so this will actually work - // Let's create a scenario where truly no keys are available - assert!(response.status().is_success() || response.status().is_server_error()); + // Should succeed because we have an "others" key that matches all models + assert_eq!(response.status(), StatusCode::OK); } #[tokio::test] @@ -303,9 +316,9 @@ async fn test_api_streaming_response() { .and(header("authorization", "Bearer sk-test-key-1")) .respond_with( ResponseTemplate::new(200) - .set_body_string("data: {\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\ndata: {\"choices\":[{\"delta\":{\"content\":\" world!\"}}]}\n\ndata: [DONE]\n\n") - .insert_header("content-type", "text/event-stream") - .insert_header("cache-control", "no-cache"), + .set_body_bytes("data: {\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\ndata: {\"choices\":[{\"delta\":{\"content\":\" world!\"}}]}\n\ndata: [DONE]\n\n") + .append_header("content-type", "text/event-stream") + .append_header("cache-control", "no-cache"), ) .mount(&mock_server_1) .await; @@ -386,7 +399,10 @@ async fn test_api_malformed_json() { .await .unwrap(); let response_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); - assert!(response_json["error"].as_str().unwrap().contains("Invalid JSON")); + assert!(response_json["error"] + .as_str() + .unwrap() + .contains("Invalid JSON")); } #[tokio::test] @@ -396,43 +412,39 @@ async fn test_api_concurrent_requests() { // Setup both servers to respond successfully Mock::given(method("POST")) .and(path("/v1/chat/completions")) - .respond_with( - ResponseTemplate::new(200).set_body_json(json!({ - "id": "chatcmpl-concurrent", - "object": "chat.completion", - "created": 1234567890, - "model": "gpt-3.5-turbo", - "choices": [{ - "index": 0, - "message": { - "role": "assistant", - "content": "Concurrent response" - }, - "finish_reason": "stop" - }] - })), - ) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "id": "chatcmpl-concurrent", + "object": "chat.completion", + "created": 1234567890, + "model": "gpt-3.5-turbo", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "Concurrent response" + }, + "finish_reason": "stop" + }] + }))) .mount(&mock_server_1) .await; Mock::given(method("POST")) .and(path("/v1/chat/completions")) - .respond_with( - ResponseTemplate::new(200).set_body_json(json!({ - "id": "chatcmpl-concurrent", - "object": "chat.completion", - "created": 1234567890, - "model": "claude-2", - "choices": [{ - "index": 0, - "message": { - "role": "assistant", - "content": "Concurrent response" - }, - "finish_reason": "stop" - }] - })), - ) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "id": "chatcmpl-concurrent", + "object": "chat.completion", + "created": 1234567890, + "model": "claude-2", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "Concurrent response" + }, + "finish_reason": "stop" + }] + }))) .mount(&mock_server_2) .await; @@ -469,4 +481,4 @@ async fn test_api_concurrent_requests() { let response = response.unwrap(); assert_eq!(response.status(), StatusCode::OK); } -} \ No newline at end of file +} diff --git a/tests/enhanced_unit_tests.rs b/tests/enhanced_unit_tests.rs index 998e826..1fe236d 100644 --- a/tests/enhanced_unit_tests.rs +++ b/tests/enhanced_unit_tests.rs @@ -1,25 +1,29 @@ use key_cycle_proxy::{ config::{load_config, ApiKeyInfo, Config, UpstreamConfig}, - proxy::{KeyPool, ProxyEngine, UpstreamClient, ProxyError}, + proxy::{KeyPool, ProxyEngine, ProxyError, UpstreamClient}, types::{ErrorResponse, OpenAIRequest}, }; use secrecy::SecretString; +use std::io::Write; use std::{sync::Arc, time::Duration}; use tempfile::NamedTempFile; -use std::io::Write; #[test] fn test_config_loading_from_environment() { // Test environment variable configuration std::env::set_var("OPENAI_KEYS", "sk-key1,sk-key2,sk-key3"); - + // Clean up any existing config.json to ensure env var is used std::fs::remove_file("config.json").unwrap_or(()); - + let result = load_config(); std::env::remove_var("OPENAI_KEYS"); - - assert!(result.is_ok(), "Config loading should succeed with env var: {:?}", result.err()); + + assert!( + result.is_ok(), + "Config loading should succeed with env var: {:?}", + result.err() + ); let (_config, keys) = result.unwrap(); assert_eq!(keys.len(), 3); assert_eq!(keys[0].models, vec!["others"]); @@ -29,7 +33,9 @@ fn test_config_loading_from_environment() { fn test_config_loading_from_json() { // Create temporary config file let mut temp_file = NamedTempFile::new().unwrap(); - writeln!(temp_file, r#"{{ + writeln!( + temp_file, + r#"{{ "apiKeys": [ {{ "key": "sk-test-key-1", @@ -42,16 +48,18 @@ fn test_config_loading_from_json() { "models": ["claude-2", "others"] }} ] - }}"#).unwrap(); - + }}"# + ) + .unwrap(); + // Copy temp file to config.json in current directory std::fs::copy(temp_file.path(), "config.json").unwrap(); - + let result = load_config(); - + // Clean up std::fs::remove_file("config.json").unwrap_or(()); - + assert!(result.is_ok()); let (_config, keys) = result.unwrap(); assert_eq!(keys.len(), 2); @@ -64,10 +72,13 @@ fn test_config_error_handling_no_keys() { // Ensure no environment variable or config file exists std::env::remove_var("OPENAI_KEYS"); std::fs::remove_file("config.json").unwrap_or(()); - + let result = load_config(); assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("No API keys found")); + assert!(result + .unwrap_err() + .to_string() + .contains("No API keys found")); } #[test] @@ -79,11 +90,11 @@ fn test_api_key_model_support() { latency: None, health_score: 1.0, }; - + assert!(key_info.supports_model("gpt-3.5-turbo")); assert!(key_info.supports_model("gpt-4")); assert!(!key_info.supports_model("claude-2")); - + let fallback_key = ApiKeyInfo { key: SecretString::new("fallback-key".to_string()), url: "https://api.fallback.com".to_string(), @@ -91,7 +102,7 @@ fn test_api_key_model_support() { latency: None, health_score: 1.0, }; - + assert!(fallback_key.supports_model("any-model")); assert!(fallback_key.supports_model("claude-2")); assert!(fallback_key.supports_model("custom-model")); @@ -107,7 +118,7 @@ fn test_upstream_config_duration_conversion() { retry_max_backoff_ms: 2000, max_retries: 3, }; - + assert_eq!(config.connect_timeout(), Duration::from_millis(1000)); assert_eq!(config.request_timeout(), Duration::from_millis(5000)); assert_eq!(config.retry_max_backoff(), Duration::from_millis(2000)); @@ -138,18 +149,18 @@ fn test_key_pool_rotation_strategies() { health_score: 1.0, }, ]; - + // Test round-robin strategy let pool = KeyPool::new(keys.clone(), "round_robin"); let key1 = pool.get_key_for_model("gpt-3.5-turbo").unwrap(); let key2 = pool.get_key_for_model("gpt-3.5-turbo").unwrap(); assert_ne!(key1.url, key2.url); // Should rotate between keys - + // Test health-weighted strategy let pool = KeyPool::new(keys.clone(), "round_robin_health_weighted"); let key = pool.get_key_for_model("gpt-3.5-turbo"); assert!(key.is_some()); - + // Test least-latency strategy let pool = KeyPool::new(keys, "least_latency"); let key = pool.get_key_for_model("gpt-4"); @@ -174,13 +185,13 @@ fn test_key_pool_latency_tracking() { health_score: 1.0, }, ]; - + let pool = KeyPool::new(keys, "least_latency"); - + // Update latency measurements - pool.update_latency(0, Duration::from_millis(50)); // Fast key + pool.update_latency(0, Duration::from_millis(50)); // Fast key pool.update_latency(1, Duration::from_millis(200)); // Slow key - + // Should prefer the faster key in least_latency mode // Note: This tests the latency tracking mechanism } @@ -203,7 +214,7 @@ fn test_openai_request_parsing() { "temperature": 0.7, "max_tokens": 150 }"#; - + let request: OpenAIRequest = serde_json::from_str(json_str).unwrap(); assert_eq!(request.model, "gpt-3.5-turbo"); assert!(request.other.contains_key("messages")); @@ -221,11 +232,29 @@ fn test_openai_request_minimal() { #[test] fn test_proxy_error_status_codes() { - assert_eq!(ProxyError::NoKeyAvailable { model: "test".to_string() }.status_code(), axum::http::StatusCode::INTERNAL_SERVER_ERROR); - assert_eq!(ProxyError::MethodNotAllowed.status_code(), axum::http::StatusCode::METHOD_NOT_ALLOWED); - assert_eq!(ProxyError::Timeout.status_code(), axum::http::StatusCode::GATEWAY_TIMEOUT); - assert_eq!(ProxyError::RateLimited.status_code(), axum::http::StatusCode::TOO_MANY_REQUESTS); - assert_eq!(ProxyError::PayloadTooLarge.status_code(), axum::http::StatusCode::PAYLOAD_TOO_LARGE); + assert_eq!( + ProxyError::NoKeyAvailable { + model: "test".to_string() + } + .status_code(), + axum::http::StatusCode::INTERNAL_SERVER_ERROR + ); + assert_eq!( + ProxyError::MethodNotAllowed.status_code(), + axum::http::StatusCode::METHOD_NOT_ALLOWED + ); + assert_eq!( + ProxyError::Timeout.status_code(), + axum::http::StatusCode::GATEWAY_TIMEOUT + ); + assert_eq!( + ProxyError::RateLimited.status_code(), + axum::http::StatusCode::TOO_MANY_REQUESTS + ); + assert_eq!( + ProxyError::PayloadTooLarge.status_code(), + axum::http::StatusCode::PAYLOAD_TOO_LARGE + ); } #[test] @@ -239,7 +268,7 @@ fn test_config_defaults() { assert_eq!(config.observability.tracing_level, "info"); } -#[test] +#[test] fn test_upstream_client_creation() { let config = UpstreamConfig::default(); let client = UpstreamClient::new(config); @@ -255,12 +284,12 @@ fn test_proxy_engine_creation() { latency: None, health_score: 1.0, }]; - + let key_pool = Arc::new(KeyPool::new(keys, "round_robin")); let upstream_config = UpstreamConfig::default(); let upstream_client = UpstreamClient::new(upstream_config).unwrap(); let _engine = ProxyEngine::new(key_pool, upstream_client, 3); - + // Engine should be created successfully // This is a smoke test for the constructor } @@ -276,10 +305,10 @@ fn test_json_error_handling() { fn test_empty_key_pool() { let empty_keys: Vec = vec![]; let pool = KeyPool::new(empty_keys, "round_robin"); - + let result = pool.get_key_for_model("gpt-3.5-turbo"); assert!(result.is_none()); - + let result = pool.get_next_key(); assert!(result.is_none()); } @@ -288,7 +317,7 @@ fn test_empty_key_pool() { fn test_key_pool_concurrent_access() { use std::sync::Arc; use std::thread; - + let keys = vec![ ApiKeyInfo { key: SecretString::new("key1".to_string()), @@ -305,9 +334,9 @@ fn test_key_pool_concurrent_access() { health_score: 1.0, }, ]; - + let pool = Arc::new(KeyPool::new(keys, "round_robin")); - + // Spawn multiple threads accessing the pool concurrently let mut handles = vec![]; for _ in 0..10 { @@ -320,11 +349,11 @@ fn test_key_pool_concurrent_access() { }); handles.push(handle); } - + for handle in handles { handle.join().unwrap(); } - + // If we get here without deadlocks or panics, concurrent access works } @@ -338,14 +367,17 @@ async fn test_latency_measurement_timeout() { latency: None, health_score: 1.0, }]; - + let pool = KeyPool::new(keys, "round_robin"); - + // This should complete without hanging, even though the URL is unreachable let start = std::time::Instant::now(); pool.update_all_latencies().await; let duration = start.elapsed(); - + // Should complete within reasonable time (timeout mechanism working) - assert!(duration < Duration::from_secs(10), "Latency measurement should timeout quickly"); -} \ No newline at end of file + assert!( + duration < Duration::from_secs(10), + "Latency measurement should timeout quickly" + ); +} diff --git a/tests/performance_tests.rs b/tests/performance_tests.rs index 988e886..1c15ded 100644 --- a/tests/performance_tests.rs +++ b/tests/performance_tests.rs @@ -6,7 +6,10 @@ use key_cycle_proxy::{ }; use secrecy::SecretString; use serde_json::json; -use std::{sync::Arc, time::{Duration, Instant}}; +use std::{ + sync::Arc, + time::{Duration, Instant}, +}; use tokio::time::timeout; use tower::ServiceExt; use wiremock::{matchers::method, Mock, MockServer, ResponseTemplate}; @@ -20,7 +23,7 @@ async fn create_performance_test_app() -> (Router, Vec) { // Create 5 mock servers for load testing for i in 0..5 { let mock_server = MockServer::start().await; - + // Setup each server to respond with a slight delay to simulate real API latency Mock::given(method("POST")) .respond_with( @@ -44,7 +47,7 @@ async fn create_performance_test_app() -> (Router, Vec) { "completion_tokens": 5, "total_tokens": 15 } - })) + })), ) .mount(&mock_server) .await; @@ -82,10 +85,10 @@ async fn create_performance_test_app() -> (Router, Vec) { #[tokio::test] async fn test_load_performance_100_concurrent_requests() { let (app, _mock_servers) = create_performance_test_app().await; - + let start_time = Instant::now(); let num_requests = 100; - + // Create 100 concurrent requests let mut handles = vec![]; for i in 0..num_requests { @@ -109,7 +112,7 @@ async fn test_load_performance_100_concurrent_requests() { let request_start = Instant::now(); let response = app_clone.oneshot(request).await.unwrap(); let request_duration = request_start.elapsed(); - + (response.status(), request_duration) }); handles.push(handle); @@ -121,13 +124,13 @@ async fn test_load_performance_100_concurrent_requests() { .expect("Load test should complete within 30 seconds"); let total_duration = start_time.elapsed(); - + // Analyze results let mut successful_requests = 0; let mut total_request_time = Duration::ZERO; let mut max_request_time = Duration::ZERO; let mut min_request_time = Duration::from_secs(999); - + for result in results { let (status, duration) = result.unwrap(); if status.is_success() { @@ -137,32 +140,46 @@ async fn test_load_performance_100_concurrent_requests() { min_request_time = min_request_time.min(duration); } } - + // Performance assertions - assert!(successful_requests >= num_requests * 95 / 100, - "At least 95% of requests should succeed, got {}/{}", successful_requests, num_requests); - + assert!( + successful_requests >= num_requests * 95 / 100, + "At least 95% of requests should succeed, got {}/{}", + successful_requests, + num_requests + ); + let avg_request_time = total_request_time / successful_requests as u32; let throughput = successful_requests as f64 / total_duration.as_secs_f64(); - + println!("Load Test Results:"); println!(" Total time: {:?}", total_duration); - println!(" Successful requests: {}/{}", successful_requests, num_requests); + println!( + " Successful requests: {}/{}", + successful_requests, num_requests + ); println!(" Throughput: {:.2} req/s", throughput); println!(" Average request time: {:?}", avg_request_time); println!(" Min request time: {:?}", min_request_time); println!(" Max request time: {:?}", max_request_time); - + // Performance benchmarks (these are reasonable expectations) - assert!(throughput > 20.0, "Throughput should be > 20 req/s, got {:.2}", throughput); - assert!(avg_request_time < Duration::from_millis(500), - "Average request time should be < 500ms, got {:?}", avg_request_time); + assert!( + throughput > 20.0, + "Throughput should be > 20 req/s, got {:.2}", + throughput + ); + assert!( + avg_request_time < Duration::from_millis(500), + "Average request time should be < 500ms, got {:?}", + avg_request_time + ); } #[tokio::test] async fn test_latency_key_selection() { let (app, _mock_servers) = create_performance_test_app().await; - + // Make several requests to populate latency measurements for i in 0..10 { let request = Request::builder() @@ -179,14 +196,14 @@ async fn test_latency_key_selection() { .to_string(), )) .unwrap(); - + let response = app.clone().oneshot(request).await.unwrap(); assert!(response.status().is_success()); - + // Small delay to allow latency measurements to update tokio::time::sleep(Duration::from_millis(10)).await; } - + // The key pool should now have latency data and prefer faster servers // This is more of a smoke test since latency selection is internal } @@ -194,14 +211,14 @@ async fn test_latency_key_selection() { #[tokio::test] async fn test_memory_usage_under_load() { let (app, _mock_servers) = create_performance_test_app().await; - + // This test ensures we don't have memory leaks under sustained load let num_batches = 10; let requests_per_batch = 20; - + for batch in 0..num_batches { let mut handles = vec![]; - + for i in 0..requests_per_batch { let app_clone = app.clone(); let handle = tokio::spawn(async move { @@ -219,33 +236,36 @@ async fn test_memory_usage_under_load() { .to_string(), )) .unwrap(); - + app_clone.oneshot(request).await.unwrap() }); handles.push(handle); } - + // Wait for batch to complete let results = futures::future::join_all(handles).await; - + // Verify all requests in batch succeeded for result in results { let response = result.unwrap(); assert!(response.status().is_success()); } - + // Small delay between batches tokio::time::sleep(Duration::from_millis(100)).await; } - + // If we get here without panicking or timing out, memory usage is likely stable - println!("Completed {} batches of {} requests each", num_batches, requests_per_batch); + println!( + "Completed {} batches of {} requests each", + num_batches, requests_per_batch + ); } #[tokio::test] async fn test_error_resilience_under_load() { let (app, mock_servers) = create_performance_test_app().await; - + // Configure some servers to fail intermittently for (i, server) in mock_servers.iter().enumerate() { if i % 2 == 0 { @@ -262,10 +282,10 @@ async fn test_error_resilience_under_load() { .await; } } - + let num_requests = 50; let mut handles = vec![]; - + for i in 0..num_requests { let app_clone = app.clone(); let handle = tokio::spawn(async move { @@ -283,17 +303,17 @@ async fn test_error_resilience_under_load() { .to_string(), )) .unwrap(); - + app_clone.oneshot(request).await.unwrap() }); handles.push(handle); } - + let results = futures::future::join_all(handles).await; - + let mut successful = 0; let mut failed = 0; - + for result in results { let response = result.unwrap(); if response.status().is_success() { @@ -302,15 +322,22 @@ async fn test_error_resilience_under_load() { failed += 1; } } - - println!("Resilience test: {} successful, {} failed", successful, failed); - + + println!( + "Resilience test: {} successful, {} failed", + successful, failed + ); + // Even with some servers failing, we should still have high success rate due to key rotation - assert!(successful >= num_requests * 80 / 100, - "At least 80% requests should succeed due to key rotation, got {}/{}", successful, num_requests); + assert!( + successful >= num_requests * 80 / 100, + "At least 80% requests should succeed due to key rotation, got {}/{}", + successful, + num_requests + ); } -#[tokio::test] +#[tokio::test] async fn test_timeout_handling() { // Create a mock server that responds very slowly let slow_server = MockServer::start().await; @@ -318,11 +345,11 @@ async fn test_timeout_handling() { .respond_with( ResponseTemplate::new(200) .set_delay(Duration::from_secs(10)) // Very slow response - .set_body_json(json!({"response": "too slow"})) + .set_body_json(json!({"response": "too slow"})), ) .mount(&slow_server) .await; - + let keys = vec![ApiKeyInfo { key: SecretString::new("sk-slow-key".to_string()), url: slow_server.uri(), @@ -330,7 +357,7 @@ async fn test_timeout_handling() { latency: None, health_score: 1.0, }]; - + let key_pool = Arc::new(KeyPool::new(keys, "round_robin")); let upstream_config = UpstreamConfig { base_url: "http://timeout-test.com/v1".to_string(), @@ -344,7 +371,7 @@ async fn test_timeout_handling() { let engine = Arc::new(ProxyEngine::new(key_pool, upstream_client, 1)); let handler = Arc::new(ProxyHandler::new(engine)); let app = create_router(handler, 1024 * 1024, Duration::from_secs(2)); - + let request = Request::builder() .method("POST") .uri("/v1/chat/completions") @@ -359,13 +386,19 @@ async fn test_timeout_handling() { .to_string(), )) .unwrap(); - + let start = Instant::now(); let response = app.oneshot(request).await.unwrap(); let duration = start.elapsed(); - + // Should fail due to timeout, and should fail relatively quickly - assert!(response.status().is_server_error() || response.status().is_client_error(), - "Expected error status but got: {}", response.status()); - assert!(duration < Duration::from_secs(5), "Timeout should be handled quickly"); -} \ No newline at end of file + assert!( + response.status().is_server_error() || response.status().is_client_error(), + "Expected error status but got: {}", + response.status() + ); + assert!( + duration < Duration::from_secs(5), + "Timeout should be handled quickly" + ); +}