From a96d4bec4e25326db17539e1d531fc5f71bee88a Mon Sep 17 00:00:00 2001 From: John Nunley Date: Fri, 22 Sep 2023 17:23:55 -0700 Subject: [PATCH 1/5] Refractor Async into a new file This commit moves the implementation for Async into a new file called io.rs. This is done, as adding wasm32-unknown-unknown does not support async I/O, so it's easiest to move it all to another file to cut if off. As this commit will make up the majority of the lines of the overall PR, I've elected to make it separate for ease of review. Signed-off-by: John Nunley --- src/io.rs | 1577 ++++++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 1572 +-------------------------------------------------- 2 files changed, 1581 insertions(+), 1568 deletions(-) create mode 100644 src/io.rs diff --git a/src/io.rs b/src/io.rs new file mode 100644 index 0000000..4a17551 --- /dev/null +++ b/src/io.rs @@ -0,0 +1,1577 @@ +//! Implements [`Async`], allowing users to register FDs/sockets into the reactor. + +use std::convert::TryFrom; +use std::future::Future; +use std::io::{self, IoSlice, IoSliceMut, Read, Write}; +use std::net::{SocketAddr, TcpListener, TcpStream, UdpSocket}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +#[cfg(unix)] +use std::{ + os::unix::io::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd}, + os::unix::net::{SocketAddr as UnixSocketAddr, UnixDatagram, UnixListener, UnixStream}, + path::Path, +}; + +#[cfg(windows)] +use std::os::windows::io::{AsRawSocket, AsSocket, BorrowedSocket, OwnedSocket, RawSocket}; + +use futures_io::{AsyncRead, AsyncWrite}; +use futures_lite::stream::{self, Stream}; +use futures_lite::{future, pin, ready}; +use socket2::{Domain, Protocol, SockAddr, Socket, Type}; + +use crate::reactor::{ + Reactor, Readable, ReadableOwned, Registration, Source, Writable, WritableOwned, +}; + +/// Async adapter for I/O types. +/// +/// This type puts an I/O handle into non-blocking mode, registers it in +/// [epoll]/[kqueue]/[event ports]/[IOCP], and then provides an async interface for it. +/// +/// [epoll]: https://en.wikipedia.org/wiki/Epoll +/// [kqueue]: https://en.wikipedia.org/wiki/Kqueue +/// [event ports]: https://illumos.org/man/port_create +/// [IOCP]: https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports +/// +/// # Caveats +/// +/// [`Async`] is a low-level primitive, and as such it comes with some caveats. +/// +/// For higher-level primitives built on top of [`Async`], look into [`async-net`] or +/// [`async-process`] (on Unix). +/// +/// The most notable caveat is that it is unsafe to access the inner I/O source mutably +/// using this primitive. Traits likes [`AsyncRead`] and [`AsyncWrite`] are not implemented by +/// default unless it is guaranteed that the resource won't be invalidated by reading or writing. +/// See the [`IoSafe`] trait for more information. +/// +/// [`async-net`]: https://github.com/smol-rs/async-net +/// [`async-process`]: https://github.com/smol-rs/async-process +/// [`AsyncRead`]: https://docs.rs/futures-io/latest/futures_io/trait.AsyncRead.html +/// [`AsyncWrite`]: https://docs.rs/futures-io/latest/futures_io/trait.AsyncWrite.html +/// +/// ### Supported types +/// +/// [`Async`] supports all networking types, as well as some OS-specific file descriptors like +/// [timerfd] and [inotify]. +/// +/// However, do not use [`Async`] with types like [`File`][`std::fs::File`], +/// [`Stdin`][`std::io::Stdin`], [`Stdout`][`std::io::Stdout`], or [`Stderr`][`std::io::Stderr`] +/// because all operating systems have issues with them when put in non-blocking mode. +/// +/// [timerfd]: https://github.com/smol-rs/async-io/blob/master/examples/linux-timerfd.rs +/// [inotify]: https://github.com/smol-rs/async-io/blob/master/examples/linux-inotify.rs +/// +/// ### Concurrent I/O +/// +/// Note that [`&Async`][`Async`] implements [`AsyncRead`] and [`AsyncWrite`] if `&T` +/// implements those traits, which means tasks can concurrently read and write using shared +/// references. +/// +/// But there is a catch: only one task can read a time, and only one task can write at a time. It +/// is okay to have two tasks where one is reading and the other is writing at the same time, but +/// it is not okay to have two tasks reading at the same time or writing at the same time. If you +/// try to do that, conflicting tasks will just keep waking each other in turn, thus wasting CPU +/// time. +/// +/// Besides [`AsyncRead`] and [`AsyncWrite`], this caveat also applies to +/// [`poll_readable()`][`Async::poll_readable()`] and +/// [`poll_writable()`][`Async::poll_writable()`]. +/// +/// However, any number of tasks can be concurrently calling other methods like +/// [`readable()`][`Async::readable()`] or [`read_with()`][`Async::read_with()`]. +/// +/// ### Closing +/// +/// Closing the write side of [`Async`] with [`close()`][`futures_lite::AsyncWriteExt::close()`] +/// simply flushes. If you want to shutdown a TCP or Unix socket, use +/// [`Shutdown`][`std::net::Shutdown`]. +/// +/// # Examples +/// +/// Connect to a server and echo incoming messages back to the server: +/// +/// ```no_run +/// use async_io::Async; +/// use futures_lite::io; +/// use std::net::TcpStream; +/// +/// # futures_lite::future::block_on(async { +/// // Connect to a local server. +/// let stream = Async::::connect(([127, 0, 0, 1], 8000)).await?; +/// +/// // Echo all messages from the read side of the stream into the write side. +/// io::copy(&stream, &stream).await?; +/// # std::io::Result::Ok(()) }); +/// ``` +/// +/// You can use either predefined async methods or wrap blocking I/O operations in +/// [`Async::read_with()`], [`Async::read_with_mut()`], [`Async::write_with()`], and +/// [`Async::write_with_mut()`]: +/// +/// ```no_run +/// use async_io::Async; +/// use std::net::TcpListener; +/// +/// # futures_lite::future::block_on(async { +/// let listener = Async::::bind(([127, 0, 0, 1], 0))?; +/// +/// // These two lines are equivalent: +/// let (stream, addr) = listener.accept().await?; +/// let (stream, addr) = listener.read_with(|inner| inner.accept()).await?; +/// # std::io::Result::Ok(()) }); +/// ``` +#[derive(Debug)] +pub struct Async { + /// A source registered in the reactor. + pub(crate) source: Arc, + + /// The inner I/O handle. + pub(crate) io: Option, +} + +impl Unpin for Async {} + +#[cfg(unix)] +impl Async { + /// Creates an async I/O handle. + /// + /// This method will put the handle in non-blocking mode and register it in + /// [epoll]/[kqueue]/[event ports]/[IOCP]. + /// + /// On Unix systems, the handle must implement `AsRawFd`, while on Windows it must implement + /// `AsRawSocket`. + /// + /// [epoll]: https://en.wikipedia.org/wiki/Epoll + /// [kqueue]: https://en.wikipedia.org/wiki/Kqueue + /// [event ports]: https://illumos.org/man/port_create + /// [IOCP]: https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports + /// + /// # Examples + /// + /// ``` + /// use async_io::Async; + /// use std::net::{SocketAddr, TcpListener}; + /// + /// # futures_lite::future::block_on(async { + /// let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))?; + /// let listener = Async::new(listener)?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn new(io: T) -> io::Result> { + // Put the file descriptor in non-blocking mode. + let fd = io.as_fd(); + cfg_if::cfg_if! { + // ioctl(FIONBIO) sets the flag atomically, but we use this only on Linux + // for now, as with the standard library, because it seems to behave + // differently depending on the platform. + // https://github.com/rust-lang/rust/commit/efeb42be2837842d1beb47b51bb693c7474aba3d + // https://github.com/libuv/libuv/blob/e9d91fccfc3e5ff772d5da90e1c4a24061198ca0/src/unix/poll.c#L78-L80 + // https://github.com/tokio-rs/mio/commit/0db49f6d5caf54b12176821363d154384357e70a + if #[cfg(target_os = "linux")] { + rustix::io::ioctl_fionbio(fd, true)?; + } else { + let previous = rustix::fs::fcntl_getfl(fd)?; + let new = previous | rustix::fs::OFlags::NONBLOCK; + if new != previous { + rustix::fs::fcntl_setfl(fd, new)?; + } + } + } + + // SAFETY: It is impossible to drop the I/O source while it is registered through + // this type. + let registration = unsafe { Registration::new(fd) }; + + Ok(Async { + source: Reactor::get().insert_io(registration)?, + io: Some(io), + }) + } +} + +#[cfg(unix)] +impl AsRawFd for Async { + fn as_raw_fd(&self) -> RawFd { + self.get_ref().as_raw_fd() + } +} + +#[cfg(unix)] +impl AsFd for Async { + fn as_fd(&self) -> BorrowedFd<'_> { + self.get_ref().as_fd() + } +} + +#[cfg(unix)] +impl> TryFrom for Async { + type Error = io::Error; + + fn try_from(value: OwnedFd) -> Result { + Async::new(value.into()) + } +} + +#[cfg(unix)] +impl> TryFrom> for OwnedFd { + type Error = io::Error; + + fn try_from(value: Async) -> Result { + value.into_inner().map(Into::into) + } +} + +#[cfg(windows)] +impl Async { + /// Creates an async I/O handle. + /// + /// This method will put the handle in non-blocking mode and register it in + /// [epoll]/[kqueue]/[event ports]/[IOCP]. + /// + /// On Unix systems, the handle must implement `AsRawFd`, while on Windows it must implement + /// `AsRawSocket`. + /// + /// [epoll]: https://en.wikipedia.org/wiki/Epoll + /// [kqueue]: https://en.wikipedia.org/wiki/Kqueue + /// [event ports]: https://illumos.org/man/port_create + /// [IOCP]: https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports + /// + /// # Examples + /// + /// ``` + /// use async_io::Async; + /// use std::net::{SocketAddr, TcpListener}; + /// + /// # futures_lite::future::block_on(async { + /// let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))?; + /// let listener = Async::new(listener)?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn new(io: T) -> io::Result> { + let borrowed = io.as_socket(); + + // Put the socket in non-blocking mode. + // + // Safety: We assume `as_raw_socket()` returns a valid fd. When we can + // depend on Rust >= 1.63, where `AsFd` is stabilized, and when + // `TimerFd` implements it, we can remove this unsafe and simplify this. + rustix::io::ioctl_fionbio(borrowed, true)?; + + // Create the registration. + // + // SAFETY: It is impossible to drop the I/O source while it is registered through + // this type. + let registration = unsafe { Registration::new(borrowed) }; + + Ok(Async { + source: Reactor::get().insert_io(registration)?, + io: Some(io), + }) + } +} + +#[cfg(windows)] +impl AsRawSocket for Async { + fn as_raw_socket(&self) -> RawSocket { + self.get_ref().as_raw_socket() + } +} + +#[cfg(windows)] +impl AsSocket for Async { + fn as_socket(&self) -> BorrowedSocket<'_> { + self.get_ref().as_socket() + } +} + +#[cfg(windows)] +impl> TryFrom for Async { + type Error = io::Error; + + fn try_from(value: OwnedSocket) -> Result { + Async::new(value.into()) + } +} + +#[cfg(windows)] +impl> TryFrom> for OwnedSocket { + type Error = io::Error; + + fn try_from(value: Async) -> Result { + value.into_inner().map(Into::into) + } +} + +impl Async { + /// Gets a reference to the inner I/O handle. + /// + /// # Examples + /// + /// ``` + /// use async_io::Async; + /// use std::net::TcpListener; + /// + /// # futures_lite::future::block_on(async { + /// let listener = Async::::bind(([127, 0, 0, 1], 0))?; + /// let inner = listener.get_ref(); + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn get_ref(&self) -> &T { + self.io.as_ref().unwrap() + } + + /// Gets a mutable reference to the inner I/O handle. + /// + /// # Safety + /// + /// The underlying I/O source must not be dropped using this function. + /// + /// # Examples + /// + /// ``` + /// use async_io::Async; + /// use std::net::TcpListener; + /// + /// # futures_lite::future::block_on(async { + /// let mut listener = Async::::bind(([127, 0, 0, 1], 0))?; + /// let inner = unsafe { listener.get_mut() }; + /// # std::io::Result::Ok(()) }); + /// ``` + pub unsafe fn get_mut(&mut self) -> &mut T { + self.io.as_mut().unwrap() + } + + /// Unwraps the inner I/O handle. + /// + /// This method will **not** put the I/O handle back into blocking mode. + /// + /// # Examples + /// + /// ``` + /// use async_io::Async; + /// use std::net::TcpListener; + /// + /// # futures_lite::future::block_on(async { + /// let listener = Async::::bind(([127, 0, 0, 1], 0))?; + /// let inner = listener.into_inner()?; + /// + /// // Put the listener back into blocking mode. + /// inner.set_nonblocking(false)?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn into_inner(mut self) -> io::Result { + let io = self.io.take().unwrap(); + Reactor::get().remove_io(&self.source)?; + Ok(io) + } + + /// Waits until the I/O handle is readable. + /// + /// This method completes when a read operation on this I/O handle wouldn't block. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::TcpListener; + /// + /// # futures_lite::future::block_on(async { + /// let mut listener = Async::::bind(([127, 0, 0, 1], 0))?; + /// + /// // Wait until a client can be accepted. + /// listener.readable().await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn readable(&self) -> Readable<'_, T> { + Source::readable(self) + } + + /// Waits until the I/O handle is readable. + /// + /// This method completes when a read operation on this I/O handle wouldn't block. + pub fn readable_owned(self: Arc) -> ReadableOwned { + Source::readable_owned(self) + } + + /// Waits until the I/O handle is writable. + /// + /// This method completes when a write operation on this I/O handle wouldn't block. + /// + /// # Examples + /// + /// ``` + /// use async_io::Async; + /// use std::net::{TcpStream, ToSocketAddrs}; + /// + /// # futures_lite::future::block_on(async { + /// let addr = "example.com:80".to_socket_addrs()?.next().unwrap(); + /// let stream = Async::::connect(addr).await?; + /// + /// // Wait until the stream is writable. + /// stream.writable().await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn writable(&self) -> Writable<'_, T> { + Source::writable(self) + } + + /// Waits until the I/O handle is writable. + /// + /// This method completes when a write operation on this I/O handle wouldn't block. + pub fn writable_owned(self: Arc) -> WritableOwned { + Source::writable_owned(self) + } + + /// Polls the I/O handle for readability. + /// + /// When this method returns [`Poll::Ready`], that means the OS has delivered an event + /// indicating readability since the last time this task has called the method and received + /// [`Poll::Pending`]. + /// + /// # Caveats + /// + /// Two different tasks should not call this method concurrently. Otherwise, conflicting tasks + /// will just keep waking each other in turn, thus wasting CPU time. + /// + /// Note that the [`AsyncRead`] implementation for [`Async`] also uses this method. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use futures_lite::future; + /// use std::net::TcpListener; + /// + /// # futures_lite::future::block_on(async { + /// let mut listener = Async::::bind(([127, 0, 0, 1], 0))?; + /// + /// // Wait until a client can be accepted. + /// future::poll_fn(|cx| listener.poll_readable(cx)).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn poll_readable(&self, cx: &mut Context<'_>) -> Poll> { + self.source.poll_readable(cx) + } + + /// Polls the I/O handle for writability. + /// + /// When this method returns [`Poll::Ready`], that means the OS has delivered an event + /// indicating writability since the last time this task has called the method and received + /// [`Poll::Pending`]. + /// + /// # Caveats + /// + /// Two different tasks should not call this method concurrently. Otherwise, conflicting tasks + /// will just keep waking each other in turn, thus wasting CPU time. + /// + /// Note that the [`AsyncWrite`] implementation for [`Async`] also uses this method. + /// + /// # Examples + /// + /// ``` + /// use async_io::Async; + /// use futures_lite::future; + /// use std::net::{TcpStream, ToSocketAddrs}; + /// + /// # futures_lite::future::block_on(async { + /// let addr = "example.com:80".to_socket_addrs()?.next().unwrap(); + /// let stream = Async::::connect(addr).await?; + /// + /// // Wait until the stream is writable. + /// future::poll_fn(|cx| stream.poll_writable(cx)).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn poll_writable(&self, cx: &mut Context<'_>) -> Poll> { + self.source.poll_writable(cx) + } + + /// Performs a read operation asynchronously. + /// + /// The I/O handle is registered in the reactor and put in non-blocking mode. This method + /// invokes the `op` closure in a loop until it succeeds or returns an error other than + /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS + /// sends a notification that the I/O handle is readable. + /// + /// The closure receives a shared reference to the I/O handle. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::TcpListener; + /// + /// # futures_lite::future::block_on(async { + /// let listener = Async::::bind(([127, 0, 0, 1], 0))?; + /// + /// // Accept a new client asynchronously. + /// let (stream, addr) = listener.read_with(|l| l.accept()).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn read_with(&self, op: impl FnMut(&T) -> io::Result) -> io::Result { + let mut op = op; + loop { + match op(self.get_ref()) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return res, + } + optimistic(self.readable()).await?; + } + } + + /// Performs a read operation asynchronously. + /// + /// The I/O handle is registered in the reactor and put in non-blocking mode. This method + /// invokes the `op` closure in a loop until it succeeds or returns an error other than + /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS + /// sends a notification that the I/O handle is readable. + /// + /// The closure receives a mutable reference to the I/O handle. + /// + /// # Safety + /// + /// In the closure, the underlying I/O source must not be dropped. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::TcpListener; + /// + /// # futures_lite::future::block_on(async { + /// let mut listener = Async::::bind(([127, 0, 0, 1], 0))?; + /// + /// // Accept a new client asynchronously. + /// let (stream, addr) = unsafe { listener.read_with_mut(|l| l.accept()).await? }; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async unsafe fn read_with_mut( + &mut self, + op: impl FnMut(&mut T) -> io::Result, + ) -> io::Result { + let mut op = op; + loop { + match op(self.get_mut()) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return res, + } + optimistic(self.readable()).await?; + } + } + + /// Performs a write operation asynchronously. + /// + /// The I/O handle is registered in the reactor and put in non-blocking mode. This method + /// invokes the `op` closure in a loop until it succeeds or returns an error other than + /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS + /// sends a notification that the I/O handle is writable. + /// + /// The closure receives a shared reference to the I/O handle. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::UdpSocket; + /// + /// # futures_lite::future::block_on(async { + /// let socket = Async::::bind(([127, 0, 0, 1], 8000))?; + /// socket.get_ref().connect("127.0.0.1:9000")?; + /// + /// let msg = b"hello"; + /// let len = socket.write_with(|s| s.send(msg)).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn write_with(&self, op: impl FnMut(&T) -> io::Result) -> io::Result { + let mut op = op; + loop { + match op(self.get_ref()) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return res, + } + optimistic(self.writable()).await?; + } + } + + /// Performs a write operation asynchronously. + /// + /// The I/O handle is registered in the reactor and put in non-blocking mode. This method + /// invokes the `op` closure in a loop until it succeeds or returns an error other than + /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS + /// sends a notification that the I/O handle is writable. + /// + /// # Safety + /// + /// The closure receives a mutable reference to the I/O handle. In the closure, the underlying + /// I/O source must not be dropped. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::UdpSocket; + /// + /// # futures_lite::future::block_on(async { + /// let mut socket = Async::::bind(([127, 0, 0, 1], 8000))?; + /// socket.get_ref().connect("127.0.0.1:9000")?; + /// + /// let msg = b"hello"; + /// let len = unsafe { socket.write_with_mut(|s| s.send(msg)).await? }; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async unsafe fn write_with_mut( + &mut self, + op: impl FnMut(&mut T) -> io::Result, + ) -> io::Result { + let mut op = op; + loop { + match op(self.get_mut()) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return res, + } + optimistic(self.writable()).await?; + } + } +} + +impl AsRef for Async { + fn as_ref(&self) -> &T { + self.get_ref() + } +} + +impl Drop for Async { + fn drop(&mut self) { + if self.io.is_some() { + // Deregister and ignore errors because destructors should not panic. + Reactor::get().remove_io(&self.source).ok(); + + // Drop the I/O handle to close it. + self.io.take(); + } + } +} + +/// Types whose I/O trait implementations do not drop the underlying I/O source. +/// +/// The resource contained inside of the [`Async`] cannot be invalidated. This invalidation can +/// happen if the inner resource (the [`TcpStream`], [`UnixListener`] or other `T`) is moved out +/// and dropped before the [`Async`]. Because of this, functions that grant mutable access to +/// the inner type are unsafe, as there is no way to guarantee that the source won't be dropped +/// and a dangling handle won't be left behind. +/// +/// Unfortunately this extends to implementations of [`Read`] and [`Write`]. Since methods on those +/// traits take `&mut`, there is no guarantee that the implementor of those traits won't move the +/// source out while the method is being run. +/// +/// This trait is an antidote to this predicament. By implementing this trait, the user pledges +/// that using any I/O traits won't destroy the source. This way, [`Async`] can implement the +/// `async` version of these I/O traits, like [`AsyncRead`] and [`AsyncWrite`]. +/// +/// # Safety +/// +/// Any I/O trait implementations for this type must not drop the underlying I/O source. Traits +/// affected by this trait include [`Read`], [`Write`], [`Seek`] and [`BufRead`]. +/// +/// This trait is implemented by default on top of `libstd` types. In addition, it is implemented +/// for immutable reference types, as it is impossible to invalidate any outstanding references +/// while holding an immutable reference, even with interior mutability. As Rust's current pinning +/// system relies on similar guarantees, I believe that this approach is robust. +/// +/// [`BufRead`]: https://doc.rust-lang.org/std/io/trait.BufRead.html +/// [`Read`]: https://doc.rust-lang.org/std/io/trait.Read.html +/// [`Seek`]: https://doc.rust-lang.org/std/io/trait.Seek.html +/// [`Write`]: https://doc.rust-lang.org/std/io/trait.Write.html +/// +/// [`AsyncRead`]: https://docs.rs/futures-io/latest/futures_io/trait.AsyncRead.html +/// [`AsyncWrite`]: https://docs.rs/futures-io/latest/futures_io/trait.AsyncWrite.html +pub unsafe trait IoSafe {} + +/// Reference types can't be mutated. +/// +/// The worst thing that can happen is that external state is used to change what kind of pointer +/// `as_fd()` returns. For instance: +/// +/// ``` +/// # #[cfg(unix)] { +/// use std::cell::Cell; +/// use std::net::TcpStream; +/// use std::os::unix::io::{AsFd, BorrowedFd}; +/// +/// struct Bar { +/// flag: Cell, +/// a: TcpStream, +/// b: TcpStream +/// } +/// +/// impl AsFd for Bar { +/// fn as_fd(&self) -> BorrowedFd<'_> { +/// if self.flag.replace(!self.flag.get()) { +/// self.a.as_fd() +/// } else { +/// self.b.as_fd() +/// } +/// } +/// } +/// # } +/// ``` +/// +/// We solve this problem by only calling `as_fd()` once to get the original source. Implementations +/// like this are considered buggy (but not unsound) and are thus not really supported by `async-io`. +unsafe impl IoSafe for &T {} + +// Can be implemented on top of libstd types. +unsafe impl IoSafe for std::fs::File {} +unsafe impl IoSafe for std::io::Stderr {} +unsafe impl IoSafe for std::io::Stdin {} +unsafe impl IoSafe for std::io::Stdout {} +unsafe impl IoSafe for std::io::StderrLock<'_> {} +unsafe impl IoSafe for std::io::StdinLock<'_> {} +unsafe impl IoSafe for std::io::StdoutLock<'_> {} +unsafe impl IoSafe for std::net::TcpStream {} + +#[cfg(unix)] +unsafe impl IoSafe for std::os::unix::net::UnixStream {} + +unsafe impl IoSafe for std::io::BufReader {} +unsafe impl IoSafe for std::io::BufWriter {} +unsafe impl IoSafe for std::io::LineWriter {} +unsafe impl IoSafe for &mut T {} +unsafe impl IoSafe for Box {} +unsafe impl IoSafe for std::borrow::Cow<'_, T> {} + +impl AsyncRead for Async { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + loop { + match unsafe { (*self).get_mut() }.read(buf) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), + } + ready!(self.poll_readable(cx))?; + } + } + + fn poll_read_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &mut [IoSliceMut<'_>], + ) -> Poll> { + loop { + match unsafe { (*self).get_mut() }.read_vectored(bufs) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), + } + ready!(self.poll_readable(cx))?; + } + } +} + +// Since this is through a reference, we can't mutate the inner I/O source. +// Therefore this is safe! +impl AsyncRead for &Async +where + for<'a> &'a T: Read, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + loop { + match (*self).get_ref().read(buf) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), + } + ready!(self.poll_readable(cx))?; + } + } + + fn poll_read_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &mut [IoSliceMut<'_>], + ) -> Poll> { + loop { + match (*self).get_ref().read_vectored(bufs) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), + } + ready!(self.poll_readable(cx))?; + } + } +} + +impl AsyncWrite for Async { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + loop { + match unsafe { (*self).get_mut() }.write(buf) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), + } + ready!(self.poll_writable(cx))?; + } + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + loop { + match unsafe { (*self).get_mut() }.write_vectored(bufs) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), + } + ready!(self.poll_writable(cx))?; + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + match unsafe { (*self).get_mut() }.flush() { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), + } + ready!(self.poll_writable(cx))?; + } + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_flush(cx) + } +} + +impl AsyncWrite for &Async +where + for<'a> &'a T: Write, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + loop { + match (*self).get_ref().write(buf) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), + } + ready!(self.poll_writable(cx))?; + } + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + loop { + match (*self).get_ref().write_vectored(bufs) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), + } + ready!(self.poll_writable(cx))?; + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + match (*self).get_ref().flush() { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), + } + ready!(self.poll_writable(cx))?; + } + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_flush(cx) + } +} + +impl Async { + /// Creates a TCP listener bound to the specified address. + /// + /// Binding with port number 0 will request an available port from the OS. + /// + /// # Examples + /// + /// ``` + /// use async_io::Async; + /// use std::net::TcpListener; + /// + /// # futures_lite::future::block_on(async { + /// let listener = Async::::bind(([127, 0, 0, 1], 0))?; + /// println!("Listening on {}", listener.get_ref().local_addr()?); + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn bind>(addr: A) -> io::Result> { + let addr = addr.into(); + Async::new(TcpListener::bind(addr)?) + } + + /// Accepts a new incoming TCP connection. + /// + /// When a connection is established, it will be returned as a TCP stream together with its + /// remote address. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::TcpListener; + /// + /// # futures_lite::future::block_on(async { + /// let listener = Async::::bind(([127, 0, 0, 1], 8000))?; + /// let (stream, addr) = listener.accept().await?; + /// println!("Accepted client: {}", addr); + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn accept(&self) -> io::Result<(Async, SocketAddr)> { + let (stream, addr) = self.read_with(|io| io.accept()).await?; + Ok((Async::new(stream)?, addr)) + } + + /// Returns a stream of incoming TCP connections. + /// + /// The stream is infinite, i.e. it never stops with a [`None`]. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use futures_lite::{pin, stream::StreamExt}; + /// use std::net::TcpListener; + /// + /// # futures_lite::future::block_on(async { + /// let listener = Async::::bind(([127, 0, 0, 1], 8000))?; + /// let incoming = listener.incoming(); + /// pin!(incoming); + /// + /// while let Some(stream) = incoming.next().await { + /// let stream = stream?; + /// println!("Accepted client: {}", stream.get_ref().peer_addr()?); + /// } + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn incoming(&self) -> impl Stream>> + Send + '_ { + stream::unfold(self, |listener| async move { + let res = listener.accept().await.map(|(stream, _)| stream); + Some((res, listener)) + }) + } +} + +impl TryFrom for Async { + type Error = io::Error; + + fn try_from(listener: std::net::TcpListener) -> io::Result { + Async::new(listener) + } +} + +impl Async { + /// Creates a TCP connection to the specified address. + /// + /// # Examples + /// + /// ``` + /// use async_io::Async; + /// use std::net::{TcpStream, ToSocketAddrs}; + /// + /// # futures_lite::future::block_on(async { + /// let addr = "example.com:80".to_socket_addrs()?.next().unwrap(); + /// let stream = Async::::connect(addr).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn connect>(addr: A) -> io::Result> { + // Begin async connect. + let addr = addr.into(); + let domain = Domain::for_address(addr); + let socket = connect(addr.into(), domain, Some(Protocol::TCP))?; + let stream = Async::new(TcpStream::from(socket))?; + + // The stream becomes writable when connected. + stream.writable().await?; + + // Check if there was an error while connecting. + match stream.get_ref().take_error()? { + None => Ok(stream), + Some(err) => Err(err), + } + } + + /// Reads data from the stream without removing it from the buffer. + /// + /// Returns the number of bytes read. Successive calls of this method read the same data. + /// + /// # Examples + /// + /// ``` + /// use async_io::Async; + /// use futures_lite::{io::AsyncWriteExt, stream::StreamExt}; + /// use std::net::{TcpStream, ToSocketAddrs}; + /// + /// # futures_lite::future::block_on(async { + /// let addr = "example.com:80".to_socket_addrs()?.next().unwrap(); + /// let mut stream = Async::::connect(addr).await?; + /// + /// stream + /// .write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") + /// .await?; + /// + /// let mut buf = [0u8; 1024]; + /// let len = stream.peek(&mut buf).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn peek(&self, buf: &mut [u8]) -> io::Result { + self.read_with(|io| io.peek(buf)).await + } +} + +impl TryFrom for Async { + type Error = io::Error; + + fn try_from(stream: std::net::TcpStream) -> io::Result { + Async::new(stream) + } +} + +impl Async { + /// Creates a UDP socket bound to the specified address. + /// + /// Binding with port number 0 will request an available port from the OS. + /// + /// # Examples + /// + /// ``` + /// use async_io::Async; + /// use std::net::UdpSocket; + /// + /// # futures_lite::future::block_on(async { + /// let socket = Async::::bind(([127, 0, 0, 1], 0))?; + /// println!("Bound to {}", socket.get_ref().local_addr()?); + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn bind>(addr: A) -> io::Result> { + let addr = addr.into(); + Async::new(UdpSocket::bind(addr)?) + } + + /// Receives a single datagram message. + /// + /// Returns the number of bytes read and the address the message came from. + /// + /// This method must be called with a valid byte slice of sufficient size to hold the message. + /// If the message is too long to fit, excess bytes may get discarded. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::UdpSocket; + /// + /// # futures_lite::future::block_on(async { + /// let socket = Async::::bind(([127, 0, 0, 1], 8000))?; + /// + /// let mut buf = [0u8; 1024]; + /// let (len, addr) = socket.recv_from(&mut buf).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + self.read_with(|io| io.recv_from(buf)).await + } + + /// Receives a single datagram message without removing it from the queue. + /// + /// Returns the number of bytes read and the address the message came from. + /// + /// This method must be called with a valid byte slice of sufficient size to hold the message. + /// If the message is too long to fit, excess bytes may get discarded. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::UdpSocket; + /// + /// # futures_lite::future::block_on(async { + /// let socket = Async::::bind(([127, 0, 0, 1], 8000))?; + /// + /// let mut buf = [0u8; 1024]; + /// let (len, addr) = socket.peek_from(&mut buf).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + self.read_with(|io| io.peek_from(buf)).await + } + + /// Sends data to the specified address. + /// + /// Returns the number of bytes writen. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::UdpSocket; + /// + /// # futures_lite::future::block_on(async { + /// let socket = Async::::bind(([127, 0, 0, 1], 0))?; + /// let addr = socket.get_ref().local_addr()?; + /// + /// let msg = b"hello"; + /// let len = socket.send_to(msg, addr).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn send_to>(&self, buf: &[u8], addr: A) -> io::Result { + let addr = addr.into(); + self.write_with(|io| io.send_to(buf, addr)).await + } + + /// Receives a single datagram message from the connected peer. + /// + /// Returns the number of bytes read. + /// + /// This method must be called with a valid byte slice of sufficient size to hold the message. + /// If the message is too long to fit, excess bytes may get discarded. + /// + /// The [`connect`][`UdpSocket::connect()`] method connects this socket to a remote address. + /// This method will fail if the socket is not connected. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::UdpSocket; + /// + /// # futures_lite::future::block_on(async { + /// let socket = Async::::bind(([127, 0, 0, 1], 8000))?; + /// socket.get_ref().connect("127.0.0.1:9000")?; + /// + /// let mut buf = [0u8; 1024]; + /// let len = socket.recv(&mut buf).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn recv(&self, buf: &mut [u8]) -> io::Result { + self.read_with(|io| io.recv(buf)).await + } + + /// Receives a single datagram message from the connected peer without removing it from the + /// queue. + /// + /// Returns the number of bytes read and the address the message came from. + /// + /// This method must be called with a valid byte slice of sufficient size to hold the message. + /// If the message is too long to fit, excess bytes may get discarded. + /// + /// The [`connect`][`UdpSocket::connect()`] method connects this socket to a remote address. + /// This method will fail if the socket is not connected. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::UdpSocket; + /// + /// # futures_lite::future::block_on(async { + /// let socket = Async::::bind(([127, 0, 0, 1], 8000))?; + /// socket.get_ref().connect("127.0.0.1:9000")?; + /// + /// let mut buf = [0u8; 1024]; + /// let len = socket.peek(&mut buf).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn peek(&self, buf: &mut [u8]) -> io::Result { + self.read_with(|io| io.peek(buf)).await + } + + /// Sends data to the connected peer. + /// + /// Returns the number of bytes written. + /// + /// The [`connect`][`UdpSocket::connect()`] method connects this socket to a remote address. + /// This method will fail if the socket is not connected. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::UdpSocket; + /// + /// # futures_lite::future::block_on(async { + /// let socket = Async::::bind(([127, 0, 0, 1], 8000))?; + /// socket.get_ref().connect("127.0.0.1:9000")?; + /// + /// let msg = b"hello"; + /// let len = socket.send(msg).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn send(&self, buf: &[u8]) -> io::Result { + self.write_with(|io| io.send(buf)).await + } +} + +impl TryFrom for Async { + type Error = io::Error; + + fn try_from(socket: std::net::UdpSocket) -> io::Result { + Async::new(socket) + } +} + +#[cfg(unix)] +impl Async { + /// Creates a UDS listener bound to the specified path. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::os::unix::net::UnixListener; + /// + /// # futures_lite::future::block_on(async { + /// let listener = Async::::bind("/tmp/socket")?; + /// println!("Listening on {:?}", listener.get_ref().local_addr()?); + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn bind>(path: P) -> io::Result> { + let path = path.as_ref().to_owned(); + Async::new(UnixListener::bind(path)?) + } + + /// Accepts a new incoming UDS stream connection. + /// + /// When a connection is established, it will be returned as a stream together with its remote + /// address. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::os::unix::net::UnixListener; + /// + /// # futures_lite::future::block_on(async { + /// let listener = Async::::bind("/tmp/socket")?; + /// let (stream, addr) = listener.accept().await?; + /// println!("Accepted client: {:?}", addr); + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn accept(&self) -> io::Result<(Async, UnixSocketAddr)> { + let (stream, addr) = self.read_with(|io| io.accept()).await?; + Ok((Async::new(stream)?, addr)) + } + + /// Returns a stream of incoming UDS connections. + /// + /// The stream is infinite, i.e. it never stops with a [`None`] item. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use futures_lite::{pin, stream::StreamExt}; + /// use std::os::unix::net::UnixListener; + /// + /// # futures_lite::future::block_on(async { + /// let listener = Async::::bind("/tmp/socket")?; + /// let incoming = listener.incoming(); + /// pin!(incoming); + /// + /// while let Some(stream) = incoming.next().await { + /// let stream = stream?; + /// println!("Accepted client: {:?}", stream.get_ref().peer_addr()?); + /// } + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn incoming(&self) -> impl Stream>> + Send + '_ { + stream::unfold(self, |listener| async move { + let res = listener.accept().await.map(|(stream, _)| stream); + Some((res, listener)) + }) + } +} + +#[cfg(unix)] +impl TryFrom for Async { + type Error = io::Error; + + fn try_from(listener: std::os::unix::net::UnixListener) -> io::Result { + Async::new(listener) + } +} + +#[cfg(unix)] +impl Async { + /// Creates a UDS stream connected to the specified path. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::os::unix::net::UnixStream; + /// + /// # futures_lite::future::block_on(async { + /// let stream = Async::::connect("/tmp/socket").await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn connect>(path: P) -> io::Result> { + // Begin async connect. + let socket = connect(SockAddr::unix(path)?, Domain::UNIX, None)?; + let stream = Async::new(UnixStream::from(socket))?; + + // The stream becomes writable when connected. + stream.writable().await?; + + // On Linux, it appears the socket may become writable even when connecting fails, so we + // must do an extra check here and see if the peer address is retrievable. + stream.get_ref().peer_addr()?; + Ok(stream) + } + + /// Creates an unnamed pair of connected UDS stream sockets. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::os::unix::net::UnixStream; + /// + /// # futures_lite::future::block_on(async { + /// let (stream1, stream2) = Async::::pair()?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn pair() -> io::Result<(Async, Async)> { + let (stream1, stream2) = UnixStream::pair()?; + Ok((Async::new(stream1)?, Async::new(stream2)?)) + } +} + +#[cfg(unix)] +impl TryFrom for Async { + type Error = io::Error; + + fn try_from(stream: std::os::unix::net::UnixStream) -> io::Result { + Async::new(stream) + } +} + +#[cfg(unix)] +impl Async { + /// Creates a UDS datagram socket bound to the specified path. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::os::unix::net::UnixDatagram; + /// + /// # futures_lite::future::block_on(async { + /// let socket = Async::::bind("/tmp/socket")?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn bind>(path: P) -> io::Result> { + let path = path.as_ref().to_owned(); + Async::new(UnixDatagram::bind(path)?) + } + + /// Creates a UDS datagram socket not bound to any address. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::os::unix::net::UnixDatagram; + /// + /// # futures_lite::future::block_on(async { + /// let socket = Async::::unbound()?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn unbound() -> io::Result> { + Async::new(UnixDatagram::unbound()?) + } + + /// Creates an unnamed pair of connected Unix datagram sockets. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::os::unix::net::UnixDatagram; + /// + /// # futures_lite::future::block_on(async { + /// let (socket1, socket2) = Async::::pair()?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn pair() -> io::Result<(Async, Async)> { + let (socket1, socket2) = UnixDatagram::pair()?; + Ok((Async::new(socket1)?, Async::new(socket2)?)) + } + + /// Receives data from the socket. + /// + /// Returns the number of bytes read and the address the message came from. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::os::unix::net::UnixDatagram; + /// + /// # futures_lite::future::block_on(async { + /// let socket = Async::::bind("/tmp/socket")?; + /// + /// let mut buf = [0u8; 1024]; + /// let (len, addr) = socket.recv_from(&mut buf).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, UnixSocketAddr)> { + self.read_with(|io| io.recv_from(buf)).await + } + + /// Sends data to the specified address. + /// + /// Returns the number of bytes written. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::os::unix::net::UnixDatagram; + /// + /// # futures_lite::future::block_on(async { + /// let socket = Async::::unbound()?; + /// + /// let msg = b"hello"; + /// let addr = "/tmp/socket"; + /// let len = socket.send_to(msg, addr).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn send_to>(&self, buf: &[u8], path: P) -> io::Result { + self.write_with(|io| io.send_to(buf, &path)).await + } + + /// Receives data from the connected peer. + /// + /// Returns the number of bytes read and the address the message came from. + /// + /// The [`connect`][`UnixDatagram::connect()`] method connects this socket to a remote address. + /// This method will fail if the socket is not connected. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::os::unix::net::UnixDatagram; + /// + /// # futures_lite::future::block_on(async { + /// let socket = Async::::bind("/tmp/socket1")?; + /// socket.get_ref().connect("/tmp/socket2")?; + /// + /// let mut buf = [0u8; 1024]; + /// let len = socket.recv(&mut buf).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn recv(&self, buf: &mut [u8]) -> io::Result { + self.read_with(|io| io.recv(buf)).await + } + + /// Sends data to the connected peer. + /// + /// Returns the number of bytes written. + /// + /// The [`connect`][`UnixDatagram::connect()`] method connects this socket to a remote address. + /// This method will fail if the socket is not connected. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::os::unix::net::UnixDatagram; + /// + /// # futures_lite::future::block_on(async { + /// let socket = Async::::bind("/tmp/socket1")?; + /// socket.get_ref().connect("/tmp/socket2")?; + /// + /// let msg = b"hello"; + /// let len = socket.send(msg).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn send(&self, buf: &[u8]) -> io::Result { + self.write_with(|io| io.send(buf)).await + } +} + +#[cfg(unix)] +impl TryFrom for Async { + type Error = io::Error; + + fn try_from(socket: std::os::unix::net::UnixDatagram) -> io::Result { + Async::new(socket) + } +} + +/// Polls a future once, waits for a wakeup, and then optimistically assumes the future is ready. +async fn optimistic(fut: impl Future>) -> io::Result<()> { + let mut polled = false; + pin!(fut); + + future::poll_fn(|cx| { + if !polled { + polled = true; + fut.as_mut().poll(cx) + } else { + Poll::Ready(Ok(())) + } + }) + .await +} + +fn connect(addr: SockAddr, domain: Domain, protocol: Option) -> io::Result { + let sock_type = Type::STREAM; + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "linux", + target_os = "netbsd", + target_os = "openbsd" + ))] + // If we can, set nonblocking at socket creation for unix + let sock_type = sock_type.nonblocking(); + // This automatically handles cloexec on unix, no_inherit on windows and nosigpipe on macos + let socket = Socket::new(domain, sock_type, protocol)?; + #[cfg(not(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "linux", + target_os = "netbsd", + target_os = "openbsd" + )))] + // If the current platform doesn't support nonblocking at creation, enable it after creation + socket.set_nonblocking(true)?; + match socket.connect(&addr) { + Ok(_) => {} + #[cfg(unix)] + Err(err) if err.raw_os_error() == Some(rustix::io::Errno::INPROGRESS.raw_os_error()) => {} + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + Err(err) => return Err(err), + } + Ok(socket) +} diff --git a/src/lib.rs b/src/lib.rs index fd671b1..24915f7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -61,38 +61,23 @@ html_logo_url = "https://raw.githubusercontent.com/smol-rs/smol/master/assets/images/logo_fullsize_transparent.png" )] -use std::convert::TryFrom; use std::future::Future; -use std::io::{self, IoSlice, IoSliceMut, Read, Write}; -use std::net::{SocketAddr, TcpListener, TcpStream, UdpSocket}; use std::pin::Pin; -use std::sync::Arc; use std::task::{Context, Poll, Waker}; use std::time::{Duration, Instant}; -#[cfg(unix)] -use std::{ - os::unix::io::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd}, - os::unix::net::{SocketAddr as UnixSocketAddr, UnixDatagram, UnixListener, UnixStream}, - path::Path, -}; +use futures_lite::stream::Stream; -#[cfg(windows)] -use std::os::windows::io::{AsRawSocket, AsSocket, BorrowedSocket, OwnedSocket, RawSocket}; - -use futures_io::{AsyncRead, AsyncWrite}; -use futures_lite::stream::{self, Stream}; -use futures_lite::{future, pin, ready}; -use socket2::{Domain, Protocol, SockAddr, Socket, Type}; - -use crate::reactor::{Reactor, Registration, Source}; +use crate::reactor::Reactor; mod driver; +mod io; mod reactor; pub mod os; pub use driver::block_on; +pub use io::{Async, IoSafe}; pub use reactor::{Readable, ReadableOwned, Writable, WritableOwned}; /// A future or stream that emits timed events. @@ -517,1552 +502,3 @@ impl Stream for Timer { Poll::Pending } } - -/// Async adapter for I/O types. -/// -/// This type puts an I/O handle into non-blocking mode, registers it in -/// [epoll]/[kqueue]/[event ports]/[IOCP], and then provides an async interface for it. -/// -/// [epoll]: https://en.wikipedia.org/wiki/Epoll -/// [kqueue]: https://en.wikipedia.org/wiki/Kqueue -/// [event ports]: https://illumos.org/man/port_create -/// [IOCP]: https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports -/// -/// # Caveats -/// -/// [`Async`] is a low-level primitive, and as such it comes with some caveats. -/// -/// For higher-level primitives built on top of [`Async`], look into [`async-net`] or -/// [`async-process`] (on Unix). -/// -/// The most notable caveat is that it is unsafe to access the inner I/O source mutably -/// using this primitive. Traits likes [`AsyncRead`] and [`AsyncWrite`] are not implemented by -/// default unless it is guaranteed that the resource won't be invalidated by reading or writing. -/// See the [`IoSafe`] trait for more information. -/// -/// [`async-net`]: https://github.com/smol-rs/async-net -/// [`async-process`]: https://github.com/smol-rs/async-process -/// [`AsyncRead`]: https://docs.rs/futures-io/latest/futures_io/trait.AsyncRead.html -/// [`AsyncWrite`]: https://docs.rs/futures-io/latest/futures_io/trait.AsyncWrite.html -/// -/// ### Supported types -/// -/// [`Async`] supports all networking types, as well as some OS-specific file descriptors like -/// [timerfd] and [inotify]. -/// -/// However, do not use [`Async`] with types like [`File`][`std::fs::File`], -/// [`Stdin`][`std::io::Stdin`], [`Stdout`][`std::io::Stdout`], or [`Stderr`][`std::io::Stderr`] -/// because all operating systems have issues with them when put in non-blocking mode. -/// -/// [timerfd]: https://github.com/smol-rs/async-io/blob/master/examples/linux-timerfd.rs -/// [inotify]: https://github.com/smol-rs/async-io/blob/master/examples/linux-inotify.rs -/// -/// ### Concurrent I/O -/// -/// Note that [`&Async`][`Async`] implements [`AsyncRead`] and [`AsyncWrite`] if `&T` -/// implements those traits, which means tasks can concurrently read and write using shared -/// references. -/// -/// But there is a catch: only one task can read a time, and only one task can write at a time. It -/// is okay to have two tasks where one is reading and the other is writing at the same time, but -/// it is not okay to have two tasks reading at the same time or writing at the same time. If you -/// try to do that, conflicting tasks will just keep waking each other in turn, thus wasting CPU -/// time. -/// -/// Besides [`AsyncRead`] and [`AsyncWrite`], this caveat also applies to -/// [`poll_readable()`][`Async::poll_readable()`] and -/// [`poll_writable()`][`Async::poll_writable()`]. -/// -/// However, any number of tasks can be concurrently calling other methods like -/// [`readable()`][`Async::readable()`] or [`read_with()`][`Async::read_with()`]. -/// -/// ### Closing -/// -/// Closing the write side of [`Async`] with [`close()`][`futures_lite::AsyncWriteExt::close()`] -/// simply flushes. If you want to shutdown a TCP or Unix socket, use -/// [`Shutdown`][`std::net::Shutdown`]. -/// -/// # Examples -/// -/// Connect to a server and echo incoming messages back to the server: -/// -/// ```no_run -/// use async_io::Async; -/// use futures_lite::io; -/// use std::net::TcpStream; -/// -/// # futures_lite::future::block_on(async { -/// // Connect to a local server. -/// let stream = Async::::connect(([127, 0, 0, 1], 8000)).await?; -/// -/// // Echo all messages from the read side of the stream into the write side. -/// io::copy(&stream, &stream).await?; -/// # std::io::Result::Ok(()) }); -/// ``` -/// -/// You can use either predefined async methods or wrap blocking I/O operations in -/// [`Async::read_with()`], [`Async::read_with_mut()`], [`Async::write_with()`], and -/// [`Async::write_with_mut()`]: -/// -/// ```no_run -/// use async_io::Async; -/// use std::net::TcpListener; -/// -/// # futures_lite::future::block_on(async { -/// let listener = Async::::bind(([127, 0, 0, 1], 0))?; -/// -/// // These two lines are equivalent: -/// let (stream, addr) = listener.accept().await?; -/// let (stream, addr) = listener.read_with(|inner| inner.accept()).await?; -/// # std::io::Result::Ok(()) }); -/// ``` -#[derive(Debug)] -pub struct Async { - /// A source registered in the reactor. - source: Arc, - - /// The inner I/O handle. - io: Option, -} - -impl Unpin for Async {} - -#[cfg(unix)] -impl Async { - /// Creates an async I/O handle. - /// - /// This method will put the handle in non-blocking mode and register it in - /// [epoll]/[kqueue]/[event ports]/[IOCP]. - /// - /// On Unix systems, the handle must implement `AsRawFd`, while on Windows it must implement - /// `AsRawSocket`. - /// - /// [epoll]: https://en.wikipedia.org/wiki/Epoll - /// [kqueue]: https://en.wikipedia.org/wiki/Kqueue - /// [event ports]: https://illumos.org/man/port_create - /// [IOCP]: https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports - /// - /// # Examples - /// - /// ``` - /// use async_io::Async; - /// use std::net::{SocketAddr, TcpListener}; - /// - /// # futures_lite::future::block_on(async { - /// let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))?; - /// let listener = Async::new(listener)?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn new(io: T) -> io::Result> { - // Put the file descriptor in non-blocking mode. - let fd = io.as_fd(); - cfg_if::cfg_if! { - // ioctl(FIONBIO) sets the flag atomically, but we use this only on Linux - // for now, as with the standard library, because it seems to behave - // differently depending on the platform. - // https://github.com/rust-lang/rust/commit/efeb42be2837842d1beb47b51bb693c7474aba3d - // https://github.com/libuv/libuv/blob/e9d91fccfc3e5ff772d5da90e1c4a24061198ca0/src/unix/poll.c#L78-L80 - // https://github.com/tokio-rs/mio/commit/0db49f6d5caf54b12176821363d154384357e70a - if #[cfg(target_os = "linux")] { - rustix::io::ioctl_fionbio(fd, true)?; - } else { - let previous = rustix::fs::fcntl_getfl(fd)?; - let new = previous | rustix::fs::OFlags::NONBLOCK; - if new != previous { - rustix::fs::fcntl_setfl(fd, new)?; - } - } - } - - // SAFETY: It is impossible to drop the I/O source while it is registered through - // this type. - let registration = unsafe { Registration::new(fd) }; - - Ok(Async { - source: Reactor::get().insert_io(registration)?, - io: Some(io), - }) - } -} - -#[cfg(unix)] -impl AsRawFd for Async { - fn as_raw_fd(&self) -> RawFd { - self.get_ref().as_raw_fd() - } -} - -#[cfg(unix)] -impl AsFd for Async { - fn as_fd(&self) -> BorrowedFd<'_> { - self.get_ref().as_fd() - } -} - -#[cfg(unix)] -impl> TryFrom for Async { - type Error = io::Error; - - fn try_from(value: OwnedFd) -> Result { - Async::new(value.into()) - } -} - -#[cfg(unix)] -impl> TryFrom> for OwnedFd { - type Error = io::Error; - - fn try_from(value: Async) -> Result { - value.into_inner().map(Into::into) - } -} - -#[cfg(windows)] -impl Async { - /// Creates an async I/O handle. - /// - /// This method will put the handle in non-blocking mode and register it in - /// [epoll]/[kqueue]/[event ports]/[IOCP]. - /// - /// On Unix systems, the handle must implement `AsRawFd`, while on Windows it must implement - /// `AsRawSocket`. - /// - /// [epoll]: https://en.wikipedia.org/wiki/Epoll - /// [kqueue]: https://en.wikipedia.org/wiki/Kqueue - /// [event ports]: https://illumos.org/man/port_create - /// [IOCP]: https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports - /// - /// # Examples - /// - /// ``` - /// use async_io::Async; - /// use std::net::{SocketAddr, TcpListener}; - /// - /// # futures_lite::future::block_on(async { - /// let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))?; - /// let listener = Async::new(listener)?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn new(io: T) -> io::Result> { - let borrowed = io.as_socket(); - - // Put the socket in non-blocking mode. - // - // Safety: We assume `as_raw_socket()` returns a valid fd. When we can - // depend on Rust >= 1.63, where `AsFd` is stabilized, and when - // `TimerFd` implements it, we can remove this unsafe and simplify this. - rustix::io::ioctl_fionbio(borrowed, true)?; - - // Create the registration. - // - // SAFETY: It is impossible to drop the I/O source while it is registered through - // this type. - let registration = unsafe { Registration::new(borrowed) }; - - Ok(Async { - source: Reactor::get().insert_io(registration)?, - io: Some(io), - }) - } -} - -#[cfg(windows)] -impl AsRawSocket for Async { - fn as_raw_socket(&self) -> RawSocket { - self.get_ref().as_raw_socket() - } -} - -#[cfg(windows)] -impl AsSocket for Async { - fn as_socket(&self) -> BorrowedSocket<'_> { - self.get_ref().as_socket() - } -} - -#[cfg(windows)] -impl> TryFrom for Async { - type Error = io::Error; - - fn try_from(value: OwnedSocket) -> Result { - Async::new(value.into()) - } -} - -#[cfg(windows)] -impl> TryFrom> for OwnedSocket { - type Error = io::Error; - - fn try_from(value: Async) -> Result { - value.into_inner().map(Into::into) - } -} - -impl Async { - /// Gets a reference to the inner I/O handle. - /// - /// # Examples - /// - /// ``` - /// use async_io::Async; - /// use std::net::TcpListener; - /// - /// # futures_lite::future::block_on(async { - /// let listener = Async::::bind(([127, 0, 0, 1], 0))?; - /// let inner = listener.get_ref(); - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn get_ref(&self) -> &T { - self.io.as_ref().unwrap() - } - - /// Gets a mutable reference to the inner I/O handle. - /// - /// # Safety - /// - /// The underlying I/O source must not be dropped using this function. - /// - /// # Examples - /// - /// ``` - /// use async_io::Async; - /// use std::net::TcpListener; - /// - /// # futures_lite::future::block_on(async { - /// let mut listener = Async::::bind(([127, 0, 0, 1], 0))?; - /// let inner = unsafe { listener.get_mut() }; - /// # std::io::Result::Ok(()) }); - /// ``` - pub unsafe fn get_mut(&mut self) -> &mut T { - self.io.as_mut().unwrap() - } - - /// Unwraps the inner I/O handle. - /// - /// This method will **not** put the I/O handle back into blocking mode. - /// - /// # Examples - /// - /// ``` - /// use async_io::Async; - /// use std::net::TcpListener; - /// - /// # futures_lite::future::block_on(async { - /// let listener = Async::::bind(([127, 0, 0, 1], 0))?; - /// let inner = listener.into_inner()?; - /// - /// // Put the listener back into blocking mode. - /// inner.set_nonblocking(false)?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn into_inner(mut self) -> io::Result { - let io = self.io.take().unwrap(); - Reactor::get().remove_io(&self.source)?; - Ok(io) - } - - /// Waits until the I/O handle is readable. - /// - /// This method completes when a read operation on this I/O handle wouldn't block. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::net::TcpListener; - /// - /// # futures_lite::future::block_on(async { - /// let mut listener = Async::::bind(([127, 0, 0, 1], 0))?; - /// - /// // Wait until a client can be accepted. - /// listener.readable().await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn readable(&self) -> Readable<'_, T> { - Source::readable(self) - } - - /// Waits until the I/O handle is readable. - /// - /// This method completes when a read operation on this I/O handle wouldn't block. - pub fn readable_owned(self: Arc) -> ReadableOwned { - Source::readable_owned(self) - } - - /// Waits until the I/O handle is writable. - /// - /// This method completes when a write operation on this I/O handle wouldn't block. - /// - /// # Examples - /// - /// ``` - /// use async_io::Async; - /// use std::net::{TcpStream, ToSocketAddrs}; - /// - /// # futures_lite::future::block_on(async { - /// let addr = "example.com:80".to_socket_addrs()?.next().unwrap(); - /// let stream = Async::::connect(addr).await?; - /// - /// // Wait until the stream is writable. - /// stream.writable().await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn writable(&self) -> Writable<'_, T> { - Source::writable(self) - } - - /// Waits until the I/O handle is writable. - /// - /// This method completes when a write operation on this I/O handle wouldn't block. - pub fn writable_owned(self: Arc) -> WritableOwned { - Source::writable_owned(self) - } - - /// Polls the I/O handle for readability. - /// - /// When this method returns [`Poll::Ready`], that means the OS has delivered an event - /// indicating readability since the last time this task has called the method and received - /// [`Poll::Pending`]. - /// - /// # Caveats - /// - /// Two different tasks should not call this method concurrently. Otherwise, conflicting tasks - /// will just keep waking each other in turn, thus wasting CPU time. - /// - /// Note that the [`AsyncRead`] implementation for [`Async`] also uses this method. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use futures_lite::future; - /// use std::net::TcpListener; - /// - /// # futures_lite::future::block_on(async { - /// let mut listener = Async::::bind(([127, 0, 0, 1], 0))?; - /// - /// // Wait until a client can be accepted. - /// future::poll_fn(|cx| listener.poll_readable(cx)).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn poll_readable(&self, cx: &mut Context<'_>) -> Poll> { - self.source.poll_readable(cx) - } - - /// Polls the I/O handle for writability. - /// - /// When this method returns [`Poll::Ready`], that means the OS has delivered an event - /// indicating writability since the last time this task has called the method and received - /// [`Poll::Pending`]. - /// - /// # Caveats - /// - /// Two different tasks should not call this method concurrently. Otherwise, conflicting tasks - /// will just keep waking each other in turn, thus wasting CPU time. - /// - /// Note that the [`AsyncWrite`] implementation for [`Async`] also uses this method. - /// - /// # Examples - /// - /// ``` - /// use async_io::Async; - /// use futures_lite::future; - /// use std::net::{TcpStream, ToSocketAddrs}; - /// - /// # futures_lite::future::block_on(async { - /// let addr = "example.com:80".to_socket_addrs()?.next().unwrap(); - /// let stream = Async::::connect(addr).await?; - /// - /// // Wait until the stream is writable. - /// future::poll_fn(|cx| stream.poll_writable(cx)).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn poll_writable(&self, cx: &mut Context<'_>) -> Poll> { - self.source.poll_writable(cx) - } - - /// Performs a read operation asynchronously. - /// - /// The I/O handle is registered in the reactor and put in non-blocking mode. This method - /// invokes the `op` closure in a loop until it succeeds or returns an error other than - /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS - /// sends a notification that the I/O handle is readable. - /// - /// The closure receives a shared reference to the I/O handle. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::net::TcpListener; - /// - /// # futures_lite::future::block_on(async { - /// let listener = Async::::bind(([127, 0, 0, 1], 0))?; - /// - /// // Accept a new client asynchronously. - /// let (stream, addr) = listener.read_with(|l| l.accept()).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn read_with(&self, op: impl FnMut(&T) -> io::Result) -> io::Result { - let mut op = op; - loop { - match op(self.get_ref()) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return res, - } - optimistic(self.readable()).await?; - } - } - - /// Performs a read operation asynchronously. - /// - /// The I/O handle is registered in the reactor and put in non-blocking mode. This method - /// invokes the `op` closure in a loop until it succeeds or returns an error other than - /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS - /// sends a notification that the I/O handle is readable. - /// - /// The closure receives a mutable reference to the I/O handle. - /// - /// # Safety - /// - /// In the closure, the underlying I/O source must not be dropped. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::net::TcpListener; - /// - /// # futures_lite::future::block_on(async { - /// let mut listener = Async::::bind(([127, 0, 0, 1], 0))?; - /// - /// // Accept a new client asynchronously. - /// let (stream, addr) = unsafe { listener.read_with_mut(|l| l.accept()).await? }; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async unsafe fn read_with_mut( - &mut self, - op: impl FnMut(&mut T) -> io::Result, - ) -> io::Result { - let mut op = op; - loop { - match op(self.get_mut()) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return res, - } - optimistic(self.readable()).await?; - } - } - - /// Performs a write operation asynchronously. - /// - /// The I/O handle is registered in the reactor and put in non-blocking mode. This method - /// invokes the `op` closure in a loop until it succeeds or returns an error other than - /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS - /// sends a notification that the I/O handle is writable. - /// - /// The closure receives a shared reference to the I/O handle. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::net::UdpSocket; - /// - /// # futures_lite::future::block_on(async { - /// let socket = Async::::bind(([127, 0, 0, 1], 8000))?; - /// socket.get_ref().connect("127.0.0.1:9000")?; - /// - /// let msg = b"hello"; - /// let len = socket.write_with(|s| s.send(msg)).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn write_with(&self, op: impl FnMut(&T) -> io::Result) -> io::Result { - let mut op = op; - loop { - match op(self.get_ref()) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return res, - } - optimistic(self.writable()).await?; - } - } - - /// Performs a write operation asynchronously. - /// - /// The I/O handle is registered in the reactor and put in non-blocking mode. This method - /// invokes the `op` closure in a loop until it succeeds or returns an error other than - /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS - /// sends a notification that the I/O handle is writable. - /// - /// # Safety - /// - /// The closure receives a mutable reference to the I/O handle. In the closure, the underlying - /// I/O source must not be dropped. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::net::UdpSocket; - /// - /// # futures_lite::future::block_on(async { - /// let mut socket = Async::::bind(([127, 0, 0, 1], 8000))?; - /// socket.get_ref().connect("127.0.0.1:9000")?; - /// - /// let msg = b"hello"; - /// let len = unsafe { socket.write_with_mut(|s| s.send(msg)).await? }; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async unsafe fn write_with_mut( - &mut self, - op: impl FnMut(&mut T) -> io::Result, - ) -> io::Result { - let mut op = op; - loop { - match op(self.get_mut()) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return res, - } - optimistic(self.writable()).await?; - } - } -} - -impl AsRef for Async { - fn as_ref(&self) -> &T { - self.get_ref() - } -} - -impl Drop for Async { - fn drop(&mut self) { - if self.io.is_some() { - // Deregister and ignore errors because destructors should not panic. - Reactor::get().remove_io(&self.source).ok(); - - // Drop the I/O handle to close it. - self.io.take(); - } - } -} - -/// Types whose I/O trait implementations do not drop the underlying I/O source. -/// -/// The resource contained inside of the [`Async`] cannot be invalidated. This invalidation can -/// happen if the inner resource (the [`TcpStream`], [`UnixListener`] or other `T`) is moved out -/// and dropped before the [`Async`]. Because of this, functions that grant mutable access to -/// the inner type are unsafe, as there is no way to guarantee that the source won't be dropped -/// and a dangling handle won't be left behind. -/// -/// Unfortunately this extends to implementations of [`Read`] and [`Write`]. Since methods on those -/// traits take `&mut`, there is no guarantee that the implementor of those traits won't move the -/// source out while the method is being run. -/// -/// This trait is an antidote to this predicament. By implementing this trait, the user pledges -/// that using any I/O traits won't destroy the source. This way, [`Async`] can implement the -/// `async` version of these I/O traits, like [`AsyncRead`] and [`AsyncWrite`]. -/// -/// # Safety -/// -/// Any I/O trait implementations for this type must not drop the underlying I/O source. Traits -/// affected by this trait include [`Read`], [`Write`], [`Seek`] and [`BufRead`]. -/// -/// This trait is implemented by default on top of `libstd` types. In addition, it is implemented -/// for immutable reference types, as it is impossible to invalidate any outstanding references -/// while holding an immutable reference, even with interior mutability. As Rust's current pinning -/// system relies on similar guarantees, I believe that this approach is robust. -/// -/// [`BufRead`]: https://doc.rust-lang.org/std/io/trait.BufRead.html -/// [`Read`]: https://doc.rust-lang.org/std/io/trait.Read.html -/// [`Seek`]: https://doc.rust-lang.org/std/io/trait.Seek.html -/// [`Write`]: https://doc.rust-lang.org/std/io/trait.Write.html -/// -/// [`AsyncRead`]: https://docs.rs/futures-io/latest/futures_io/trait.AsyncRead.html -/// [`AsyncWrite`]: https://docs.rs/futures-io/latest/futures_io/trait.AsyncWrite.html -pub unsafe trait IoSafe {} - -/// Reference types can't be mutated. -/// -/// The worst thing that can happen is that external state is used to change what kind of pointer -/// `as_fd()` returns. For instance: -/// -/// ``` -/// # #[cfg(unix)] { -/// use std::cell::Cell; -/// use std::net::TcpStream; -/// use std::os::unix::io::{AsFd, BorrowedFd}; -/// -/// struct Bar { -/// flag: Cell, -/// a: TcpStream, -/// b: TcpStream -/// } -/// -/// impl AsFd for Bar { -/// fn as_fd(&self) -> BorrowedFd<'_> { -/// if self.flag.replace(!self.flag.get()) { -/// self.a.as_fd() -/// } else { -/// self.b.as_fd() -/// } -/// } -/// } -/// # } -/// ``` -/// -/// We solve this problem by only calling `as_fd()` once to get the original source. Implementations -/// like this are considered buggy (but not unsound) and are thus not really supported by `async-io`. -unsafe impl IoSafe for &T {} - -// Can be implemented on top of libstd types. -unsafe impl IoSafe for std::fs::File {} -unsafe impl IoSafe for std::io::Stderr {} -unsafe impl IoSafe for std::io::Stdin {} -unsafe impl IoSafe for std::io::Stdout {} -unsafe impl IoSafe for std::io::StderrLock<'_> {} -unsafe impl IoSafe for std::io::StdinLock<'_> {} -unsafe impl IoSafe for std::io::StdoutLock<'_> {} -unsafe impl IoSafe for std::net::TcpStream {} - -#[cfg(unix)] -unsafe impl IoSafe for std::os::unix::net::UnixStream {} - -unsafe impl IoSafe for std::io::BufReader {} -unsafe impl IoSafe for std::io::BufWriter {} -unsafe impl IoSafe for std::io::LineWriter {} -unsafe impl IoSafe for &mut T {} -unsafe impl IoSafe for Box {} -unsafe impl IoSafe for std::borrow::Cow<'_, T> {} - -impl AsyncRead for Async { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - loop { - match unsafe { (*self).get_mut() }.read(buf) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return Poll::Ready(res), - } - ready!(self.poll_readable(cx))?; - } - } - - fn poll_read_vectored( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &mut [IoSliceMut<'_>], - ) -> Poll> { - loop { - match unsafe { (*self).get_mut() }.read_vectored(bufs) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return Poll::Ready(res), - } - ready!(self.poll_readable(cx))?; - } - } -} - -// Since this is through a reference, we can't mutate the inner I/O source. -// Therefore this is safe! -impl AsyncRead for &Async -where - for<'a> &'a T: Read, -{ - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - loop { - match (*self).get_ref().read(buf) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return Poll::Ready(res), - } - ready!(self.poll_readable(cx))?; - } - } - - fn poll_read_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &mut [IoSliceMut<'_>], - ) -> Poll> { - loop { - match (*self).get_ref().read_vectored(bufs) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return Poll::Ready(res), - } - ready!(self.poll_readable(cx))?; - } - } -} - -impl AsyncWrite for Async { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - loop { - match unsafe { (*self).get_mut() }.write(buf) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return Poll::Ready(res), - } - ready!(self.poll_writable(cx))?; - } - } - - fn poll_write_vectored( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[IoSlice<'_>], - ) -> Poll> { - loop { - match unsafe { (*self).get_mut() }.write_vectored(bufs) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return Poll::Ready(res), - } - ready!(self.poll_writable(cx))?; - } - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - loop { - match unsafe { (*self).get_mut() }.flush() { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return Poll::Ready(res), - } - ready!(self.poll_writable(cx))?; - } - } - - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.poll_flush(cx) - } -} - -impl AsyncWrite for &Async -where - for<'a> &'a T: Write, -{ - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - loop { - match (*self).get_ref().write(buf) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return Poll::Ready(res), - } - ready!(self.poll_writable(cx))?; - } - } - - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[IoSlice<'_>], - ) -> Poll> { - loop { - match (*self).get_ref().write_vectored(bufs) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return Poll::Ready(res), - } - ready!(self.poll_writable(cx))?; - } - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - loop { - match (*self).get_ref().flush() { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return Poll::Ready(res), - } - ready!(self.poll_writable(cx))?; - } - } - - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.poll_flush(cx) - } -} - -impl Async { - /// Creates a TCP listener bound to the specified address. - /// - /// Binding with port number 0 will request an available port from the OS. - /// - /// # Examples - /// - /// ``` - /// use async_io::Async; - /// use std::net::TcpListener; - /// - /// # futures_lite::future::block_on(async { - /// let listener = Async::::bind(([127, 0, 0, 1], 0))?; - /// println!("Listening on {}", listener.get_ref().local_addr()?); - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn bind>(addr: A) -> io::Result> { - let addr = addr.into(); - Async::new(TcpListener::bind(addr)?) - } - - /// Accepts a new incoming TCP connection. - /// - /// When a connection is established, it will be returned as a TCP stream together with its - /// remote address. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::net::TcpListener; - /// - /// # futures_lite::future::block_on(async { - /// let listener = Async::::bind(([127, 0, 0, 1], 8000))?; - /// let (stream, addr) = listener.accept().await?; - /// println!("Accepted client: {}", addr); - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn accept(&self) -> io::Result<(Async, SocketAddr)> { - let (stream, addr) = self.read_with(|io| io.accept()).await?; - Ok((Async::new(stream)?, addr)) - } - - /// Returns a stream of incoming TCP connections. - /// - /// The stream is infinite, i.e. it never stops with a [`None`]. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use futures_lite::{pin, stream::StreamExt}; - /// use std::net::TcpListener; - /// - /// # futures_lite::future::block_on(async { - /// let listener = Async::::bind(([127, 0, 0, 1], 8000))?; - /// let incoming = listener.incoming(); - /// pin!(incoming); - /// - /// while let Some(stream) = incoming.next().await { - /// let stream = stream?; - /// println!("Accepted client: {}", stream.get_ref().peer_addr()?); - /// } - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn incoming(&self) -> impl Stream>> + Send + '_ { - stream::unfold(self, |listener| async move { - let res = listener.accept().await.map(|(stream, _)| stream); - Some((res, listener)) - }) - } -} - -impl TryFrom for Async { - type Error = io::Error; - - fn try_from(listener: std::net::TcpListener) -> io::Result { - Async::new(listener) - } -} - -impl Async { - /// Creates a TCP connection to the specified address. - /// - /// # Examples - /// - /// ``` - /// use async_io::Async; - /// use std::net::{TcpStream, ToSocketAddrs}; - /// - /// # futures_lite::future::block_on(async { - /// let addr = "example.com:80".to_socket_addrs()?.next().unwrap(); - /// let stream = Async::::connect(addr).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn connect>(addr: A) -> io::Result> { - // Begin async connect. - let addr = addr.into(); - let domain = Domain::for_address(addr); - let socket = connect(addr.into(), domain, Some(Protocol::TCP))?; - let stream = Async::new(TcpStream::from(socket))?; - - // The stream becomes writable when connected. - stream.writable().await?; - - // Check if there was an error while connecting. - match stream.get_ref().take_error()? { - None => Ok(stream), - Some(err) => Err(err), - } - } - - /// Reads data from the stream without removing it from the buffer. - /// - /// Returns the number of bytes read. Successive calls of this method read the same data. - /// - /// # Examples - /// - /// ``` - /// use async_io::Async; - /// use futures_lite::{io::AsyncWriteExt, stream::StreamExt}; - /// use std::net::{TcpStream, ToSocketAddrs}; - /// - /// # futures_lite::future::block_on(async { - /// let addr = "example.com:80".to_socket_addrs()?.next().unwrap(); - /// let mut stream = Async::::connect(addr).await?; - /// - /// stream - /// .write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") - /// .await?; - /// - /// let mut buf = [0u8; 1024]; - /// let len = stream.peek(&mut buf).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn peek(&self, buf: &mut [u8]) -> io::Result { - self.read_with(|io| io.peek(buf)).await - } -} - -impl TryFrom for Async { - type Error = io::Error; - - fn try_from(stream: std::net::TcpStream) -> io::Result { - Async::new(stream) - } -} - -impl Async { - /// Creates a UDP socket bound to the specified address. - /// - /// Binding with port number 0 will request an available port from the OS. - /// - /// # Examples - /// - /// ``` - /// use async_io::Async; - /// use std::net::UdpSocket; - /// - /// # futures_lite::future::block_on(async { - /// let socket = Async::::bind(([127, 0, 0, 1], 0))?; - /// println!("Bound to {}", socket.get_ref().local_addr()?); - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn bind>(addr: A) -> io::Result> { - let addr = addr.into(); - Async::new(UdpSocket::bind(addr)?) - } - - /// Receives a single datagram message. - /// - /// Returns the number of bytes read and the address the message came from. - /// - /// This method must be called with a valid byte slice of sufficient size to hold the message. - /// If the message is too long to fit, excess bytes may get discarded. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::net::UdpSocket; - /// - /// # futures_lite::future::block_on(async { - /// let socket = Async::::bind(([127, 0, 0, 1], 8000))?; - /// - /// let mut buf = [0u8; 1024]; - /// let (len, addr) = socket.recv_from(&mut buf).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { - self.read_with(|io| io.recv_from(buf)).await - } - - /// Receives a single datagram message without removing it from the queue. - /// - /// Returns the number of bytes read and the address the message came from. - /// - /// This method must be called with a valid byte slice of sufficient size to hold the message. - /// If the message is too long to fit, excess bytes may get discarded. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::net::UdpSocket; - /// - /// # futures_lite::future::block_on(async { - /// let socket = Async::::bind(([127, 0, 0, 1], 8000))?; - /// - /// let mut buf = [0u8; 1024]; - /// let (len, addr) = socket.peek_from(&mut buf).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { - self.read_with(|io| io.peek_from(buf)).await - } - - /// Sends data to the specified address. - /// - /// Returns the number of bytes writen. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::net::UdpSocket; - /// - /// # futures_lite::future::block_on(async { - /// let socket = Async::::bind(([127, 0, 0, 1], 0))?; - /// let addr = socket.get_ref().local_addr()?; - /// - /// let msg = b"hello"; - /// let len = socket.send_to(msg, addr).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn send_to>(&self, buf: &[u8], addr: A) -> io::Result { - let addr = addr.into(); - self.write_with(|io| io.send_to(buf, addr)).await - } - - /// Receives a single datagram message from the connected peer. - /// - /// Returns the number of bytes read. - /// - /// This method must be called with a valid byte slice of sufficient size to hold the message. - /// If the message is too long to fit, excess bytes may get discarded. - /// - /// The [`connect`][`UdpSocket::connect()`] method connects this socket to a remote address. - /// This method will fail if the socket is not connected. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::net::UdpSocket; - /// - /// # futures_lite::future::block_on(async { - /// let socket = Async::::bind(([127, 0, 0, 1], 8000))?; - /// socket.get_ref().connect("127.0.0.1:9000")?; - /// - /// let mut buf = [0u8; 1024]; - /// let len = socket.recv(&mut buf).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn recv(&self, buf: &mut [u8]) -> io::Result { - self.read_with(|io| io.recv(buf)).await - } - - /// Receives a single datagram message from the connected peer without removing it from the - /// queue. - /// - /// Returns the number of bytes read and the address the message came from. - /// - /// This method must be called with a valid byte slice of sufficient size to hold the message. - /// If the message is too long to fit, excess bytes may get discarded. - /// - /// The [`connect`][`UdpSocket::connect()`] method connects this socket to a remote address. - /// This method will fail if the socket is not connected. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::net::UdpSocket; - /// - /// # futures_lite::future::block_on(async { - /// let socket = Async::::bind(([127, 0, 0, 1], 8000))?; - /// socket.get_ref().connect("127.0.0.1:9000")?; - /// - /// let mut buf = [0u8; 1024]; - /// let len = socket.peek(&mut buf).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn peek(&self, buf: &mut [u8]) -> io::Result { - self.read_with(|io| io.peek(buf)).await - } - - /// Sends data to the connected peer. - /// - /// Returns the number of bytes written. - /// - /// The [`connect`][`UdpSocket::connect()`] method connects this socket to a remote address. - /// This method will fail if the socket is not connected. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::net::UdpSocket; - /// - /// # futures_lite::future::block_on(async { - /// let socket = Async::::bind(([127, 0, 0, 1], 8000))?; - /// socket.get_ref().connect("127.0.0.1:9000")?; - /// - /// let msg = b"hello"; - /// let len = socket.send(msg).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn send(&self, buf: &[u8]) -> io::Result { - self.write_with(|io| io.send(buf)).await - } -} - -impl TryFrom for Async { - type Error = io::Error; - - fn try_from(socket: std::net::UdpSocket) -> io::Result { - Async::new(socket) - } -} - -#[cfg(unix)] -impl Async { - /// Creates a UDS listener bound to the specified path. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::os::unix::net::UnixListener; - /// - /// # futures_lite::future::block_on(async { - /// let listener = Async::::bind("/tmp/socket")?; - /// println!("Listening on {:?}", listener.get_ref().local_addr()?); - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn bind>(path: P) -> io::Result> { - let path = path.as_ref().to_owned(); - Async::new(UnixListener::bind(path)?) - } - - /// Accepts a new incoming UDS stream connection. - /// - /// When a connection is established, it will be returned as a stream together with its remote - /// address. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::os::unix::net::UnixListener; - /// - /// # futures_lite::future::block_on(async { - /// let listener = Async::::bind("/tmp/socket")?; - /// let (stream, addr) = listener.accept().await?; - /// println!("Accepted client: {:?}", addr); - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn accept(&self) -> io::Result<(Async, UnixSocketAddr)> { - let (stream, addr) = self.read_with(|io| io.accept()).await?; - Ok((Async::new(stream)?, addr)) - } - - /// Returns a stream of incoming UDS connections. - /// - /// The stream is infinite, i.e. it never stops with a [`None`] item. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use futures_lite::{pin, stream::StreamExt}; - /// use std::os::unix::net::UnixListener; - /// - /// # futures_lite::future::block_on(async { - /// let listener = Async::::bind("/tmp/socket")?; - /// let incoming = listener.incoming(); - /// pin!(incoming); - /// - /// while let Some(stream) = incoming.next().await { - /// let stream = stream?; - /// println!("Accepted client: {:?}", stream.get_ref().peer_addr()?); - /// } - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn incoming(&self) -> impl Stream>> + Send + '_ { - stream::unfold(self, |listener| async move { - let res = listener.accept().await.map(|(stream, _)| stream); - Some((res, listener)) - }) - } -} - -#[cfg(unix)] -impl TryFrom for Async { - type Error = io::Error; - - fn try_from(listener: std::os::unix::net::UnixListener) -> io::Result { - Async::new(listener) - } -} - -#[cfg(unix)] -impl Async { - /// Creates a UDS stream connected to the specified path. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::os::unix::net::UnixStream; - /// - /// # futures_lite::future::block_on(async { - /// let stream = Async::::connect("/tmp/socket").await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn connect>(path: P) -> io::Result> { - // Begin async connect. - let socket = connect(SockAddr::unix(path)?, Domain::UNIX, None)?; - let stream = Async::new(UnixStream::from(socket))?; - - // The stream becomes writable when connected. - stream.writable().await?; - - // On Linux, it appears the socket may become writable even when connecting fails, so we - // must do an extra check here and see if the peer address is retrievable. - stream.get_ref().peer_addr()?; - Ok(stream) - } - - /// Creates an unnamed pair of connected UDS stream sockets. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::os::unix::net::UnixStream; - /// - /// # futures_lite::future::block_on(async { - /// let (stream1, stream2) = Async::::pair()?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn pair() -> io::Result<(Async, Async)> { - let (stream1, stream2) = UnixStream::pair()?; - Ok((Async::new(stream1)?, Async::new(stream2)?)) - } -} - -#[cfg(unix)] -impl TryFrom for Async { - type Error = io::Error; - - fn try_from(stream: std::os::unix::net::UnixStream) -> io::Result { - Async::new(stream) - } -} - -#[cfg(unix)] -impl Async { - /// Creates a UDS datagram socket bound to the specified path. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::os::unix::net::UnixDatagram; - /// - /// # futures_lite::future::block_on(async { - /// let socket = Async::::bind("/tmp/socket")?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn bind>(path: P) -> io::Result> { - let path = path.as_ref().to_owned(); - Async::new(UnixDatagram::bind(path)?) - } - - /// Creates a UDS datagram socket not bound to any address. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::os::unix::net::UnixDatagram; - /// - /// # futures_lite::future::block_on(async { - /// let socket = Async::::unbound()?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn unbound() -> io::Result> { - Async::new(UnixDatagram::unbound()?) - } - - /// Creates an unnamed pair of connected Unix datagram sockets. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::os::unix::net::UnixDatagram; - /// - /// # futures_lite::future::block_on(async { - /// let (socket1, socket2) = Async::::pair()?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn pair() -> io::Result<(Async, Async)> { - let (socket1, socket2) = UnixDatagram::pair()?; - Ok((Async::new(socket1)?, Async::new(socket2)?)) - } - - /// Receives data from the socket. - /// - /// Returns the number of bytes read and the address the message came from. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::os::unix::net::UnixDatagram; - /// - /// # futures_lite::future::block_on(async { - /// let socket = Async::::bind("/tmp/socket")?; - /// - /// let mut buf = [0u8; 1024]; - /// let (len, addr) = socket.recv_from(&mut buf).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, UnixSocketAddr)> { - self.read_with(|io| io.recv_from(buf)).await - } - - /// Sends data to the specified address. - /// - /// Returns the number of bytes written. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::os::unix::net::UnixDatagram; - /// - /// # futures_lite::future::block_on(async { - /// let socket = Async::::unbound()?; - /// - /// let msg = b"hello"; - /// let addr = "/tmp/socket"; - /// let len = socket.send_to(msg, addr).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn send_to>(&self, buf: &[u8], path: P) -> io::Result { - self.write_with(|io| io.send_to(buf, &path)).await - } - - /// Receives data from the connected peer. - /// - /// Returns the number of bytes read and the address the message came from. - /// - /// The [`connect`][`UnixDatagram::connect()`] method connects this socket to a remote address. - /// This method will fail if the socket is not connected. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::os::unix::net::UnixDatagram; - /// - /// # futures_lite::future::block_on(async { - /// let socket = Async::::bind("/tmp/socket1")?; - /// socket.get_ref().connect("/tmp/socket2")?; - /// - /// let mut buf = [0u8; 1024]; - /// let len = socket.recv(&mut buf).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn recv(&self, buf: &mut [u8]) -> io::Result { - self.read_with(|io| io.recv(buf)).await - } - - /// Sends data to the connected peer. - /// - /// Returns the number of bytes written. - /// - /// The [`connect`][`UnixDatagram::connect()`] method connects this socket to a remote address. - /// This method will fail if the socket is not connected. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::os::unix::net::UnixDatagram; - /// - /// # futures_lite::future::block_on(async { - /// let socket = Async::::bind("/tmp/socket1")?; - /// socket.get_ref().connect("/tmp/socket2")?; - /// - /// let msg = b"hello"; - /// let len = socket.send(msg).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn send(&self, buf: &[u8]) -> io::Result { - self.write_with(|io| io.send(buf)).await - } -} - -#[cfg(unix)] -impl TryFrom for Async { - type Error = io::Error; - - fn try_from(socket: std::os::unix::net::UnixDatagram) -> io::Result { - Async::new(socket) - } -} - -/// Polls a future once, waits for a wakeup, and then optimistically assumes the future is ready. -async fn optimistic(fut: impl Future>) -> io::Result<()> { - let mut polled = false; - pin!(fut); - - future::poll_fn(|cx| { - if !polled { - polled = true; - fut.as_mut().poll(cx) - } else { - Poll::Ready(Ok(())) - } - }) - .await -} - -fn connect(addr: SockAddr, domain: Domain, protocol: Option) -> io::Result { - let sock_type = Type::STREAM; - #[cfg(any( - target_os = "android", - target_os = "dragonfly", - target_os = "freebsd", - target_os = "fuchsia", - target_os = "illumos", - target_os = "linux", - target_os = "netbsd", - target_os = "openbsd" - ))] - // If we can, set nonblocking at socket creation for unix - let sock_type = sock_type.nonblocking(); - // This automatically handles cloexec on unix, no_inherit on windows and nosigpipe on macos - let socket = Socket::new(domain, sock_type, protocol)?; - #[cfg(not(any( - target_os = "android", - target_os = "dragonfly", - target_os = "freebsd", - target_os = "fuchsia", - target_os = "illumos", - target_os = "linux", - target_os = "netbsd", - target_os = "openbsd" - )))] - // If the current platform doesn't support nonblocking at creation, enable it after creation - socket.set_nonblocking(true)?; - match socket.connect(&addr) { - Ok(_) => {} - #[cfg(unix)] - Err(err) if err.raw_os_error() == Some(rustix::io::Errno::INPROGRESS.raw_os_error()) => {} - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - Err(err) => return Err(err), - } - Ok(socket) -} From a9aabe5f6f62a6e269c6dcee225dfd4d0ec9630f Mon Sep 17 00:00:00 2001 From: John Nunley Date: Fri, 22 Sep 2023 17:41:02 -0700 Subject: [PATCH 2/5] Move actual Timer implementation to other file This commit moves the native Timer implementation into another file. The goal is to setup the ability to toggle between web-based and IO-based Timer implementations. Signed-off-by: John Nunley --- src/lib.rs | 155 +++++++---------------------------- src/timer/native.rs | 192 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 220 insertions(+), 127 deletions(-) create mode 100644 src/timer/native.rs diff --git a/src/lib.rs b/src/lib.rs index 24915f7..7d3f042 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -63,7 +63,7 @@ use std::future::Future; use std::pin::Pin; -use std::task::{Context, Poll, Waker}; +use std::task::{Context, Poll}; use std::time::{Duration, Instant}; use futures_lite::stream::Stream; @@ -74,6 +74,9 @@ mod driver; mod io; mod reactor; +#[path = "timer/native.rs"] +mod timer; + pub mod os; pub use driver::block_on; @@ -123,22 +126,7 @@ pub use reactor::{Readable, ReadableOwned, Writable, WritableOwned}; /// # std::io::Result::Ok(()) }); /// ``` #[derive(Debug)] -pub struct Timer { - /// This timer's ID and last waker that polled it. - /// - /// When this field is set to `None`, this timer is not registered in the reactor. - id_and_waker: Option<(usize, Waker)>, - - /// The next instant at which this timer fires. - /// - /// If this timer is a blank timer, this value is None. If the timer - /// must be set, this value contains the next instant at which the - /// timer must fire. - when: Option, - - /// The period. - period: Duration, -} +pub struct Timer(timer::Timer); impl Timer { /// Creates a timer that will never fire. @@ -173,12 +161,9 @@ impl Timer { /// run_with_timeout(None).await; /// # }); /// ``` + #[inline] pub fn never() -> Timer { - Timer { - id_and_waker: None, - when: None, - period: Duration::MAX, - } + Timer(timer::Timer::never()) } /// Creates a timer that emits an event once after the given duration of time. @@ -193,10 +178,9 @@ impl Timer { /// Timer::after(Duration::from_secs(1)).await; /// # }); /// ``` + #[inline] pub fn after(duration: Duration) -> Timer { - Instant::now() - .checked_add(duration) - .map_or_else(Timer::never, Timer::at) + Timer(timer::Timer::after(duration)) } /// Creates a timer that emits an event once at the given time instant. @@ -213,8 +197,9 @@ impl Timer { /// Timer::at(when).await; /// # }); /// ``` + #[inline] pub fn at(instant: Instant) -> Timer { - Timer::interval_at(instant, Duration::MAX) + Timer(timer::Timer::at(instant)) } /// Creates a timer that emits events periodically. @@ -231,10 +216,9 @@ impl Timer { /// Timer::interval(period).next().await; /// # }); /// ``` + #[inline] pub fn interval(period: Duration) -> Timer { - Instant::now() - .checked_add(period) - .map_or_else(Timer::never, |at| Timer::interval_at(at, period)) + Timer(timer::Timer::interval(period)) } /// Creates a timer that emits events periodically, starting at `start`. @@ -252,12 +236,9 @@ impl Timer { /// Timer::interval_at(start, period).next().await; /// # }); /// ``` + #[inline] pub fn interval_at(start: Instant, period: Duration) -> Timer { - Timer { - id_and_waker: None, - when: Some(start), - period, - } + Timer(timer::Timer::interval_at(start, period)) } /// Indicates whether or not this timer will ever fire. @@ -299,7 +280,7 @@ impl Timer { /// ``` #[inline] pub fn will_fire(&self) -> bool { - self.when.is_some() + self.0.will_fire() } /// Sets the timer to emit an en event once after the given duration of time. @@ -319,15 +300,9 @@ impl Timer { /// t.set_after(Duration::from_millis(100)); /// # }); /// ``` + #[inline] pub fn set_after(&mut self, duration: Duration) { - match Instant::now().checked_add(duration) { - Some(instant) => self.set_at(instant), - None => { - // Overflow to never going off. - self.clear(); - self.when = None; - } - } + self.0.set_after(duration) } /// Sets the timer to emit an event once at the given time instant. @@ -350,16 +325,9 @@ impl Timer { /// t.set_at(when); /// # }); /// ``` + #[inline] pub fn set_at(&mut self, instant: Instant) { - self.clear(); - - // Update the timeout. - self.when = Some(instant); - - if let Some((id, waker)) = self.id_and_waker.as_mut() { - // Re-register the timer with the new timeout. - *id = Reactor::get().insert_timer(instant, waker); - } + self.0.set_at(instant) } /// Sets the timer to emit events periodically. @@ -382,15 +350,9 @@ impl Timer { /// t.set_interval(period); /// # }); /// ``` + #[inline] pub fn set_interval(&mut self, period: Duration) { - match Instant::now().checked_add(period) { - Some(instant) => self.set_interval_at(instant, period), - None => { - // Overflow to never going off. - self.clear(); - self.when = None; - } - } + self.0.set_interval(period) } /// Sets the timer to emit events periodically, starting at `start`. @@ -414,39 +376,16 @@ impl Timer { /// t.set_interval_at(start, period); /// # }); /// ``` + #[inline] pub fn set_interval_at(&mut self, start: Instant, period: Duration) { - self.clear(); - - self.when = Some(start); - self.period = period; - - if let Some((id, waker)) = self.id_and_waker.as_mut() { - // Re-register the timer with the new timeout. - *id = Reactor::get().insert_timer(start, waker); - } - } - - /// Helper function to clear the current timer. - fn clear(&mut self) { - if let (Some(when), Some((id, _))) = (self.when, self.id_and_waker.as_ref()) { - // Deregister the timer from the reactor. - Reactor::get().remove_timer(when, *id); - } - } -} - -impl Drop for Timer { - fn drop(&mut self) { - if let (Some(when), Some((id, _))) = (self.when, self.id_and_waker.take()) { - // Deregister the timer from the reactor. - Reactor::get().remove_timer(when, id); - } + self.0.set_interval_at(start, period) } } impl Future for Timer { type Output = Instant; + #[inline] fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.poll_next(cx) { Poll::Ready(Some(when)) => Poll::Ready(when), @@ -459,46 +398,8 @@ impl Future for Timer { impl Stream for Timer { type Item = Instant; - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.get_mut(); - - if let Some(ref mut when) = this.when { - // Check if the timer has already fired. - if Instant::now() >= *when { - if let Some((id, _)) = this.id_and_waker.take() { - // Deregister the timer from the reactor. - Reactor::get().remove_timer(*when, id); - } - let result_time = *when; - if let Some(next) = (*when).checked_add(this.period) { - *when = next; - // Register the timer in the reactor. - let id = Reactor::get().insert_timer(next, cx.waker()); - this.id_and_waker = Some((id, cx.waker().clone())); - } else { - this.when = None; - } - return Poll::Ready(Some(result_time)); - } else { - match &this.id_and_waker { - None => { - // Register the timer in the reactor. - let id = Reactor::get().insert_timer(*when, cx.waker()); - this.id_and_waker = Some((id, cx.waker().clone())); - } - Some((id, w)) if !w.will_wake(cx.waker()) => { - // Deregister the timer from the reactor to remove the old waker. - Reactor::get().remove_timer(*when, *id); - - // Register the timer in the reactor with the new waker. - let id = Reactor::get().insert_timer(*when, cx.waker()); - this.id_and_waker = Some((id, cx.waker().clone())); - } - Some(_) => {} - } - } - } - - Poll::Pending + #[inline] + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.0.poll_next(cx) } } diff --git a/src/timer/native.rs b/src/timer/native.rs new file mode 100644 index 0000000..2dde232 --- /dev/null +++ b/src/timer/native.rs @@ -0,0 +1,192 @@ +//! Timer implementation for non-Web platforms. + +use std::task::{Context, Poll, Waker}; +use std::time::{Duration, Instant}; + +use crate::reactor::Reactor; + +/// A timer for non-Web platforms. +/// +/// self registers a timeout in the global reactor, which in turn sets a timeout in the poll call. +#[derive(Debug)] +pub(super) struct Timer { + /// self timer's ID and last waker that polled it. + /// + /// When self field is set to `None`, self timer is not registered in the reactor. + id_and_waker: Option<(usize, Waker)>, + + /// The next instant at which self timer fires. + /// + /// If self timer is a blank timer, self value is None. If the timer + /// must be set, self value contains the next instant at which the + /// timer must fire. + when: Option, + + /// The period. + period: Duration, +} + +impl Timer { + /// Create a timer that will never fire. + #[inline] + pub(super) fn never() -> Self { + Self { + id_and_waker: None, + when: None, + period: Duration::MAX, + } + } + + /// Create a timer that will fire at the given instant. + #[inline] + pub(super) fn after(duration: Duration) -> Timer { + Instant::now() + .checked_add(duration) + .map_or_else(Timer::never, Timer::at) + } + + /// Create a timer that will fire at the given instant. + #[inline] + pub(super) fn at(instant: Instant) -> Timer { + Timer::interval_at(instant, Duration::MAX) + } + + /// Create a timer that will fire at the given instant. + #[inline] + pub(super) fn interval(period: Duration) -> Timer { + Instant::now() + .checked_add(period) + .map_or_else(Timer::never, |at| Timer::interval_at(at, period)) + } + + /// Create a timer that will fire at self interval at self point. + #[inline] + pub(super) fn interval_at(start: Instant, period: Duration) -> Timer { + Timer { + id_and_waker: None, + when: Some(start), + period, + } + } + + /// Returns `true` if self timer will fire at some point. + #[inline] + pub(super) fn will_fire(&self) -> bool { + self.when.is_some() + } + + /// Set the timer to fire after the given duration. + #[inline] + pub(super) fn set_after(&mut self, duration: Duration) { + match Instant::now().checked_add(duration) { + Some(instant) => self.set_at(instant), + None => { + // Overflow to never going off. + self.clear(); + self.when = None; + } + } + } + + /// Set the timer to fire at the given instant. + #[inline] + pub(super) fn set_at(&mut self, instant: Instant) { + self.clear(); + + // Update the timeout. + self.when = Some(instant); + + if let Some((id, waker)) = self.id_and_waker.as_mut() { + // Re-register the timer with the new timeout. + *id = Reactor::get().insert_timer(instant, waker); + } + } + + /// Set the timer to emit events periodically. + #[inline] + pub(super) fn set_interval(&mut self, period: Duration) { + match Instant::now().checked_add(period) { + Some(instant) => self.set_interval_at(instant, period), + None => { + // Overflow to never going off. + self.clear(); + self.when = None; + } + } + } + + /// Set the timer to emit events periodically starting at a given instant. + #[inline] + pub(super) fn set_interval_at(&mut self, start: Instant, period: Duration) { + self.clear(); + + self.when = Some(start); + self.period = period; + + if let Some((id, waker)) = self.id_and_waker.as_mut() { + // Re-register the timer with the new timeout. + *id = Reactor::get().insert_timer(start, waker); + } + } + + /// Helper function to clear the current timer. + #[inline] + fn clear(&mut self) { + if let (Some(when), Some((id, _))) = (self.when, self.id_and_waker.as_ref()) { + // Deregister the timer from the reactor. + Reactor::get().remove_timer(when, *id); + } + } + + /// Poll for the next timer event. + #[inline] + pub(super) fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll> { + if let Some(ref mut when) = self.when { + // Check if the timer has already fired. + if Instant::now() >= *when { + if let Some((id, _)) = self.id_and_waker.take() { + // Deregister the timer from the reactor. + Reactor::get().remove_timer(*when, id); + } + let result_time = *when; + if let Some(next) = (*when).checked_add(self.period) { + *when = next; + // Register the timer in the reactor. + let id = Reactor::get().insert_timer(next, cx.waker()); + self.id_and_waker = Some((id, cx.waker().clone())); + } else { + self.when = None; + } + return Poll::Ready(Some(result_time)); + } else { + match &self.id_and_waker { + None => { + // Register the timer in the reactor. + let id = Reactor::get().insert_timer(*when, cx.waker()); + self.id_and_waker = Some((id, cx.waker().clone())); + } + Some((id, w)) if !w.will_wake(cx.waker()) => { + // Deregister the timer from the reactor to remove the old waker. + Reactor::get().remove_timer(*when, *id); + + // Register the timer in the reactor with the new waker. + let id = Reactor::get().insert_timer(*when, cx.waker()); + self.id_and_waker = Some((id, cx.waker().clone())); + } + Some(_) => {} + } + } + } + + Poll::Pending + } +} + +impl Drop for Timer { + fn drop(&mut self) { + if let (Some(when), Some((id, _))) = (self.when, self.id_and_waker.take()) { + // Deregister the timer from the reactor. + Reactor::get().remove_timer(when, id); + } + } +} From 8b7c787cfdc0ca7c5b5440d22ca26f6d40d245d3 Mon Sep 17 00:00:00 2001 From: John Nunley Date: Fri, 22 Sep 2023 18:51:57 -0700 Subject: [PATCH 3/5] Add support for timers on web platforms This commit adds support for async-io on wasm32-unknown-unknown. Not all features of async-io can be ported to WASM; for instance: - Async can't be ported over as WASM doesn't really have a reactor. WASI could eventually be supported here, but that is dependent on smol-rs/polling#102 - block_on() can't be ported over, as blocking isn't allowed on the web. The only thing left is Timer, which can be implemented using setTimeout and setInterval. So that's what's been done: when the WASM target family is enabled, Async and block_on() will be disabled and Timer will switch to an implementation that uses web timeouts. This is not a breaking change, as this crate previously failed to compile on web platforms anyways. This functionality currently does not support Node.js. Signed-off-by: John Nunley --- .github/workflows/ci.yml | 4 + Cargo.toml | 19 +++- src/lib.rs | 28 +++++- src/os/unix.rs | 2 +- src/timer/web.rs | 212 +++++++++++++++++++++++++++++++++++++++ tests/async.rs | 2 + tests/block_on.rs | 2 + tests/timer.rs | 82 ++++++++++++--- 8 files changed, 331 insertions(+), 20 deletions(-) create mode 100644 src/timer/web.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1e59d5a..1afcf4f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -38,6 +38,7 @@ jobs: - name: Install Rust # --no-self-update is necessary because the windows environment cannot self-update rustup.exe. run: rustup update ${{ matrix.rust }} --no-self-update && rustup default ${{ matrix.rust }} + - run: rustup target add wasm32-unknown-unknown - run: cargo build --all --all-features --all-targets - name: Run cargo check (without dev-dependencies to catch missing feature flags) if: startsWith(matrix.rust, 'nightly') @@ -50,6 +51,9 @@ jobs: # if: startsWith(matrix.rust, 'nightly') && matrix.os == 'ubuntu-latest' # run: cargo check -Z build-std --target=riscv32imc-esp-espidf - run: cargo test + - uses: taiki-e/install-action@wasm-pack + - run: cargo check --target wasm32-unknown-unknown --all-features --tests + - run: wasm-pack test --node # Copied from: https://github.com/rust-lang/stacker/pull/19/files windows_gnu: diff --git a/Cargo.toml b/Cargo.toml index 701b1d0..cf69099 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,11 +23,13 @@ name = "timer" harness = false [dependencies] -async-lock = "2.6" cfg-if = "1" +futures-lite = { version = "1.11.0", default-features = false } + +[target.'cfg(not(target_family = "wasm"))'.dependencies] +async-lock = "2.6" concurrent-queue = "2.2.0" futures-io = { version = "0.3.28", default-features = false, features = ["std"] } -futures-lite = { version = "1.11.0", default-features = false } parking = "2.0.0" polling = "3.0.0" rustix = { version = "0.38.2", default-features = false, features = ["std", "fs"] } @@ -36,8 +38,15 @@ socket2 = { version = "0.5.3", features = ["all"] } tracing = { version = "0.1.37", default-features = false } waker-fn = "1.1.0" +[target.'cfg(target_family = "wasm")'.dependencies] +atomic-waker = "1.1.1" +wasm-bindgen = "0.2.87" +web-sys = { version = "0.3.0", features = ["Window"] } + [dev-dependencies] async-channel = "1" + +[target.'cfg(not(target_family = "wasm"))'.dev-dependencies] async-net = "1" blocking = "1" criterion = { version = "0.4", default-features = false, features = ["cargo_bench_support"] } @@ -45,6 +54,12 @@ getrandom = "0.2.7" signal-hook = "0.3" tempfile = "3" +[target.'cfg(target_family = "wasm")'.dev-dependencies] +console_error_panic_hook = "0.1.7" +wasm-bindgen-futures = "0.4.37" +wasm-bindgen-test = "0.3.37" +web-time = "0.2.0" + [target.'cfg(target_os = "linux")'.dev-dependencies] inotify = { version = "0.10.1", default-features = false } timerfd = "1" diff --git a/src/lib.rs b/src/lib.rs index 7d3f042..a7759a5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -64,23 +64,31 @@ use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; -use std::time::{Duration, Instant}; +use std::time::Duration; -use futures_lite::stream::Stream; +#[cfg(not(target_family = "wasm"))] +use std::time::Instant; -use crate::reactor::Reactor; +use futures_lite::stream::Stream; +#[cfg(not(target_family = "wasm"))] mod driver; +#[cfg(not(target_family = "wasm"))] mod io; +#[cfg(not(target_family = "wasm"))] mod reactor; -#[path = "timer/native.rs"] +#[cfg_attr(not(target_family = "wasm"), path = "timer/native.rs")] +#[cfg_attr(target_family = "wasm", path = "timer/web.rs")] mod timer; pub mod os; +#[cfg(not(target_family = "wasm"))] pub use driver::block_on; +#[cfg(not(target_family = "wasm"))] pub use io::{Async, IoSafe}; +#[cfg(not(target_family = "wasm"))] pub use reactor::{Readable, ReadableOwned, Writable, WritableOwned}; /// A future or stream that emits timed events. @@ -197,6 +205,7 @@ impl Timer { /// Timer::at(when).await; /// # }); /// ``` + #[cfg(not(target_family = "wasm"))] #[inline] pub fn at(instant: Instant) -> Timer { Timer(timer::Timer::at(instant)) @@ -236,6 +245,7 @@ impl Timer { /// Timer::interval_at(start, period).next().await; /// # }); /// ``` + #[cfg(not(target_family = "wasm"))] #[inline] pub fn interval_at(start: Instant, period: Duration) -> Timer { Timer(timer::Timer::interval_at(start, period)) @@ -325,6 +335,7 @@ impl Timer { /// t.set_at(when); /// # }); /// ``` + #[cfg(not(target_family = "wasm"))] #[inline] pub fn set_at(&mut self, instant: Instant) { self.0.set_at(instant) @@ -376,6 +387,7 @@ impl Timer { /// t.set_interval_at(start, period); /// # }); /// ``` + #[cfg(not(target_family = "wasm"))] #[inline] pub fn set_interval_at(&mut self, start: Instant, period: Duration) { self.0.set_interval_at(start, period) @@ -383,8 +395,12 @@ impl Timer { } impl Future for Timer { + #[cfg(not(target_family = "wasm"))] type Output = Instant; + #[cfg(target_family = "wasm")] + type Output = (); + #[inline] fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.poll_next(cx) { @@ -396,8 +412,12 @@ impl Future for Timer { } impl Stream for Timer { + #[cfg(not(target_family = "wasm"))] type Item = Instant; + #[cfg(target_family = "wasm")] + type Item = (); + #[inline] fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.0.poll_next(cx) diff --git a/src/os/unix.rs b/src/os/unix.rs index ffb0832..d43462e 100644 --- a/src/os/unix.rs +++ b/src/os/unix.rs @@ -62,7 +62,7 @@ pub fn reactor_fd() -> Option> { not(polling_test_poll_backend), ))] { use std::os::unix::io::AsFd; - Some(crate::Reactor::get().poller.as_fd()) + Some(crate::reactor::Reactor::get().poller.as_fd()) } else { None } diff --git a/src/timer/web.rs b/src/timer/web.rs new file mode 100644 index 0000000..4fdfc19 --- /dev/null +++ b/src/timer/web.rs @@ -0,0 +1,212 @@ +//! Timers for web targets. +//! +//! These use the `setTimeout` function on the web to handle timing. + +use std::convert::TryInto; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::time::Duration; + +use atomic_waker::AtomicWaker; +use wasm_bindgen::closure::Closure; +use wasm_bindgen::JsCast; + +/// A timer for non-Web platforms. +/// +/// self registers a timeout in the global reactor, which in turn sets a timeout in the poll call. +#[derive(Debug)] +pub(super) struct Timer { + /// The waker to wake when the timer fires. + waker: Arc, + + /// The ongoing timeout or interval. + ongoing_timeout: TimerId, + + /// Keep the closure alive so we don't drop it. + closure: Option>, +} + +#[derive(Debug)] +struct State { + /// The number of times this timer has been woken. + woken: AtomicUsize, + + /// The waker to wake when the timer fires. + waker: AtomicWaker, +} + +#[derive(Debug)] +enum TimerId { + NoTimer, + Timeout(i32), + Interval(i32), +} + +impl Timer { + /// Create a timer that will never fire. + #[inline] + pub(super) fn never() -> Self { + Self { + waker: Arc::new(State { + woken: AtomicUsize::new(0), + waker: AtomicWaker::new(), + }), + ongoing_timeout: TimerId::NoTimer, + closure: None, + } + } + + /// Create a timer that will fire at the given instant. + #[inline] + pub(super) fn after(duration: Duration) -> Timer { + let mut this = Self::never(); + this.set_after(duration); + this + } + + /// Create a timer that will fire at the given instant. + #[inline] + pub(super) fn interval(period: Duration) -> Timer { + let mut this = Self::never(); + this.set_interval(period); + this + } + + /// Returns `true` if self timer will fire at some point. + #[inline] + pub(super) fn will_fire(&self) -> bool { + matches!( + self.ongoing_timeout, + TimerId::Timeout(_) | TimerId::Interval(_) + ) + } + + /// Set the timer to fire after the given duration. + #[inline] + pub(super) fn set_after(&mut self, duration: Duration) { + // Set the timeout. + let id = { + let waker = self.waker.clone(); + let closure: Closure = Closure::wrap(Box::new(move || { + waker.wake(); + })); + + let result = web_sys::window() + .unwrap() + .set_timeout_with_callback_and_timeout_and_arguments_0( + closure.as_ref().unchecked_ref(), + duration.as_millis().try_into().expect("timeout too long"), + ); + + // Make sure we don't drop the closure before it's called. + self.closure = Some(closure); + + match result { + Ok(id) => id, + Err(_) => { + panic!("failed to set timeout") + } + } + }; + + // Set our ID. + self.ongoing_timeout = TimerId::Timeout(id); + } + + /// Set the timer to emit events periodically. + #[inline] + pub(super) fn set_interval(&mut self, period: Duration) { + // Set the timeout. + let id = { + let waker = self.waker.clone(); + let closure: Closure = Closure::wrap(Box::new(move || { + waker.wake(); + })); + + let result = web_sys::window() + .unwrap() + .set_interval_with_callback_and_timeout_and_arguments_0( + closure.as_ref().unchecked_ref(), + period.as_millis().try_into().expect("timeout too long"), + ); + + // Make sure we don't drop the closure before it's called. + self.closure = Some(closure); + + match result { + Ok(id) => id, + Err(_) => { + panic!("failed to set interval") + } + } + }; + + // Set our ID. + self.ongoing_timeout = TimerId::Interval(id); + } + + /// Poll for the next timer event. + #[inline] + pub(super) fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll> { + let mut registered = false; + let mut woken = self.waker.woken.load(Ordering::Acquire); + + loop { + if woken > 0 { + // Try to decrement the number of woken events. + if let Err(new_woken) = self.waker.woken.compare_exchange( + woken, + woken - 1, + Ordering::SeqCst, + Ordering::Acquire, + ) { + woken = new_woken; + continue; + } + + // If we are using a one-shot timer, clear it. + if let TimerId::Timeout(_) = self.ongoing_timeout { + self.clear(); + } + + return Poll::Ready(Some(())); + } + + if !registered { + // Register the waker. + self.waker.waker.register(cx.waker()); + registered = true; + } else { + // We've already registered, so we can just return pending. + return Poll::Pending; + } + } + } + + /// Clear the current timeout. + fn clear(&mut self) { + match self.ongoing_timeout { + TimerId::NoTimer => {} + TimerId::Timeout(id) => { + web_sys::window().unwrap().clear_timeout_with_handle(id); + } + TimerId::Interval(id) => { + web_sys::window().unwrap().clear_interval_with_handle(id); + } + } + } +} + +impl State { + fn wake(&self) { + self.woken.fetch_add(1, Ordering::SeqCst); + self.waker.wake(); + } +} + +impl Drop for Timer { + fn drop(&mut self) { + self.clear(); + } +} diff --git a/tests/async.rs b/tests/async.rs index c856760..5d9bd12 100644 --- a/tests/async.rs +++ b/tests/async.rs @@ -1,3 +1,5 @@ +#![cfg(not(target_family = "wasm"))] + use std::future::Future; use std::io; use std::net::{Shutdown, TcpListener, TcpStream, UdpSocket}; diff --git a/tests/block_on.rs b/tests/block_on.rs index 70241f0..3a5c1ba 100644 --- a/tests/block_on.rs +++ b/tests/block_on.rs @@ -1,3 +1,5 @@ +#![cfg(not(target_family = "wasm"))] + use async_io::block_on; use std::{ future::Future, diff --git a/tests/timer.rs b/tests/timer.rs index cdd90db..5a16089 100644 --- a/tests/timer.rs +++ b/tests/timer.rs @@ -1,12 +1,26 @@ use std::future::Future; +#[cfg(not(target_family = "wasm"))] use std::pin::Pin; +#[cfg(not(target_family = "wasm"))] use std::sync::{Arc, Mutex}; +#[cfg(not(target_family = "wasm"))] use std::thread; + +#[cfg(not(target_family = "wasm"))] use std::time::{Duration, Instant}; +#[cfg(target_family = "wasm")] +use web_time::{Duration, Instant}; + +#[cfg(target_family = "wasm")] +wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser); use async_io::Timer; -use futures_lite::{future, FutureExt, StreamExt}; +use futures_lite::{FutureExt, StreamExt}; + +#[cfg(not(target_family = "wasm"))] +use futures_lite::future; +#[cfg(not(target_family = "wasm"))] fn spawn( f: impl Future + Send + 'static, ) -> impl Future + Send + 'static { @@ -21,18 +35,60 @@ fn spawn( Box::pin(async move { r.recv().await.unwrap() }) } -#[test] -fn smoke() { - future::block_on(async { +#[cfg(target_family = "wasm")] +fn spawn(f: impl Future + 'static) -> impl Future + 'static { + let (s, r) = async_channel::bounded(1); + + #[cfg(target_family = "wasm")] + wasm_bindgen_futures::spawn_local(async move { + s.send(f.await).await.ok(); + }); + + Box::pin(async move { r.recv().await.unwrap() }) +} + +#[cfg(not(target_family = "wasm"))] +macro_rules! test { + ( + $(#[$meta:meta])* + async fn $name:ident () $bl:block + ) => { + #[test] + $(#[$meta])* + fn $name() { + futures_lite::future::block_on(async { + $bl + }) + } + }; +} + +#[cfg(target_family = "wasm")] +macro_rules! test { + ( + $(#[$meta:meta])* + async fn $name:ident () $bl:block + ) => { + // wasm-bindgen-test handles waiting on the future for us + #[wasm_bindgen_test::wasm_bindgen_test] + $(#[$meta])* + async fn $name() { + console_error_panic_hook::set_once(); + $bl + } + }; +} + +test! { + async fn smoke() { let start = Instant::now(); Timer::after(Duration::from_secs(1)).await; assert!(start.elapsed() >= Duration::from_secs(1)); - }); + } } -#[test] -fn interval() { - future::block_on(async { +test! { + async fn interval() { let period = Duration::from_secs(1); let jitter = Duration::from_millis(500); let start = Instant::now(); @@ -43,12 +99,11 @@ fn interval() { timer.next().await; let elapsed = start.elapsed(); assert!(elapsed >= period * 2 && elapsed - period * 2 < jitter); - }); + } } -#[test] -fn poll_across_tasks() { - future::block_on(async { +test! { + async fn poll_across_tasks() { let start = Instant::now(); let (sender, receiver) = async_channel::bounded(1); @@ -74,9 +129,10 @@ fn poll_across_tasks() { task2.await; assert!(start.elapsed() >= Duration::from_secs(1)); - }); + } } +#[cfg(not(target_family = "wasm"))] #[test] fn set() { future::block_on(async { From 5c36d2e233886c1a7141bb15eb4c4854c06f5aee Mon Sep 17 00:00:00 2001 From: John Nunley Date: Fri, 29 Sep 2023 16:04:06 -0700 Subject: [PATCH 4/5] Exclude WASI from web targets Signed-off-by: John Nunley --- Cargo.toml | 8 ++++---- src/lib.rs | 40 +++++++++++++++++++++++----------------- tests/async.rs | 2 +- tests/block_on.rs | 2 +- tests/timer.rs | 26 +++++++++++++------------- 5 files changed, 42 insertions(+), 36 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index cf69099..c6aa16f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,7 @@ harness = false cfg-if = "1" futures-lite = { version = "1.11.0", default-features = false } -[target.'cfg(not(target_family = "wasm"))'.dependencies] +[target.'cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))'.dependencies] async-lock = "2.6" concurrent-queue = "2.2.0" futures-io = { version = "0.3.28", default-features = false, features = ["std"] } @@ -38,7 +38,7 @@ socket2 = { version = "0.5.3", features = ["all"] } tracing = { version = "0.1.37", default-features = false } waker-fn = "1.1.0" -[target.'cfg(target_family = "wasm")'.dependencies] +[target.'cfg(all(target_family = "wasm", not(target_os = "wasi")))'.dependencies] atomic-waker = "1.1.1" wasm-bindgen = "0.2.87" web-sys = { version = "0.3.0", features = ["Window"] } @@ -46,7 +46,7 @@ web-sys = { version = "0.3.0", features = ["Window"] } [dev-dependencies] async-channel = "1" -[target.'cfg(not(target_family = "wasm"))'.dev-dependencies] +[target.'cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))'.dev-dependencies] async-net = "1" blocking = "1" criterion = { version = "0.4", default-features = false, features = ["cargo_bench_support"] } @@ -54,7 +54,7 @@ getrandom = "0.2.7" signal-hook = "0.3" tempfile = "3" -[target.'cfg(target_family = "wasm")'.dev-dependencies] +[target.'cfg(all(target_family = "wasm", not(target_os = "wasi")))'.dev-dependencies] console_error_panic_hook = "0.1.7" wasm-bindgen-futures = "0.4.37" wasm-bindgen-test = "0.3.37" diff --git a/src/lib.rs b/src/lib.rs index a7759a5..1caf524 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -66,29 +66,35 @@ use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; -#[cfg(not(target_family = "wasm"))] +#[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] use std::time::Instant; use futures_lite::stream::Stream; -#[cfg(not(target_family = "wasm"))] +#[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] mod driver; -#[cfg(not(target_family = "wasm"))] +#[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] mod io; -#[cfg(not(target_family = "wasm"))] +#[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] mod reactor; -#[cfg_attr(not(target_family = "wasm"), path = "timer/native.rs")] -#[cfg_attr(target_family = "wasm", path = "timer/web.rs")] +#[cfg_attr( + not(all(target_family = "wasm", not(target_os = "wasi"))), + path = "timer/native.rs" +)] +#[cfg_attr( + all(target_family = "wasm", not(target_os = "wasi")), + path = "timer/web.rs" +)] mod timer; pub mod os; -#[cfg(not(target_family = "wasm"))] +#[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] pub use driver::block_on; -#[cfg(not(target_family = "wasm"))] +#[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] pub use io::{Async, IoSafe}; -#[cfg(not(target_family = "wasm"))] +#[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] pub use reactor::{Readable, ReadableOwned, Writable, WritableOwned}; /// A future or stream that emits timed events. @@ -205,7 +211,7 @@ impl Timer { /// Timer::at(when).await; /// # }); /// ``` - #[cfg(not(target_family = "wasm"))] + #[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] #[inline] pub fn at(instant: Instant) -> Timer { Timer(timer::Timer::at(instant)) @@ -245,7 +251,7 @@ impl Timer { /// Timer::interval_at(start, period).next().await; /// # }); /// ``` - #[cfg(not(target_family = "wasm"))] + #[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] #[inline] pub fn interval_at(start: Instant, period: Duration) -> Timer { Timer(timer::Timer::interval_at(start, period)) @@ -335,7 +341,7 @@ impl Timer { /// t.set_at(when); /// # }); /// ``` - #[cfg(not(target_family = "wasm"))] + #[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] #[inline] pub fn set_at(&mut self, instant: Instant) { self.0.set_at(instant) @@ -387,7 +393,7 @@ impl Timer { /// t.set_interval_at(start, period); /// # }); /// ``` - #[cfg(not(target_family = "wasm"))] + #[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] #[inline] pub fn set_interval_at(&mut self, start: Instant, period: Duration) { self.0.set_interval_at(start, period) @@ -395,10 +401,10 @@ impl Timer { } impl Future for Timer { - #[cfg(not(target_family = "wasm"))] + #[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] type Output = Instant; - #[cfg(target_family = "wasm")] + #[cfg(all(target_family = "wasm", not(target_os = "wasi")))] type Output = (); #[inline] @@ -412,10 +418,10 @@ impl Future for Timer { } impl Stream for Timer { - #[cfg(not(target_family = "wasm"))] + #[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] type Item = Instant; - #[cfg(target_family = "wasm")] + #[cfg(all(target_family = "wasm", not(target_os = "wasi")))] type Item = (); #[inline] diff --git a/tests/async.rs b/tests/async.rs index 5d9bd12..bfb5a30 100644 --- a/tests/async.rs +++ b/tests/async.rs @@ -1,4 +1,4 @@ -#![cfg(not(target_family = "wasm"))] +#![cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] use std::future::Future; use std::io; diff --git a/tests/block_on.rs b/tests/block_on.rs index 3a5c1ba..330ce6d 100644 --- a/tests/block_on.rs +++ b/tests/block_on.rs @@ -1,4 +1,4 @@ -#![cfg(not(target_family = "wasm"))] +#![cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] use async_io::block_on; use std::{ diff --git a/tests/timer.rs b/tests/timer.rs index 5a16089..f08858e 100644 --- a/tests/timer.rs +++ b/tests/timer.rs @@ -1,26 +1,26 @@ use std::future::Future; -#[cfg(not(target_family = "wasm"))] +#[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] use std::pin::Pin; -#[cfg(not(target_family = "wasm"))] +#[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] use std::sync::{Arc, Mutex}; -#[cfg(not(target_family = "wasm"))] +#[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] use std::thread; -#[cfg(not(target_family = "wasm"))] +#[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] use std::time::{Duration, Instant}; -#[cfg(target_family = "wasm")] +#[cfg(all(target_family = "wasm", not(target_os = "wasi")))] use web_time::{Duration, Instant}; -#[cfg(target_family = "wasm")] +#[cfg(all(target_family = "wasm", not(target_os = "wasi")))] wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser); use async_io::Timer; use futures_lite::{FutureExt, StreamExt}; -#[cfg(not(target_family = "wasm"))] +#[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] use futures_lite::future; -#[cfg(not(target_family = "wasm"))] +#[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] fn spawn( f: impl Future + Send + 'static, ) -> impl Future + Send + 'static { @@ -35,11 +35,11 @@ fn spawn( Box::pin(async move { r.recv().await.unwrap() }) } -#[cfg(target_family = "wasm")] +#[cfg(all(target_family = "wasm", not(target_os = "wasi")))] fn spawn(f: impl Future + 'static) -> impl Future + 'static { let (s, r) = async_channel::bounded(1); - #[cfg(target_family = "wasm")] + #[cfg(all(target_family = "wasm", not(target_os = "wasi")))] wasm_bindgen_futures::spawn_local(async move { s.send(f.await).await.ok(); }); @@ -47,7 +47,7 @@ fn spawn(f: impl Future + 'static) -> impl Future Date: Sat, 30 Sep 2023 07:32:21 -0700 Subject: [PATCH 5/5] Use headless Firefox tests in CI Signed-off-by: John Nunley --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1afcf4f..fbe84d6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -53,7 +53,7 @@ jobs: - run: cargo test - uses: taiki-e/install-action@wasm-pack - run: cargo check --target wasm32-unknown-unknown --all-features --tests - - run: wasm-pack test --node + - run: wasm-pack test --firefox --headless # Copied from: https://github.com/rust-lang/stacker/pull/19/files windows_gnu: