diff --git a/Cargo.lock b/Cargo.lock index e59580e4..76eacbc8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -253,6 +253,39 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "asynk-strim" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52697735bdaac441a29391a9e97102c74c6ef0f9b60a40cf109b1b404e29d2f6" +dependencies = [ + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "asynk-strim-attr" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6ccb67be092524ce594e599332719f1cd6d64dcaed8d46f1e8726d466c10bcb" +dependencies = [ + "asynk-strim", + "asynk-strim-attr-macro", + "futures-core", +] + +[[package]] +name = "asynk-strim-attr-macro" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34e40a2181bb16fb68e25c49c8b3e25bbb9a808bf8f9f83bc596ac4ad70c86a1" +dependencies = [ + "proc-macro-crate", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "atomic-waker" version = "1.1.2" @@ -1380,28 +1413,6 @@ dependencies = [ "futures-util", ] -[[package]] -name = "futures-async-stream" -version = "0.2.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1890317fad72bd60d4254c0a06140676bcc321889e73b8be23ddf1e7adb2b62d" -dependencies = [ - "futures-async-stream-macro", - "futures-core", - "pin-project", -] - -[[package]] -name = "futures-async-stream-macro" -version = "0.2.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fc954740d4b5ce849e6d8e01b79dbdb0e156b02737f7fb8d17130e25f36ff00" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "futures-channel" version = "0.3.32" @@ -3097,6 +3108,15 @@ dependencies = [ "num-integer", ] +[[package]] +name = "proc-macro-crate" +version = "3.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e67ba7e9b2b56446f1d419b1d807906278ffa1a658a8a5d8a39dcb1f5a78614f" +dependencies = [ + "toml_edit", +] + [[package]] name = "proc-macro-error-attr2" version = "2.0.0" @@ -4791,6 +4811,36 @@ dependencies = [ "tokio", ] +[[package]] +name = "toml_datetime" +version = "1.1.1+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3165f65f62e28e0115a00b2ebdd37eb6f3b641855f9d636d3cd4103767159ad7" +dependencies = [ + "serde_core", +] + +[[package]] +name = "toml_edit" +version = "0.25.11+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b59c4d22ed448339746c59b905d24568fcbb3ab65a500494f7b8c3e97739f2b" +dependencies = [ + "indexmap 2.13.0", + "toml_datetime", + "toml_parser", + "winnow", +] + +[[package]] +name = "toml_parser" +version = "1.1.2+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2abe9b86193656635d2411dc43050282ca48aa31c2451210f4202550afb7526" +dependencies = [ + "winnow", +] + [[package]] name = "tonic" version = "0.14.5" @@ -5340,13 +5390,13 @@ name = "vllm-chat" version = "0.1.0" dependencies = [ "anyhow", + "asynk-strim-attr", "bytes", "clap", "criterion", "easy-ext", "expect-test", "futures", - "futures-async-stream", "minijinja", "minijinja-contrib", "openai-harmony", @@ -5469,12 +5519,12 @@ version = "0.1.0" dependencies = [ "anyhow", "async-openai", + "asynk-strim-attr", "axum", "bytes", "clap", "expect-test", "futures", - "futures-async-stream", "http-body", "libc", "prost", @@ -5513,6 +5563,7 @@ name = "vllm-text" version = "0.1.0" dependencies = [ "anyhow", + "asynk-strim-attr", "base64 0.22.1", "criterion", "easy-ext", @@ -5520,7 +5571,6 @@ dependencies = [ "expect-test", "fastokens", "futures", - "futures-async-stream", "hf-hub 0.5.0", "rustc-hash 1.1.0", "serde", diff --git a/Cargo.toml b/Cargo.toml index c9938885..da999b67 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ anyhow = "1.0.100" arc-swap = "1.9.0" async-openai = "0.33.1" async-trait = "0.1.89" +asynk-strim-attr = "0.1.0" axum = "0.8.8" base64 = "0.22.1" byteorder = "1.5.0" @@ -31,7 +32,6 @@ enum-as-inner = "0.7.0" expect-test = "1.5.1" fastokens = { git = "https://github.com/BugenZhao/fastokens.git", rev = "12a865a1f13aaae8f7b14bab1f177bba30577ad7" } futures = "0.3.31" -futures-async-stream = "0.2.13" hex = "0.4.3" hf-hub = { version = "0.5.0", features = ["tokio"] } http-body = "1.0.1" diff --git a/src/chat/Cargo.toml b/src/chat/Cargo.toml index 572078fc..601a6fd7 100644 --- a/src/chat/Cargo.toml +++ b/src/chat/Cargo.toml @@ -8,9 +8,9 @@ test-util = [] [dependencies] anyhow.workspace = true +asynk-strim-attr.workspace = true easy-ext.workspace = true futures.workspace = true -futures-async-stream.workspace = true minijinja.workspace = true minijinja-contrib.workspace = true openai-harmony.workspace = true diff --git a/src/chat/src/lib.rs b/src/chat/src/lib.rs index 9c319563..6486a95b 100644 --- a/src/chat/src/lib.rs +++ b/src/chat/src/lib.rs @@ -1,4 +1,3 @@ -#![feature(coroutines)] #![feature(trait_alias)] //! Minimal chat facade above [`vllm_text`]. diff --git a/src/chat/src/output/default/reasoning.rs b/src/chat/src/output/default/reasoning.rs index a6a99440..7028be23 100644 --- a/src/chat/src/output/default/reasoning.rs +++ b/src/chat/src/output/default/reasoning.rs @@ -5,13 +5,14 @@ //! separation: `decoded.rs` still only produces plain text deltas, while later //! stages consume the semantic `Text` / `Reasoning` split emitted here. +use asynk_strim_attr::{TryYielder, try_stream}; use futures::{StreamExt as _, pin_mut}; -use futures_async_stream::try_stream; use thiserror_ext::AsReport; use tracing::warn; use vllm_text::output::DecodedTextEvent; use super::ContentEvent; +use crate::Result; use crate::error::Error; use crate::event::AssistantBlockKind; use crate::output::DecodedTextEventStream; @@ -122,18 +123,19 @@ fn push_reasoning_delta(events: &mut Vec, delta: ReasoningDelta) { } /// Wrap one decoded-text stream into the internal reasoning event stream. -#[try_stream(ok = ContentEvent, error = Error)] +#[try_stream] pub(crate) async fn reasoning_event_stream( decoded_stream: impl DecodedTextEventStream, reasoning_parser: Option>, -) { + mut y: TryYielder, +) -> Result<()> { pin_mut!(decoded_stream); // Without a parser, pass through as plain text deltas. let Some(reasoning_parser) = reasoning_parser else { while let Some(event) = decoded_stream.next().await.transpose()? { for next in ContentEvent::from_decoded_plain_text(event) { - yield next; + y.yield_ok(next).await; } } return Ok(()); @@ -148,10 +150,11 @@ pub(crate) async fn reasoning_event_stream( prompt_logprobs, } => { state.initialize(&prompt_token_ids); - yield ContentEvent::Start { + y.yield_ok(ContentEvent::Start { prompt_token_ids, prompt_logprobs, - } + }) + .await; } DecodedTextEvent::TextDelta { delta, @@ -160,28 +163,31 @@ pub(crate) async fn reasoning_event_stream( finished, } => { for next in state.process_delta(delta) { - yield next; + y.yield_ok(next).await; } if logprobs.is_some() || !token_ids.is_empty() { - yield ContentEvent::LogprobsDelta { + y.yield_ok(ContentEvent::LogprobsDelta { logprobs, token_ids, - }; + }) + .await; } if let Some(finished) = finished { for next in state.finish() { - yield next; + y.yield_ok(next).await; } - yield ContentEvent::Done { + y.yield_ok(ContentEvent::Done { prompt_token_count: finished.prompt_token_count, output_token_count: finished.output_token_count, finish_reason: finished.finish_reason, kv_transfer_params: finished.kv_transfer_params, - }; + }) + .await; } } } } + Ok(()) } #[cfg(test)] diff --git a/src/chat/src/output/default/tool.rs b/src/chat/src/output/default/tool.rs index e2f97729..4b5e7a7e 100644 --- a/src/chat/src/output/default/tool.rs +++ b/src/chat/src/output/default/tool.rs @@ -5,8 +5,8 @@ //! and translates incremental `tool-parser` output into internal tool-call //! events while preserving plain-text fallback behavior. +use asynk_strim_attr::{TryYielder, try_stream}; use futures::{StreamExt as _, pin_mut}; -use futures_async_stream::try_stream; use thiserror_ext::AsReport; use tracing::warn; @@ -187,11 +187,12 @@ fn push_text_delta(events: &mut Vec, kind: AssistantBlockKind, d /// We keep this separate because some adaptor-backed parsers may not correctly handle the full text /// passed to incremental `push` interface, but override `parse_complete()` with a dedicated /// one-shot implementation to ensure correctness. -#[try_stream(ok = AssistantEvent, error = Error)] +#[try_stream] async fn final_only_tool_event_stream( stream: impl ContentEventStream, mut parser: Box, -) { + mut y: TryYielder, +) -> Result<()> { pin_mut!(stream); let mut final_text = String::new(); @@ -202,26 +203,28 @@ async fn final_only_tool_event_stream( prompt_token_ids, prompt_logprobs, } => { - yield AssistantEvent::Start { + y.yield_ok(AssistantEvent::Start { prompt_token_ids, prompt_logprobs, - } + }) + .await; } ContentEvent::TextDelta { kind, delta } => { if kind == AssistantBlockKind::Text { final_text.push_str(&delta); } else { - yield AssistantEvent::TextDelta { kind, delta }; + y.yield_ok(AssistantEvent::TextDelta { kind, delta }).await; } } ContentEvent::LogprobsDelta { logprobs, token_ids, } => { - yield AssistantEvent::LogprobsDelta { + y.yield_ok(AssistantEvent::LogprobsDelta { logprobs, token_ids, - }; + }) + .await; } ContentEvent::Done { prompt_token_count, @@ -232,10 +235,11 @@ async fn final_only_tool_event_stream( match parser.parse_complete(&final_text) { Ok(ToolParseResult { normal_text, calls }) => { if !normal_text.is_empty() { - yield AssistantEvent::TextDelta { + y.yield_ok(AssistantEvent::TextDelta { kind: AssistantBlockKind::Text, delta: normal_text, - }; + }) + .await; } // `parse_complete` currently returns one complete delta // per tool call, so we can surface each one as a start @@ -249,14 +253,16 @@ async fn final_only_tool_event_stream( ), }); }; - yield AssistantEvent::ToolCallStart { + y.yield_ok(AssistantEvent::ToolCallStart { id: generate_tool_call_id(), name, - }; + }) + .await; if !tool_call.arguments.is_empty() { - yield AssistantEvent::ToolCallArgumentsDelta { + y.yield_ok(AssistantEvent::ToolCallArgumentsDelta { delta: tool_call.arguments, - }; + }) + .await; } } } @@ -265,38 +271,42 @@ async fn final_only_tool_event_stream( error = %error.as_report(), "tool parser full-output parse failed; falling back to plain text" ); - yield AssistantEvent::TextDelta { + y.yield_ok(AssistantEvent::TextDelta { kind: AssistantBlockKind::Text, delta: final_text, - }; + }) + .await; } } - yield AssistantEvent::Done { + y.yield_ok(AssistantEvent::Done { prompt_token_count, output_token_count, finish_reason, kv_transfer_params, - }; + }) + .await; return Ok(()); } } } + Ok(()) } /// Wrap one semantic assistant stream into the internal tool-aware assistant /// stream. -#[try_stream(ok = AssistantEvent, error = Error)] +#[try_stream] pub(crate) async fn tool_event_stream( stream: impl ContentEventStream, intermediate: bool, parser: Option>, -) { + mut y: TryYielder, +) -> Result<()> { // Without a parser, pass through the input stream unchanged. let Some(parser) = parser else { pin_mut!(stream); while let Some(event) = stream.next().await.transpose()? { - yield event.into(); + y.yield_ok(event.into()).await; } return Ok(()); }; @@ -306,7 +316,7 @@ pub(crate) async fn tool_event_stream( let final_stream = final_only_tool_event_stream(stream, parser); pin_mut!(final_stream); while let Some(event) = final_stream.next().await.transpose()? { - yield event; + y.yield_ok(event).await; } return Ok(()); } @@ -320,24 +330,26 @@ pub(crate) async fn tool_event_stream( prompt_token_ids, prompt_logprobs, } => { - yield AssistantEvent::Start { + y.yield_ok(AssistantEvent::Start { prompt_token_ids, prompt_logprobs, - } + }) + .await; } ContentEvent::TextDelta { kind, delta } => { for next in state.process_text_delta(kind, delta)? { - yield next; + y.yield_ok(next).await; } } ContentEvent::LogprobsDelta { logprobs, token_ids, } => { - yield AssistantEvent::LogprobsDelta { + y.yield_ok(AssistantEvent::LogprobsDelta { logprobs, token_ids, - }; + }) + .await; } ContentEvent::Done { prompt_token_count, @@ -346,18 +358,20 @@ pub(crate) async fn tool_event_stream( kv_transfer_params, } => { for next in state.finish()? { - yield next; + y.yield_ok(next).await; } - yield AssistantEvent::Done { + y.yield_ok(AssistantEvent::Done { prompt_token_count, output_token_count, finish_reason, kv_transfer_params, - }; + }) + .await; } } } + Ok(()) } #[cfg(test)] diff --git a/src/chat/src/output/harmony/mod.rs b/src/chat/src/output/harmony/mod.rs index d2cf4e20..66b7b113 100644 --- a/src/chat/src/output/harmony/mod.rs +++ b/src/chat/src/output/harmony/mod.rs @@ -7,8 +7,8 @@ use std::sync::LazyLock; use anyhow::Context; +use asynk_strim_attr::{TryYielder, try_stream}; use futures::StreamExt as _; -use futures_async_stream::try_stream; use openai_harmony::chat::{Content as HarmonyContent, Message as HarmonyMessage, Role}; use openai_harmony::{ HarmonyEncoding, HarmonyEncodingName, StreamableParser, load_harmony_encoding, @@ -324,12 +324,13 @@ impl HarmonyState { } /// Convert decoded token updates into internal assistant events with Harmony parsing. -#[try_stream(ok = AssistantEvent, error = Error)] +#[try_stream] async fn harmony_assistant_event_stream( decoded: DynDecodedTextEventStream, encoding: &'static HarmonyEncoding, tool_calls_enabled: bool, -) { + mut y: TryYielder, +) -> Result<()> { let mut state = HarmonyState::new(encoding.clone(), tool_calls_enabled)?; futures::pin_mut!(decoded); @@ -339,10 +340,11 @@ async fn harmony_assistant_event_stream( prompt_token_ids, prompt_logprobs, } => { - yield AssistantEvent::Start { + y.yield_ok(AssistantEvent::Start { prompt_token_ids, prompt_logprobs, - }; + }) + .await; } DecodedTextEvent::TextDelta { delta: _, // harmony takes raw token IDs as input, so we ignore text deltas here @@ -351,33 +353,36 @@ async fn harmony_assistant_event_stream( finished, } => { for event in state.process_token_ids(&token_ids)? { - yield event; + y.yield_ok(event).await; } if finished.is_some() { for event in state.process_eos()? { - yield event; + y.yield_ok(event).await; } } if logprobs.is_some() || !token_ids.is_empty() { - yield AssistantEvent::LogprobsDelta { + y.yield_ok(AssistantEvent::LogprobsDelta { logprobs, token_ids, - }; + }) + .await; } if let Some(finished) = finished { - yield AssistantEvent::Done { + y.yield_ok(AssistantEvent::Done { prompt_token_count: finished.prompt_token_count, output_token_count: finished.output_token_count, finish_reason: finished.finish_reason, kv_transfer_params: finished.kv_transfer_params, - }; + }) + .await; } } } } + Ok(()) } /// Lazily load the shared GPT-OSS Harmony encoding once per process. diff --git a/src/chat/src/output/structured.rs b/src/chat/src/output/structured.rs index 69bf5199..5559c8e1 100644 --- a/src/chat/src/output/structured.rs +++ b/src/chat/src/output/structured.rs @@ -5,8 +5,8 @@ //! parsing are handled earlier by their own adapters. This stage consumes those //! parsed deltas and assembles higher-level assistant content blocks. +use asynk_strim_attr::{TryYielder, try_stream}; use futures::{StreamExt as _, pin_mut}; -use futures_async_stream::try_stream; use vllm_text::DecodedLogprobs; use super::{AssistantEvent, AssistantEventStream}; @@ -227,8 +227,11 @@ impl StructuredEventState { } /// Wrap one parsed assistant stream into the public structured chat event stream. -#[try_stream(ok = ChatEvent, error = Error)] -pub(crate) async fn structured_chat_event_stream(stream: impl AssistantEventStream) { +#[try_stream] +pub(crate) async fn structured_chat_event_stream( + stream: impl AssistantEventStream, + mut y: TryYielder, +) -> Result<()> { pin_mut!(stream); let mut state = StructuredEventState::new(); @@ -239,14 +242,15 @@ pub(crate) async fn structured_chat_event_stream(stream: impl AssistantEventStre prompt_token_ids, prompt_logprobs, } => { - yield ChatEvent::Start { + y.yield_ok(ChatEvent::Start { prompt_token_ids, prompt_logprobs, - } + }) + .await; } AssistantEvent::TextDelta { kind, delta } => { for next in state.process_text_delta(kind, delta)? { - yield next; + y.yield_ok(next).await; } } AssistantEvent::LogprobsDelta { @@ -254,17 +258,17 @@ pub(crate) async fn structured_chat_event_stream(stream: impl AssistantEventStre token_ids, } => { for next in state.process_logprobs_delta(logprobs, token_ids)? { - yield next; + y.yield_ok(next).await; } } AssistantEvent::ToolCallStart { id, name } => { for next in state.start_tool_call(id, name)? { - yield next; + y.yield_ok(next).await; } } AssistantEvent::ToolCallArgumentsDelta { delta } => { for next in state.push_tool_call_arguments(delta)? { - yield next; + y.yield_ok(next).await; } } AssistantEvent::Done { @@ -279,11 +283,12 @@ pub(crate) async fn structured_chat_event_stream(stream: impl AssistantEventStre finish_reason, kv_transfer_params, )? { - yield next; + y.yield_ok(next).await; } } } } + Ok(()) } #[cfg(test)] diff --git a/src/server/Cargo.toml b/src/server/Cargo.toml index d35bd00a..7269dde1 100644 --- a/src/server/Cargo.toml +++ b/src/server/Cargo.toml @@ -5,9 +5,9 @@ edition.workspace = true [dependencies] anyhow.workspace = true +asynk-strim-attr.workspace = true axum.workspace = true futures.workspace = true -futures-async-stream.workspace = true http-body.workspace = true libc.workspace = true prost.workspace = true diff --git a/src/server/src/lib.rs b/src/server/src/lib.rs index 0f097ea2..3a9d26d6 100644 --- a/src/server/src/lib.rs +++ b/src/server/src/lib.rs @@ -1,4 +1,3 @@ -#![feature(coroutines)] #![feature(iterator_try_collect)] //! Minimal OpenAI-compatible HTTP server above [`vllm_chat`]. diff --git a/src/server/src/routes/openai/chat_completions.rs b/src/server/src/routes/openai/chat_completions.rs index fa55d193..106dceaa 100644 --- a/src/server/src/routes/openai/chat_completions.rs +++ b/src/server/src/routes/openai/chat_completions.rs @@ -3,15 +3,16 @@ mod types; mod validate; use std::convert::Infallible; +use std::result::Result; use std::sync::Arc; +use asynk_strim_attr::{TryYielder, try_stream}; use axum::Json; use axum::extract::State; use axum::http::HeaderMap; use axum::response::sse::{Event, Sse}; use axum::response::{IntoResponse, Response}; use futures::{Stream, StreamExt as _, pin_mut}; -use futures_async_stream::try_stream; use serde_json::Value; use thiserror_ext::AsReport as _; use tracing::{debug, error, info, trace}; @@ -223,7 +224,7 @@ async fn collect_chat_completion( } /// Convert one internal chat event stream into OpenAI chat-completion chunks. -#[try_stream(ok = ChatCompletionStreamResponse, error = ApiError)] +#[try_stream] async fn chat_completion_chunk_stream( mut stream: impl ChatEventStreamTrait + Unpin, request_id: String, @@ -235,7 +236,8 @@ async fn chat_completion_chunk_stream( echo: Option, return_token_ids: bool, return_tokens_as_token_ids: bool, -) { + mut y: TryYielder, +) -> Result<(), ApiError> { let mut saw_tool_calls = false; // If the client requested logprobs or token_ids, we need to buffer chunks until we receive @@ -253,23 +255,31 @@ async fn chat_completion_chunk_stream( if return_token_ids { chunk.prompt_token_ids = Some(prompt_token_ids.to_vec()); } - yield chunk; + y.yield_ok(chunk).await; // When echo=true, emit the last assistant message content as a delta chunk. if let Some(echo_text) = &echo { - yield block_delta_chunk( + y.yield_ok(block_delta_chunk( &request_id, &response_model, created, AssistantBlockKind::Text, echo_text.clone(), - ); + )) + .await; } } Ok(ChatEvent::BlockDelta { kind, delta, .. }) => { if let Some(pending_chunk) = pending_chunk.as_mut() { pending_chunk.push_block_delta(kind, delta); } else { - yield block_delta_chunk(&request_id, &response_model, created, kind, delta) + y.yield_ok(block_delta_chunk( + &request_id, + &response_model, + created, + kind, + delta, + )) + .await; } } Ok(ChatEvent::LogprobsDelta { @@ -289,10 +299,16 @@ async fn chat_completion_chunk_stream( if let Some(chunk) = pending_chunk.take_chunk(&request_id, &response_model, created) { - yield chunk; + y.yield_ok(chunk).await; } } else if let Some(logprobs) = openai_logprobs { - yield logprobs_only_chunk(&request_id, &response_model, created, logprobs); + y.yield_ok(logprobs_only_chunk( + &request_id, + &response_model, + created, + logprobs, + )) + .await; } } Ok(ChatEvent::BlockStart { kind, .. }) => { @@ -312,14 +328,15 @@ async fn chat_completion_chunk_stream( if let Some(pending_chunk) = pending_chunk.as_mut() { pending_chunk.push_tool_call_start(tool_index, id, name); } else { - yield tool_call_start_chunk( + y.yield_ok(tool_call_start_chunk( &request_id, &response_model, created, tool_index, id, name, - ); + )) + .await; } } Ok(ChatEvent::ToolCallArgumentsDelta { index, delta }) => { @@ -327,13 +344,14 @@ async fn chat_completion_chunk_stream( if let Some(pending_chunk) = pending_chunk.as_mut() { pending_chunk.push_tool_call_arguments(tool_index, delta); } else { - yield tool_call_arguments_chunk( + y.yield_ok(tool_call_arguments_chunk( &request_id, &response_model, created, tool_index, delta, - ); + )) + .await; } } Ok(ChatEvent::ToolCallEnd { .. }) => { @@ -360,7 +378,7 @@ async fn chat_completion_chunk_stream( && let Some(chunk) = pending_chunk.take_chunk(&request_id, &response_model, created) { - yield chunk; + y.yield_ok(chunk).await; } match final_chunk( @@ -370,7 +388,7 @@ async fn chat_completion_chunk_stream( finish_reason, saw_tool_calls, ) { - Ok(chunk) => yield chunk, + Ok(chunk) => y.yield_ok(chunk).await, Err(error) => { error!( error = %error.to_error_response().error.message, @@ -381,12 +399,13 @@ async fn chat_completion_chunk_stream( } if include_usage { - yield usage_chunk( + y.yield_ok(usage_chunk( &request_id, &response_model, created, Usage::from_counts(prompt_token_count as u32, output_token_count as u32), - ); + )) + .await; } return Ok(()); @@ -400,6 +419,7 @@ async fn chat_completion_chunk_stream( } } } + Ok(()) } fn usage_chunk( @@ -540,23 +560,25 @@ fn append_delta_text(slot: &mut Option, delta: String) { /// OpenAI-style streaming errors are encoded as ordinary `data: {"error": ...}` /// events followed by `data: [DONE]`, so the transport stream itself stays /// infallible even when generation fails after the HTTP response has started. -#[try_stream(ok = Event, error = Infallible)] +#[try_stream] async fn chat_completion_sse_stream( stream: impl Stream>, -) { + mut y: TryYielder, +) -> Result<(), Infallible> { pin_mut!(stream); while let Some(next) = stream.next().await { match next { - Ok(chunk) => yield to_sse_event(&chunk), + Ok(chunk) => y.yield_ok(to_sse_event(&chunk)).await, Err(error) => { - yield to_error_sse_event(&error); + y.yield_ok(to_error_sse_event(&error)).await; break; } } } - yield done_sse_event(); + y.yield_ok(done_sse_event()).await; + Ok(()) } /// Serialize one OpenAI chunk payload into one SSE `data:` event. diff --git a/src/server/src/routes/openai/completions.rs b/src/server/src/routes/openai/completions.rs index bacab9a1..3a3709a5 100644 --- a/src/server/src/routes/openai/completions.rs +++ b/src/server/src/routes/openai/completions.rs @@ -3,15 +3,16 @@ mod types; mod validate; use std::convert::Infallible; +use std::result::Result; use std::sync::Arc; +use asynk_strim_attr::{TryYielder, try_stream}; use axum::Json; use axum::extract::State; use axum::http::HeaderMap; use axum::response::sse::{Event, Sse}; use axum::response::{IntoResponse, Response}; use futures::{Stream, StreamExt as _, pin_mut}; -use futures_async_stream::try_stream; use thiserror_ext::AsReport as _; use tracing::{debug, error, info, trace}; use tracing_futures::Instrument as _; @@ -205,7 +206,7 @@ async fn collect_completion( } /// Convert one internal decoded-text stream into OpenAI completions chunks. -#[try_stream(ok = CompletionSseChunk, error = ApiError)] +#[try_stream] async fn completion_chunk_stream( stream: impl TextOutputStream, request_id: String, @@ -217,7 +218,8 @@ async fn completion_chunk_stream( requested_logprobs: Option, return_token_ids: bool, return_tokens_as_token_ids: bool, -) { + mut y: TryYielder, +) -> Result<(), ApiError> { pin_mut!(stream); let mut visible_text_len = 0_u32; let mut first_chunk = true; @@ -238,7 +240,7 @@ async fn completion_chunk_stream( } first_chunk = false; } - yield CompletionSseChunk::Chunk(chunk); + y.yield_ok(CompletionSseChunk::Chunk(chunk)).await; } else if return_token_ids { // Emit a chunk with prompt_token_ids in the first streaming response let mut chunk = @@ -247,7 +249,7 @@ async fn completion_chunk_stream( choice.prompt_token_ids = Some(prompt_token_ids.to_vec()); } first_chunk = false; - yield CompletionSseChunk::Chunk(chunk); + y.yield_ok(CompletionSseChunk::Chunk(chunk)).await; } } Ok(DecodedTextEvent::TextDelta { @@ -275,7 +277,7 @@ async fn completion_chunk_stream( if return_token_ids && let Some(choice) = chunk.choices.first_mut() { choice.token_ids = Some(token_ids); } - yield CompletionSseChunk::Chunk(chunk); + y.yield_ok(CompletionSseChunk::Chunk(chunk)).await; visible_text_len = visible_text_len.saturating_add(delta_text_len); if let Some(finished) = finished { @@ -289,15 +291,16 @@ async fn completion_chunk_stream( "completion finished" ); } - yield CompletionSseChunk::Chunk(final_chunk( + y.yield_ok(CompletionSseChunk::Chunk(final_chunk( &request_id, &response_model, created, finished.finish_reason, - )?); + )?)) + .await; if include_usage { - yield CompletionSseChunk::Usage(usage_chunk( + y.yield_ok(CompletionSseChunk::Usage(usage_chunk( &request_id, &response_model, created, @@ -305,7 +308,8 @@ async fn completion_chunk_stream( finished.prompt_token_count as u32, finished.output_token_count as u32, ), - )); + ))) + .await; } } } @@ -318,6 +322,7 @@ async fn completion_chunk_stream( } } } + Ok(()) } fn delta_chunk( @@ -381,21 +386,25 @@ fn usage_chunk( /// OpenAI-style streaming errors are encoded as ordinary `data: {"error": ...}` /// events followed by `data: [DONE]`, so the transport stream itself stays /// infallible even when generation fails after the HTTP response has started. -#[try_stream(ok = Event, error = Infallible)] -async fn completion_sse_stream(stream: impl Stream>) { +#[try_stream] +async fn completion_sse_stream( + stream: impl Stream>, + mut y: TryYielder, +) -> Result<(), Infallible> { pin_mut!(stream); while let Some(next) = stream.next().await { match next { - Ok(chunk) => yield to_sse_event(&chunk), + Ok(chunk) => y.yield_ok(to_sse_event(&chunk)).await, Err(error) => { - yield to_error_sse_event(&error); + y.yield_ok(to_error_sse_event(&error)).await; break; } } } - yield done_sse_event(); + y.yield_ok(done_sse_event()).await; + Ok(()) } /// Serialize one OpenAI chunk payload into one SSE `data:` event. diff --git a/src/text/Cargo.toml b/src/text/Cargo.toml index bcd39fb0..b5704cac 100644 --- a/src/text/Cargo.toml +++ b/src/text/Cargo.toml @@ -5,12 +5,12 @@ edition.workspace = true [dependencies] anyhow.workspace = true +asynk-strim-attr.workspace = true base64.workspace = true easy-ext.workspace = true enum-as-inner.workspace = true fastokens.workspace = true futures.workspace = true -futures-async-stream.workspace = true hf-hub.workspace = true rustc-hash.workspace = true serde.workspace = true diff --git a/src/text/src/lib.rs b/src/text/src/lib.rs index f39f5d4c..230e3005 100644 --- a/src/text/src/lib.rs +++ b/src/text/src/lib.rs @@ -1,4 +1,3 @@ -#![feature(coroutines)] #![feature(trait_alias)] #![feature(iterator_try_collect)] diff --git a/src/text/src/output/decoded.rs b/src/text/src/output/decoded.rs index 376dbf72..02c4ccc6 100644 --- a/src/text/src/output/decoded.rs +++ b/src/text/src/output/decoded.rs @@ -1,7 +1,7 @@ use std::sync::Arc; +use asynk_strim_attr::{TryYielder, try_stream}; use futures::{Stream, StreamExt}; -use futures_async_stream::try_stream; use serde::{Deserialize, Serialize}; use tracing::{Level, debug, trace}; use vllm_engine_core_client::AbortCause; @@ -81,14 +81,15 @@ pub enum DecodedTextEvent { } /// Convert the output token stream from the `vllm_llm` layer into incrementally decoded text. -#[try_stream(ok = DecodedTextEvent, error = Error)] +#[try_stream] pub async fn decoded_text_event_stream( request_id: String, tokenizer: DynTokenizer, mut raw_stream: impl Stream> + Unpin, mut decode_options: TextDecodeOptions, intermediate: bool, -) { + mut y: TryYielder, +) -> crate::Result<()> { let mut decoder: Option> = None; let mut prompt_token_count = 0_usize; let mut token_ids = Vec::new(); @@ -125,7 +126,7 @@ pub async fn decoded_text_event_stream( ); decoder = Some(dec); - yield DecodedTextEvent::Start { + y.yield_ok(DecodedTextEvent::Start { prompt_token_ids: prompt_token_ids.clone(), prompt_logprobs: output .prompt_logprobs() @@ -138,7 +139,8 @@ pub async fn decoded_text_event_stream( ) }) .transpose()?, - }; + }) + .await; }; let decoder = decoder.as_mut().unwrap(); @@ -258,7 +260,7 @@ pub async fn decoded_text_event_stream( AbortCause::StopStringMatched.drop_as(raw_stream); } - yield DecodedTextEvent::TextDelta { + y.yield_ok(DecodedTextEvent::TextDelta { delta: text, token_ids, logprobs, @@ -268,21 +270,23 @@ pub async fn decoded_text_event_stream( finish_reason: reason, kv_transfer_params, }), - }; + }) + .await; return Ok(()); } if intermediate { - yield DecodedTextEvent::TextDelta { + y.yield_ok(DecodedTextEvent::TextDelta { delta: delta.unwrap_or_default(), token_ids: new_token_ids, logprobs: decoded_logprobs, finished: None, - }; + }) + .await; } } - Err(Error::StreamClosedBeforeTerminalOutput { request_id })?; + Err(Error::StreamClosedBeforeTerminalOutput { request_id }) } /// If stop string matches, returns tuple