diff --git a/async-nats/Cargo.toml b/async-nats/Cargo.toml index d24573d32..3ef199aa2 100644 --- a/async-nats/Cargo.toml +++ b/async-nats/Cargo.toml @@ -25,6 +25,7 @@ serde = { version = "1.0.184", features = ["derive"] } serde_json = "1.0.104" serde_repr = "0.1.16" tokio = { version = "1.36", features = ["macros", "rt", "fs", "net", "sync", "time", "io-util"] } +tokio-stream = "0.1" url = { version = "2"} tokio-rustls = { version = "0.26", default-features = false } rustls-pemfile = "2" diff --git a/async-nats/src/jetstream/context.rs b/async-nats/src/jetstream/context.rs index b7c4e5185..426d36e2d 100644 --- a/async-nats/src/jetstream/context.rs +++ b/async-nats/src/jetstream/context.rs @@ -22,18 +22,21 @@ use crate::subject::ToSubject; use crate::{header, Client, Command, HeaderMap, HeaderValue, Message, StatusCode}; use bytes::Bytes; use futures::future::BoxFuture; -use futures::{Future, TryFutureExt}; +use futures::{Future, StreamExt, TryFutureExt}; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use serde_json::{self, json}; use std::borrow::Borrow; +use std::fmt::Debug; use std::fmt::Display; use std::future::IntoFuture; use std::pin::Pin; use std::str::from_utf8; +use std::sync::Arc; use std::task::Poll; -use std::time::Duration; -use tokio::sync::oneshot; +use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit, TryAcquireError}; +use tokio::time::Duration; +use tokio_stream::wrappers::ReceiverStream; use tracing::debug; use super::consumer::{self, Consumer, FromConsumer, IntoConsumerConfig}; @@ -54,36 +57,200 @@ pub struct Context { pub(crate) client: Client, pub(crate) prefix: String, pub(crate) timeout: Duration, + pub(crate) max_ack_semaphore: Arc<tokio::sync::Semaphore>, + pub(crate) acker_task: Arc<tokio::task::JoinHandle<()>>, + pub(crate) ack_sender: + tokio::sync::mpsc::Sender<(oneshot::Receiver<Message>, OwnedSemaphorePermit)>, } -impl Context { - pub(crate) fn new(client: Client) -> Context { - Context { - client, +fn spawn_acker( + rx: tokio::sync::mpsc::Receiver<(oneshot::Receiver<Message>, OwnedSemaphorePermit)>, + ack_timeout: Duration, +) -> tokio::task::JoinHandle<()> { + tokio::spawn(async move { + let stream = ReceiverStream::new(rx); + stream + .for_each_concurrent(None, |(subscription, permit)| async move { + tokio::time::timeout(ack_timeout, subscription).await.ok(); + drop(permit); + }) + .await; + }) +} + +impl Drop for Context { + fn drop(&mut self) { + self.acker_task.abort(); + } +} + +use std::marker::PhantomData; + +#[derive(Debug, Default)] +pub struct Yes; +#[derive(Debug, Default)] +pub struct No; + +pub trait ToAssign: Debug {} + +impl ToAssign for Yes {} +impl ToAssign for No {} + +/// A builder for [Context]. Beyond what can be set by standard constructor, it allows tweaking +/// pending publish ack backpressure settings. +/// # Examples +/// ```no_run +/// # use async_nats::jetstream::context::ContextBuilder; +/// # use async_nats::Client; +/// # use std::time::Duration; +/// # #[tokio::main] +/// # async fn main() -> Result<(), async_nats::Error> { +/// let client = async_nats::connect("demo.nats.io").await?; +/// let context = ContextBuilder::new() +/// .timeout(Duration::from_secs(5)) +/// .api_prefix("MY.JS.API") +/// .max_ack_inflight(1000) +/// .build(client); +/// # Ok(()) +/// # } +/// ``` +/// +pub struct ContextBuilder<PREFIX: ToAssign> { + prefix: String, + timeout: Duration, + semaphore_capacity: usize, + ack_timeout: Duration, + _phantom: PhantomData<PREFIX>, +} + +impl Default for ContextBuilder<Yes> { + fn default() -> Self { + ContextBuilder { prefix: "$JS.API".to_string(), timeout: Duration::from_secs(5), + semaphore_capacity: 50_000, + ack_timeout: Duration::from_secs(30), + _phantom: PhantomData {}, } } +} - pub fn set_timeout(&mut self, timeout: Duration) { - self.timeout = timeout +impl ContextBuilder<Yes> { + /// Create a new [ContextBuilder] with default settings. + pub fn new() -> ContextBuilder<Yes> { + ContextBuilder::default() } +} - pub(crate) fn with_prefix<T: ToString>(client: Client, prefix: T) -> Context { - Context { - client, - prefix: prefix.to_string(), - timeout: Duration::from_secs(5), +impl ContextBuilder<Yes> { + /// Set the prefix for the JetStream API. + pub fn api_prefix<T: Into<String>>(self, prefix: T) -> ContextBuilder<No> { + ContextBuilder { + prefix: prefix.into(), + timeout: self.timeout, + semaphore_capacity: self.semaphore_capacity, + ack_timeout: self.ack_timeout, + _phantom: PhantomData, } } - pub(crate) fn with_domain<T: AsRef<str>>(client: Client, domain: T) -> Context { + /// Set the domain for the JetStream API. Domain is the middle part of standard API prefix: + /// $JS.{domain}.API. + pub fn domain<T: Into<String>>(self, domain: T) -> ContextBuilder<No> { + ContextBuilder { + prefix: format!("$JS.{}.API", domain.into()), + timeout: self.timeout, + semaphore_capacity: self.semaphore_capacity, + ack_timeout: self.ack_timeout, + _phantom: PhantomData, + } + } +} + +impl<PREFIX> ContextBuilder<PREFIX> +where + PREFIX: ToAssign, +{ + /// Set the timeout for all JetStream API requests. + pub fn timeout(self, timeout: Duration) -> ContextBuilder<Yes> + where + Yes: ToAssign, + { + ContextBuilder { + prefix: self.prefix, + timeout, + semaphore_capacity: self.semaphore_capacity, + ack_timeout: self.ack_timeout, + _phantom: PhantomData, + } + } + + /// Sets the maximum time client waits for acks from the server when default backpressure is + /// used. + pub fn ack_timeout(self, ack_timeout: Duration) -> ContextBuilder<Yes> + where + Yes: ToAssign, + { + ContextBuilder { + prefix: self.prefix, + timeout: self.timeout, + semaphore_capacity: self.semaphore_capacity, + ack_timeout, + _phantom: PhantomData, + } + } + + /// Sets the maximum number of pending acks that can be in flight at any given time. + /// If limit is reached, `publish` throws an error. + pub fn max_ack_inflight(self, capacity: usize) -> ContextBuilder<Yes> + where + Yes: ToAssign, + { + ContextBuilder { + prefix: self.prefix, + timeout: self.timeout, + semaphore_capacity: capacity, + ack_timeout: self.ack_timeout, + _phantom: PhantomData, + } + } + + /// Build the [Context] with the given settings. + pub fn build(self, client: Client) -> Context { + let (tx, rx) = tokio::sync::mpsc::channel::<( + oneshot::Receiver<Message>, + OwnedSemaphorePermit, + )>(self.semaphore_capacity); + let acker_task = Arc::new(spawn_acker(rx, self.ack_timeout)); Context { client, - prefix: format!("$JS.{}.API", domain.as_ref()), - timeout: Duration::from_secs(5), + prefix: self.prefix, + timeout: self.timeout, + max_ack_semaphore: Arc::new(tokio::sync::Semaphore::new(self.semaphore_capacity)), + acker_task, + ack_sender: tx, } } +} + +impl Context { + pub(crate) fn new(client: Client) -> Context { + ContextBuilder::default().build(client) + } + + pub fn set_timeout(&mut self, timeout: Duration) { + self.timeout = timeout + } + + pub(crate) fn with_prefix<T: ToString>(client: Client, prefix: T) -> Context { + ContextBuilder::new() + .api_prefix(prefix.to_string()) + .build(client) + } + + pub(crate) fn with_domain<T: AsRef<str>>(client: Client, domain: T) -> Context { + ContextBuilder::new().domain(domain.as_ref()).build(client) + } /// Publishes [jetstream::Message][super::message::Message] to the [Stream] without waiting for /// acknowledgment from the server that the message has been successfully delivered. @@ -192,6 +359,16 @@ impl Context { subject: S, publish: Publish, ) -> Result<PublishAckFuture, PublishError> { + let permit = + self.max_ack_semaphore + .clone() + .try_acquire_owned() + .map_err(|err| match err { + TryAcquireError::NoPermits => { + PublishError::new(PublishErrorKind::MaxAckPending) + } + _ => PublishError::with_source(PublishErrorKind::Other, err), + })?; let subject = subject.to_subject(); let (sender, receiver) = oneshot::channel(); @@ -215,7 +392,9 @@ impl Context { Ok(PublishAckFuture { timeout: self.timeout, - subscription: receiver, + subscription: Some(receiver), + permit: Some(permit), + tx: self.ack_sender.clone(), }) } @@ -1212,6 +1391,7 @@ pub enum PublishErrorKind { WrongLastSequence, TimedOut, BrokenPipe, + MaxAckPending, Other, } @@ -1224,6 +1404,7 @@ impl Display for PublishErrorKind { Self::BrokenPipe => write!(f, "broken pipe"), Self::WrongLastMessageId => write!(f, "wrong last message id"), Self::WrongLastSequence => write!(f, "wrong last sequence"), + Self::MaxAckPending => write!(f, "max ack pending reached"), } } } @@ -1233,12 +1414,25 @@ pub type PublishError = Error<PublishErrorKind>; #[derive(Debug)] pub struct PublishAckFuture { timeout: Duration, - subscription: oneshot::Receiver<Message>, + subscription: Option<oneshot::Receiver<Message>>, + permit: Option<OwnedSemaphorePermit>, + tx: mpsc::Sender<(oneshot::Receiver<Message>, OwnedSemaphorePermit)>, +} + +impl Drop for PublishAckFuture { + fn drop(&mut self) { + match (self.subscription.take(), self.permit.take()) { + (Some(sub), Some(permit)) => { + self.tx.try_send((sub, permit)).ok(); + } + _ => {} + } + } } impl PublishAckFuture { - async fn next_with_timeout(self) -> Result<PublishAck, PublishError> { - let next = tokio::time::timeout(self.timeout, self.subscription) + async fn next_with_timeout(mut self) -> Result<PublishAck, PublishError> { + let next = tokio::time::timeout(self.timeout, self.subscription.take().unwrap()) .await .map_err(|_| PublishError::new(PublishErrorKind::TimedOut))?; next.map_or_else( diff --git a/async-nats/tests/jetstream_tests.rs b/async-nats/tests/jetstream_tests.rs index b1588852f..aa79fd3db 100755 --- a/async-nats/tests/jetstream_tests.rs +++ b/async-nats/tests/jetstream_tests.rs @@ -3654,4 +3654,28 @@ mod jetstream { .await .expect_err("should fail but not panic because of lack of server info"); } + + #[tokio::test] + async fn test_async_publish_max_ack_pending() { + let server = nats_server::run_server("tests/configs/jetstream.conf"); + let client = async_nats::connect(server.client_url()).await.unwrap(); + + let jetstream = async_nats::jetstream::new(client); + + jetstream + .create_stream(stream::Config { + name: "events".to_string(), + subjects: vec!["events".to_string()], + ..Default::default() + }) + .await + .unwrap(); + + for i in 0..100_000 { + jetstream + .publish("events", format!("{i}").into()) + .await + .unwrap(); + } + } }