diff --git a/Cargo.lock b/Cargo.lock index 031e9790..d26e012b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5464,7 +5464,9 @@ name = "pegainfer-vllm-frontend" version = "0.1.0" dependencies = [ "anyhow", + "async-stream", "axum", + "futures", "log", "pegainfer-engine", "rmp-serde", diff --git a/pegainfer-deepseek-v2-lite/src/engine.rs b/pegainfer-deepseek-v2-lite/src/engine.rs index 4b995984..94550573 100644 --- a/pegainfer-deepseek-v2-lite/src/engine.rs +++ b/pegainfer-deepseek-v2-lite/src/engine.rs @@ -35,6 +35,7 @@ fn handle_request(generator: &mut DeepSeekV2LiteEp2Generator, req: &GenerateRequ queued_at_unix_s: req.queued_at_unix_s.unwrap_or(now), scheduled_at_unix_s: now, prompt_tokens, + cached_tokens: 0, }); if req.echo { let _ = req.token_tx.send(TokenEvent::PromptTokens { diff --git a/pegainfer-deepseek-v4/src/direct/scheduler.rs b/pegainfer-deepseek-v4/src/direct/scheduler.rs index 9605f96c..733dc1e0 100644 --- a/pegainfer-deepseek-v4/src/direct/scheduler.rs +++ b/pegainfer-deepseek-v4/src/direct/scheduler.rs @@ -1179,6 +1179,7 @@ fn handle_request(generator: &mut DeepSeekV4DirectGenerator, req: GenerateReques queued_at_unix_s, scheduled_at_unix_s, prompt_tokens: prompt_len, + cached_tokens: 0, }); if req.echo { let _ = req.token_tx.send(TokenEvent::PromptTokens { @@ -1356,6 +1357,7 @@ fn handle_request_wave(generator: &mut DeepSeekV4DirectGenerator, requests: Vec< queued_at_unix_s, scheduled_at_unix_s, prompt_tokens: prompt_len, + cached_tokens: 0, }); if req.echo { let _ = req.token_tx.send(TokenEvent::PromptTokens { diff --git a/pegainfer-engine/src/engine.rs b/pegainfer-engine/src/engine.rs index 2bb6b9c5..11295ea2 100644 --- a/pegainfer-engine/src/engine.rs +++ b/pegainfer-engine/src/engine.rs @@ -134,6 +134,7 @@ pub enum TokenEvent { queued_at_unix_s: f64, scheduled_at_unix_s: f64, prompt_tokens: usize, + cached_tokens: usize, }, Token { id: u32, diff --git a/pegainfer-kimi-k2/src/runner/scheduler/lifecycle.rs b/pegainfer-kimi-k2/src/runner/scheduler/lifecycle.rs index 328fba27..109b775a 100644 --- a/pegainfer-kimi-k2/src/runner/scheduler/lifecycle.rs +++ b/pegainfer-kimi-k2/src/runner/scheduler/lifecycle.rs @@ -95,6 +95,7 @@ pub(in crate::runner) fn send_scheduled(req: &GenerateRequest) { queued_at_unix_s: req.queued_at_unix_s.unwrap_or(scheduled_at), scheduled_at_unix_s: scheduled_at, prompt_tokens: req.prompt_tokens.len(), + cached_tokens: 0, }); } diff --git a/pegainfer-qwen3-4b/src/scheduler.rs b/pegainfer-qwen3-4b/src/scheduler.rs index 7a108842..85f5faa5 100644 --- a/pegainfer-qwen3-4b/src/scheduler.rs +++ b/pegainfer-qwen3-4b/src/scheduler.rs @@ -11,6 +11,7 @@ mod resolve; use std::collections::{HashSet, VecDeque}; use std::thread; +use std::time::{SystemTime, UNIX_EPOCH}; use anyhow::Result; use log::{info, warn}; @@ -24,7 +25,7 @@ use pegainfer_core::engine::{ }; use pegainfer_core::sampler::SamplingParams; -use self::effects::apply_effects; +use self::effects::{PrefixCacheStats, apply_effects}; use self::plan::{build_next_plan, execute_plan}; use self::resolve::resolve_step; @@ -47,6 +48,8 @@ pub(super) struct ActiveRequestState { pub(super) struct PendingRequest { pub(super) request_id: RequestId, pub(super) lora_adapter: Option, + pub(super) queued_at_unix_s: Option, + pub(super) scheduled_at_unix_s: Option, pub(super) prompt_tokens: Vec, pub(super) params: SamplingParams, pub(super) max_tokens: usize, @@ -60,6 +63,8 @@ impl PendingRequest { Self { request_id, lora_adapter: req.lora_adapter, + queued_at_unix_s: req.queued_at_unix_s, + scheduled_at_unix_s: None, prompt_tokens: req.prompt_tokens, params: req.params, max_tokens: req.max_tokens, @@ -136,6 +141,7 @@ fn scheduler_loop( let mut rng = StdRng::seed_from_u64(seed); let mut active: Vec = Vec::new(); let mut next_request_id = 0u64; + let mut prefix_cache_stats = PrefixCacheStats::default(); // Requests that could not be admitted due to KV budget pressure. // Held here so they aren't lost; re-evaluated every loop iteration. let mut deferred: Vec = Vec::new(); @@ -205,7 +211,7 @@ fn scheduler_loop( } }; let effects = resolve_step(&executor, &active, artifacts); - apply_effects(&mut executor, &mut active, effects); + apply_effects(&mut executor, &mut active, effects, &mut prefix_cache_stats); } } @@ -219,6 +225,7 @@ fn scheduler_loop_with_lora_control( let mut rng = StdRng::seed_from_u64(seed); let mut active: Vec = Vec::new(); let mut next_request_id = 0u64; + let mut prefix_cache_stats = PrefixCacheStats::default(); let mut deferred: Vec = Vec::new(); let mut pending_control: VecDeque = VecDeque::new(); let mut post_control_deferred: Vec = Vec::new(); @@ -327,7 +334,7 @@ fn scheduler_loop_with_lora_control( } }; let effects = resolve_step(&executor, &active, artifacts); - apply_effects(&mut executor, &mut active, effects); + apply_effects(&mut executor, &mut active, effects, &mut prefix_cache_stats); } } @@ -563,6 +570,13 @@ fn send_unknown_lora_rejection(req: &PendingRequest) { }); } +fn now_secs_f64() -> f64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system clock before unix epoch") + .as_secs_f64() +} + fn failure_targets_for( active: &[ActiveRequestState], plan: &self::plan::ExecutionPlan, @@ -648,6 +662,7 @@ mod tests { held_tokens: HashMap, fail_decode_once: bool, decode_delay: Duration, + cached_tokens: usize, loaded_lora_adapters: HashSet, active_lora_adapter: Option, lora_activations: Arc>>>, @@ -666,6 +681,7 @@ mod tests { held_tokens: HashMap::new(), fail_decode_once: false, decode_delay: Duration::ZERO, + cached_tokens: 0, loaded_lora_adapters: HashSet::new(), active_lora_adapter: None, lora_activations: Arc::new(Mutex::new(Vec::new())), @@ -690,6 +706,11 @@ mod tests { self } + fn with_cached_tokens(mut self, cached_tokens: usize) -> Self { + self.cached_tokens = cached_tokens; + self + } + fn with_lora_adapters( mut self, names: &[&str], @@ -808,7 +829,7 @@ mod tests { first_token: 100 + req.request_id.get() as u32, first_token_logprob: None, prompt_logprobs: None, - cached_tokens: 0, + cached_tokens: self.cached_tokens, }) .collect(), }) @@ -888,7 +909,7 @@ mod tests { first_token: 100 + req.request_id.get() as u32, first_token_logprob: None, prompt_logprobs: None, - cached_tokens: 0, + cached_tokens: self.cached_tokens, }) .collect(), decode_requests: plan @@ -1046,11 +1067,21 @@ mod tests { let (fits_exactly, mut rx) = request(16, 1); handle.submit(fits_exactly).expect("submit fits_exactly"); assert!( - matches!(rx.blocking_recv(), Some(TokenEvent::Token { id: 100, .. })), + matches!(rx.blocking_recv(), Some(TokenEvent::Scheduled { .. })), + "prefill should emit the scheduled event first" + ); + assert!( + matches!( + next_non_scheduled(&mut rx), + Some(TokenEvent::Token { id: 100, .. }) + ), "prefill should emit the sampled token" ); assert!( - matches!(rx.blocking_recv(), Some(TokenEvent::Finished { .. })), + matches!( + next_non_scheduled(&mut rx), + Some(TokenEvent::Finished { .. }) + ), "one-token completion should finish without a decode KV page" ); assert!( @@ -1067,9 +1098,13 @@ mod tests { let (long_running, mut long_rx) = request(16, 18); handle.submit(long_running).expect("submit long_running"); + assert!( + matches!(long_rx.blocking_recv(), Some(TokenEvent::Scheduled { .. })), + "first request should report scheduling" + ); assert!( matches!( - long_rx.blocking_recv(), + next_non_scheduled(&mut long_rx), Some(TokenEvent::Token { id: 100, .. }) ), "first request should prefill" @@ -1078,9 +1113,13 @@ mod tests { let (must_wait, mut wait_rx) = request(17, 1); handle.submit(must_wait).expect("submit must_wait"); + assert!( + matches!(wait_rx.blocking_recv(), Some(TokenEvent::Scheduled { .. })), + "waiting request should report scheduling" + ); assert!( matches!( - wait_rx.blocking_recv(), + next_non_scheduled(&mut wait_rx), Some(TokenEvent::Token { id: 101, .. }) ), "waiting request should start once the active request releases its full KV budget" @@ -1090,7 +1129,10 @@ mod tests { "second request was admitted before the first request released KV" ); assert!( - matches!(wait_rx.blocking_recv(), Some(TokenEvent::Finished { .. })), + matches!( + next_non_scheduled(&mut wait_rx), + Some(TokenEvent::Finished { .. }) + ), "waiting request should finish after admission" ); } @@ -1126,6 +1168,34 @@ mod tests { (request, token_rx) } + fn request_with_queue_time( + prompt_len: usize, + max_tokens: usize, + queued_at_unix_s: f64, + ) -> (GenerateRequest, mpsc::UnboundedReceiver) { + let (mut request, token_rx) = request(prompt_len, max_tokens); + request.queued_at_unix_s = Some(queued_at_unix_s); + (request, token_rx) + } + + fn echo_request( + prompt_len: usize, + max_tokens: usize, + ) -> (GenerateRequest, mpsc::UnboundedReceiver) { + let (mut request, token_rx) = request(prompt_len, max_tokens); + request.echo = true; + (request, token_rx) + } + + fn next_non_scheduled(rx: &mut mpsc::UnboundedReceiver) -> Option { + loop { + match rx.blocking_recv() { + Some(TokenEvent::Scheduled { .. }) => continue, + other => return other, + } + } + } + fn wait_until(timeout: Duration, mut predicate: impl FnMut() -> bool) -> bool { let start = Instant::now(); while start.elapsed() < timeout { @@ -1137,6 +1207,147 @@ mod tests { false } + #[test] + fn scheduled_event_preserves_cached_tokens_and_order() { + let dropped = Arc::new(Mutex::new(Vec::new())); + let executor = FakeExecutor::new(4, Arc::clone(&dropped)).with_cached_tokens(12); + let handle = start_with_executor(executor, 42); + + let queued_at = 1234.5; + let (request, mut rx) = request_with_queue_time(16, 1, queued_at); + handle.submit(request).expect("submit cached request"); + + match rx.blocking_recv() { + Some(TokenEvent::Scheduled { + queued_at_unix_s, + scheduled_at_unix_s, + prompt_tokens, + cached_tokens, + }) => { + assert_eq!(queued_at_unix_s, queued_at); + assert!(scheduled_at_unix_s >= queued_at); + assert_eq!(prompt_tokens, 16); + assert_eq!(cached_tokens, 12); + } + _ => panic!("expected Scheduled before token/finish"), + } + assert!( + matches!( + next_non_scheduled(&mut rx), + Some(TokenEvent::Token { id: 100, .. }) + ), + "token should follow Scheduled" + ); + assert!( + matches!( + next_non_scheduled(&mut rx), + Some(TokenEvent::Finished { .. }) + ), + "finish should follow token" + ); + } + + #[test] + fn scheduled_event_queue_time_falls_back_to_scheduled_time() { + let dropped = Arc::new(Mutex::new(Vec::new())); + let executor = FakeExecutor::new(4, Arc::clone(&dropped)); + let handle = start_with_executor(executor, 42); + + let (request, mut rx) = request(16, 1); + handle.submit(request).expect("submit request"); + + match rx.blocking_recv() { + Some(TokenEvent::Scheduled { + queued_at_unix_s, + scheduled_at_unix_s, + cached_tokens, + .. + }) => { + assert_eq!(queued_at_unix_s, scheduled_at_unix_s); + assert_eq!(cached_tokens, 0); + } + _ => panic!("expected Scheduled before token/finish"), + } + } + + #[test] + fn echo_request_emits_scheduled_before_prompt_tokens() { + let dropped = Arc::new(Mutex::new(Vec::new())); + let executor = FakeExecutor::new(4, Arc::clone(&dropped)).with_cached_tokens(8); + let handle = start_with_executor(executor, 42); + + let (request, mut rx) = echo_request(16, 1); + handle.submit(request).expect("submit echo request"); + + match rx.blocking_recv() { + Some(TokenEvent::Scheduled { + prompt_tokens, + cached_tokens, + .. + }) => { + assert_eq!(prompt_tokens, 16); + assert_eq!(cached_tokens, 8); + } + _ => panic!("expected Scheduled before prompt echo"), + } + match rx.blocking_recv() { + Some(TokenEvent::PromptTokens { ids, logprobs }) => { + assert_eq!(ids, vec![1; 16]); + assert_eq!(logprobs, vec![None; 16]); + } + _ => panic!("expected PromptTokens after Scheduled"), + } + assert!( + matches!(rx.blocking_recv(), Some(TokenEvent::Token { id: 100, .. })), + "sampled token should follow prompt echo" + ); + assert!( + matches!(rx.blocking_recv(), Some(TokenEvent::Finished { .. })), + "one-token echo request should finish after token" + ); + } + + #[test] + fn prefix_cache_stats_accumulate_cumulative_engine_totals() { + let mut stats = PrefixCacheStats::default(); + + stats.observe(effects::ScheduledEffect { + queued_at_unix_s: 1.0, + scheduled_at_unix_s: 2.0, + prompt_tokens: 16, + cached_tokens: 0, + }); + stats.observe(effects::ScheduledEffect { + queued_at_unix_s: 3.0, + scheduled_at_unix_s: 4.0, + prompt_tokens: 16, + cached_tokens: 12, + }); + stats.observe(effects::ScheduledEffect { + queued_at_unix_s: 5.0, + scheduled_at_unix_s: 6.0, + prompt_tokens: 8, + cached_tokens: 0, + }); + + let ( + total_requests, + hit_requests, + miss_requests, + hit_rate, + total_prompt_tokens, + total_cached_tokens, + token_hit_rate, + ) = stats.totals(); + assert_eq!(total_requests, 3); + assert_eq!(hit_requests, 1); + assert_eq!(miss_requests, 2); + assert!((hit_rate - (1.0 / 3.0)).abs() < f64::EPSILON); + assert_eq!(total_prompt_tokens, 40); + assert_eq!(total_cached_tokens, 12); + assert!((token_hit_rate - 0.3).abs() < f64::EPSILON); + } + #[test] fn impossible_request_is_rejected_without_blocking_later_work() { let dropped = Arc::new(Mutex::new(Vec::new())); @@ -1145,7 +1356,7 @@ mod tests { let (too_large, mut too_large_rx) = request(16, 34); handle.submit(too_large).expect("submit too_large"); - match too_large_rx.blocking_recv() { + match next_non_scheduled(&mut too_large_rx) { Some(TokenEvent::Rejected { prompt_tokens, completion_tokens, @@ -1160,12 +1371,19 @@ mod tests { let (fits, mut fits_rx) = request(16, 1); handle.submit(fits).expect("submit fits"); - match fits_rx.blocking_recv() { + assert!( + matches!(fits_rx.blocking_recv(), Some(TokenEvent::Scheduled { .. })), + "later fitting request should report scheduling" + ); + match next_non_scheduled(&mut fits_rx) { Some(TokenEvent::Token { id, .. }) => assert_eq!(id, 101), _ => panic!("later fitting request should emit a token"), } assert!( - matches!(fits_rx.blocking_recv(), Some(TokenEvent::Finished { .. })), + matches!( + next_non_scheduled(&mut fits_rx), + Some(TokenEvent::Finished { .. }) + ), "later fitting request should finish" ); } @@ -1281,7 +1499,7 @@ mod tests { handle.submit(unknown).expect("submit unknown adapter"); handle.submit(base).expect("submit base"); - match unknown_rx.blocking_recv() { + match next_non_scheduled(&mut unknown_rx) { Some(TokenEvent::Rejected { message, prompt_tokens, @@ -1294,15 +1512,22 @@ mod tests { _ => panic!("unknown adapter request should be rejected"), } + assert!( + matches!(base_rx.blocking_recv(), Some(TokenEvent::Scheduled { .. })), + "base request should report scheduling" + ); assert!( matches!( - base_rx.blocking_recv(), + next_non_scheduled(&mut base_rx), Some(TokenEvent::Token { id: 101, .. }) ), "base request should still run after unknown adapter rejection" ); assert!( - matches!(base_rx.blocking_recv(), Some(TokenEvent::Finished { .. })), + matches!( + next_non_scheduled(&mut base_rx), + Some(TokenEvent::Finished { .. }) + ), "base request should finish" ); } @@ -1315,14 +1540,18 @@ mod tests { let (will_fail, mut fail_rx) = request(16, 2); handle.submit(will_fail).expect("submit will_fail"); + assert!( + matches!(fail_rx.blocking_recv(), Some(TokenEvent::Scheduled { .. })), + "scheduled event should be emitted before first token" + ); assert!( matches!( - fail_rx.blocking_recv(), + next_non_scheduled(&mut fail_rx), Some(TokenEvent::Token { id: 100, .. }) ), "first token should be emitted before decode failure" ); - match fail_rx.blocking_recv() { + match next_non_scheduled(&mut fail_rx) { Some(TokenEvent::Error { message, prompt_tokens, @@ -1344,15 +1573,22 @@ mod tests { let (after_failure, mut after_rx) = request(16, 1); handle.submit(after_failure).expect("submit after_failure"); + assert!( + matches!(after_rx.blocking_recv(), Some(TokenEvent::Scheduled { .. })), + "request after failure should report scheduling" + ); assert!( matches!( - after_rx.blocking_recv(), + next_non_scheduled(&mut after_rx), Some(TokenEvent::Token { id: 101, .. }) ), "scheduler should accept new work after a decode error" ); assert!( - matches!(after_rx.blocking_recv(), Some(TokenEvent::Finished { .. })), + matches!( + next_non_scheduled(&mut after_rx), + Some(TokenEvent::Finished { .. }) + ), "request after failure should finish" ); } @@ -1367,9 +1603,13 @@ mod tests { handle .submit(will_disconnect) .expect("submit will_disconnect"); + assert!( + matches!(token_rx.blocking_recv(), Some(TokenEvent::Scheduled { .. })), + "prefill should report scheduling" + ); assert!( matches!( - token_rx.blocking_recv(), + next_non_scheduled(&mut token_rx), Some(TokenEvent::Token { id: 100, .. }) ), "prefill should emit the first token" @@ -1409,11 +1649,11 @@ mod tests { .expect("seed fake request state"); } + let mut prefix_cache_stats = PrefixCacheStats::default(); apply_effects( &mut executor, &mut active, effects::StepEffects { - prompt_echoes: Vec::new(), pending: Vec::new(), decode: vec![ effects::DecodeEffect::EmitAndFinish { @@ -1439,6 +1679,7 @@ mod tests { }, ], }, + &mut prefix_cache_stats, ); assert!( @@ -1508,9 +1749,13 @@ mod tests { let (long_running, mut token_rx) = request(16, 3); handle.submit(long_running).expect("submit long_running"); + assert!( + matches!(token_rx.blocking_recv(), Some(TokenEvent::Scheduled { .. })), + "scheduled event should be emitted before first token" + ); assert!( matches!( - token_rx.blocking_recv(), + next_non_scheduled(&mut token_rx), Some(TokenEvent::Token { id: 100, .. }) ), "first token should be emitted before decode" @@ -1538,7 +1783,10 @@ mod tests { "load_lora_adapter should wait while generation is active" ); - while !matches!(token_rx.blocking_recv(), Some(TokenEvent::Finished { .. })) {} + while !matches!( + next_non_scheduled(&mut token_rx), + Some(TokenEvent::Finished { .. }) + ) {} let error = load_thread .join() diff --git a/pegainfer-qwen3-4b/src/scheduler/effects.rs b/pegainfer-qwen3-4b/src/scheduler/effects.rs index 456dc87e..6daecec6 100644 --- a/pegainfer-qwen3-4b/src/scheduler/effects.rs +++ b/pegainfer-qwen3-4b/src/scheduler/effects.rs @@ -1,3 +1,4 @@ +use log::debug; use tokio::sync::mpsc; use crate::executor::RequestId; @@ -6,30 +7,41 @@ use pegainfer_core::engine::{FinishReason, TokenLogprob}; use super::{ActiveRequestState, TokenEvent}; pub(super) struct PromptEchoEffect { - pub(super) token_tx: mpsc::UnboundedSender, pub(super) ids: Vec, pub(super) logprobs: Vec>, } +#[derive(Clone, Copy)] +pub(super) struct ScheduledEffect { + pub(super) queued_at_unix_s: f64, + pub(super) scheduled_at_unix_s: f64, + pub(super) prompt_tokens: usize, + pub(super) cached_tokens: usize, +} + pub(super) enum PendingEffect { Finish { request_id: RequestId, token_tx: mpsc::UnboundedSender, + scheduled: ScheduledEffect, + prompt_echo: Option, finish_reason: FinishReason, - prompt_tokens: usize, completion_tokens: usize, }, EmitAndFinish { request_id: RequestId, token_tx: mpsc::UnboundedSender, + scheduled: ScheduledEffect, + prompt_echo: Option, token: u32, logprob: Option, finish_reason: FinishReason, - prompt_tokens: usize, completion_tokens: usize, }, Promote { state: ActiveRequestState, + scheduled: ScheduledEffect, + prompt_echo: Option, first_token: u32, logprob: Option, }, @@ -57,33 +69,110 @@ pub(super) enum DecodeEffect { } pub(super) struct StepEffects { - pub(super) prompt_echoes: Vec, pub(super) pending: Vec, pub(super) decode: Vec, } +#[derive(Default)] +pub(super) struct PrefixCacheStats { + total_requests: u64, + hit_requests: u64, + miss_requests: u64, + total_prompt_tokens: u64, + total_cached_tokens: u64, +} + +impl PrefixCacheStats { + pub(super) fn observe(&mut self, scheduled: ScheduledEffect) { + self.total_requests += 1; + if scheduled.cached_tokens > 0 { + self.hit_requests += 1; + } else { + self.miss_requests += 1; + } + self.total_prompt_tokens += scheduled.prompt_tokens as u64; + self.total_cached_tokens += scheduled.cached_tokens as u64; + } + + pub(super) fn log_snapshot(&self) { + debug!( + "Qwen3 prefix cache stats: total_requests={}, hit_requests={}, miss_requests={}, hit_rate={:.4}, total_prompt_tokens={}, total_cached_tokens={}, token_hit_rate={:.4}", + self.total_requests, + self.hit_requests, + self.miss_requests, + self.hit_rate(), + self.total_prompt_tokens, + self.total_cached_tokens, + self.token_hit_rate() + ); + } + + fn hit_rate(&self) -> f64 { + if self.total_requests == 0 { + return 0.0; + } + self.hit_requests as f64 / self.total_requests as f64 + } + + fn token_hit_rate(&self) -> f64 { + if self.total_prompt_tokens == 0 { + return 0.0; + } + self.total_cached_tokens as f64 / self.total_prompt_tokens as f64 + } + + #[cfg(test)] + pub(super) fn totals(&self) -> (u64, u64, u64, f64, u64, u64, f64) { + ( + self.total_requests, + self.hit_requests, + self.miss_requests, + self.hit_rate(), + self.total_prompt_tokens, + self.total_cached_tokens, + self.token_hit_rate(), + ) + } +} + impl StepEffects { pub(super) fn empty() -> Self { Self { - prompt_echoes: Vec::new(), pending: Vec::new(), decode: Vec::new(), } } } +fn send_pending_scheduled_and_echo( + token_tx: &mpsc::UnboundedSender, + scheduled: ScheduledEffect, + prompt_echo: Option, + prefix_cache_stats: &mut PrefixCacheStats, +) -> Result<(), mpsc::error::SendError> { + prefix_cache_stats.observe(scheduled); + prefix_cache_stats.log_snapshot(); + token_tx.send(TokenEvent::Scheduled { + queued_at_unix_s: scheduled.queued_at_unix_s, + scheduled_at_unix_s: scheduled.scheduled_at_unix_s, + prompt_tokens: scheduled.prompt_tokens, + cached_tokens: scheduled.cached_tokens, + })?; + if let Some(echo) = prompt_echo { + token_tx.send(TokenEvent::PromptTokens { + ids: echo.ids, + logprobs: echo.logprobs, + })?; + } + Ok(()) +} + pub(super) fn apply_effects( executor: &mut impl crate::executor::ModelExecutor, active: &mut Vec, effects: StepEffects, + prefix_cache_stats: &mut PrefixCacheStats, ) { - for echo in effects.prompt_echoes { - let _ = echo.token_tx.send(TokenEvent::PromptTokens { - ids: echo.ids, - logprobs: echo.logprobs, - }); - } - let mut to_retire = Vec::new(); for effect in effects.decode { match effect { @@ -164,13 +253,20 @@ pub(super) fn apply_effects( PendingEffect::Finish { request_id, token_tx, + scheduled, + prompt_echo, finish_reason, - prompt_tokens, completion_tokens, } => { + let _ = send_pending_scheduled_and_echo( + &token_tx, + scheduled, + prompt_echo, + prefix_cache_stats, + ); let _ = token_tx.send(TokenEvent::Finished { finish_reason, - prompt_tokens, + prompt_tokens: scheduled.prompt_tokens, completion_tokens, }); let _ = executor.drop_request(request_id); @@ -178,19 +274,27 @@ pub(super) fn apply_effects( PendingEffect::EmitAndFinish { request_id, token_tx, + scheduled, + prompt_echo, token, logprob, finish_reason, - prompt_tokens, completion_tokens, } => { - if token_tx - .send(TokenEvent::Token { id: token, logprob }) - .is_ok() + if send_pending_scheduled_and_echo( + &token_tx, + scheduled, + prompt_echo, + prefix_cache_stats, + ) + .is_ok() + && token_tx + .send(TokenEvent::Token { id: token, logprob }) + .is_ok() { let _ = token_tx.send(TokenEvent::Finished { finish_reason, - prompt_tokens, + prompt_tokens: scheduled.prompt_tokens, completion_tokens, }); } @@ -198,16 +302,25 @@ pub(super) fn apply_effects( } PendingEffect::Promote { state, + scheduled, + prompt_echo, first_token, logprob, } => { - if state - .token_tx - .send(TokenEvent::Token { - id: first_token, - logprob, - }) - .is_ok() + if send_pending_scheduled_and_echo( + &state.token_tx, + scheduled, + prompt_echo, + prefix_cache_stats, + ) + .is_ok() + && state + .token_tx + .send(TokenEvent::Token { + id: first_token, + logprob, + }) + .is_ok() { active.push(state); } else { diff --git a/pegainfer-qwen3-4b/src/scheduler/plan.rs b/pegainfer-qwen3-4b/src/scheduler/plan.rs index b222c397..631d88d4 100644 --- a/pegainfer-qwen3-4b/src/scheduler/plan.rs +++ b/pegainfer-qwen3-4b/src/scheduler/plan.rs @@ -8,7 +8,7 @@ use crate::executor::{ PrefillStepItem, UnifiedPlan, UnifiedResult, }; -use super::{ActiveRequestState, PendingRequest}; +use super::{ActiveRequestState, PendingRequest, now_secs_f64}; pub(super) enum ExecutionPlan { Prefill { pending: Vec }, @@ -74,7 +74,8 @@ pub(super) fn execute_plan( rng: &mut StdRng, ) -> Result { match plan { - ExecutionPlan::Prefill { pending } => { + ExecutionPlan::Prefill { mut pending } => { + mark_scheduled(&mut pending); let mut result = PrefillResult { requests: Vec::with_capacity(pending.len()), }; @@ -106,7 +107,8 @@ pub(super) fn execute_plan( sort_decode_results(&mut result.requests); Ok(ExecutionArtifacts::Decode { result }) } - ExecutionPlan::Unified { pending } => { + ExecutionPlan::Unified { mut pending } => { + mark_scheduled(&mut pending); let mut result = UnifiedResult { prefill_requests: Vec::with_capacity(pending.len()), decode_requests: Vec::with_capacity(active.len()), @@ -158,6 +160,13 @@ pub(super) fn execute_plan( } } +fn mark_scheduled(pending: &mut [PendingRequest]) { + let scheduled_at = now_secs_f64(); + for req in pending { + req.scheduled_at_unix_s = Some(scheduled_at); + } +} + fn group_pending_indices(pending: &[PendingRequest]) -> BTreeMap> { let mut groups = BTreeMap::new(); for (index, req) in pending.iter().enumerate() { @@ -252,6 +261,8 @@ mod tests { PendingRequest { request_id: RequestId::new(0), lora_adapter: None, + queued_at_unix_s: None, + scheduled_at_unix_s: None, prompt_tokens: vec![1, 2, 3], params: SamplingParams::default(), max_tokens: 8, diff --git a/pegainfer-qwen3-4b/src/scheduler/resolve.rs b/pegainfer-qwen3-4b/src/scheduler/resolve.rs index acfec650..ab9fc26f 100644 --- a/pegainfer-qwen3-4b/src/scheduler/resolve.rs +++ b/pegainfer-qwen3-4b/src/scheduler/resolve.rs @@ -1,7 +1,7 @@ use crate::executor::{DecodeRequestResult, ModelExecutor, PrefillRequestResult}; use pegainfer_core::engine::FinishReason; -use super::effects::{DecodeEffect, PendingEffect, PromptEchoEffect, StepEffects}; +use super::effects::{DecodeEffect, PendingEffect, PromptEchoEffect, ScheduledEffect, StepEffects}; use super::plan::ExecutionArtifacts; use super::{ActiveRequestState, PendingRequest}; @@ -15,7 +15,6 @@ pub(super) fn resolve_step( resolve_prefill_outputs(executor, pending, result.requests) } ExecutionArtifacts::Decode { result } => StepEffects { - prompt_echoes: Vec::new(), pending: Vec::new(), decode: resolve_decode_outputs(executor, active, &result.requests), }, @@ -36,23 +35,28 @@ fn resolve_prefill_outputs( for (req, result) in pending.into_iter().zip(request_results) { debug_assert_eq!(req.request_id, result.request_id); let prompt_len = req.prompt_tokens.len(); + let scheduled_at_unix_s = req.scheduled_at_unix_s.unwrap_or_else(super::now_secs_f64); + let scheduled = ScheduledEffect { + queued_at_unix_s: req.queued_at_unix_s.unwrap_or(scheduled_at_unix_s), + scheduled_at_unix_s, + prompt_tokens: prompt_len, + cached_tokens: result.cached_tokens, + }; - if req.echo { - effects.prompt_echoes.push(PromptEchoEffect { - token_tx: req.token_tx.clone(), - ids: req.prompt_tokens.clone(), - logprobs: result - .prompt_logprobs - .unwrap_or_else(|| vec![None; req.prompt_tokens.len()]), - }); - } + let prompt_echo = req.echo.then(|| PromptEchoEffect { + ids: req.prompt_tokens.clone(), + logprobs: result + .prompt_logprobs + .unwrap_or_else(|| vec![None; req.prompt_tokens.len()]), + }); if !req.params.ignore_eos && executor.is_stop_token(result.first_token) { effects.pending.push(PendingEffect::Finish { request_id: req.request_id, token_tx: req.token_tx, + scheduled, + prompt_echo, finish_reason: FinishReason::Stop, - prompt_tokens: prompt_len, completion_tokens: 0, }); continue; @@ -62,10 +66,11 @@ fn resolve_prefill_outputs( effects.pending.push(PendingEffect::EmitAndFinish { request_id: req.request_id, token_tx: req.token_tx, + scheduled, + prompt_echo, token: result.first_token, logprob: result.first_token_logprob, finish_reason: FinishReason::Length, - prompt_tokens: prompt_len, completion_tokens: 1, }); continue; @@ -83,6 +88,8 @@ fn resolve_prefill_outputs( params: req.params, logprobs: req.logprobs, }, + scheduled, + prompt_echo, first_token: result.first_token, logprob: result.first_token_logprob, }); diff --git a/pegainfer-sim/src/lib.rs b/pegainfer-sim/src/lib.rs index f947089d..9ca8c2cf 100644 --- a/pegainfer-sim/src/lib.rs +++ b/pegainfer-sim/src/lib.rs @@ -12,6 +12,7 @@ pub struct SimulatedEngineConfig { prefill_tokens_per_ms: f64, tpot_ms: f64, fallback_token_id: u32, + scheduled_cached_tokens: usize, } impl SimulatedEngineConfig { @@ -39,9 +40,16 @@ impl SimulatedEngineConfig { prefill_tokens_per_ms, tpot_ms, fallback_token_id, + scheduled_cached_tokens: 0, }) } + #[must_use] + pub fn with_scheduled_cached_tokens(mut self, cached_tokens: usize) -> Self { + self.scheduled_cached_tokens = cached_tokens; + self + } + fn ttft(&self, prompt_tokens: usize) -> Duration { duration_from_ms(self.base_ttft_ms + prompt_tokens as f64 / self.prefill_tokens_per_ms) } @@ -58,6 +66,7 @@ impl Default for SimulatedEngineConfig { prefill_tokens_per_ms: 100.0, tpot_ms: 12.0, fallback_token_id: 0, + scheduled_cached_tokens: 0, } } } @@ -83,6 +92,7 @@ async fn run_simulated_request(req: GenerateRequest, config: SimulatedEngineConf queued_at_unix_s, scheduled_at_unix_s: now_secs_f64(), prompt_tokens: prompt_len, + cached_tokens: config.scheduled_cached_tokens.min(prompt_len), }) .is_err() { diff --git a/pegainfer-sim/tests/frontend_e2e.rs b/pegainfer-sim/tests/frontend_e2e.rs index 4bb99a45..3e13e8aa 100644 --- a/pegainfer-sim/tests/frontend_e2e.rs +++ b/pegainfer-sim/tests/frontend_e2e.rs @@ -25,7 +25,11 @@ struct SimServer { impl SimServer { async fn spawn() -> Result { - Self::spawn_with_model_dir(TempModelDir::with_minimal_metadata()?).await + Self::spawn_with_config(SimulatedEngineConfig::new(0.0, 1000.0, 0.0, 1)?).await + } + + async fn spawn_with_config(config: SimulatedEngineConfig) -> Result { + Self::spawn_with_model_dir_and_config(TempModelDir::with_minimal_metadata()?, config).await } async fn spawn_with_lora_routes() -> Result { @@ -34,16 +38,40 @@ impl SimServer { } async fn spawn_with_model_dir(model_dir: TempModelDir) -> Result { - Self::spawn_with_model_dir_and_lora_routes(model_dir, false).await + Self::spawn_with_model_dir_and_config( + model_dir, + SimulatedEngineConfig::new(0.0, 1000.0, 0.0, 1)?, + ) + .await + } + + async fn spawn_with_model_dir_and_config( + model_dir: TempModelDir, + config: SimulatedEngineConfig, + ) -> Result { + Self::spawn_with_model_dir_lora_routes_and_config(model_dir, false, config).await } async fn spawn_with_model_dir_and_lora_routes( model_dir: TempModelDir, enable_lora_routes: bool, + ) -> Result { + Self::spawn_with_model_dir_lora_routes_and_config( + model_dir, + enable_lora_routes, + SimulatedEngineConfig::new(0.0, 1000.0, 0.0, 1)?, + ) + .await + } + + async fn spawn_with_model_dir_lora_routes_and_config( + model_dir: TempModelDir, + enable_lora_routes: bool, + config: SimulatedEngineConfig, ) -> Result { let mut last_error = None; for attempt in 1..=SERVER_START_ATTEMPTS { - match Self::spawn_once(&model_dir, enable_lora_routes).await { + match Self::spawn_once(&model_dir, enable_lora_routes, config).await { Ok(started) => { return Ok(Self { base_url: started.base_url, @@ -70,11 +98,12 @@ impl SimServer { async fn spawn_once( model_dir: &TempModelDir, enable_lora_routes: bool, + config: SimulatedEngineConfig, ) -> Result { let port = reserve_loopback_port()?; let base_url = format!("http://127.0.0.1:{port}"); let shutdown = CancellationToken::new(); - let engine = start_engine(SimulatedEngineConfig::new(0.0, 1000.0, 0.0, 1)?); + let engine = start_engine(config); let server_shutdown = shutdown.clone(); let model_path = model_dir.path.to_string_lossy().into_owned(); let mut task = tokio::spawn(async move { @@ -250,6 +279,62 @@ async fn streaming_completion_emits_terminal_done() -> Result<()> { server.shutdown().await } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn non_streaming_completion_reports_default_zero_cached_tokens() -> Result<()> { + let server = SimServer::spawn().await?; + let client = test_client()?; + + let completion = post_completion(&client, &server.base_url, false).await?; + assert_usage_cached_tokens(&completion["usage"], 0)?; + + server.shutdown().await +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn non_streaming_completion_reports_configured_cached_tokens() -> Result<()> { + let cached_tokens = 1; + let server = SimServer::spawn_with_config( + SimulatedEngineConfig::new(0.0, 1000.0, 0.0, 1)? + .with_scheduled_cached_tokens(cached_tokens), + ) + .await?; + let client = test_client()?; + + let completion = post_completion(&client, &server.base_url, false).await?; + assert_usage_cached_tokens(&completion["usage"], cached_tokens)?; + + server.shutdown().await +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn streaming_completion_reports_default_zero_cached_tokens_in_usage_chunk() -> Result<()> { + let server = SimServer::spawn().await?; + let client = test_client()?; + + let stream = post_completion_stream(&client, &server.base_url, true).await?; + let usage = final_usage_chunk(&stream)?; + assert_usage_cached_tokens(&usage, 0)?; + + server.shutdown().await +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn streaming_completion_reports_configured_cached_tokens_in_usage_chunk() -> Result<()> { + let cached_tokens = 1; + let server = SimServer::spawn_with_config( + SimulatedEngineConfig::new(0.0, 1000.0, 0.0, 1)? + .with_scheduled_cached_tokens(cached_tokens), + ) + .await?; + let client = test_client()?; + + let stream = post_completion_stream(&client, &server.base_url, true).await?; + let usage = final_usage_chunk(&stream)?; + assert_usage_cached_tokens(&usage, cached_tokens)?; + + server.shutdown().await +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn simulated_frontend_metadata_contract_is_executable() -> Result<()> { let model_dir = TempModelDir::with_minimal_metadata()?; @@ -308,7 +393,7 @@ async fn assert_non_streaming_completion_has_output(client: &Client, base_url: & } async fn assert_streaming_completion_emits_done(client: &Client, base_url: &str) -> Result<()> { - let stream = post_completion_stream(client, base_url).await?; + let stream = post_completion_stream(client, base_url, false).await?; if !stream.lines().any(|line| line.trim() == "data: [DONE]") { bail!("streaming completion did not emit terminal data: [DONE]: {stream}"); } @@ -320,7 +405,7 @@ async fn post_completion(client: &Client, base_url: &str, stream: bool) -> Resul let response = client .post(format!("{base_url}/v1/completions")) .header(reqwest::header::CONTENT_TYPE, "application/json") - .body(completion_body(stream).to_string()) + .body(completion_body(stream, false).to_string()) .send() .await? .error_for_status()?; @@ -330,11 +415,55 @@ async fn post_completion(client: &Client, base_url: &str, stream: bool) -> Resul .context("failed to parse non-streaming completion response") } -async fn post_completion_stream(client: &Client, base_url: &str) -> Result { +fn assert_usage_cached_tokens(usage: &Value, expected_cached_tokens: usize) -> Result<()> { + let prompt_tokens = usage["prompt_tokens"] + .as_u64() + .ok_or_else(|| anyhow!("usage has no prompt_tokens: {usage}"))?; + if prompt_tokens != 2 { + bail!("expected prompt_tokens=2, got {prompt_tokens}: {usage}"); + } + let cached_tokens = usage["prompt_tokens_details"]["cached_tokens"] + .as_u64() + .ok_or_else(|| anyhow!("usage has no prompt_tokens_details.cached_tokens: {usage}"))?; + if cached_tokens != expected_cached_tokens as u64 { + bail!("expected cached_tokens={expected_cached_tokens}, got {cached_tokens}: {usage}"); + } + + Ok(()) +} + +fn final_usage_chunk(stream: &str) -> Result { + let mut usage_chunks = Vec::new(); + for line in stream.lines() { + let Some(data) = line.strip_prefix("data: ") else { + continue; + }; + if data.trim() == "[DONE]" { + continue; + } + let chunk: Value = serde_json::from_str(data) + .with_context(|| format!("failed to parse SSE chunk JSON: {data}"))?; + if !chunk["usage"].is_null() { + usage_chunks.push(chunk["usage"].clone()); + } + } + + match usage_chunks.as_slice() { + [usage] => Ok(usage.clone()), + [] => bail!("stream did not include a usage chunk: {stream}"), + _ => bail!("stream included multiple usage chunks: {stream}"), + } +} + +async fn post_completion_stream( + client: &Client, + base_url: &str, + include_usage: bool, +) -> Result { client .post(format!("{base_url}/v1/completions")) .header(reqwest::header::CONTENT_TYPE, "application/json") - .body(completion_body(true).to_string()) + .body(completion_body(true, include_usage).to_string()) .send() .await? .error_for_status()? @@ -350,15 +479,19 @@ fn test_client() -> Result { .context("failed to build HTTP test client") } -fn completion_body(stream: bool) -> Value { - json!({ +fn completion_body(stream: bool, include_usage: bool) -> Value { + let mut body = json!({ "model": MODEL_NAME, "prompt": [1, 2], "max_tokens": 3, "temperature": 0.0, "ignore_eos": true, "stream": stream - }) + }); + if stream && include_usage { + body["stream_options"] = json!({ "include_usage": true }); + } + body } async fn wait_for_health(client: &Client, base_url: &str) -> Result<()> { diff --git a/pegainfer-vllm-frontend/Cargo.toml b/pegainfer-vllm-frontend/Cargo.toml index 2718b0dd..eeab9ff5 100644 --- a/pegainfer-vllm-frontend/Cargo.toml +++ b/pegainfer-vllm-frontend/Cargo.toml @@ -6,7 +6,9 @@ edition = "2024" [dependencies] pegainfer-engine = { workspace = true } anyhow = { workspace = true } +async-stream = { workspace = true } axum = { workspace = true } +futures = { workspace = true } log = { workspace = true } rmp-serde = { workspace = true } rmpv = { workspace = true } diff --git a/pegainfer-vllm-frontend/src/lib.rs b/pegainfer-vllm-frontend/src/lib.rs index 13c76bd0..3c746614 100644 --- a/pegainfer-vllm-frontend/src/lib.rs +++ b/pegainfer-vllm-frontend/src/lib.rs @@ -1,6 +1,7 @@ use std::collections::{BTreeSet, HashMap, HashSet}; use std::path::{Path, PathBuf}; use std::sync::Arc; +use std::sync::LazyLock; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use anyhow::{Context, Result, bail}; @@ -39,6 +40,9 @@ use zeromq::prelude::{Socket, SocketRecv, SocketSend}; use zeromq::util::PeerIdentity; use zeromq::{DealerSocket, PushSocket, SocketOptions, ZmqMessage}; +mod patch_usage; + +use patch_usage::{CachedTokenUsageMap, cached_token_usage_routes, external_request_id}; use pegainfer_engine::engine::{ EngineControlError, EngineHandle, FinishReason, GenerateRequest, LoadLoraAdapterRequest, TokenEvent, TokenLogprob, UnloadLoraAdapterRequest, @@ -48,6 +52,8 @@ use pegainfer_engine::sampler::SamplingParams; const ENGINE_INDEX: u32 = 0; const LORA_ROUTE_BODY_LIMIT: usize = 128 * 1024 * 1024; const LORA_ADAPTER_XARG: &str = "pegainfer_lora_adapter"; +static CACHED_TOKENS_BY_REQUEST_ID: LazyLock = + LazyLock::new(|| Arc::new(RwLock::new(HashMap::new()))); #[derive(Clone)] struct LoraRouteState { @@ -426,8 +432,17 @@ impl LocalEngineBridge { let output_tx = output_tx.clone(); let done_tx = done_tx.clone(); let task_request_id = request_id.clone(); + let usage_request_id = external_request_id(&request_id); + let cached_tokens_by_request_id = CACHED_TOKENS_BY_REQUEST_ID.clone(); let task = tokio::spawn(async move { - run_request_stream(task_request_id.clone(), token_rx, output_tx).await; + run_request_stream( + task_request_id.clone(), + usage_request_id, + token_rx, + output_tx, + cached_tokens_by_request_id, + ) + .await; let _ = done_tx.send(task_request_id); }); active.insert(request_id, task); @@ -457,7 +472,7 @@ pub async fn serve( model_path.to_string_lossy().into_owned(), served_model_name .into_iter() - .map(|name| name.to_string()) + .map(std::string::ToString::to_string) .collect(), port, max_model_len, @@ -512,14 +527,18 @@ pub async fn serve_model_with_lora_routes( max_model_len, shutdown, move |router| { + let patched_router = + cached_token_usage_routes(router, CACHED_TOKENS_BY_REQUEST_ID.clone()); let lora_router = lora_routes(handle.clone(), Arc::clone(&adapter_names)); let openai_router = lora_openai_routes( - router.clone(), + patched_router.clone(), base_model_name, served_model_name, Arc::clone(&adapter_names), ); - openai_router.merge(lora_router).fallback_service(router) + openai_router + .merge(lora_router) + .fallback_service(patched_router) }, ) .await @@ -567,7 +586,7 @@ async fn serve_model_on_host( port, max_model_len, shutdown, - |router| router, + |router| cached_token_usage_routes(router, CACHED_TOKENS_BY_REQUEST_ID.clone()), ) .await } @@ -748,8 +767,10 @@ async fn lora_models_response( async fn run_request_stream( request_id: String, + usage_request_id: String, mut token_rx: mpsc::UnboundedReceiver, output_tx: mpsc::UnboundedSender, + cached_tokens_by_request_id: Arc>>, ) { let mut first_token_events = None; let mut first_token_prefill_stats = None; @@ -768,6 +789,7 @@ async fn run_request_stream( queued_at_unix_s, scheduled_at_unix_s, prompt_tokens, + cached_tokens, } => { first_token_events = Some(vec![ EngineCoreEvent { @@ -779,13 +801,19 @@ async fn run_request_stream( timestamp: scheduled_at_unix_s, }, ]); + let cached_tokens = cached_tokens as u32; first_token_prefill_stats = Some(PrefillStats { num_prompt_tokens: prompt_tokens as u32, - num_computed_tokens: prompt_tokens as u32, - num_cached_tokens: 0, - num_local_cached_tokens: 0, + num_computed_tokens: prompt_tokens.saturating_sub(cached_tokens as usize) + as u32, + num_cached_tokens: cached_tokens, + num_local_cached_tokens: cached_tokens, num_external_cached_tokens: 0, }); + cached_tokens_by_request_id + .write() + .await + .insert(usage_request_id.clone(), cached_tokens); } TokenEvent::Token { id, logprob } => { // Keep the first streamed token on the direct path so TTFT @@ -1006,14 +1034,14 @@ fn collect_ready_token_batch( Ok(other) => { return ( token_ids, - has_logprobs.then(|| MaybeWireLogprobs::Direct(Logprobs { positions })), + has_logprobs.then_some(MaybeWireLogprobs::Direct(Logprobs { positions })), Some(other), ); } Err(TryRecvError::Empty | TryRecvError::Disconnected) => { return ( token_ids, - has_logprobs.then(|| MaybeWireLogprobs::Direct(Logprobs { positions })), + has_logprobs.then_some(MaybeWireLogprobs::Direct(Logprobs { positions })), None, ); } @@ -1355,11 +1383,11 @@ mod tests { // ignore_eos=true lowering: _eos_token_id=None while // _all_stop_token_ids still carries the model EOS set. let mut params = EngineCoreSamplingParams::for_test(); - params.all_stop_token_ids = BTreeSet::from([163586]); + params.all_stop_token_ids = BTreeSet::from([163_586]); assert!(convert_sampling(¶ms).ignore_eos); // Normal request: _eos_token_id present. - params.eos_token_id = Some(163586); + params.eos_token_id = Some(163_586); assert!(!convert_sampling(¶ms).ignore_eos); // Explicit client stop tokens keep EOS detection on even when the @@ -1399,7 +1427,14 @@ mod tests { .expect("send rejected event"); drop(token_tx); - run_request_stream("req-1".to_string(), token_rx, output_tx).await; + run_request_stream( + "req-1".to_string(), + "req-1".to_string(), + token_rx, + output_tx, + Arc::new(RwLock::new(HashMap::new())), + ) + .await; let outputs = output_rx.recv().await.expect("terminal output"); assert!( @@ -1427,6 +1462,7 @@ mod tests { token_tx .send(TokenEvent::Scheduled { + cached_tokens: 0, queued_at_unix_s: 1.0, scheduled_at_unix_s: 2.0, prompt_tokens: 16, @@ -1459,7 +1495,14 @@ mod tests { .expect("send finished"); drop(token_tx); - run_request_stream("req-1".to_string(), token_rx, output_tx).await; + run_request_stream( + "req-1".to_string(), + "req-1".to_string(), + token_rx, + output_tx, + Arc::new(RwLock::new(HashMap::new())), + ) + .await; let token_outputs = output_rx.recv().await.expect("token output"); assert_eq!(token_outputs.outputs.len(), 1); @@ -1467,7 +1510,14 @@ mod tests { assert_eq!(token_outputs.outputs[0].new_token_ids, vec![11, 21]); assert!(token_outputs.outputs[0].finish_reason.is_none()); assert!(token_outputs.outputs[0].events.is_some()); - assert!(token_outputs.outputs[0].prefill_stats.is_some()); + let prefill_stats = token_outputs.outputs[0] + .prefill_stats + .as_ref() + .expect("prefill stats"); + assert_eq!(prefill_stats.num_cached_tokens, 0); + assert_eq!(prefill_stats.num_local_cached_tokens, 0); + assert_eq!(prefill_stats.num_external_cached_tokens, 0); + assert_eq!(prefill_stats.num_computed_tokens, 16); let direct = match token_outputs.outputs[0] .new_logprobs @@ -1495,6 +1545,48 @@ mod tests { assert!(output_rx.recv().await.is_none()); } + #[tokio::test] + async fn scheduled_cached_tokens_are_passed_through_to_prefill_stats() { + let (token_tx, token_rx) = mpsc::unbounded_channel(); + let (output_tx, mut output_rx) = mpsc::unbounded_channel(); + + token_tx + .send(TokenEvent::Scheduled { + queued_at_unix_s: 1.0, + scheduled_at_unix_s: 2.0, + prompt_tokens: 10, + cached_tokens: 6, + }) + .expect("send scheduled"); + token_tx + .send(TokenEvent::Token { + id: 1, + logprob: None, + }) + .expect("send first token"); + drop(token_tx); + + run_request_stream( + "req-cached".to_string(), + "req-cached".to_string(), + token_rx, + output_tx, + Arc::new(RwLock::new(HashMap::new())), + ) + .await; + + let first_batch = output_rx.recv().await.expect("first batch"); + let prefill_stats = first_batch.outputs[0] + .prefill_stats + .as_ref() + .expect("prefill stats"); + assert_eq!(prefill_stats.num_prompt_tokens, 10); + assert_eq!(prefill_stats.num_computed_tokens, 4); + assert_eq!(prefill_stats.num_cached_tokens, 6); + assert_eq!(prefill_stats.num_local_cached_tokens, 6); + assert_eq!(prefill_stats.num_external_cached_tokens, 0); + } + #[tokio::test] async fn first_token_metadata_is_only_sent_with_first_batch() { let (token_tx, token_rx) = mpsc::unbounded_channel(); @@ -1502,6 +1594,7 @@ mod tests { token_tx .send(TokenEvent::Scheduled { + cached_tokens: 0, queued_at_unix_s: 1.0, scheduled_at_unix_s: 2.0, prompt_tokens: 8, @@ -1527,7 +1620,14 @@ mod tests { .expect("send second token"); drop(token_tx); - run_request_stream("req-2".to_string(), token_rx, output_tx).await; + run_request_stream( + "req-2".to_string(), + "req-2".to_string(), + token_rx, + output_tx, + Arc::new(RwLock::new(HashMap::new())), + ) + .await; let first_batch = output_rx.recv().await.expect("first batch"); let second_batch = output_rx.recv().await.expect("second batch"); @@ -1562,7 +1662,14 @@ mod tests { .expect("send token with logprob"); drop(token_tx); - run_request_stream("req-3".to_string(), token_rx, output_tx).await; + run_request_stream( + "req-3".to_string(), + "req-3".to_string(), + token_rx, + output_tx, + Arc::new(RwLock::new(HashMap::new())), + ) + .await; let batch = output_rx.recv().await.expect("batched output"); let direct = match batch.outputs[0] diff --git a/pegainfer-vllm-frontend/src/patch_usage.rs b/pegainfer-vllm-frontend/src/patch_usage.rs new file mode 100644 index 00000000..2380af88 --- /dev/null +++ b/pegainfer-vllm-frontend/src/patch_usage.rs @@ -0,0 +1,913 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use axum::body::{Body, Bytes, to_bytes}; +use axum::extract::{Request, State}; +use axum::http::{HeaderValue, StatusCode, header::CONTENT_LENGTH}; +use axum::response::{IntoResponse, Response}; +use axum::routing::post; +use axum::{Json, Router}; +use futures::StreamExt; +use serde::Serialize; +use tokio::sync::RwLock; +use tower::ServiceExt; + +const COMPLETION_USAGE_PATCH_BODY_LIMIT: usize = 128 * 1024 * 1024; + +pub(crate) type CachedTokenUsageMap = Arc>>; + +#[derive(Clone)] +struct UsagePatchState { + vllm_router: Router, + cached_tokens_by_request_id: CachedTokenUsageMap, +} + +#[derive(Debug, Serialize)] +struct ErrorBody { + error: String, +} + +#[derive(Clone, Copy, Debug, Default)] +struct CompletionUsagePatchOptions { + stream: bool, + include_usage: bool, +} + +struct StreamingRequestCleanup { + cached_tokens_by_request_id: CachedTokenUsageMap, + request_id: Option, +} + +impl Drop for StreamingRequestCleanup { + fn drop(&mut self) { + let Some(request_id) = self.request_id.take() else { + return; + }; + let cached_tokens_by_request_id = self.cached_tokens_by_request_id.clone(); + if let Ok(handle) = tokio::runtime::Handle::try_current() { + handle.spawn(async move { + cached_tokens_by_request_id + .write() + .await + .remove(&request_id); + }); + } + } +} + +async fn remove_cached_tokens( + cached_tokens_by_request_id: &CachedTokenUsageMap, + request_id: Option<&str>, +) { + if let Some(request_id) = request_id { + cached_tokens_by_request_id.write().await.remove(request_id); + } +} + +async fn remove_cached_tokens_pair( + cached_tokens_by_request_id: &CachedTokenUsageMap, + first: Option<&str>, + second: Option<&str>, +) { + let mut cached_tokens_by_request_id = cached_tokens_by_request_id.write().await; + if let Some(first) = first { + cached_tokens_by_request_id.remove(first); + } + if second != first { + if let Some(second) = second { + cached_tokens_by_request_id.remove(second); + } + } +} + +async fn take_cached_tokens( + cached_tokens_by_request_id: &CachedTokenUsageMap, + request_id: &str, + fallback_request_id: Option<&str>, +) -> u32 { + let mut cached_tokens_by_request_id = cached_tokens_by_request_id.write().await; + let cached_tokens = cached_tokens_by_request_id.remove(request_id); + if fallback_request_id != Some(request_id) { + if let Some(fallback_request_id) = fallback_request_id { + return cached_tokens + .or_else(|| cached_tokens_by_request_id.remove(fallback_request_id)) + .unwrap_or(0); + } + } + cached_tokens.unwrap_or(0) +} + +// Rust vLLM currently emits OpenAI Usage without copying PrefillStats cached counts; +// this wrapper only patches the existing usage object with the engine-provided value. +pub(crate) fn cached_token_usage_routes( + vllm_router: Router, + cached_tokens_by_request_id: CachedTokenUsageMap, +) -> Router { + Router::new() + .route("/v1/completions", post(forward_cached_token_usage_request)) + .route( + "/v1/chat/completions", + post(forward_cached_token_usage_request), + ) + .with_state(UsagePatchState { + vllm_router: vllm_router.clone(), + cached_tokens_by_request_id, + }) + .fallback_service(vllm_router) +} + +pub(crate) fn external_request_id(engine_request_id: &str) -> String { + if let Some((request_id, suffix)) = engine_request_id.rsplit_once('-') { + if suffix.len() == 8 && suffix.chars().all(|ch| ch.is_ascii_hexdigit()) { + return request_id.to_string(); + } + } + engine_request_id.to_string() +} + +async fn forward_cached_token_usage_request( + State(state): State, + request: Request, +) -> Response { + let (parts, body) = request.into_parts(); + let bytes = match to_bytes(body, COMPLETION_USAGE_PATCH_BODY_LIMIT).await { + Ok(bytes) => bytes, + Err(error) => { + return ( + StatusCode::BAD_REQUEST, + Json(ErrorBody { + error: format!("failed to read completion request body: {error}"), + }), + ) + .into_response(); + } + }; + let options = completion_request_usage_patch_options(&bytes); + let request_id_for_cleanup = request_id_from_completion_request(&parts, &bytes); + let vllm_router = state.vllm_router.clone(); + let request = Request::from_parts(parts, Body::from(bytes)); + match vllm_router.oneshot(request).await { + Ok(response) if options.stream => { + patch_streaming_completion_usage( + response, + state, + options.include_usage, + request_id_for_cleanup, + ) + .await + } + Ok(response) => patch_completion_usage(response, state, request_id_for_cleanup).await, + Err(error) => { + remove_cached_tokens( + &state.cached_tokens_by_request_id, + request_id_for_cleanup.as_deref(), + ) + .await; + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorBody { + error: format!("vLLM router failed to handle completion request: {error:#}"), + }), + ) + .into_response() + } + } +} + +fn completion_request_usage_patch_options(bytes: &Bytes) -> CompletionUsagePatchOptions { + let Ok(value) = serde_json::from_slice::(bytes) else { + return CompletionUsagePatchOptions::default(); + }; + CompletionUsagePatchOptions { + stream: value + .get("stream") + .and_then(serde_json::Value::as_bool) + .unwrap_or(false), + include_usage: value + .get("stream_options") + .and_then(|options| options.get("include_usage")) + .and_then(serde_json::Value::as_bool) + .unwrap_or(false), + } +} + +fn request_id_from_completion_request( + parts: &axum::http::request::Parts, + bytes: &Bytes, +) -> Option { + let value = serde_json::from_slice::(bytes).ok(); + let request_id = parts + .headers + .get("X-Request-Id") + .and_then(|value| value.to_str().ok()) + .map(ToString::to_string) + .or_else(|| { + value + .as_ref() + .and_then(|value| value.get("request_id")) + .and_then(serde_json::Value::as_str) + .map(ToString::to_string) + })?; + + match parts.uri.path() { + "/v1/chat/completions" => Some(format!("chatcmpl-{request_id}")), + "/v1/completions" => Some(format!("cmpl-{request_id}")), + _ => Some(request_id), + } +} + +async fn patch_completion_usage( + response: Response, + state: UsagePatchState, + request_id_for_cleanup: Option, +) -> Response { + let status = response.status(); + let (parts, body) = response.into_parts(); + let bytes = match to_bytes(body, COMPLETION_USAGE_PATCH_BODY_LIMIT).await { + Ok(bytes) => bytes, + Err(error) => { + remove_cached_tokens( + &state.cached_tokens_by_request_id, + request_id_for_cleanup.as_deref(), + ) + .await; + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorBody { + error: format!("failed to read completion response body: {error}"), + }), + ) + .into_response(); + } + }; + + let Ok(mut value) = serde_json::from_slice::(&bytes) else { + remove_cached_tokens( + &state.cached_tokens_by_request_id, + request_id_for_cleanup.as_deref(), + ) + .await; + return Response::from_parts(parts, Body::from(bytes)); + }; + let response_request_id = value + .get("id") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + let cleanup_request_id = response_request_id + .as_deref() + .or(request_id_for_cleanup.as_deref()); + + if !status.is_success() { + remove_cached_tokens_pair( + &state.cached_tokens_by_request_id, + response_request_id.as_deref(), + request_id_for_cleanup.as_deref(), + ) + .await; + return Response::from_parts(parts, Body::from(bytes)); + } + + let Some(request_id) = cleanup_request_id else { + return Response::from_parts(parts, Body::from(bytes)); + }; + let cached_tokens = take_cached_tokens( + &state.cached_tokens_by_request_id, + request_id, + request_id_for_cleanup.as_deref(), + ) + .await; + if response_request_id.is_none() { + return Response::from_parts(parts, Body::from(bytes)); + } + patch_usage_value(&mut value["usage"], cached_tokens); + + response_from_json_parts(parts, &value) +} + +async fn patch_streaming_completion_usage( + response: Response, + state: UsagePatchState, + include_usage: bool, + request_id_for_cleanup: Option, +) -> Response { + let status = response.status(); + if !status.is_success() { + remove_cached_tokens( + &state.cached_tokens_by_request_id, + request_id_for_cleanup.as_deref(), + ) + .await; + return response; + } + let cached_tokens_by_request_id = state.cached_tokens_by_request_id; + let (mut parts, body) = response.into_parts(); + parts.headers.remove(CONTENT_LENGTH); + let stream = body.into_data_stream(); + let mut cleanup = StreamingRequestCleanup { + cached_tokens_by_request_id: cached_tokens_by_request_id.clone(), + request_id: request_id_for_cleanup, + }; + let body = Body::from_stream(async_stream::stream! { + let mut stream = stream; + while let Some(next) = stream.next().await { + let Ok(bytes) = next else { + continue; + }; + yield Ok::(patch_sse_usage_chunk( + bytes, + &cached_tokens_by_request_id, + include_usage, + &mut cleanup.request_id, + ).await); + } + }); + Response::from_parts(parts, body) +} + +async fn patch_sse_usage_chunk( + bytes: Bytes, + cached_tokens_by_request_id: &CachedTokenUsageMap, + include_usage: bool, + request_id_for_cleanup: &mut Option, +) -> Bytes { + let Ok(text) = std::str::from_utf8(&bytes) else { + return bytes; + }; + let mut patched = String::with_capacity(text.len()); + let mut changed = false; + for line in text.lines() { + let Some(data) = line.strip_prefix("data: ") else { + patched.push_str(line); + patched.push('\n'); + continue; + }; + if data.trim() == "[DONE]" { + patched.push_str(line); + patched.push('\n'); + continue; + } + let Ok(mut value) = serde_json::from_str::(data) else { + patched.push_str(line); + patched.push('\n'); + continue; + }; + if request_id_for_cleanup.is_none() { + if let Some(request_id) = value.get("id").and_then(serde_json::Value::as_str) { + *request_id_for_cleanup = Some(request_id.to_string()); + } + } + if include_usage && value.get("usage").is_some_and(|usage| !usage.is_null()) { + if let Some(request_id) = value.get("id").and_then(serde_json::Value::as_str) { + let cached_tokens = take_cached_tokens( + cached_tokens_by_request_id, + request_id, + request_id_for_cleanup.as_deref(), + ) + .await; + request_id_for_cleanup.take(); + patch_usage_value(&mut value["usage"], cached_tokens); + patched.push_str("data: "); + patched.push_str(&value.to_string()); + patched.push('\n'); + changed = true; + continue; + } + } + patched.push_str(line); + patched.push('\n'); + } + if changed { Bytes::from(patched) } else { bytes } +} + +fn patch_usage_value(usage: &mut serde_json::Value, cached_tokens: u32) { + let Some(usage) = usage.as_object_mut() else { + return; + }; + let details = usage + .entry("prompt_tokens_details") + .or_insert_with(|| serde_json::json!({})); + if !details.is_object() { + *details = serde_json::json!({}); + } + details + .as_object_mut() + .expect("prompt_tokens_details must be object") + .insert( + "cached_tokens".to_string(), + serde_json::Value::Number(cached_tokens.into()), + ); +} + +fn response_from_json_parts( + mut parts: axum::http::response::Parts, + value: &serde_json::Value, +) -> Response { + let body = value.to_string(); + parts.headers.insert( + CONTENT_LENGTH, + HeaderValue::from_str(&body.len().to_string()).expect("json body length must be valid"), + ); + Response::from_parts(parts, Body::from(body)) +} + +#[cfg(test)] +mod tests { + use super::*; + use axum::body::BodyDataStream; + use tokio::sync::mpsc; + use tokio::time::{Duration, timeout}; + + async fn read_next_stream_chunk(stream: &mut BodyDataStream) -> Bytes { + stream + .next() + .await + .expect("next stream item") + .expect("stream chunk") + } + + async fn wait_until_cache_entry_removed( + cached_tokens_by_request_id: &CachedTokenUsageMap, + request_id: &str, + ) { + timeout(Duration::from_secs(1), async { + loop { + if !cached_tokens_by_request_id + .read() + .await + .contains_key(request_id) + { + return; + } + tokio::task::yield_now().await; + } + }) + .await + .expect("cache entry removed"); + } + + #[tokio::test] + async fn chat_completion_usage_is_patched_with_cached_tokens() { + let cached_tokens_by_request_id = + Arc::new(RwLock::new(HashMap::from([("chatcmpl-1".to_string(), 7)]))); + let vllm_router = Router::new().route( + "/v1/chat/completions", + post(|| async { + Json(serde_json::json!({ + "id": "chatcmpl-1", + "object": "chat.completion", + "usage": { + "prompt_tokens": 11, + "prompt_tokens_details": {} + } + })) + }), + ); + let router = cached_token_usage_routes(vllm_router, cached_tokens_by_request_id.clone()); + let request = Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .body(Body::from( + serde_json::json!({ + "model": "model", + "messages": [{"role": "user", "content": "hello"}] + }) + .to_string(), + )) + .expect("request"); + + let response = router.oneshot(request).await.expect("route request"); + + assert_eq!(response.status(), StatusCode::OK); + let bytes = to_bytes(response.into_body(), COMPLETION_USAGE_PATCH_BODY_LIMIT) + .await + .expect("read body"); + let value: serde_json::Value = serde_json::from_slice(&bytes).expect("json response"); + assert_eq!(value["usage"]["prompt_tokens_details"]["cached_tokens"], 7); + assert!(cached_tokens_by_request_id.read().await.is_empty()); + } + + #[tokio::test] + async fn non_streaming_unparseable_response_cleans_known_request_id() { + let cached_tokens_by_request_id = Arc::new(RwLock::new(HashMap::from([ + ("cmpl-known".to_string(), 7), + ("cmpl-other".to_string(), 9), + ]))); + let vllm_router = Router::new().route( + "/v1/completions", + post(|| async { + Response::builder() + .body(Body::from("not json")) + .expect("response") + }), + ); + let router = cached_token_usage_routes(vllm_router, cached_tokens_by_request_id.clone()); + let request = Request::builder() + .method("POST") + .uri("/v1/completions") + .body(Body::from( + serde_json::json!({ + "model": "model", + "prompt": "hello", + "request_id": "known" + }) + .to_string(), + )) + .expect("request"); + + let response = router.oneshot(request).await.expect("route request"); + + assert_eq!(response.status(), StatusCode::OK); + let bytes = to_bytes(response.into_body(), COMPLETION_USAGE_PATCH_BODY_LIMIT) + .await + .expect("read body"); + assert_eq!(&bytes[..], b"not json"); + let map = cached_tokens_by_request_id.read().await; + assert!(!map.contains_key("cmpl-known")); + assert_eq!(map.get("cmpl-other"), Some(&9)); + } + + #[tokio::test] + async fn non_streaming_missing_response_id_cleans_known_request_id_without_patching() { + let cached_tokens_by_request_id = Arc::new(RwLock::new(HashMap::from([ + ("cmpl-known".to_string(), 7), + ("cmpl-other".to_string(), 9), + ]))); + let vllm_router = Router::new().route( + "/v1/completions", + post(|| async { + Json(serde_json::json!({ + "object": "text_completion", + "usage": {"prompt_tokens": 11, "prompt_tokens_details": {}} + })) + }), + ); + let router = cached_token_usage_routes(vllm_router, cached_tokens_by_request_id.clone()); + let request = Request::builder() + .method("POST") + .uri("/v1/completions") + .body(Body::from( + serde_json::json!({ + "model": "model", + "prompt": "hello", + "request_id": "known" + }) + .to_string(), + )) + .expect("request"); + + let response = router.oneshot(request).await.expect("route request"); + + assert_eq!(response.status(), StatusCode::OK); + let bytes = to_bytes(response.into_body(), COMPLETION_USAGE_PATCH_BODY_LIMIT) + .await + .expect("read body"); + let value: serde_json::Value = serde_json::from_slice(&bytes).expect("json response"); + assert!( + value["usage"]["prompt_tokens_details"] + .get("cached_tokens") + .is_none() + ); + let map = cached_tokens_by_request_id.read().await; + assert!(!map.contains_key("cmpl-known")); + assert_eq!(map.get("cmpl-other"), Some(&9)); + } + + #[tokio::test] + async fn non_streaming_unparseable_response_cleans_header_request_id_without_body_id() { + let cached_tokens_by_request_id = Arc::new(RwLock::new(HashMap::from([ + ("chatcmpl-header-known".to_string(), 7), + ("chatcmpl-other".to_string(), 9), + ]))); + let vllm_router = Router::new().route( + "/v1/chat/completions", + post(|| async { + Response::builder() + .body(Body::from("not json")) + .expect("response") + }), + ); + let router = cached_token_usage_routes(vllm_router, cached_tokens_by_request_id.clone()); + let request = Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("X-Request-Id", "header-known") + .body(Body::from( + serde_json::json!({ + "model": "model", + "messages": [{"role": "user", "content": "hello"}] + }) + .to_string(), + )) + .expect("request"); + + let response = router.oneshot(request).await.expect("route request"); + + assert_eq!(response.status(), StatusCode::OK); + let bytes = to_bytes(response.into_body(), COMPLETION_USAGE_PATCH_BODY_LIMIT) + .await + .expect("read body"); + assert_eq!(&bytes[..], b"not json"); + let map = cached_tokens_by_request_id.read().await; + assert!(!map.contains_key("chatcmpl-header-known")); + assert_eq!(map.get("chatcmpl-other"), Some(&9)); + } + + #[tokio::test] + async fn non_streaming_non_success_response_id_cleans_response_and_fallback_ids() { + let cached_tokens_by_request_id = Arc::new(RwLock::new(HashMap::from([ + ("cmpl-error".to_string(), 7), + ("cmpl-known".to_string(), 8), + ("cmpl-other".to_string(), 9), + ]))); + let vllm_router = Router::new().route( + "/v1/completions", + post(|| async { + Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Body::from(r#"{"id":"cmpl-error","error":"bad request"}"#)) + .expect("response") + }), + ); + let router = cached_token_usage_routes(vllm_router, cached_tokens_by_request_id.clone()); + let request = Request::builder() + .method("POST") + .uri("/v1/completions") + .body(Body::from( + serde_json::json!({ + "model": "model", + "prompt": "hello", + "request_id": "known" + }) + .to_string(), + )) + .expect("request"); + + let response = router.oneshot(request).await.expect("route request"); + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + let bytes = to_bytes(response.into_body(), COMPLETION_USAGE_PATCH_BODY_LIMIT) + .await + .expect("read body"); + assert_eq!(&bytes[..], br#"{"id":"cmpl-error","error":"bad request"}"#); + let map = cached_tokens_by_request_id.read().await; + assert!(!map.contains_key("cmpl-error")); + assert!(!map.contains_key("cmpl-known")); + assert_eq!(map.get("cmpl-other"), Some(&9)); + } + + #[tokio::test] + async fn non_streaming_non_success_without_response_id_cleans_fallback_and_preserves_response() + { + let cached_tokens_by_request_id = Arc::new(RwLock::new(HashMap::from([ + ("chatcmpl-header-known".to_string(), 7), + ("chatcmpl-other".to_string(), 9), + ]))); + let vllm_router = Router::new().route( + "/v1/chat/completions", + post(|| async { + Response::builder() + .status(StatusCode::SERVICE_UNAVAILABLE) + .body(Body::from(r#"{"error":"temporarily unavailable"}"#)) + .expect("response") + }), + ); + let router = cached_token_usage_routes(vllm_router, cached_tokens_by_request_id.clone()); + let request = Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("X-Request-Id", "header-known") + .body(Body::from( + serde_json::json!({ + "model": "model", + "messages": [{"role": "user", "content": "hello"}] + }) + .to_string(), + )) + .expect("request"); + + let response = router.oneshot(request).await.expect("route request"); + + assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE); + let bytes = to_bytes(response.into_body(), COMPLETION_USAGE_PATCH_BODY_LIMIT) + .await + .expect("read body"); + assert_eq!(&bytes[..], br#"{"error":"temporarily unavailable"}"#); + let map = cached_tokens_by_request_id.read().await; + assert!(!map.contains_key("chatcmpl-header-known")); + assert_eq!(map.get("chatcmpl-other"), Some(&9)); + } + + #[tokio::test] + async fn streaming_chat_completion_usage_chunk_is_patched_when_included() { + let cached_tokens_by_request_id = Arc::new(RwLock::new(HashMap::from([( + "chatcmpl-stream".to_string(), + 5, + )]))); + let vllm_router = Router::new().route( + "/v1/chat/completions", + post(|| async { + Response::builder() + .header(axum::http::header::CONTENT_TYPE, "text/event-stream") + .body(Body::from( + "data: {\"id\":\"chatcmpl-stream\",\"object\":\"chat.completion.chunk\",\"usage\":null}\n\ + data: {\"id\":\"chatcmpl-stream\",\"object\":\"chat.completion.chunk\",\"usage\":{\"prompt_tokens\":11,\"prompt_tokens_details\":{}}}\n\ + data: [DONE]\n", + )) + .expect("streaming response") + }), + ); + let router = cached_token_usage_routes(vllm_router, cached_tokens_by_request_id.clone()); + let request = Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .body(Body::from( + serde_json::json!({ + "model": "model", + "messages": [{"role": "user", "content": "hello"}], + "stream": true, + "stream_options": {"include_usage": true} + }) + .to_string(), + )) + .expect("request"); + + let response = router.oneshot(request).await.expect("route request"); + + assert_eq!(response.status(), StatusCode::OK); + let bytes = to_bytes(response.into_body(), COMPLETION_USAGE_PATCH_BODY_LIMIT) + .await + .expect("read body"); + let text = std::str::from_utf8(&bytes).expect("utf8 stream"); + let usage = text + .lines() + .filter_map(|line| line.strip_prefix("data: ")) + .filter(|data| data.trim() != "[DONE]") + .map(|data| serde_json::from_str::(data).expect("sse json")) + .find_map(|value| value["usage"].is_object().then(|| value["usage"].clone())) + .expect("usage chunk"); + assert_eq!(usage["prompt_tokens_details"]["cached_tokens"], 5); + assert!(cached_tokens_by_request_id.read().await.is_empty()); + } + + #[tokio::test] + async fn streaming_non_success_cleans_fallback_request_id() { + let cached_tokens_by_request_id = Arc::new(RwLock::new(HashMap::from([ + ("chatcmpl-header-known".to_string(), 5), + ("chatcmpl-other".to_string(), 9), + ]))); + let vllm_router = Router::new().route( + "/v1/chat/completions", + post(|| async { + Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Body::from(r#"{"error":"bad request"}"#)) + .expect("response") + }), + ); + let router = cached_token_usage_routes(vllm_router, cached_tokens_by_request_id.clone()); + let request = Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("X-Request-Id", "header-known") + .body(Body::from( + serde_json::json!({ + "model": "model", + "messages": [{"role": "user", "content": "hello"}], + "stream": true, + "stream_options": {"include_usage": true} + }) + .to_string(), + )) + .expect("request"); + + let response = router.oneshot(request).await.expect("route request"); + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + let bytes = to_bytes(response.into_body(), COMPLETION_USAGE_PATCH_BODY_LIMIT) + .await + .expect("read body"); + assert_eq!(&bytes[..], br#"{"error":"bad request"}"#); + let map = cached_tokens_by_request_id.read().await; + assert!(!map.contains_key("chatcmpl-header-known")); + assert_eq!(map.get("chatcmpl-other"), Some(&9)); + } + + #[tokio::test] + async fn streaming_body_drop_before_first_id_cleans_fallback_request_id() { + let cached_tokens_by_request_id = Arc::new(RwLock::new(HashMap::from([ + ("chatcmpl-header-known".to_string(), 5), + ("chatcmpl-other".to_string(), 9), + ]))); + let (_chunk_tx, chunk_rx) = + mpsc::unbounded_channel::>(); + let chunk_rx = Arc::new(std::sync::Mutex::new(Some(chunk_rx))); + let vllm_router = Router::new().route( + "/v1/chat/completions", + post(move || { + let chunk_rx = chunk_rx.clone(); + async move { + let mut chunk_rx = chunk_rx + .lock() + .expect("receiver lock") + .take() + .expect("receiver available"); + Response::builder() + .header(axum::http::header::CONTENT_TYPE, "text/event-stream") + .body(Body::from_stream(async_stream::stream! { + while let Some(item) = chunk_rx.recv().await { + yield item; + } + })) + .expect("streaming response") + } + }), + ); + let router = cached_token_usage_routes(vllm_router, cached_tokens_by_request_id.clone()); + let request = Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("X-Request-Id", "header-known") + .body(Body::from( + serde_json::json!({ + "model": "model", + "messages": [{"role": "user", "content": "hello"}], + "stream": true, + "stream_options": {"include_usage": true} + }) + .to_string(), + )) + .expect("request"); + + let response = router.oneshot(request).await.expect("route request"); + let stream = response.into_body().into_data_stream(); + drop(stream); + + wait_until_cache_entry_removed(&cached_tokens_by_request_id, "chatcmpl-header-known").await; + assert_eq!( + cached_tokens_by_request_id + .read() + .await + .get("chatcmpl-other"), + Some(&9) + ); + } + + #[tokio::test] + async fn streaming_body_drop_cleans_cached_tokens_for_seen_request_id() { + let cached_tokens_by_request_id = Arc::new(RwLock::new(HashMap::from([( + "chatcmpl-drop".to_string(), + 5, + )]))); + let (chunk_tx, chunk_rx) = + mpsc::unbounded_channel::>(); + let chunk_rx = Arc::new(std::sync::Mutex::new(Some(chunk_rx))); + let vllm_router = Router::new().route( + "/v1/chat/completions", + post(move || { + let chunk_rx = chunk_rx.clone(); + async move { + let mut chunk_rx = chunk_rx + .lock() + .expect("receiver lock") + .take() + .expect("receiver available"); + Response::builder() + .header(axum::http::header::CONTENT_TYPE, "text/event-stream") + .body(Body::from_stream(async_stream::stream! { + while let Some(item) = chunk_rx.recv().await { + yield item; + } + })) + .expect("streaming response") + } + }), + ); + let router = cached_token_usage_routes(vllm_router, cached_tokens_by_request_id.clone()); + let request = Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .body(Body::from( + serde_json::json!({ + "model": "model", + "messages": [{"role": "user", "content": "hello"}], + "stream": true, + "stream_options": {"include_usage": true} + }) + .to_string(), + )) + .expect("request"); + + let response = router.oneshot(request).await.expect("route request"); + let mut stream = response.into_body().into_data_stream(); + chunk_tx + .send(Ok(Bytes::from( + "data: {\"id\":\"chatcmpl-drop\",\"object\":\"chat.completion.chunk\",\"usage\":null}\n", + ))) + .expect("send first chunk"); + let _ = read_next_stream_chunk(&mut stream).await; + + drop(stream); + + wait_until_cache_entry_removed(&cached_tokens_by_request_id, "chatcmpl-drop").await; + } +}