diff --git a/benches/throughput.rs b/benches/throughput.rs index 1ee6aef..0973148 100644 --- a/benches/throughput.rs +++ b/benches/throughput.rs @@ -73,12 +73,12 @@ impl Routine for Async { struct Block; impl Routine for Block { - fn produce(mut tx: RingSender, limit: usize) -> JoinHandle { + fn produce(tx: RingSender, limit: usize) -> JoinHandle { let producer = iter::from_fn(move || tx.send(T::default()).ok()); task::spawn_blocking(move || producer.take(limit).count()) } - fn consume(mut rx: RingReceiver, limit: usize) -> JoinHandle { + fn consume(rx: RingReceiver, limit: usize) -> JoinHandle { let consumer = iter::from_fn(move || loop { match rx.try_recv() { Ok(m) => return Some(m), diff --git a/src/channel.rs b/src/channel.rs index 9c662aa..a965814 100644 --- a/src/channel.rs +++ b/src/channel.rs @@ -41,7 +41,7 @@ impl RingSender { /// * If the internal ring buffer is full, the oldest pending message is overwritten /// and returned as `Ok(Some(_))`, otherwise `Ok(None)` is returned. /// * If the channel is disconnected, [`SendError::Disconnected`] is returned. - pub fn send(&mut self, message: T) -> Result, SendError> { + pub fn send(&self, message: T) -> Result, SendError> { if self.handle.receivers.load(Ordering::Acquire) > 0 { let overwritten = self.handle.buffer.push(message); @@ -154,27 +154,13 @@ impl RingReceiver { } } - /// Receives a message through the channel (requires [feature] `"futures_api"`). - /// - /// * If the internal ring buffer isn't empty, the oldest pending message is returned. - /// * If the internal ring buffer is empty, the call blocks until a message is sent - /// or the channel disconnects. - /// * If the channel is disconnected and the internal ring buffer is empty, - /// [`RecvError::Disconnected`] is returned. - /// - /// [feature]: index.html#optional-features - #[cfg(feature = "futures_api")] - pub fn recv(&mut self) -> Result { - futures::executor::block_on(futures::StreamExt::next(self)).ok_or(RecvError::Disconnected) - } - /// Receives a message through the channel without blocking. /// /// * If the internal ring buffer isn't empty, the oldest pending message is returned. /// * If the internal ring buffer is empty, [`TryRecvError::Empty`] is returned. /// * If the channel is disconnected and the internal ring buffer is empty, /// [`TryRecvError::Disconnected`] is returned. - pub fn try_recv(&mut self) -> Result { + pub fn try_recv(&self) -> Result { // We must check whether the channel is connected using acquire ordering before we look at // the buffer, in order to ensure that the loads associated with popping from the buffer // happen after the stores associated with a push into the buffer that may have happened @@ -185,6 +171,49 @@ impl RingReceiver { self.handle.buffer.pop().ok_or(TryRecvError::Disconnected) } } + + /// Receives a message through the channel (requires [feature] `"futures_api"`). + /// + /// * If the internal ring buffer isn't empty, the oldest pending message is returned. + /// * If the internal ring buffer is empty, the call blocks until a message is sent + /// or the channel disconnects. + /// * If the channel is disconnected and the internal ring buffer is empty, + /// [`RecvError::Disconnected`] is returned. + /// + /// [feature]: index.html#optional-features + #[cfg(feature = "futures_api")] + pub fn recv(&self) -> Result { + futures::executor::block_on(futures::future::poll_fn(|ctx| self.poll(ctx))) + .ok_or(RecvError::Disconnected) + } + + #[cfg(feature = "futures_api")] + fn poll(&self, ctx: &mut Context<'_>) -> Poll> { + match self.try_recv() { + result @ Ok(_) | result @ Err(TryRecvError::Disconnected) => { + self.handle.waitlist.remove(self.slot); + Poll::Ready(result.ok()) + } + + Err(TryRecvError::Empty) => { + self.handle.waitlist.insert(self.slot, ctx.waker().clone()); + + // A full memory barrier is necessary to ensure that storing the waker + // happens before attempting to retrieve a message from the buffer. + fence(Ordering::SeqCst); + + // Look at the buffer again in case a new message has been sent in the meantime. + match self.try_recv() { + result @ Ok(_) | result @ Err(TryRecvError::Disconnected) => { + self.handle.waitlist.remove(self.slot); + Poll::Ready(result.ok()) + } + + Err(TryRecvError::Empty) => Poll::Pending, + } + } + } + } } impl Clone for RingReceiver { @@ -221,31 +250,8 @@ impl Drop for RingReceiver { impl Stream for RingReceiver { type Item = T; - fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll> { - match self.try_recv() { - result @ Ok(_) | result @ Err(TryRecvError::Disconnected) => { - self.handle.waitlist.remove(self.slot); - Poll::Ready(result.ok()) - } - - Err(TryRecvError::Empty) => { - self.handle.waitlist.insert(self.slot, ctx.waker().clone()); - - // A full memory barrier is necessary to ensure that storing the waker - // happens before attempting to retrieve a message from the buffer. - fence(Ordering::SeqCst); - - // Look at the buffer again in case a new message has been sent in the meantime. - match self.try_recv() { - result @ Ok(_) | result @ Err(TryRecvError::Disconnected) => { - self.handle.waitlist.remove(self.slot); - Poll::Ready(result.ok()) - } - - Err(TryRecvError::Empty) => Poll::Pending, - } - } - } + fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll> { + self.poll(ctx) } } @@ -270,7 +276,7 @@ impl Stream for RingReceiver { /// /// // Open a channel to transmit the time elapsed since the beginning of the countdown. /// // We only need a buffer of size 1, since we're only interested in the current value. -/// let (mut tx, mut rx) = ring_channel(NonZeroUsize::try_from(1)?); +/// let (tx, rx) = ring_channel(NonZeroUsize::try_from(1)?); /// /// thread::spawn(move || { /// let countdown = Instant::now() + Duration::from_secs(10); @@ -472,7 +478,7 @@ mod tests { let (tx, _rx) = ring_channel(NonZeroUsize::try_from(cap)?); rt.block_on(iter(msgs).map(Ok).try_for_each_concurrent(None, |msg| { - let mut tx = tx.clone(); + let tx = tx.clone(); spawn_blocking(move || assert!(tx.send(msg).is_ok())) }))?; } @@ -486,7 +492,7 @@ mod tests { let (tx, _) = ring_channel(NonZeroUsize::try_from(cap)?); rt.block_on(iter(msgs).map(Ok).try_for_each_concurrent(None, |msg| { - let mut tx = tx.clone(); + let tx = tx.clone(); spawn_blocking(move || assert_eq!(tx.send(msg), Err(SendError::Disconnected(msg)))) }))?; } @@ -496,7 +502,7 @@ mod tests { #[strategy(1..=10usize)] cap: usize, #[any(size_range(#cap..=10).lift())] msgs: Vec, ) { - let (mut tx, rx) = ring_channel(NonZeroUsize::try_from(cap)?); + let (tx, rx) = ring_channel(NonZeroUsize::try_from(cap)?); let overwritten = msgs.len() - min(msgs.len(), cap); for &msg in &msgs[..cap] { @@ -528,7 +534,7 @@ mod tests { try_join_all( iter::repeat(rx) .take(msgs.len()) - .map(|mut rx| spawn_blocking(move || rx.try_recv().unwrap())), + .map(|rx| spawn_blocking(move || rx.try_recv().unwrap())), ) .await })?; @@ -552,7 +558,7 @@ mod tests { try_join_all( iter::repeat(rx) .take(msgs.len()) - .map(|mut rx| spawn_blocking(move || rx.try_recv().unwrap())), + .map(|rx| spawn_blocking(move || rx.try_recv().unwrap())), ) .await })?; @@ -573,7 +579,7 @@ mod tests { repeat(rx) .take(n) .map(Ok) - .try_for_each_concurrent(None, |mut rx| { + .try_for_each_concurrent(None, |rx| { spawn_blocking(move || assert_eq!(rx.try_recv(), Err(TryRecvError::Empty))) }), )?; @@ -591,7 +597,7 @@ mod tests { repeat(rx) .take(n) .map(Ok) - .try_for_each_concurrent(None, |mut rx| { + .try_for_each_concurrent(None, |rx| { spawn_blocking(move || { assert_eq!(rx.try_recv(), Err(TryRecvError::Disconnected)) }) @@ -615,7 +621,7 @@ mod tests { try_join_all( iter::repeat(rx) .take(msgs.len()) - .map(|mut rx| spawn_blocking(move || rx.recv().unwrap())), + .map(|rx| spawn_blocking(move || rx.recv().unwrap())), ) .await })?; @@ -640,7 +646,7 @@ mod tests { try_join_all( iter::repeat(rx) .take(msgs.len()) - .map(|mut rx| spawn_blocking(move || rx.recv().unwrap())), + .map(|rx| spawn_blocking(move || rx.recv().unwrap())), ) .await })?; @@ -662,7 +668,7 @@ mod tests { repeat(rx) .take(n) .map(Ok) - .try_for_each_concurrent(None, |mut rx| { + .try_for_each_concurrent(None, |rx| { spawn_blocking(move || assert_eq!(rx.recv(), Err(RecvError::Disconnected))) }), )?; @@ -685,7 +691,7 @@ mod tests { let consumer = repeat(rx) .take(n) .map(Ok) - .try_for_each_concurrent(None, |mut rx| { + .try_for_each_concurrent(None, |rx| { spawn_blocking(move || assert_eq!(rx.recv(), Err(RecvError::Disconnected))) }); @@ -704,14 +710,14 @@ mod tests { let producer = repeat(tx) .take(n) .map(Ok) - .try_for_each_concurrent(None, |mut tx| { + .try_for_each_concurrent(None, |tx| { spawn_blocking(move || assert!(tx.send(()).is_ok())) }); let consumer = repeat(rx) .take(n) .map(Ok) - .try_for_each_concurrent(None, |mut rx| { + .try_for_each_concurrent(None, |rx| { spawn_blocking(move || assert_eq!(rx.recv(), Ok(()))) }); @@ -727,7 +733,7 @@ mod tests { #[any(size_range(1..=10).lift())] msgs: Vec, ) { let rt = runtime::Builder::new_multi_thread().build()?; - let (mut tx, mut rx) = ring_channel(NonZeroUsize::try_from(cap)?); + let (mut tx, rx) = ring_channel(NonZeroUsize::try_from(cap)?); let overwritten = msgs.len() - min(msgs.len(), cap); assert_eq!(rt.block_on(iter(&msgs).map(Ok).forward(&mut tx)), Ok(())); @@ -773,7 +779,7 @@ mod tests { #[any(size_range(1..=10).lift())] msgs: Vec, ) { let rt = runtime::Builder::new_multi_thread().build()?; - let (mut tx, rx) = ring_channel(NonZeroUsize::try_from(cap)?); + let (tx, rx) = ring_channel(NonZeroUsize::try_from(cap)?); let overwritten = msgs.len() - min(msgs.len(), cap); for &msg in &msgs { @@ -809,7 +815,7 @@ mod tests { #[cfg(not(miri))] // will_wake sometimes returns false under miri #[proptest] fn receiver_withdraws_waker_if_channel_not_empty(#[strategy(1..=10usize)] cap: usize, msg: u8) { - let (mut tx, mut rx) = ring_channel(NonZeroUsize::try_from(cap)?); + let (tx, mut rx) = ring_channel(NonZeroUsize::try_from(cap)?); let waker = Arc::new(MockWaker).into(); let mut ctx = Context::from_waker(&waker);