diff --git a/tokio/src/runtime/builder.rs b/tokio/src/runtime/builder.rs index b5bf35d69b4..4d35120b1f9 100644 --- a/tokio/src/runtime/builder.rs +++ b/tokio/src/runtime/builder.rs @@ -1,13 +1,16 @@ #![cfg_attr(loom, allow(unused_imports))] use crate::runtime::handle::Handle; -#[cfg(tokio_unstable)] -use crate::runtime::TaskMeta; use crate::runtime::{blocking, driver, Callback, HistogramBuilder, Runtime, TaskCallback}; +#[cfg(tokio_unstable)] +use crate::runtime::{LocalOptions, LocalRuntime, TaskMeta}; use crate::util::rand::{RngSeed, RngSeedGenerator}; +use crate::runtime::blocking::BlockingPool; +use crate::runtime::scheduler::CurrentThread; use std::fmt; use std::io; +use std::thread::ThreadId; use std::time::Duration; /// Builds Tokio Runtime with custom configuration values. @@ -800,6 +803,37 @@ impl Builder { } } + /// Creates the configured `LocalRuntime`. + /// + /// The returned `LocalRuntime` instance is ready to spawn tasks. + /// + /// # Panics + /// This will panic if `current_thread` is not the selected runtime flavor. + /// All other runtime flavors are unsupported by [`LocalRuntime`]. + /// + /// [`LocalRuntime`]: [crate::runtime::LocalRuntime] + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Builder; + /// + /// let rt = Builder::new_current_thread().build_local(&mut Default::default()).unwrap(); + /// + /// rt.block_on(async { + /// println!("Hello from the Tokio runtime"); + /// }); + /// ``` + #[allow(unused_variables, unreachable_patterns)] + #[cfg(tokio_unstable)] + #[cfg_attr(docsrs, doc(cfg(tokio_unstable)))] + pub fn build_local(&mut self, options: &LocalOptions) -> io::Result { + match &self.kind { + Kind::CurrentThread => self.build_current_thread_local_runtime(), + _ => panic!("Only current_thread is supported when building a local runtime"), + } + } + fn get_cfg(&self, workers: usize) -> driver::Cfg { driver::Cfg { enable_pause_time: match self.kind { @@ -1191,8 +1225,40 @@ impl Builder { } fn build_current_thread_runtime(&mut self) -> io::Result { - use crate::runtime::scheduler::{self, CurrentThread}; - use crate::runtime::{runtime::Scheduler, Config}; + use crate::runtime::runtime::Scheduler; + + let (scheduler, handle, blocking_pool) = + self.build_current_thread_runtime_components(None)?; + + Ok(Runtime::from_parts( + Scheduler::CurrentThread(scheduler), + handle, + blocking_pool, + )) + } + + #[cfg(tokio_unstable)] + fn build_current_thread_local_runtime(&mut self) -> io::Result { + use crate::runtime::local_runtime::LocalRuntimeScheduler; + + let tid = std::thread::current().id(); + + let (scheduler, handle, blocking_pool) = + self.build_current_thread_runtime_components(Some(tid))?; + + Ok(LocalRuntime::from_parts( + LocalRuntimeScheduler::CurrentThread(scheduler), + handle, + blocking_pool, + )) + } + + fn build_current_thread_runtime_components( + &mut self, + local_tid: Option, + ) -> io::Result<(CurrentThread, Handle, BlockingPool)> { + use crate::runtime::scheduler; + use crate::runtime::Config; let (driver, driver_handle) = driver::Driver::new(self.get_cfg(1))?; @@ -1227,17 +1293,14 @@ impl Builder { seed_generator: seed_generator_1, metrics_poll_count_histogram: self.metrics_poll_count_histogram_builder(), }, + local_tid, ); let handle = Handle { inner: scheduler::Handle::CurrentThread(handle), }; - Ok(Runtime::from_parts( - Scheduler::CurrentThread(scheduler), - handle, - blocking_pool, - )) + Ok((scheduler, handle, blocking_pool)) } fn metrics_poll_count_histogram_builder(&self) -> Option { diff --git a/tokio/src/runtime/handle.rs b/tokio/src/runtime/handle.rs index 9026e8773a0..752640d75bd 100644 --- a/tokio/src/runtime/handle.rs +++ b/tokio/src/runtime/handle.rs @@ -250,8 +250,8 @@ impl Handle { /// # Panics /// /// This function panics if the provided future panics, if called within an - /// asynchronous execution context, or if a timer future is executed on a - /// runtime that has been shut down. + /// asynchronous execution context, or if a timer future is executed on a runtime that has been + /// shut down. /// /// # Examples /// @@ -348,6 +348,31 @@ impl Handle { self.inner.spawn(future, id) } + #[track_caller] + #[allow(dead_code)] + pub(crate) unsafe fn spawn_local_named( + &self, + future: F, + _meta: SpawnMeta<'_>, + ) -> JoinHandle + where + F: Future + 'static, + F::Output: 'static, + { + let id = crate::runtime::task::Id::next(); + #[cfg(all( + tokio_unstable, + tokio_taskdump, + feature = "rt", + target_os = "linux", + any(target_arch = "aarch64", target_arch = "x86", target_arch = "x86_64") + ))] + let future = super::task::trace::Trace::root(future); + #[cfg(all(tokio_unstable, feature = "tracing"))] + let future = crate::util::trace::task(future, "task", _meta, id.as_u64()); + self.inner.spawn_local(future, id) + } + /// Returns the flavor of the current `Runtime`. /// /// # Examples diff --git a/tokio/src/runtime/local_runtime/mod.rs b/tokio/src/runtime/local_runtime/mod.rs new file mode 100644 index 00000000000..1ea7693f292 --- /dev/null +++ b/tokio/src/runtime/local_runtime/mod.rs @@ -0,0 +1,7 @@ +mod runtime; + +mod options; + +pub use options::LocalOptions; +pub use runtime::LocalRuntime; +pub(super) use runtime::LocalRuntimeScheduler; diff --git a/tokio/src/runtime/local_runtime/options.rs b/tokio/src/runtime/local_runtime/options.rs new file mode 100644 index 00000000000..ed25d9ccd44 --- /dev/null +++ b/tokio/src/runtime/local_runtime/options.rs @@ -0,0 +1,12 @@ +use std::marker::PhantomData; + +/// `LocalRuntime`-only config options +/// +/// Currently, there are no such options, but in the future, things like `!Send + !Sync` hooks may +/// be added. +#[derive(Default, Debug)] +#[non_exhaustive] +pub struct LocalOptions { + /// Marker used to make this !Send and !Sync. + _phantom: PhantomData<*mut u8>, +} diff --git a/tokio/src/runtime/local_runtime/runtime.rs b/tokio/src/runtime/local_runtime/runtime.rs new file mode 100644 index 00000000000..0f2b944e4eb --- /dev/null +++ b/tokio/src/runtime/local_runtime/runtime.rs @@ -0,0 +1,393 @@ +#![allow(irrefutable_let_patterns)] + +use crate::runtime::blocking::BlockingPool; +use crate::runtime::scheduler::CurrentThread; +use crate::runtime::{context, Builder, EnterGuard, Handle, BOX_FUTURE_THRESHOLD}; +use crate::task::JoinHandle; + +use crate::util::trace::SpawnMeta; +use std::future::Future; +use std::marker::PhantomData; +use std::mem; +use std::time::Duration; + +/// A local Tokio runtime. +/// +/// This runtime is capable of driving tasks which are not `Send + Sync` without the use of a +/// `LocalSet`, and thus supports `spawn_local` without the need for a `LocalSet` context. +/// +/// This runtime cannot be moved between threads or driven from different threads. +/// +/// This runtime is incompatible with `LocalSet`. You should not attempt to drive a `LocalSet` within a +/// `LocalRuntime`. +/// +/// Currently, this runtime supports one flavor, which is internally identical to `current_thread`, +/// save for the aforementioned differences related to `spawn_local`. +/// +/// For more general information on how to use runtimes, see the [module] docs. +/// +/// [runtime]: crate::runtime::Runtime +/// [module]: crate::runtime +#[derive(Debug)] +#[cfg_attr(docsrs, doc(cfg(tokio_unstable)))] +pub struct LocalRuntime { + /// Task scheduler + scheduler: LocalRuntimeScheduler, + + /// Handle to runtime, also contains driver handles + handle: Handle, + + /// Blocking pool handle, used to signal shutdown + blocking_pool: BlockingPool, + + /// Marker used to make this !Send and !Sync. + _phantom: PhantomData<*mut u8>, +} + +/// The runtime scheduler is always a `current_thread` scheduler right now. +#[derive(Debug)] +pub(crate) enum LocalRuntimeScheduler { + /// Execute all tasks on the current-thread. + CurrentThread(CurrentThread), +} + +impl LocalRuntime { + pub(crate) fn from_parts( + scheduler: LocalRuntimeScheduler, + handle: Handle, + blocking_pool: BlockingPool, + ) -> LocalRuntime { + LocalRuntime { + scheduler, + handle, + blocking_pool, + _phantom: Default::default(), + } + } + + /// Creates a new local runtime instance with default configuration values. + /// + /// This results in the scheduler, I/O driver, and time driver being + /// initialized. + /// + /// When a more complex configuration is necessary, the [runtime builder] may be used. + /// + /// See [module level][mod] documentation for more details. + /// + /// # Examples + /// + /// Creating a new `LocalRuntime` with default configuration values. + /// + /// ``` + /// use tokio::runtime::LocalRuntime; + /// + /// let rt = LocalRuntime::new() + /// .unwrap(); + /// + /// // Use the runtime... + /// ``` + /// + /// [mod]: crate::runtime + /// [runtime builder]: crate::runtime::Builder + pub fn new() -> std::io::Result { + Builder::new_current_thread() + .enable_all() + .build_local(&Default::default()) + } + + /// Returns a handle to the runtime's spawner. + /// + /// The returned handle can be used to spawn tasks that run on this runtime, and can + /// be cloned to allow moving the `Handle` to other threads. + /// + /// As the handle can be sent to other threads, it can only be used to spawn tasks that are `Send`. + /// + /// Calling [`Handle::block_on`] on a handle to a `LocalRuntime` is error-prone. + /// Refer to the documentation of [`Handle::block_on`] for more. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::LocalRuntime; + /// + /// let rt = LocalRuntime::new() + /// .unwrap(); + /// + /// let handle = rt.handle(); + /// + /// // Use the handle... + /// ``` + pub fn handle(&self) -> &Handle { + &self.handle + } + + /// Spawns a task on the runtime. + /// + /// This is analogous to the [`spawn`] method on the standard [`Runtime`], but works even if the task is not thread-safe. + /// + /// [`spawn`]: crate::runtime::Runtime::spawn + /// [`Runtime`]: crate::runtime::Runtime + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::LocalRuntime; + /// + /// # fn dox() { + /// // Create the runtime + /// let rt = LocalRuntime::new().unwrap(); + /// + /// // Spawn a future onto the runtime + /// rt.spawn_local(async { + /// println!("now running on a worker thread"); + /// }); + /// # } + /// ``` + #[track_caller] + pub fn spawn_local(&self, future: F) -> JoinHandle + where + F: Future + 'static, + F::Output: 'static, + { + let fut_size = std::mem::size_of::(); + let meta = SpawnMeta::new_unnamed(fut_size); + + // safety: spawn_local can only be called from `LocalRuntime`, which this is + unsafe { + if std::mem::size_of::() > BOX_FUTURE_THRESHOLD { + self.handle.spawn_local_named(Box::pin(future), meta) + } else { + self.handle.spawn_local_named(future, meta) + } + } + } + + /// Runs the provided function on a thread from a dedicated blocking thread pool. + /// + /// This function _will_ be run on another thread. + /// + /// See the documentation in the non-local runtime for more information. + /// + /// [Runtime]: crate::runtime::Runtime::spawn_blocking + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::LocalRuntime; + /// + /// # fn dox() { + /// // Create the runtime + /// let rt = LocalRuntime::new().unwrap(); + /// + /// // Spawn a blocking function onto the runtime + /// rt.spawn_blocking(|| { + /// println!("now running on a worker thread"); + /// }); + /// # } + /// ``` + #[track_caller] + pub fn spawn_blocking(&self, func: F) -> JoinHandle + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + self.handle.spawn_blocking(func) + } + + /// Runs a future to completion on the Tokio runtime. This is the + /// runtime's entry point. + /// + /// See the documentation for [the equivalent method on Runtime] for more information. + /// + /// [Runtime]: crate::runtime::Runtime::block_on + /// + /// # Examples + /// + /// ```no_run + /// use tokio::runtime::LocalRuntime; + /// + /// // Create the runtime + /// let rt = LocalRuntime::new().unwrap(); + /// + /// // Execute the future, blocking the current thread until completion + /// rt.block_on(async { + /// println!("hello"); + /// }); + /// ``` + #[track_caller] + pub fn block_on(&self, future: F) -> F::Output { + let fut_size = mem::size_of::(); + let meta = SpawnMeta::new_unnamed(fut_size); + + if std::mem::size_of::() > BOX_FUTURE_THRESHOLD { + self.block_on_inner(Box::pin(future), meta) + } else { + self.block_on_inner(future, meta) + } + } + + #[track_caller] + fn block_on_inner(&self, future: F, _meta: SpawnMeta<'_>) -> F::Output { + #[cfg(all( + tokio_unstable, + tokio_taskdump, + feature = "rt", + target_os = "linux", + any(target_arch = "aarch64", target_arch = "x86", target_arch = "x86_64") + ))] + let future = crate::runtime::task::trace::Trace::root(future); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let future = crate::util::trace::task( + future, + "block_on", + _meta, + crate::runtime::task::Id::next().as_u64(), + ); + + let _enter = self.enter(); + + if let LocalRuntimeScheduler::CurrentThread(exec) = &self.scheduler { + exec.block_on(&self.handle.inner, future) + } else { + unreachable!("LocalRuntime only supports current_thread") + } + } + + /// Enters the runtime context. + /// + /// This allows you to construct types that must have an executor + /// available on creation such as [`Sleep`] or [`TcpStream`]. It will + /// also allow you to call methods such as [`tokio::spawn`]. + /// + /// If this is a handle to a [`LocalRuntime`], and this function is being invoked from the same + /// thread that the runtime was created on, you will also be able to call + /// [`tokio::task::spawn_local`]. + /// + /// [`Sleep`]: struct@crate::time::Sleep + /// [`TcpStream`]: struct@crate::net::TcpStream + /// [`tokio::spawn`]: fn@crate::spawn + /// [`LocalRuntime`]: struct@crate::runtime::LocalRuntime + /// [`tokio::task::spawn_local`]: fn@crate::task::spawn_local + /// + /// # Example + /// + /// ``` + /// use tokio::runtime::LocalRuntime; + /// use tokio::task::JoinHandle; + /// + /// fn function_that_spawns(msg: String) -> JoinHandle<()> { + /// // Had we not used `rt.enter` below, this would panic. + /// tokio::spawn(async move { + /// println!("{}", msg); + /// }) + /// } + /// + /// fn main() { + /// let rt = LocalRuntime::new().unwrap(); + /// + /// let s = "Hello World!".to_string(); + /// + /// // By entering the context, we tie `tokio::spawn` to this executor. + /// let _guard = rt.enter(); + /// let handle = function_that_spawns(s); + /// + /// // Wait for the task before we end the test. + /// rt.block_on(handle).unwrap(); + /// } + /// ``` + pub fn enter(&self) -> EnterGuard<'_> { + self.handle.enter() + } + + /// Shuts down the runtime, waiting for at most `duration` for all spawned + /// work to stop. + /// + /// Note that `spawn_blocking` tasks, and only `spawn_blocking` tasks, can get left behind if + /// the timeout expires. + /// + /// See the [struct level documentation](LocalRuntime#shutdown) for more details. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::LocalRuntime; + /// use tokio::task; + /// + /// use std::thread; + /// use std::time::Duration; + /// + /// fn main() { + /// let runtime = LocalRuntime::new().unwrap(); + /// + /// runtime.block_on(async move { + /// task::spawn_blocking(move || { + /// thread::sleep(Duration::from_secs(10_000)); + /// }); + /// }); + /// + /// runtime.shutdown_timeout(Duration::from_millis(100)); + /// } + /// ``` + pub fn shutdown_timeout(mut self, duration: Duration) { + // Wakeup and shutdown all the worker threads + self.handle.inner.shutdown(); + self.blocking_pool.shutdown(Some(duration)); + } + + /// Shuts down the runtime, without waiting for any spawned work to stop. + /// + /// This can be useful if you want to drop a runtime from within another runtime. + /// Normally, dropping a runtime will block indefinitely for spawned blocking tasks + /// to complete, which would normally not be permitted within an asynchronous context. + /// By calling `shutdown_background()`, you can drop the runtime from such a context. + /// + /// Note however, that because we do not wait for any blocking tasks to complete, this + /// may result in a resource leak (in that any blocking tasks are still running until they + /// return. No other tasks will leak. + /// + /// See the [struct level documentation](LocalRuntime#shutdown) for more details. + /// + /// This function is equivalent to calling `shutdown_timeout(Duration::from_nanos(0))`. + /// + /// ``` + /// use tokio::runtime::LocalRuntime; + /// + /// fn main() { + /// let runtime = LocalRuntime::new().unwrap(); + /// + /// runtime.block_on(async move { + /// let inner_runtime = LocalRuntime::new().unwrap(); + /// // ... + /// inner_runtime.shutdown_background(); + /// }); + /// } + /// ``` + pub fn shutdown_background(self) { + self.shutdown_timeout(Duration::from_nanos(0)); + } + + /// Returns a view that lets you get information about how the runtime + /// is performing. + pub fn metrics(&self) -> crate::runtime::RuntimeMetrics { + self.handle.metrics() + } +} + +#[allow(clippy::single_match)] // there are comments in the error branch, so we don't want if-let +impl Drop for LocalRuntime { + fn drop(&mut self) { + if let LocalRuntimeScheduler::CurrentThread(current_thread) = &mut self.scheduler { + // This ensures that tasks spawned on the current-thread + // runtime are dropped inside the runtime's context. + let _guard = context::try_set_current(&self.handle.inner); + current_thread.shutdown(&self.handle.inner); + } else { + unreachable!("LocalRuntime only supports current-thread") + } + } +} + +impl std::panic::UnwindSafe for LocalRuntime {} + +impl std::panic::RefUnwindSafe for LocalRuntime {} diff --git a/tokio/src/runtime/mod.rs b/tokio/src/runtime/mod.rs index 3f2467f6dbc..c8efbe2f1cd 100644 --- a/tokio/src/runtime/mod.rs +++ b/tokio/src/runtime/mod.rs @@ -372,6 +372,9 @@ cfg_rt! { pub use self::builder::UnhandledPanic; pub use crate::util::rand::RngSeed; + + mod local_runtime; + pub use local_runtime::{LocalRuntime, LocalOptions}; } cfg_taskdump! { diff --git a/tokio/src/runtime/scheduler/current_thread/mod.rs b/tokio/src/runtime/scheduler/current_thread/mod.rs index 9959dff8e46..c66635e7bd6 100644 --- a/tokio/src/runtime/scheduler/current_thread/mod.rs +++ b/tokio/src/runtime/scheduler/current_thread/mod.rs @@ -18,6 +18,7 @@ use std::future::{poll_fn, Future}; use std::sync::atomic::Ordering::{AcqRel, Release}; use std::task::Poll::{Pending, Ready}; use std::task::Waker; +use std::thread::ThreadId; use std::time::Duration; use std::{fmt, thread}; @@ -47,6 +48,9 @@ pub(crate) struct Handle { /// User-supplied hooks to invoke for things pub(crate) task_hooks: TaskHooks, + + /// If this is a `LocalRuntime`, flags the owning thread ID. + pub(crate) local_tid: Option, } /// Data required for executing the scheduler. The struct is passed around to @@ -127,6 +131,7 @@ impl CurrentThread { blocking_spawner: blocking::Spawner, seed_generator: RngSeedGenerator, config: Config, + local_tid: Option, ) -> (CurrentThread, Arc) { let worker_metrics = WorkerMetrics::from_config(&config); worker_metrics.set_thread_id(thread::current().id()); @@ -152,6 +157,7 @@ impl CurrentThread { driver: driver_handle, blocking_spawner, seed_generator, + local_tid, }); let core = AtomicCell::new(Some(Box::new(Core { @@ -458,6 +464,35 @@ impl Handle { handle } + /// Spawn a task which isn't safe to send across thread boundaries onto the runtime. + /// + /// # Safety + /// This should only be used when this is a `LocalRuntime` or in another case where the runtime + /// provably cannot be driven from or moved to different threads from the one on which the task + /// is spawned. + pub(crate) unsafe fn spawn_local( + me: &Arc, + future: F, + id: crate::runtime::task::Id, + ) -> JoinHandle + where + F: crate::future::Future + 'static, + F::Output: 'static, + { + let (handle, notified) = me.shared.owned.bind_local(future, me.clone(), id); + + me.task_hooks.spawn(&TaskMeta { + id, + _phantom: Default::default(), + }); + + if let Some(notified) = notified { + me.schedule(notified); + } + + handle + } + /// Capture a snapshot of this runtime's state. #[cfg(all( tokio_unstable, diff --git a/tokio/src/runtime/scheduler/mod.rs b/tokio/src/runtime/scheduler/mod.rs index ada8efbad63..e0a1b20b5bc 100644 --- a/tokio/src/runtime/scheduler/mod.rs +++ b/tokio/src/runtime/scheduler/mod.rs @@ -113,6 +113,31 @@ cfg_rt! { match_flavor!(self, Handle(h) => &h.blocking_spawner) } + pub(crate) fn is_local(&self) -> bool { + match self { + Handle::CurrentThread(h) => h.local_tid.is_some(), + + #[cfg(feature = "rt-multi-thread")] + Handle::MultiThread(_) => false, + + #[cfg(all(tokio_unstable, feature = "rt-multi-thread"))] + Handle::MultiThreadAlt(_) => false, + } + } + + /// Returns true if this is a local runtime and the runtime is owned by the current thread. + pub(crate) fn can_spawn_local_on_local_runtime(&self) -> bool { + match self { + Handle::CurrentThread(h) => h.local_tid.map(|x| std::thread::current().id() == x).unwrap_or(false), + + #[cfg(feature = "rt-multi-thread")] + Handle::MultiThread(_) => false, + + #[cfg(all(tokio_unstable, feature = "rt-multi-thread"))] + Handle::MultiThreadAlt(_) => false, + } + } + pub(crate) fn spawn(&self, future: F, id: Id) -> JoinHandle where F: Future + Send + 'static, @@ -129,6 +154,24 @@ cfg_rt! { } } + /// Spawn a local task + /// + /// # Safety + /// This should only be called in `LocalRuntime` if the runtime has been verified to be owned + /// by the current thread. + #[allow(irrefutable_let_patterns)] + pub(crate) unsafe fn spawn_local(&self, future: F, id: Id) -> JoinHandle + where + F: Future + 'static, + F::Output: 'static, + { + if let Handle::CurrentThread(h) = self { + current_thread::Handle::spawn_local(h, future, id) + } else { + panic!("Only current_thread and LocalSet have spawn_local internals implemented") + } + } + pub(crate) fn shutdown(&self) { match *self { Handle::CurrentThread(_) => {}, diff --git a/tokio/src/runtime/task/list.rs b/tokio/src/runtime/task/list.rs index 988d422836d..54bfc01aafb 100644 --- a/tokio/src/runtime/task/list.rs +++ b/tokio/src/runtime/task/list.rs @@ -102,6 +102,26 @@ impl OwnedTasks { (join, notified) } + /// Bind a task that isn't safe to transfer across thread boundaries. + /// + /// # Safety + /// Only use this in `LocalRuntime` where the task cannot move + pub(crate) unsafe fn bind_local( + &self, + task: T, + scheduler: S, + id: super::Id, + ) -> (JoinHandle, Option>) + where + S: Schedule, + T: Future + 'static, + T::Output: 'static, + { + let (task, notified, join) = super::new_task(task, scheduler, id); + let notified = unsafe { self.bind_inner(task, notified) }; + (join, notified) + } + /// The part of `bind` that's the same for every type of future. unsafe fn bind_inner(&self, task: Task, notified: Notified) -> Option> where diff --git a/tokio/src/task/local.rs b/tokio/src/task/local.rs index d5341937893..edd02acbac0 100644 --- a/tokio/src/task/local.rs +++ b/tokio/src/task/local.rs @@ -322,7 +322,7 @@ impl<'a> Drop for LocalDataEnterGuard<'a> { } cfg_rt! { - /// Spawns a `!Send` future on the current [`LocalSet`]. + /// Spawns a `!Send` future on the current [`LocalSet`] or [`LocalRuntime`]. /// /// The spawned future will run on the same thread that called `spawn_local`. /// @@ -362,6 +362,7 @@ cfg_rt! { /// ``` /// /// [`LocalSet`]: struct@crate::task::LocalSet + /// [`LocalRuntime`]: struct@crate::runtime::LocalRuntime /// [`tokio::spawn`]: fn@crate::task::spawn #[track_caller] pub fn spawn_local(future: F) -> JoinHandle @@ -383,10 +384,51 @@ cfg_rt! { where F: Future + 'static, F::Output: 'static { - match CURRENT.with(|LocalData { ctx, .. }| ctx.get()) { - None => panic!("`spawn_local` called from outside of a `task::LocalSet`"), - Some(cx) => cx.spawn(future, meta) - } + use crate::runtime::{context, task}; + + let mut future = Some(future); + + let res = context::with_current(|handle| { + Some(if handle.is_local() { + if !handle.can_spawn_local_on_local_runtime() { + return None; + } + + let future = future.take().unwrap(); + + #[cfg(all( + tokio_unstable, + tokio_taskdump, + feature = "rt", + target_os = "linux", + any( + target_arch = "aarch64", + target_arch = "x86", + target_arch = "x86_64" + ) + ))] + let future = task::trace::Trace::root(future); + let id = task::Id::next(); + let task = crate::util::trace::task(future, "task", meta, id.as_u64()); + + // safety: we have verified that this is a `LocalRuntime` owned by the current thread + unsafe { handle.spawn_local(task, id) } + } else { + match CURRENT.with(|LocalData { ctx, .. }| ctx.get()) { + None => panic!("`spawn_local` called from outside of a `task::LocalSet` or LocalRuntime"), + Some(cx) => cx.spawn(future.take().unwrap(), meta) + } + }) + }); + + match res { + Ok(None) => panic!("Local tasks can only be spawned on a LocalRuntime from the thread the runtime was created on"), + Ok(Some(join_handle)) => join_handle, + Err(_) => match CURRENT.with(|LocalData { ctx, .. }| ctx.get()) { + None => panic!("`spawn_local` called from outside of a `task::LocalSet` or LocalRuntime"), + Some(cx) => cx.spawn(future.unwrap(), meta) + } + } } } diff --git a/tokio/src/util/trace.rs b/tokio/src/util/trace.rs index 97006df474e..b6eadba2205 100644 --- a/tokio/src/util/trace.rs +++ b/tokio/src/util/trace.rs @@ -1,6 +1,7 @@ cfg_rt! { use std::marker::PhantomData; + #[derive(Copy, Clone)] pub(crate) struct SpawnMeta<'a> { /// The name of the task #[cfg(all(tokio_unstable, feature = "tracing"))] diff --git a/tokio/tests/rt_local.rs b/tokio/tests/rt_local.rs new file mode 100644 index 00000000000..1f14f5444d3 --- /dev/null +++ b/tokio/tests/rt_local.rs @@ -0,0 +1,100 @@ +#![allow(unknown_lints, unexpected_cfgs)] +#![warn(rust_2018_idioms)] +#![cfg(all(feature = "full", tokio_unstable))] + +use tokio::runtime::LocalOptions; +use tokio::task::spawn_local; + +#[test] +fn test_spawn_local_in_runtime() { + let rt = rt(); + + let res = rt.block_on(async move { + let (tx, rx) = tokio::sync::oneshot::channel(); + + spawn_local(async { + tokio::task::yield_now().await; + tx.send(5).unwrap(); + }); + + rx.await.unwrap() + }); + + assert_eq!(res, 5); +} + +#[test] +fn test_spawn_from_handle() { + let rt = rt(); + + let (tx, rx) = tokio::sync::oneshot::channel(); + + rt.handle().spawn(async { + tokio::task::yield_now().await; + tx.send(5).unwrap(); + }); + + let res = rt.block_on(async move { rx.await.unwrap() }); + + assert_eq!(res, 5); +} + +#[test] +fn test_spawn_local_on_runtime_object() { + let rt = rt(); + + let (tx, rx) = tokio::sync::oneshot::channel(); + + rt.spawn_local(async { + tokio::task::yield_now().await; + tx.send(5).unwrap(); + }); + + let res = rt.block_on(async move { rx.await.unwrap() }); + + assert_eq!(res, 5); +} + +#[test] +fn test_spawn_local_from_guard() { + let rt = rt(); + + let (tx, rx) = tokio::sync::oneshot::channel(); + + let _guard = rt.enter(); + + spawn_local(async { + tokio::task::yield_now().await; + tx.send(5).unwrap(); + }); + + let res = rt.block_on(async move { rx.await.unwrap() }); + + assert_eq!(res, 5); +} + +#[test] +#[should_panic] +fn test_spawn_local_from_guard_other_thread() { + let (tx, rx) = std::sync::mpsc::channel(); + + std::thread::spawn(move || { + let rt = rt(); + let handle = rt.handle().clone(); + + tx.send(handle).unwrap(); + }); + + let handle = rx.recv().unwrap(); + + let _guard = handle.enter(); + + spawn_local(async {}); +} + +fn rt() -> tokio::runtime::LocalRuntime { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build_local(&LocalOptions::default()) + .unwrap() +}