Skip to content

Commit

Permalink
relax Sender and Receiver to not require &mut self
Browse files Browse the repository at this point in the history
  • Loading branch information
brunocodutra committed Oct 10, 2023
1 parent 119eb61 commit 6068058
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 60 deletions.
4 changes: 2 additions & 2 deletions benches/throughput.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,12 @@ impl<T: 'static + Send + Default> Routine<T> for Async {
struct Block;

impl<T: 'static + Send + Default> Routine<T> for Block {
fn produce(mut tx: RingSender<T>, limit: usize) -> JoinHandle<usize> {
fn produce(tx: RingSender<T>, limit: usize) -> JoinHandle<usize> {
let producer = iter::from_fn(move || tx.send(T::default()).ok());
task::spawn_blocking(move || producer.take(limit).count())
}

fn consume(mut rx: RingReceiver<T>, limit: usize) -> JoinHandle<usize> {
fn consume(rx: RingReceiver<T>, limit: usize) -> JoinHandle<usize> {
let consumer = iter::from_fn(move || loop {
match rx.try_recv() {
Ok(m) => return Some(m),
Expand Down
122 changes: 64 additions & 58 deletions src/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ impl<T> RingSender<T> {
/// * If the internal ring buffer is full, the oldest pending message is overwritten
/// and returned as `Ok(Some(_))`, otherwise `Ok(None)` is returned.
/// * If the channel is disconnected, [`SendError::Disconnected`] is returned.
pub fn send(&mut self, message: T) -> Result<Option<T>, SendError<T>> {
pub fn send(&self, message: T) -> Result<Option<T>, SendError<T>> {
if self.handle.receivers.load(Ordering::Acquire) > 0 {
let overwritten = self.handle.buffer.push(message);

Expand Down Expand Up @@ -154,27 +154,13 @@ impl<T> RingReceiver<T> {
}
}

/// Receives a message through the channel (requires [feature] `"futures_api"`).
///
/// * If the internal ring buffer isn't empty, the oldest pending message is returned.
/// * If the internal ring buffer is empty, the call blocks until a message is sent
/// or the channel disconnects.
/// * If the channel is disconnected and the internal ring buffer is empty,
/// [`RecvError::Disconnected`] is returned.
///
/// [feature]: index.html#optional-features
#[cfg(feature = "futures_api")]
pub fn recv(&mut self) -> Result<T, RecvError> {
futures::executor::block_on(futures::StreamExt::next(self)).ok_or(RecvError::Disconnected)
}

/// Receives a message through the channel without blocking.
///
/// * If the internal ring buffer isn't empty, the oldest pending message is returned.
/// * If the internal ring buffer is empty, [`TryRecvError::Empty`] is returned.
/// * If the channel is disconnected and the internal ring buffer is empty,
/// [`TryRecvError::Disconnected`] is returned.
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
pub fn try_recv(&self) -> Result<T, TryRecvError> {
// We must check whether the channel is connected using acquire ordering before we look at
// the buffer, in order to ensure that the loads associated with popping from the buffer
// happen after the stores associated with a push into the buffer that may have happened
Expand All @@ -185,6 +171,49 @@ impl<T> RingReceiver<T> {
self.handle.buffer.pop().ok_or(TryRecvError::Disconnected)
}
}

/// Receives a message through the channel (requires [feature] `"futures_api"`).
///
/// * If the internal ring buffer isn't empty, the oldest pending message is returned.
/// * If the internal ring buffer is empty, the call blocks until a message is sent
/// or the channel disconnects.
/// * If the channel is disconnected and the internal ring buffer is empty,
/// [`RecvError::Disconnected`] is returned.
///
/// [feature]: index.html#optional-features
#[cfg(feature = "futures_api")]
pub fn recv(&self) -> Result<T, RecvError> {
futures::executor::block_on(futures::future::poll_fn(|ctx| self.poll(ctx)))
.ok_or(RecvError::Disconnected)
}

#[cfg(feature = "futures_api")]
fn poll(&self, ctx: &mut Context<'_>) -> Poll<Option<T>> {
match self.try_recv() {
result @ Ok(_) | result @ Err(TryRecvError::Disconnected) => {
self.handle.waitlist.remove(self.slot);
Poll::Ready(result.ok())
}

Err(TryRecvError::Empty) => {
self.handle.waitlist.insert(self.slot, ctx.waker().clone());

// A full memory barrier is necessary to ensure that storing the waker
// happens before attempting to retrieve a message from the buffer.
fence(Ordering::SeqCst);

// Look at the buffer again in case a new message has been sent in the meantime.
match self.try_recv() {
result @ Ok(_) | result @ Err(TryRecvError::Disconnected) => {
self.handle.waitlist.remove(self.slot);
Poll::Ready(result.ok())
}

Err(TryRecvError::Empty) => Poll::Pending,
}
}
}
}
}

impl<T> Clone for RingReceiver<T> {
Expand Down Expand Up @@ -221,31 +250,8 @@ impl<T> Drop for RingReceiver<T> {
impl<T> Stream for RingReceiver<T> {
type Item = T;

fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.try_recv() {
result @ Ok(_) | result @ Err(TryRecvError::Disconnected) => {
self.handle.waitlist.remove(self.slot);
Poll::Ready(result.ok())
}

Err(TryRecvError::Empty) => {
self.handle.waitlist.insert(self.slot, ctx.waker().clone());

// A full memory barrier is necessary to ensure that storing the waker
// happens before attempting to retrieve a message from the buffer.
fence(Ordering::SeqCst);

// Look at the buffer again in case a new message has been sent in the meantime.
match self.try_recv() {
result @ Ok(_) | result @ Err(TryRecvError::Disconnected) => {
self.handle.waitlist.remove(self.slot);
Poll::Ready(result.ok())
}

Err(TryRecvError::Empty) => Poll::Pending,
}
}
}
fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.poll(ctx)
}
}

Expand All @@ -270,7 +276,7 @@ impl<T> Stream for RingReceiver<T> {
///
/// // Open a channel to transmit the time elapsed since the beginning of the countdown.
/// // We only need a buffer of size 1, since we're only interested in the current value.
/// let (mut tx, mut rx) = ring_channel(NonZeroUsize::try_from(1)?);
/// let (tx, rx) = ring_channel(NonZeroUsize::try_from(1)?);
///
/// thread::spawn(move || {
/// let countdown = Instant::now() + Duration::from_secs(10);
Expand Down Expand Up @@ -472,7 +478,7 @@ mod tests {
let (tx, _rx) = ring_channel(NonZeroUsize::try_from(cap)?);

rt.block_on(iter(msgs).map(Ok).try_for_each_concurrent(None, |msg| {
let mut tx = tx.clone();
let tx = tx.clone();
spawn_blocking(move || assert!(tx.send(msg).is_ok()))
}))?;
}
Expand All @@ -486,7 +492,7 @@ mod tests {
let (tx, _) = ring_channel(NonZeroUsize::try_from(cap)?);

rt.block_on(iter(msgs).map(Ok).try_for_each_concurrent(None, |msg| {
let mut tx = tx.clone();
let tx = tx.clone();
spawn_blocking(move || assert_eq!(tx.send(msg), Err(SendError::Disconnected(msg))))
}))?;
}
Expand All @@ -496,7 +502,7 @@ mod tests {
#[strategy(1..=10usize)] cap: usize,
#[any(size_range(#cap..=10).lift())] msgs: Vec<u8>,
) {
let (mut tx, rx) = ring_channel(NonZeroUsize::try_from(cap)?);
let (tx, rx) = ring_channel(NonZeroUsize::try_from(cap)?);
let overwritten = msgs.len() - min(msgs.len(), cap);

for &msg in &msgs[..cap] {
Expand Down Expand Up @@ -528,7 +534,7 @@ mod tests {
try_join_all(
iter::repeat(rx)
.take(msgs.len())
.map(|mut rx| spawn_blocking(move || rx.try_recv().unwrap())),
.map(|rx| spawn_blocking(move || rx.try_recv().unwrap())),
)
.await
})?;
Expand All @@ -552,7 +558,7 @@ mod tests {
try_join_all(
iter::repeat(rx)
.take(msgs.len())
.map(|mut rx| spawn_blocking(move || rx.try_recv().unwrap())),
.map(|rx| spawn_blocking(move || rx.try_recv().unwrap())),
)
.await
})?;
Expand All @@ -573,7 +579,7 @@ mod tests {
repeat(rx)
.take(n)
.map(Ok)
.try_for_each_concurrent(None, |mut rx| {
.try_for_each_concurrent(None, |rx| {
spawn_blocking(move || assert_eq!(rx.try_recv(), Err(TryRecvError::Empty)))
}),
)?;
Expand All @@ -591,7 +597,7 @@ mod tests {
repeat(rx)
.take(n)
.map(Ok)
.try_for_each_concurrent(None, |mut rx| {
.try_for_each_concurrent(None, |rx| {
spawn_blocking(move || {
assert_eq!(rx.try_recv(), Err(TryRecvError::Disconnected))
})
Expand All @@ -615,7 +621,7 @@ mod tests {
try_join_all(
iter::repeat(rx)
.take(msgs.len())
.map(|mut rx| spawn_blocking(move || rx.recv().unwrap())),
.map(|rx| spawn_blocking(move || rx.recv().unwrap())),
)
.await
})?;
Expand All @@ -640,7 +646,7 @@ mod tests {
try_join_all(
iter::repeat(rx)
.take(msgs.len())
.map(|mut rx| spawn_blocking(move || rx.recv().unwrap())),
.map(|rx| spawn_blocking(move || rx.recv().unwrap())),
)
.await
})?;
Expand All @@ -662,7 +668,7 @@ mod tests {
repeat(rx)
.take(n)
.map(Ok)
.try_for_each_concurrent(None, |mut rx| {
.try_for_each_concurrent(None, |rx| {
spawn_blocking(move || assert_eq!(rx.recv(), Err(RecvError::Disconnected)))
}),
)?;
Expand All @@ -685,7 +691,7 @@ mod tests {
let consumer = repeat(rx)
.take(n)
.map(Ok)
.try_for_each_concurrent(None, |mut rx| {
.try_for_each_concurrent(None, |rx| {
spawn_blocking(move || assert_eq!(rx.recv(), Err(RecvError::Disconnected)))
});

Expand All @@ -704,14 +710,14 @@ mod tests {
let producer = repeat(tx)
.take(n)
.map(Ok)
.try_for_each_concurrent(None, |mut tx| {
.try_for_each_concurrent(None, |tx| {
spawn_blocking(move || assert!(tx.send(()).is_ok()))
});

let consumer = repeat(rx)
.take(n)
.map(Ok)
.try_for_each_concurrent(None, |mut rx| {
.try_for_each_concurrent(None, |rx| {
spawn_blocking(move || assert_eq!(rx.recv(), Ok(())))
});

Expand All @@ -727,7 +733,7 @@ mod tests {
#[any(size_range(1..=10).lift())] msgs: Vec<u8>,
) {
let rt = runtime::Builder::new_multi_thread().build()?;
let (mut tx, mut rx) = ring_channel(NonZeroUsize::try_from(cap)?);
let (mut tx, rx) = ring_channel(NonZeroUsize::try_from(cap)?);
let overwritten = msgs.len() - min(msgs.len(), cap);

assert_eq!(rt.block_on(iter(&msgs).map(Ok).forward(&mut tx)), Ok(()));
Expand Down Expand Up @@ -773,7 +779,7 @@ mod tests {
#[any(size_range(1..=10).lift())] msgs: Vec<u8>,
) {
let rt = runtime::Builder::new_multi_thread().build()?;
let (mut tx, rx) = ring_channel(NonZeroUsize::try_from(cap)?);
let (tx, rx) = ring_channel(NonZeroUsize::try_from(cap)?);
let overwritten = msgs.len() - min(msgs.len(), cap);

for &msg in &msgs {
Expand Down Expand Up @@ -809,7 +815,7 @@ mod tests {
#[cfg(not(miri))] // will_wake sometimes returns false under miri
#[proptest]
fn receiver_withdraws_waker_if_channel_not_empty(#[strategy(1..=10usize)] cap: usize, msg: u8) {
let (mut tx, mut rx) = ring_channel(NonZeroUsize::try_from(cap)?);
let (tx, mut rx) = ring_channel(NonZeroUsize::try_from(cap)?);

let waker = Arc::new(MockWaker).into();
let mut ctx = Context::from_waker(&waker);
Expand Down

0 comments on commit 6068058

Please sign in to comment.