From a7f19522aa4884c4083dceeb9c895c427d170425 Mon Sep 17 00:00:00 2001 From: Timo Glane Date: Sat, 14 Sep 2024 10:24:40 +0200 Subject: [PATCH] sync, coop: apply cooperative scheduling to `sync::watch::Receiver::changed` --- tokio/src/runtime/coop.rs | 41 ++++++++++++++++++++++++++++++++++++++- tokio/src/sync/watch.rs | 2 +- 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/tokio/src/runtime/coop.rs b/tokio/src/runtime/coop.rs index aaca8b6baa2..0ef37537362 100644 --- a/tokio/src/runtime/coop.rs +++ b/tokio/src/runtime/coop.rs @@ -135,8 +135,11 @@ cfg_rt! { } cfg_coop! { + use pin_project_lite::pin_project; use std::cell::Cell; - use std::task::{Context, Poll}; + use std::future::Future; + use std::pin::Pin; + use std::task::{ready, Context, Poll}; #[must_use] pub(crate) struct RestoreOnPending(Cell); @@ -240,6 +243,42 @@ cfg_coop! { self.0.is_none() } } + + pin_project! { + /// A future type that calls `poll_proceed` before polling the inner future to check if the + /// inner future has exceeded its budget. If the inner future resolves, this will + /// automatically call `RestoreOnPending::made_progress` before resolving this future with + /// the result of the inner one. If polling the inner future is pending, polling this future + /// type will also return a `Poll::Pending`. + #[must_use = "futures do nothing unless polled"] + pub(crate) struct BudgetConstraintFuture { + #[pin] + pub(crate) fut: F, + } + } + + impl Future for BudgetConstraintFuture { + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let coop = ready!(poll_proceed(cx)); + let me = self.project(); + if let Poll::Ready(ret) = me.fut.poll(cx) { + coop.made_progress(); + Poll::Ready(ret) + } else { + Poll::Pending + } + } + } + + /// Run a future with a budget constraint for cooperative scheduling. + /// If the future exceeds its budget while being polled, control is yielded back to the + /// runtime. + #[inline] + pub(crate) async fn budget_constraint(fut: F) -> F::Output { + BudgetConstraintFuture { fut }.await + } } #[cfg(all(test, not(loom)))] diff --git a/tokio/src/sync/watch.rs b/tokio/src/sync/watch.rs index 490b9e4df88..eccc8c5e8ea 100644 --- a/tokio/src/sync/watch.rs +++ b/tokio/src/sync/watch.rs @@ -743,7 +743,7 @@ impl Receiver { /// } /// ``` pub async fn changed(&mut self) -> Result<(), error::RecvError> { - changed_impl(&self.shared, &mut self.version).await + crate::runtime::coop::budget_constraint(changed_impl(&self.shared, &mut self.version)).await } /// Waits for a value that satisfies the provided condition.