diff --git a/tokio/tests/rt_poll_callbacks.rs b/tokio/tests/rt_poll_callbacks.rs index fed5b6b00fb..a437b7ab87a 100644 --- a/tokio/tests/rt_poll_callbacks.rs +++ b/tokio/tests/rt_poll_callbacks.rs @@ -2,10 +2,48 @@ #[cfg(tokio_unstable)] mod unstable { - use std::sync::{atomic::AtomicUsize, Arc, Mutex}; + use std::{ + future::Future, + sync::{atomic::AtomicUsize, Arc, Mutex}, + }; use tokio::task::yield_now; + pin_project_lite::pin_project! { + struct PollCounter { + #[pin] + inner: F, + counter: Arc, + } + } + + impl Future for PollCounter { + type Output = F::Output; + + fn poll( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + let this = self.project(); + this.counter + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + this.inner.poll(cx) + } + } + + impl PollCounter<()> { + fn new(future: F) -> (PollCounter, Arc) { + let counter = Arc::new(AtomicUsize::new(0)); + ( + PollCounter { + inner: future, + counter: counter.clone(), + }, + counter, + ) + } + } + #[cfg(not(target_os = "wasi"))] #[test] fn callbacks_fire_multi_thread() { @@ -39,11 +77,12 @@ mod unstable { }) .build() .unwrap(); - let task = rt.spawn(async { + let (task, count) = PollCounter::new(async { yield_now().await; yield_now().await; yield_now().await; }); + let task = rt.spawn(task); let spawned_task_id = task.id(); @@ -57,8 +96,15 @@ mod unstable { after_task_poll_callback_task_id.lock().unwrap().unwrap(), spawned_task_id ); - assert_eq!(poll_start.load(std::sync::atomic::Ordering::SeqCst), 4); - assert_eq!(poll_stop.load(std::sync::atomic::Ordering::SeqCst), 4); + let actual_count = count.load(std::sync::atomic::Ordering::SeqCst); + assert_eq!( + poll_start.load(std::sync::atomic::Ordering::SeqCst), + actual_count + ); + assert_eq!( + poll_stop.load(std::sync::atomic::Ordering::SeqCst), + actual_count + ); } #[test]