diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1e59d5a..fbe84d6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -38,6 +38,7 @@ jobs: - name: Install Rust # --no-self-update is necessary because the windows environment cannot self-update rustup.exe. run: rustup update ${{ matrix.rust }} --no-self-update && rustup default ${{ matrix.rust }} + - run: rustup target add wasm32-unknown-unknown - run: cargo build --all --all-features --all-targets - name: Run cargo check (without dev-dependencies to catch missing feature flags) if: startsWith(matrix.rust, 'nightly') @@ -50,6 +51,9 @@ jobs: # if: startsWith(matrix.rust, 'nightly') && matrix.os == 'ubuntu-latest' # run: cargo check -Z build-std --target=riscv32imc-esp-espidf - run: cargo test + - uses: taiki-e/install-action@wasm-pack + - run: cargo check --target wasm32-unknown-unknown --all-features --tests + - run: wasm-pack test --firefox --headless # Copied from: https://github.com/rust-lang/stacker/pull/19/files windows_gnu: diff --git a/Cargo.toml b/Cargo.toml index 701b1d0..c6aa16f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,11 +23,13 @@ name = "timer" harness = false [dependencies] -async-lock = "2.6" cfg-if = "1" +futures-lite = { version = "1.11.0", default-features = false } + +[target.'cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))'.dependencies] +async-lock = "2.6" concurrent-queue = "2.2.0" futures-io = { version = "0.3.28", default-features = false, features = ["std"] } -futures-lite = { version = "1.11.0", default-features = false } parking = "2.0.0" polling = "3.0.0" rustix = { version = "0.38.2", default-features = false, features = ["std", "fs"] } @@ -36,8 +38,15 @@ socket2 = { version = "0.5.3", features = ["all"] } tracing = { version = "0.1.37", default-features = false } waker-fn = "1.1.0" +[target.'cfg(all(target_family = "wasm", not(target_os = "wasi")))'.dependencies] +atomic-waker = "1.1.1" +wasm-bindgen = "0.2.87" +web-sys = { version = "0.3.0", features = ["Window"] } + [dev-dependencies] async-channel = "1" + +[target.'cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))'.dev-dependencies] async-net = "1" blocking = "1" criterion = { version = "0.4", default-features = false, features = ["cargo_bench_support"] } @@ -45,6 +54,12 @@ getrandom = "0.2.7" signal-hook = "0.3" tempfile = "3" +[target.'cfg(all(target_family = "wasm", not(target_os = "wasi")))'.dev-dependencies] +console_error_panic_hook = "0.1.7" +wasm-bindgen-futures = "0.4.37" +wasm-bindgen-test = "0.3.37" +web-time = "0.2.0" + [target.'cfg(target_os = "linux")'.dev-dependencies] inotify = { version = "0.10.1", default-features = false } timerfd = "1" diff --git a/src/io.rs b/src/io.rs new file mode 100644 index 0000000..4a17551 --- /dev/null +++ b/src/io.rs @@ -0,0 +1,1577 @@ +//! Implements [`Async`], allowing users to register FDs/sockets into the reactor. + +use std::convert::TryFrom; +use std::future::Future; +use std::io::{self, IoSlice, IoSliceMut, Read, Write}; +use std::net::{SocketAddr, TcpListener, TcpStream, UdpSocket}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +#[cfg(unix)] +use std::{ + os::unix::io::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd}, + os::unix::net::{SocketAddr as UnixSocketAddr, UnixDatagram, UnixListener, UnixStream}, + path::Path, +}; + +#[cfg(windows)] +use std::os::windows::io::{AsRawSocket, AsSocket, BorrowedSocket, OwnedSocket, RawSocket}; + +use futures_io::{AsyncRead, AsyncWrite}; +use futures_lite::stream::{self, Stream}; +use futures_lite::{future, pin, ready}; +use socket2::{Domain, Protocol, SockAddr, Socket, Type}; + +use crate::reactor::{ + Reactor, Readable, ReadableOwned, Registration, Source, Writable, WritableOwned, +}; + +/// Async adapter for I/O types. +/// +/// This type puts an I/O handle into non-blocking mode, registers it in +/// [epoll]/[kqueue]/[event ports]/[IOCP], and then provides an async interface for it. +/// +/// [epoll]: https://en.wikipedia.org/wiki/Epoll +/// [kqueue]: https://en.wikipedia.org/wiki/Kqueue +/// [event ports]: https://illumos.org/man/port_create +/// [IOCP]: https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports +/// +/// # Caveats +/// +/// [`Async`] is a low-level primitive, and as such it comes with some caveats. +/// +/// For higher-level primitives built on top of [`Async`], look into [`async-net`] or +/// [`async-process`] (on Unix). +/// +/// The most notable caveat is that it is unsafe to access the inner I/O source mutably +/// using this primitive. Traits likes [`AsyncRead`] and [`AsyncWrite`] are not implemented by +/// default unless it is guaranteed that the resource won't be invalidated by reading or writing. +/// See the [`IoSafe`] trait for more information. +/// +/// [`async-net`]: https://github.com/smol-rs/async-net +/// [`async-process`]: https://github.com/smol-rs/async-process +/// [`AsyncRead`]: https://docs.rs/futures-io/latest/futures_io/trait.AsyncRead.html +/// [`AsyncWrite`]: https://docs.rs/futures-io/latest/futures_io/trait.AsyncWrite.html +/// +/// ### Supported types +/// +/// [`Async`] supports all networking types, as well as some OS-specific file descriptors like +/// [timerfd] and [inotify]. +/// +/// However, do not use [`Async`] with types like [`File`][`std::fs::File`], +/// [`Stdin`][`std::io::Stdin`], [`Stdout`][`std::io::Stdout`], or [`Stderr`][`std::io::Stderr`] +/// because all operating systems have issues with them when put in non-blocking mode. +/// +/// [timerfd]: https://github.com/smol-rs/async-io/blob/master/examples/linux-timerfd.rs +/// [inotify]: https://github.com/smol-rs/async-io/blob/master/examples/linux-inotify.rs +/// +/// ### Concurrent I/O +/// +/// Note that [`&Async`][`Async`] implements [`AsyncRead`] and [`AsyncWrite`] if `&T` +/// implements those traits, which means tasks can concurrently read and write using shared +/// references. +/// +/// But there is a catch: only one task can read a time, and only one task can write at a time. It +/// is okay to have two tasks where one is reading and the other is writing at the same time, but +/// it is not okay to have two tasks reading at the same time or writing at the same time. If you +/// try to do that, conflicting tasks will just keep waking each other in turn, thus wasting CPU +/// time. +/// +/// Besides [`AsyncRead`] and [`AsyncWrite`], this caveat also applies to +/// [`poll_readable()`][`Async::poll_readable()`] and +/// [`poll_writable()`][`Async::poll_writable()`]. +/// +/// However, any number of tasks can be concurrently calling other methods like +/// [`readable()`][`Async::readable()`] or [`read_with()`][`Async::read_with()`]. +/// +/// ### Closing +/// +/// Closing the write side of [`Async`] with [`close()`][`futures_lite::AsyncWriteExt::close()`] +/// simply flushes. If you want to shutdown a TCP or Unix socket, use +/// [`Shutdown`][`std::net::Shutdown`]. +/// +/// # Examples +/// +/// Connect to a server and echo incoming messages back to the server: +/// +/// ```no_run +/// use async_io::Async; +/// use futures_lite::io; +/// use std::net::TcpStream; +/// +/// # futures_lite::future::block_on(async { +/// // Connect to a local server. +/// let stream = Async::::connect(([127, 0, 0, 1], 8000)).await?; +/// +/// // Echo all messages from the read side of the stream into the write side. +/// io::copy(&stream, &stream).await?; +/// # std::io::Result::Ok(()) }); +/// ``` +/// +/// You can use either predefined async methods or wrap blocking I/O operations in +/// [`Async::read_with()`], [`Async::read_with_mut()`], [`Async::write_with()`], and +/// [`Async::write_with_mut()`]: +/// +/// ```no_run +/// use async_io::Async; +/// use std::net::TcpListener; +/// +/// # futures_lite::future::block_on(async { +/// let listener = Async::::bind(([127, 0, 0, 1], 0))?; +/// +/// // These two lines are equivalent: +/// let (stream, addr) = listener.accept().await?; +/// let (stream, addr) = listener.read_with(|inner| inner.accept()).await?; +/// # std::io::Result::Ok(()) }); +/// ``` +#[derive(Debug)] +pub struct Async { + /// A source registered in the reactor. + pub(crate) source: Arc, + + /// The inner I/O handle. + pub(crate) io: Option, +} + +impl Unpin for Async {} + +#[cfg(unix)] +impl Async { + /// Creates an async I/O handle. + /// + /// This method will put the handle in non-blocking mode and register it in + /// [epoll]/[kqueue]/[event ports]/[IOCP]. + /// + /// On Unix systems, the handle must implement `AsRawFd`, while on Windows it must implement + /// `AsRawSocket`. + /// + /// [epoll]: https://en.wikipedia.org/wiki/Epoll + /// [kqueue]: https://en.wikipedia.org/wiki/Kqueue + /// [event ports]: https://illumos.org/man/port_create + /// [IOCP]: https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports + /// + /// # Examples + /// + /// ``` + /// use async_io::Async; + /// use std::net::{SocketAddr, TcpListener}; + /// + /// # futures_lite::future::block_on(async { + /// let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))?; + /// let listener = Async::new(listener)?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn new(io: T) -> io::Result> { + // Put the file descriptor in non-blocking mode. + let fd = io.as_fd(); + cfg_if::cfg_if! { + // ioctl(FIONBIO) sets the flag atomically, but we use this only on Linux + // for now, as with the standard library, because it seems to behave + // differently depending on the platform. + // https://github.com/rust-lang/rust/commit/efeb42be2837842d1beb47b51bb693c7474aba3d + // https://github.com/libuv/libuv/blob/e9d91fccfc3e5ff772d5da90e1c4a24061198ca0/src/unix/poll.c#L78-L80 + // https://github.com/tokio-rs/mio/commit/0db49f6d5caf54b12176821363d154384357e70a + if #[cfg(target_os = "linux")] { + rustix::io::ioctl_fionbio(fd, true)?; + } else { + let previous = rustix::fs::fcntl_getfl(fd)?; + let new = previous | rustix::fs::OFlags::NONBLOCK; + if new != previous { + rustix::fs::fcntl_setfl(fd, new)?; + } + } + } + + // SAFETY: It is impossible to drop the I/O source while it is registered through + // this type. + let registration = unsafe { Registration::new(fd) }; + + Ok(Async { + source: Reactor::get().insert_io(registration)?, + io: Some(io), + }) + } +} + +#[cfg(unix)] +impl AsRawFd for Async { + fn as_raw_fd(&self) -> RawFd { + self.get_ref().as_raw_fd() + } +} + +#[cfg(unix)] +impl AsFd for Async { + fn as_fd(&self) -> BorrowedFd<'_> { + self.get_ref().as_fd() + } +} + +#[cfg(unix)] +impl> TryFrom for Async { + type Error = io::Error; + + fn try_from(value: OwnedFd) -> Result { + Async::new(value.into()) + } +} + +#[cfg(unix)] +impl> TryFrom> for OwnedFd { + type Error = io::Error; + + fn try_from(value: Async) -> Result { + value.into_inner().map(Into::into) + } +} + +#[cfg(windows)] +impl Async { + /// Creates an async I/O handle. + /// + /// This method will put the handle in non-blocking mode and register it in + /// [epoll]/[kqueue]/[event ports]/[IOCP]. + /// + /// On Unix systems, the handle must implement `AsRawFd`, while on Windows it must implement + /// `AsRawSocket`. + /// + /// [epoll]: https://en.wikipedia.org/wiki/Epoll + /// [kqueue]: https://en.wikipedia.org/wiki/Kqueue + /// [event ports]: https://illumos.org/man/port_create + /// [IOCP]: https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports + /// + /// # Examples + /// + /// ``` + /// use async_io::Async; + /// use std::net::{SocketAddr, TcpListener}; + /// + /// # futures_lite::future::block_on(async { + /// let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))?; + /// let listener = Async::new(listener)?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn new(io: T) -> io::Result> { + let borrowed = io.as_socket(); + + // Put the socket in non-blocking mode. + // + // Safety: We assume `as_raw_socket()` returns a valid fd. When we can + // depend on Rust >= 1.63, where `AsFd` is stabilized, and when + // `TimerFd` implements it, we can remove this unsafe and simplify this. + rustix::io::ioctl_fionbio(borrowed, true)?; + + // Create the registration. + // + // SAFETY: It is impossible to drop the I/O source while it is registered through + // this type. + let registration = unsafe { Registration::new(borrowed) }; + + Ok(Async { + source: Reactor::get().insert_io(registration)?, + io: Some(io), + }) + } +} + +#[cfg(windows)] +impl AsRawSocket for Async { + fn as_raw_socket(&self) -> RawSocket { + self.get_ref().as_raw_socket() + } +} + +#[cfg(windows)] +impl AsSocket for Async { + fn as_socket(&self) -> BorrowedSocket<'_> { + self.get_ref().as_socket() + } +} + +#[cfg(windows)] +impl> TryFrom for Async { + type Error = io::Error; + + fn try_from(value: OwnedSocket) -> Result { + Async::new(value.into()) + } +} + +#[cfg(windows)] +impl> TryFrom> for OwnedSocket { + type Error = io::Error; + + fn try_from(value: Async) -> Result { + value.into_inner().map(Into::into) + } +} + +impl Async { + /// Gets a reference to the inner I/O handle. + /// + /// # Examples + /// + /// ``` + /// use async_io::Async; + /// use std::net::TcpListener; + /// + /// # futures_lite::future::block_on(async { + /// let listener = Async::::bind(([127, 0, 0, 1], 0))?; + /// let inner = listener.get_ref(); + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn get_ref(&self) -> &T { + self.io.as_ref().unwrap() + } + + /// Gets a mutable reference to the inner I/O handle. + /// + /// # Safety + /// + /// The underlying I/O source must not be dropped using this function. + /// + /// # Examples + /// + /// ``` + /// use async_io::Async; + /// use std::net::TcpListener; + /// + /// # futures_lite::future::block_on(async { + /// let mut listener = Async::::bind(([127, 0, 0, 1], 0))?; + /// let inner = unsafe { listener.get_mut() }; + /// # std::io::Result::Ok(()) }); + /// ``` + pub unsafe fn get_mut(&mut self) -> &mut T { + self.io.as_mut().unwrap() + } + + /// Unwraps the inner I/O handle. + /// + /// This method will **not** put the I/O handle back into blocking mode. + /// + /// # Examples + /// + /// ``` + /// use async_io::Async; + /// use std::net::TcpListener; + /// + /// # futures_lite::future::block_on(async { + /// let listener = Async::::bind(([127, 0, 0, 1], 0))?; + /// let inner = listener.into_inner()?; + /// + /// // Put the listener back into blocking mode. + /// inner.set_nonblocking(false)?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn into_inner(mut self) -> io::Result { + let io = self.io.take().unwrap(); + Reactor::get().remove_io(&self.source)?; + Ok(io) + } + + /// Waits until the I/O handle is readable. + /// + /// This method completes when a read operation on this I/O handle wouldn't block. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::TcpListener; + /// + /// # futures_lite::future::block_on(async { + /// let mut listener = Async::::bind(([127, 0, 0, 1], 0))?; + /// + /// // Wait until a client can be accepted. + /// listener.readable().await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn readable(&self) -> Readable<'_, T> { + Source::readable(self) + } + + /// Waits until the I/O handle is readable. + /// + /// This method completes when a read operation on this I/O handle wouldn't block. + pub fn readable_owned(self: Arc) -> ReadableOwned { + Source::readable_owned(self) + } + + /// Waits until the I/O handle is writable. + /// + /// This method completes when a write operation on this I/O handle wouldn't block. + /// + /// # Examples + /// + /// ``` + /// use async_io::Async; + /// use std::net::{TcpStream, ToSocketAddrs}; + /// + /// # futures_lite::future::block_on(async { + /// let addr = "example.com:80".to_socket_addrs()?.next().unwrap(); + /// let stream = Async::::connect(addr).await?; + /// + /// // Wait until the stream is writable. + /// stream.writable().await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn writable(&self) -> Writable<'_, T> { + Source::writable(self) + } + + /// Waits until the I/O handle is writable. + /// + /// This method completes when a write operation on this I/O handle wouldn't block. + pub fn writable_owned(self: Arc) -> WritableOwned { + Source::writable_owned(self) + } + + /// Polls the I/O handle for readability. + /// + /// When this method returns [`Poll::Ready`], that means the OS has delivered an event + /// indicating readability since the last time this task has called the method and received + /// [`Poll::Pending`]. + /// + /// # Caveats + /// + /// Two different tasks should not call this method concurrently. Otherwise, conflicting tasks + /// will just keep waking each other in turn, thus wasting CPU time. + /// + /// Note that the [`AsyncRead`] implementation for [`Async`] also uses this method. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use futures_lite::future; + /// use std::net::TcpListener; + /// + /// # futures_lite::future::block_on(async { + /// let mut listener = Async::::bind(([127, 0, 0, 1], 0))?; + /// + /// // Wait until a client can be accepted. + /// future::poll_fn(|cx| listener.poll_readable(cx)).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn poll_readable(&self, cx: &mut Context<'_>) -> Poll> { + self.source.poll_readable(cx) + } + + /// Polls the I/O handle for writability. + /// + /// When this method returns [`Poll::Ready`], that means the OS has delivered an event + /// indicating writability since the last time this task has called the method and received + /// [`Poll::Pending`]. + /// + /// # Caveats + /// + /// Two different tasks should not call this method concurrently. Otherwise, conflicting tasks + /// will just keep waking each other in turn, thus wasting CPU time. + /// + /// Note that the [`AsyncWrite`] implementation for [`Async`] also uses this method. + /// + /// # Examples + /// + /// ``` + /// use async_io::Async; + /// use futures_lite::future; + /// use std::net::{TcpStream, ToSocketAddrs}; + /// + /// # futures_lite::future::block_on(async { + /// let addr = "example.com:80".to_socket_addrs()?.next().unwrap(); + /// let stream = Async::::connect(addr).await?; + /// + /// // Wait until the stream is writable. + /// future::poll_fn(|cx| stream.poll_writable(cx)).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn poll_writable(&self, cx: &mut Context<'_>) -> Poll> { + self.source.poll_writable(cx) + } + + /// Performs a read operation asynchronously. + /// + /// The I/O handle is registered in the reactor and put in non-blocking mode. This method + /// invokes the `op` closure in a loop until it succeeds or returns an error other than + /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS + /// sends a notification that the I/O handle is readable. + /// + /// The closure receives a shared reference to the I/O handle. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::TcpListener; + /// + /// # futures_lite::future::block_on(async { + /// let listener = Async::::bind(([127, 0, 0, 1], 0))?; + /// + /// // Accept a new client asynchronously. + /// let (stream, addr) = listener.read_with(|l| l.accept()).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn read_with(&self, op: impl FnMut(&T) -> io::Result) -> io::Result { + let mut op = op; + loop { + match op(self.get_ref()) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return res, + } + optimistic(self.readable()).await?; + } + } + + /// Performs a read operation asynchronously. + /// + /// The I/O handle is registered in the reactor and put in non-blocking mode. This method + /// invokes the `op` closure in a loop until it succeeds or returns an error other than + /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS + /// sends a notification that the I/O handle is readable. + /// + /// The closure receives a mutable reference to the I/O handle. + /// + /// # Safety + /// + /// In the closure, the underlying I/O source must not be dropped. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::TcpListener; + /// + /// # futures_lite::future::block_on(async { + /// let mut listener = Async::::bind(([127, 0, 0, 1], 0))?; + /// + /// // Accept a new client asynchronously. + /// let (stream, addr) = unsafe { listener.read_with_mut(|l| l.accept()).await? }; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async unsafe fn read_with_mut( + &mut self, + op: impl FnMut(&mut T) -> io::Result, + ) -> io::Result { + let mut op = op; + loop { + match op(self.get_mut()) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return res, + } + optimistic(self.readable()).await?; + } + } + + /// Performs a write operation asynchronously. + /// + /// The I/O handle is registered in the reactor and put in non-blocking mode. This method + /// invokes the `op` closure in a loop until it succeeds or returns an error other than + /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS + /// sends a notification that the I/O handle is writable. + /// + /// The closure receives a shared reference to the I/O handle. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::UdpSocket; + /// + /// # futures_lite::future::block_on(async { + /// let socket = Async::::bind(([127, 0, 0, 1], 8000))?; + /// socket.get_ref().connect("127.0.0.1:9000")?; + /// + /// let msg = b"hello"; + /// let len = socket.write_with(|s| s.send(msg)).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn write_with(&self, op: impl FnMut(&T) -> io::Result) -> io::Result { + let mut op = op; + loop { + match op(self.get_ref()) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return res, + } + optimistic(self.writable()).await?; + } + } + + /// Performs a write operation asynchronously. + /// + /// The I/O handle is registered in the reactor and put in non-blocking mode. This method + /// invokes the `op` closure in a loop until it succeeds or returns an error other than + /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS + /// sends a notification that the I/O handle is writable. + /// + /// # Safety + /// + /// The closure receives a mutable reference to the I/O handle. In the closure, the underlying + /// I/O source must not be dropped. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::UdpSocket; + /// + /// # futures_lite::future::block_on(async { + /// let mut socket = Async::::bind(([127, 0, 0, 1], 8000))?; + /// socket.get_ref().connect("127.0.0.1:9000")?; + /// + /// let msg = b"hello"; + /// let len = unsafe { socket.write_with_mut(|s| s.send(msg)).await? }; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async unsafe fn write_with_mut( + &mut self, + op: impl FnMut(&mut T) -> io::Result, + ) -> io::Result { + let mut op = op; + loop { + match op(self.get_mut()) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return res, + } + optimistic(self.writable()).await?; + } + } +} + +impl AsRef for Async { + fn as_ref(&self) -> &T { + self.get_ref() + } +} + +impl Drop for Async { + fn drop(&mut self) { + if self.io.is_some() { + // Deregister and ignore errors because destructors should not panic. + Reactor::get().remove_io(&self.source).ok(); + + // Drop the I/O handle to close it. + self.io.take(); + } + } +} + +/// Types whose I/O trait implementations do not drop the underlying I/O source. +/// +/// The resource contained inside of the [`Async`] cannot be invalidated. This invalidation can +/// happen if the inner resource (the [`TcpStream`], [`UnixListener`] or other `T`) is moved out +/// and dropped before the [`Async`]. Because of this, functions that grant mutable access to +/// the inner type are unsafe, as there is no way to guarantee that the source won't be dropped +/// and a dangling handle won't be left behind. +/// +/// Unfortunately this extends to implementations of [`Read`] and [`Write`]. Since methods on those +/// traits take `&mut`, there is no guarantee that the implementor of those traits won't move the +/// source out while the method is being run. +/// +/// This trait is an antidote to this predicament. By implementing this trait, the user pledges +/// that using any I/O traits won't destroy the source. This way, [`Async`] can implement the +/// `async` version of these I/O traits, like [`AsyncRead`] and [`AsyncWrite`]. +/// +/// # Safety +/// +/// Any I/O trait implementations for this type must not drop the underlying I/O source. Traits +/// affected by this trait include [`Read`], [`Write`], [`Seek`] and [`BufRead`]. +/// +/// This trait is implemented by default on top of `libstd` types. In addition, it is implemented +/// for immutable reference types, as it is impossible to invalidate any outstanding references +/// while holding an immutable reference, even with interior mutability. As Rust's current pinning +/// system relies on similar guarantees, I believe that this approach is robust. +/// +/// [`BufRead`]: https://doc.rust-lang.org/std/io/trait.BufRead.html +/// [`Read`]: https://doc.rust-lang.org/std/io/trait.Read.html +/// [`Seek`]: https://doc.rust-lang.org/std/io/trait.Seek.html +/// [`Write`]: https://doc.rust-lang.org/std/io/trait.Write.html +/// +/// [`AsyncRead`]: https://docs.rs/futures-io/latest/futures_io/trait.AsyncRead.html +/// [`AsyncWrite`]: https://docs.rs/futures-io/latest/futures_io/trait.AsyncWrite.html +pub unsafe trait IoSafe {} + +/// Reference types can't be mutated. +/// +/// The worst thing that can happen is that external state is used to change what kind of pointer +/// `as_fd()` returns. For instance: +/// +/// ``` +/// # #[cfg(unix)] { +/// use std::cell::Cell; +/// use std::net::TcpStream; +/// use std::os::unix::io::{AsFd, BorrowedFd}; +/// +/// struct Bar { +/// flag: Cell, +/// a: TcpStream, +/// b: TcpStream +/// } +/// +/// impl AsFd for Bar { +/// fn as_fd(&self) -> BorrowedFd<'_> { +/// if self.flag.replace(!self.flag.get()) { +/// self.a.as_fd() +/// } else { +/// self.b.as_fd() +/// } +/// } +/// } +/// # } +/// ``` +/// +/// We solve this problem by only calling `as_fd()` once to get the original source. Implementations +/// like this are considered buggy (but not unsound) and are thus not really supported by `async-io`. +unsafe impl IoSafe for &T {} + +// Can be implemented on top of libstd types. +unsafe impl IoSafe for std::fs::File {} +unsafe impl IoSafe for std::io::Stderr {} +unsafe impl IoSafe for std::io::Stdin {} +unsafe impl IoSafe for std::io::Stdout {} +unsafe impl IoSafe for std::io::StderrLock<'_> {} +unsafe impl IoSafe for std::io::StdinLock<'_> {} +unsafe impl IoSafe for std::io::StdoutLock<'_> {} +unsafe impl IoSafe for std::net::TcpStream {} + +#[cfg(unix)] +unsafe impl IoSafe for std::os::unix::net::UnixStream {} + +unsafe impl IoSafe for std::io::BufReader {} +unsafe impl IoSafe for std::io::BufWriter {} +unsafe impl IoSafe for std::io::LineWriter {} +unsafe impl IoSafe for &mut T {} +unsafe impl IoSafe for Box {} +unsafe impl IoSafe for std::borrow::Cow<'_, T> {} + +impl AsyncRead for Async { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + loop { + match unsafe { (*self).get_mut() }.read(buf) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), + } + ready!(self.poll_readable(cx))?; + } + } + + fn poll_read_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &mut [IoSliceMut<'_>], + ) -> Poll> { + loop { + match unsafe { (*self).get_mut() }.read_vectored(bufs) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), + } + ready!(self.poll_readable(cx))?; + } + } +} + +// Since this is through a reference, we can't mutate the inner I/O source. +// Therefore this is safe! +impl AsyncRead for &Async +where + for<'a> &'a T: Read, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + loop { + match (*self).get_ref().read(buf) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), + } + ready!(self.poll_readable(cx))?; + } + } + + fn poll_read_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &mut [IoSliceMut<'_>], + ) -> Poll> { + loop { + match (*self).get_ref().read_vectored(bufs) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), + } + ready!(self.poll_readable(cx))?; + } + } +} + +impl AsyncWrite for Async { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + loop { + match unsafe { (*self).get_mut() }.write(buf) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), + } + ready!(self.poll_writable(cx))?; + } + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + loop { + match unsafe { (*self).get_mut() }.write_vectored(bufs) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), + } + ready!(self.poll_writable(cx))?; + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + match unsafe { (*self).get_mut() }.flush() { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), + } + ready!(self.poll_writable(cx))?; + } + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_flush(cx) + } +} + +impl AsyncWrite for &Async +where + for<'a> &'a T: Write, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + loop { + match (*self).get_ref().write(buf) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), + } + ready!(self.poll_writable(cx))?; + } + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + loop { + match (*self).get_ref().write_vectored(bufs) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), + } + ready!(self.poll_writable(cx))?; + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + match (*self).get_ref().flush() { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), + } + ready!(self.poll_writable(cx))?; + } + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_flush(cx) + } +} + +impl Async { + /// Creates a TCP listener bound to the specified address. + /// + /// Binding with port number 0 will request an available port from the OS. + /// + /// # Examples + /// + /// ``` + /// use async_io::Async; + /// use std::net::TcpListener; + /// + /// # futures_lite::future::block_on(async { + /// let listener = Async::::bind(([127, 0, 0, 1], 0))?; + /// println!("Listening on {}", listener.get_ref().local_addr()?); + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn bind>(addr: A) -> io::Result> { + let addr = addr.into(); + Async::new(TcpListener::bind(addr)?) + } + + /// Accepts a new incoming TCP connection. + /// + /// When a connection is established, it will be returned as a TCP stream together with its + /// remote address. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::TcpListener; + /// + /// # futures_lite::future::block_on(async { + /// let listener = Async::::bind(([127, 0, 0, 1], 8000))?; + /// let (stream, addr) = listener.accept().await?; + /// println!("Accepted client: {}", addr); + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn accept(&self) -> io::Result<(Async, SocketAddr)> { + let (stream, addr) = self.read_with(|io| io.accept()).await?; + Ok((Async::new(stream)?, addr)) + } + + /// Returns a stream of incoming TCP connections. + /// + /// The stream is infinite, i.e. it never stops with a [`None`]. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use futures_lite::{pin, stream::StreamExt}; + /// use std::net::TcpListener; + /// + /// # futures_lite::future::block_on(async { + /// let listener = Async::::bind(([127, 0, 0, 1], 8000))?; + /// let incoming = listener.incoming(); + /// pin!(incoming); + /// + /// while let Some(stream) = incoming.next().await { + /// let stream = stream?; + /// println!("Accepted client: {}", stream.get_ref().peer_addr()?); + /// } + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn incoming(&self) -> impl Stream>> + Send + '_ { + stream::unfold(self, |listener| async move { + let res = listener.accept().await.map(|(stream, _)| stream); + Some((res, listener)) + }) + } +} + +impl TryFrom for Async { + type Error = io::Error; + + fn try_from(listener: std::net::TcpListener) -> io::Result { + Async::new(listener) + } +} + +impl Async { + /// Creates a TCP connection to the specified address. + /// + /// # Examples + /// + /// ``` + /// use async_io::Async; + /// use std::net::{TcpStream, ToSocketAddrs}; + /// + /// # futures_lite::future::block_on(async { + /// let addr = "example.com:80".to_socket_addrs()?.next().unwrap(); + /// let stream = Async::::connect(addr).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn connect>(addr: A) -> io::Result> { + // Begin async connect. + let addr = addr.into(); + let domain = Domain::for_address(addr); + let socket = connect(addr.into(), domain, Some(Protocol::TCP))?; + let stream = Async::new(TcpStream::from(socket))?; + + // The stream becomes writable when connected. + stream.writable().await?; + + // Check if there was an error while connecting. + match stream.get_ref().take_error()? { + None => Ok(stream), + Some(err) => Err(err), + } + } + + /// Reads data from the stream without removing it from the buffer. + /// + /// Returns the number of bytes read. Successive calls of this method read the same data. + /// + /// # Examples + /// + /// ``` + /// use async_io::Async; + /// use futures_lite::{io::AsyncWriteExt, stream::StreamExt}; + /// use std::net::{TcpStream, ToSocketAddrs}; + /// + /// # futures_lite::future::block_on(async { + /// let addr = "example.com:80".to_socket_addrs()?.next().unwrap(); + /// let mut stream = Async::::connect(addr).await?; + /// + /// stream + /// .write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") + /// .await?; + /// + /// let mut buf = [0u8; 1024]; + /// let len = stream.peek(&mut buf).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn peek(&self, buf: &mut [u8]) -> io::Result { + self.read_with(|io| io.peek(buf)).await + } +} + +impl TryFrom for Async { + type Error = io::Error; + + fn try_from(stream: std::net::TcpStream) -> io::Result { + Async::new(stream) + } +} + +impl Async { + /// Creates a UDP socket bound to the specified address. + /// + /// Binding with port number 0 will request an available port from the OS. + /// + /// # Examples + /// + /// ``` + /// use async_io::Async; + /// use std::net::UdpSocket; + /// + /// # futures_lite::future::block_on(async { + /// let socket = Async::::bind(([127, 0, 0, 1], 0))?; + /// println!("Bound to {}", socket.get_ref().local_addr()?); + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn bind>(addr: A) -> io::Result> { + let addr = addr.into(); + Async::new(UdpSocket::bind(addr)?) + } + + /// Receives a single datagram message. + /// + /// Returns the number of bytes read and the address the message came from. + /// + /// This method must be called with a valid byte slice of sufficient size to hold the message. + /// If the message is too long to fit, excess bytes may get discarded. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::UdpSocket; + /// + /// # futures_lite::future::block_on(async { + /// let socket = Async::::bind(([127, 0, 0, 1], 8000))?; + /// + /// let mut buf = [0u8; 1024]; + /// let (len, addr) = socket.recv_from(&mut buf).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + self.read_with(|io| io.recv_from(buf)).await + } + + /// Receives a single datagram message without removing it from the queue. + /// + /// Returns the number of bytes read and the address the message came from. + /// + /// This method must be called with a valid byte slice of sufficient size to hold the message. + /// If the message is too long to fit, excess bytes may get discarded. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::UdpSocket; + /// + /// # futures_lite::future::block_on(async { + /// let socket = Async::::bind(([127, 0, 0, 1], 8000))?; + /// + /// let mut buf = [0u8; 1024]; + /// let (len, addr) = socket.peek_from(&mut buf).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + self.read_with(|io| io.peek_from(buf)).await + } + + /// Sends data to the specified address. + /// + /// Returns the number of bytes writen. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::UdpSocket; + /// + /// # futures_lite::future::block_on(async { + /// let socket = Async::::bind(([127, 0, 0, 1], 0))?; + /// let addr = socket.get_ref().local_addr()?; + /// + /// let msg = b"hello"; + /// let len = socket.send_to(msg, addr).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn send_to>(&self, buf: &[u8], addr: A) -> io::Result { + let addr = addr.into(); + self.write_with(|io| io.send_to(buf, addr)).await + } + + /// Receives a single datagram message from the connected peer. + /// + /// Returns the number of bytes read. + /// + /// This method must be called with a valid byte slice of sufficient size to hold the message. + /// If the message is too long to fit, excess bytes may get discarded. + /// + /// The [`connect`][`UdpSocket::connect()`] method connects this socket to a remote address. + /// This method will fail if the socket is not connected. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::UdpSocket; + /// + /// # futures_lite::future::block_on(async { + /// let socket = Async::::bind(([127, 0, 0, 1], 8000))?; + /// socket.get_ref().connect("127.0.0.1:9000")?; + /// + /// let mut buf = [0u8; 1024]; + /// let len = socket.recv(&mut buf).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn recv(&self, buf: &mut [u8]) -> io::Result { + self.read_with(|io| io.recv(buf)).await + } + + /// Receives a single datagram message from the connected peer without removing it from the + /// queue. + /// + /// Returns the number of bytes read and the address the message came from. + /// + /// This method must be called with a valid byte slice of sufficient size to hold the message. + /// If the message is too long to fit, excess bytes may get discarded. + /// + /// The [`connect`][`UdpSocket::connect()`] method connects this socket to a remote address. + /// This method will fail if the socket is not connected. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::UdpSocket; + /// + /// # futures_lite::future::block_on(async { + /// let socket = Async::::bind(([127, 0, 0, 1], 8000))?; + /// socket.get_ref().connect("127.0.0.1:9000")?; + /// + /// let mut buf = [0u8; 1024]; + /// let len = socket.peek(&mut buf).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn peek(&self, buf: &mut [u8]) -> io::Result { + self.read_with(|io| io.peek(buf)).await + } + + /// Sends data to the connected peer. + /// + /// Returns the number of bytes written. + /// + /// The [`connect`][`UdpSocket::connect()`] method connects this socket to a remote address. + /// This method will fail if the socket is not connected. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::UdpSocket; + /// + /// # futures_lite::future::block_on(async { + /// let socket = Async::::bind(([127, 0, 0, 1], 8000))?; + /// socket.get_ref().connect("127.0.0.1:9000")?; + /// + /// let msg = b"hello"; + /// let len = socket.send(msg).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn send(&self, buf: &[u8]) -> io::Result { + self.write_with(|io| io.send(buf)).await + } +} + +impl TryFrom for Async { + type Error = io::Error; + + fn try_from(socket: std::net::UdpSocket) -> io::Result { + Async::new(socket) + } +} + +#[cfg(unix)] +impl Async { + /// Creates a UDS listener bound to the specified path. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::os::unix::net::UnixListener; + /// + /// # futures_lite::future::block_on(async { + /// let listener = Async::::bind("/tmp/socket")?; + /// println!("Listening on {:?}", listener.get_ref().local_addr()?); + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn bind>(path: P) -> io::Result> { + let path = path.as_ref().to_owned(); + Async::new(UnixListener::bind(path)?) + } + + /// Accepts a new incoming UDS stream connection. + /// + /// When a connection is established, it will be returned as a stream together with its remote + /// address. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::os::unix::net::UnixListener; + /// + /// # futures_lite::future::block_on(async { + /// let listener = Async::::bind("/tmp/socket")?; + /// let (stream, addr) = listener.accept().await?; + /// println!("Accepted client: {:?}", addr); + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn accept(&self) -> io::Result<(Async, UnixSocketAddr)> { + let (stream, addr) = self.read_with(|io| io.accept()).await?; + Ok((Async::new(stream)?, addr)) + } + + /// Returns a stream of incoming UDS connections. + /// + /// The stream is infinite, i.e. it never stops with a [`None`] item. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use futures_lite::{pin, stream::StreamExt}; + /// use std::os::unix::net::UnixListener; + /// + /// # futures_lite::future::block_on(async { + /// let listener = Async::::bind("/tmp/socket")?; + /// let incoming = listener.incoming(); + /// pin!(incoming); + /// + /// while let Some(stream) = incoming.next().await { + /// let stream = stream?; + /// println!("Accepted client: {:?}", stream.get_ref().peer_addr()?); + /// } + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn incoming(&self) -> impl Stream>> + Send + '_ { + stream::unfold(self, |listener| async move { + let res = listener.accept().await.map(|(stream, _)| stream); + Some((res, listener)) + }) + } +} + +#[cfg(unix)] +impl TryFrom for Async { + type Error = io::Error; + + fn try_from(listener: std::os::unix::net::UnixListener) -> io::Result { + Async::new(listener) + } +} + +#[cfg(unix)] +impl Async { + /// Creates a UDS stream connected to the specified path. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::os::unix::net::UnixStream; + /// + /// # futures_lite::future::block_on(async { + /// let stream = Async::::connect("/tmp/socket").await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn connect>(path: P) -> io::Result> { + // Begin async connect. + let socket = connect(SockAddr::unix(path)?, Domain::UNIX, None)?; + let stream = Async::new(UnixStream::from(socket))?; + + // The stream becomes writable when connected. + stream.writable().await?; + + // On Linux, it appears the socket may become writable even when connecting fails, so we + // must do an extra check here and see if the peer address is retrievable. + stream.get_ref().peer_addr()?; + Ok(stream) + } + + /// Creates an unnamed pair of connected UDS stream sockets. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::os::unix::net::UnixStream; + /// + /// # futures_lite::future::block_on(async { + /// let (stream1, stream2) = Async::::pair()?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn pair() -> io::Result<(Async, Async)> { + let (stream1, stream2) = UnixStream::pair()?; + Ok((Async::new(stream1)?, Async::new(stream2)?)) + } +} + +#[cfg(unix)] +impl TryFrom for Async { + type Error = io::Error; + + fn try_from(stream: std::os::unix::net::UnixStream) -> io::Result { + Async::new(stream) + } +} + +#[cfg(unix)] +impl Async { + /// Creates a UDS datagram socket bound to the specified path. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::os::unix::net::UnixDatagram; + /// + /// # futures_lite::future::block_on(async { + /// let socket = Async::::bind("/tmp/socket")?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn bind>(path: P) -> io::Result> { + let path = path.as_ref().to_owned(); + Async::new(UnixDatagram::bind(path)?) + } + + /// Creates a UDS datagram socket not bound to any address. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::os::unix::net::UnixDatagram; + /// + /// # futures_lite::future::block_on(async { + /// let socket = Async::::unbound()?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn unbound() -> io::Result> { + Async::new(UnixDatagram::unbound()?) + } + + /// Creates an unnamed pair of connected Unix datagram sockets. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::os::unix::net::UnixDatagram; + /// + /// # futures_lite::future::block_on(async { + /// let (socket1, socket2) = Async::::pair()?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn pair() -> io::Result<(Async, Async)> { + let (socket1, socket2) = UnixDatagram::pair()?; + Ok((Async::new(socket1)?, Async::new(socket2)?)) + } + + /// Receives data from the socket. + /// + /// Returns the number of bytes read and the address the message came from. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::os::unix::net::UnixDatagram; + /// + /// # futures_lite::future::block_on(async { + /// let socket = Async::::bind("/tmp/socket")?; + /// + /// let mut buf = [0u8; 1024]; + /// let (len, addr) = socket.recv_from(&mut buf).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, UnixSocketAddr)> { + self.read_with(|io| io.recv_from(buf)).await + } + + /// Sends data to the specified address. + /// + /// Returns the number of bytes written. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::os::unix::net::UnixDatagram; + /// + /// # futures_lite::future::block_on(async { + /// let socket = Async::::unbound()?; + /// + /// let msg = b"hello"; + /// let addr = "/tmp/socket"; + /// let len = socket.send_to(msg, addr).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn send_to>(&self, buf: &[u8], path: P) -> io::Result { + self.write_with(|io| io.send_to(buf, &path)).await + } + + /// Receives data from the connected peer. + /// + /// Returns the number of bytes read and the address the message came from. + /// + /// The [`connect`][`UnixDatagram::connect()`] method connects this socket to a remote address. + /// This method will fail if the socket is not connected. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::os::unix::net::UnixDatagram; + /// + /// # futures_lite::future::block_on(async { + /// let socket = Async::::bind("/tmp/socket1")?; + /// socket.get_ref().connect("/tmp/socket2")?; + /// + /// let mut buf = [0u8; 1024]; + /// let len = socket.recv(&mut buf).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn recv(&self, buf: &mut [u8]) -> io::Result { + self.read_with(|io| io.recv(buf)).await + } + + /// Sends data to the connected peer. + /// + /// Returns the number of bytes written. + /// + /// The [`connect`][`UnixDatagram::connect()`] method connects this socket to a remote address. + /// This method will fail if the socket is not connected. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::os::unix::net::UnixDatagram; + /// + /// # futures_lite::future::block_on(async { + /// let socket = Async::::bind("/tmp/socket1")?; + /// socket.get_ref().connect("/tmp/socket2")?; + /// + /// let msg = b"hello"; + /// let len = socket.send(msg).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn send(&self, buf: &[u8]) -> io::Result { + self.write_with(|io| io.send(buf)).await + } +} + +#[cfg(unix)] +impl TryFrom for Async { + type Error = io::Error; + + fn try_from(socket: std::os::unix::net::UnixDatagram) -> io::Result { + Async::new(socket) + } +} + +/// Polls a future once, waits for a wakeup, and then optimistically assumes the future is ready. +async fn optimistic(fut: impl Future>) -> io::Result<()> { + let mut polled = false; + pin!(fut); + + future::poll_fn(|cx| { + if !polled { + polled = true; + fut.as_mut().poll(cx) + } else { + Poll::Ready(Ok(())) + } + }) + .await +} + +fn connect(addr: SockAddr, domain: Domain, protocol: Option) -> io::Result { + let sock_type = Type::STREAM; + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "linux", + target_os = "netbsd", + target_os = "openbsd" + ))] + // If we can, set nonblocking at socket creation for unix + let sock_type = sock_type.nonblocking(); + // This automatically handles cloexec on unix, no_inherit on windows and nosigpipe on macos + let socket = Socket::new(domain, sock_type, protocol)?; + #[cfg(not(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "linux", + target_os = "netbsd", + target_os = "openbsd" + )))] + // If the current platform doesn't support nonblocking at creation, enable it after creation + socket.set_nonblocking(true)?; + match socket.connect(&addr) { + Ok(_) => {} + #[cfg(unix)] + Err(err) if err.raw_os_error() == Some(rustix::io::Errno::INPROGRESS.raw_os_error()) => {} + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + Err(err) => return Err(err), + } + Ok(socket) +} diff --git a/src/lib.rs b/src/lib.rs index fd671b1..1caf524 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -61,38 +61,40 @@ html_logo_url = "https://raw.githubusercontent.com/smol-rs/smol/master/assets/images/logo_fullsize_transparent.png" )] -use std::convert::TryFrom; use std::future::Future; -use std::io::{self, IoSlice, IoSliceMut, Read, Write}; -use std::net::{SocketAddr, TcpListener, TcpStream, UdpSocket}; use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll, Waker}; -use std::time::{Duration, Instant}; +use std::task::{Context, Poll}; +use std::time::Duration; -#[cfg(unix)] -use std::{ - os::unix::io::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd}, - os::unix::net::{SocketAddr as UnixSocketAddr, UnixDatagram, UnixListener, UnixStream}, - path::Path, -}; +#[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] +use std::time::Instant; -#[cfg(windows)] -use std::os::windows::io::{AsRawSocket, AsSocket, BorrowedSocket, OwnedSocket, RawSocket}; - -use futures_io::{AsyncRead, AsyncWrite}; -use futures_lite::stream::{self, Stream}; -use futures_lite::{future, pin, ready}; -use socket2::{Domain, Protocol, SockAddr, Socket, Type}; - -use crate::reactor::{Reactor, Registration, Source}; +use futures_lite::stream::Stream; +#[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] mod driver; +#[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] +mod io; +#[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] mod reactor; +#[cfg_attr( + not(all(target_family = "wasm", not(target_os = "wasi"))), + path = "timer/native.rs" +)] +#[cfg_attr( + all(target_family = "wasm", not(target_os = "wasi")), + path = "timer/web.rs" +)] +mod timer; + pub mod os; +#[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] pub use driver::block_on; +#[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] +pub use io::{Async, IoSafe}; +#[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] pub use reactor::{Readable, ReadableOwned, Writable, WritableOwned}; /// A future or stream that emits timed events. @@ -138,22 +140,7 @@ pub use reactor::{Readable, ReadableOwned, Writable, WritableOwned}; /// # std::io::Result::Ok(()) }); /// ``` #[derive(Debug)] -pub struct Timer { - /// This timer's ID and last waker that polled it. - /// - /// When this field is set to `None`, this timer is not registered in the reactor. - id_and_waker: Option<(usize, Waker)>, - - /// The next instant at which this timer fires. - /// - /// If this timer is a blank timer, this value is None. If the timer - /// must be set, this value contains the next instant at which the - /// timer must fire. - when: Option, - - /// The period. - period: Duration, -} +pub struct Timer(timer::Timer); impl Timer { /// Creates a timer that will never fire. @@ -188,12 +175,9 @@ impl Timer { /// run_with_timeout(None).await; /// # }); /// ``` + #[inline] pub fn never() -> Timer { - Timer { - id_and_waker: None, - when: None, - period: Duration::MAX, - } + Timer(timer::Timer::never()) } /// Creates a timer that emits an event once after the given duration of time. @@ -208,10 +192,9 @@ impl Timer { /// Timer::after(Duration::from_secs(1)).await; /// # }); /// ``` + #[inline] pub fn after(duration: Duration) -> Timer { - Instant::now() - .checked_add(duration) - .map_or_else(Timer::never, Timer::at) + Timer(timer::Timer::after(duration)) } /// Creates a timer that emits an event once at the given time instant. @@ -228,8 +211,10 @@ impl Timer { /// Timer::at(when).await; /// # }); /// ``` + #[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] + #[inline] pub fn at(instant: Instant) -> Timer { - Timer::interval_at(instant, Duration::MAX) + Timer(timer::Timer::at(instant)) } /// Creates a timer that emits events periodically. @@ -246,10 +231,9 @@ impl Timer { /// Timer::interval(period).next().await; /// # }); /// ``` + #[inline] pub fn interval(period: Duration) -> Timer { - Instant::now() - .checked_add(period) - .map_or_else(Timer::never, |at| Timer::interval_at(at, period)) + Timer(timer::Timer::interval(period)) } /// Creates a timer that emits events periodically, starting at `start`. @@ -267,12 +251,10 @@ impl Timer { /// Timer::interval_at(start, period).next().await; /// # }); /// ``` + #[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] + #[inline] pub fn interval_at(start: Instant, period: Duration) -> Timer { - Timer { - id_and_waker: None, - when: Some(start), - period, - } + Timer(timer::Timer::interval_at(start, period)) } /// Indicates whether or not this timer will ever fire. @@ -314,7 +296,7 @@ impl Timer { /// ``` #[inline] pub fn will_fire(&self) -> bool { - self.when.is_some() + self.0.will_fire() } /// Sets the timer to emit an en event once after the given duration of time. @@ -334,15 +316,9 @@ impl Timer { /// t.set_after(Duration::from_millis(100)); /// # }); /// ``` + #[inline] pub fn set_after(&mut self, duration: Duration) { - match Instant::now().checked_add(duration) { - Some(instant) => self.set_at(instant), - None => { - // Overflow to never going off. - self.clear(); - self.when = None; - } - } + self.0.set_after(duration) } /// Sets the timer to emit an event once at the given time instant. @@ -365,16 +341,10 @@ impl Timer { /// t.set_at(when); /// # }); /// ``` + #[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] + #[inline] pub fn set_at(&mut self, instant: Instant) { - self.clear(); - - // Update the timeout. - self.when = Some(instant); - - if let Some((id, waker)) = self.id_and_waker.as_mut() { - // Re-register the timer with the new timeout. - *id = Reactor::get().insert_timer(instant, waker); - } + self.0.set_at(instant) } /// Sets the timer to emit events periodically. @@ -397,15 +367,9 @@ impl Timer { /// t.set_interval(period); /// # }); /// ``` + #[inline] pub fn set_interval(&mut self, period: Duration) { - match Instant::now().checked_add(period) { - Some(instant) => self.set_interval_at(instant, period), - None => { - // Overflow to never going off. - self.clear(); - self.when = None; - } - } + self.0.set_interval(period) } /// Sets the timer to emit events periodically, starting at `start`. @@ -429,39 +393,21 @@ impl Timer { /// t.set_interval_at(start, period); /// # }); /// ``` + #[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] + #[inline] pub fn set_interval_at(&mut self, start: Instant, period: Duration) { - self.clear(); - - self.when = Some(start); - self.period = period; - - if let Some((id, waker)) = self.id_and_waker.as_mut() { - // Re-register the timer with the new timeout. - *id = Reactor::get().insert_timer(start, waker); - } - } - - /// Helper function to clear the current timer. - fn clear(&mut self) { - if let (Some(when), Some((id, _))) = (self.when, self.id_and_waker.as_ref()) { - // Deregister the timer from the reactor. - Reactor::get().remove_timer(when, *id); - } - } -} - -impl Drop for Timer { - fn drop(&mut self) { - if let (Some(when), Some((id, _))) = (self.when, self.id_and_waker.take()) { - // Deregister the timer from the reactor. - Reactor::get().remove_timer(when, id); - } + self.0.set_interval_at(start, period) } } impl Future for Timer { + #[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] type Output = Instant; + #[cfg(all(target_family = "wasm", not(target_os = "wasi")))] + type Output = (); + + #[inline] fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.poll_next(cx) { Poll::Ready(Some(when)) => Poll::Ready(when), @@ -472,1597 +418,14 @@ impl Future for Timer { } impl Stream for Timer { + #[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] type Item = Instant; - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.get_mut(); - - if let Some(ref mut when) = this.when { - // Check if the timer has already fired. - if Instant::now() >= *when { - if let Some((id, _)) = this.id_and_waker.take() { - // Deregister the timer from the reactor. - Reactor::get().remove_timer(*when, id); - } - let result_time = *when; - if let Some(next) = (*when).checked_add(this.period) { - *when = next; - // Register the timer in the reactor. - let id = Reactor::get().insert_timer(next, cx.waker()); - this.id_and_waker = Some((id, cx.waker().clone())); - } else { - this.when = None; - } - return Poll::Ready(Some(result_time)); - } else { - match &this.id_and_waker { - None => { - // Register the timer in the reactor. - let id = Reactor::get().insert_timer(*when, cx.waker()); - this.id_and_waker = Some((id, cx.waker().clone())); - } - Some((id, w)) if !w.will_wake(cx.waker()) => { - // Deregister the timer from the reactor to remove the old waker. - Reactor::get().remove_timer(*when, *id); - - // Register the timer in the reactor with the new waker. - let id = Reactor::get().insert_timer(*when, cx.waker()); - this.id_and_waker = Some((id, cx.waker().clone())); - } - Some(_) => {} - } - } - } - - Poll::Pending - } -} - -/// Async adapter for I/O types. -/// -/// This type puts an I/O handle into non-blocking mode, registers it in -/// [epoll]/[kqueue]/[event ports]/[IOCP], and then provides an async interface for it. -/// -/// [epoll]: https://en.wikipedia.org/wiki/Epoll -/// [kqueue]: https://en.wikipedia.org/wiki/Kqueue -/// [event ports]: https://illumos.org/man/port_create -/// [IOCP]: https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports -/// -/// # Caveats -/// -/// [`Async`] is a low-level primitive, and as such it comes with some caveats. -/// -/// For higher-level primitives built on top of [`Async`], look into [`async-net`] or -/// [`async-process`] (on Unix). -/// -/// The most notable caveat is that it is unsafe to access the inner I/O source mutably -/// using this primitive. Traits likes [`AsyncRead`] and [`AsyncWrite`] are not implemented by -/// default unless it is guaranteed that the resource won't be invalidated by reading or writing. -/// See the [`IoSafe`] trait for more information. -/// -/// [`async-net`]: https://github.com/smol-rs/async-net -/// [`async-process`]: https://github.com/smol-rs/async-process -/// [`AsyncRead`]: https://docs.rs/futures-io/latest/futures_io/trait.AsyncRead.html -/// [`AsyncWrite`]: https://docs.rs/futures-io/latest/futures_io/trait.AsyncWrite.html -/// -/// ### Supported types -/// -/// [`Async`] supports all networking types, as well as some OS-specific file descriptors like -/// [timerfd] and [inotify]. -/// -/// However, do not use [`Async`] with types like [`File`][`std::fs::File`], -/// [`Stdin`][`std::io::Stdin`], [`Stdout`][`std::io::Stdout`], or [`Stderr`][`std::io::Stderr`] -/// because all operating systems have issues with them when put in non-blocking mode. -/// -/// [timerfd]: https://github.com/smol-rs/async-io/blob/master/examples/linux-timerfd.rs -/// [inotify]: https://github.com/smol-rs/async-io/blob/master/examples/linux-inotify.rs -/// -/// ### Concurrent I/O -/// -/// Note that [`&Async`][`Async`] implements [`AsyncRead`] and [`AsyncWrite`] if `&T` -/// implements those traits, which means tasks can concurrently read and write using shared -/// references. -/// -/// But there is a catch: only one task can read a time, and only one task can write at a time. It -/// is okay to have two tasks where one is reading and the other is writing at the same time, but -/// it is not okay to have two tasks reading at the same time or writing at the same time. If you -/// try to do that, conflicting tasks will just keep waking each other in turn, thus wasting CPU -/// time. -/// -/// Besides [`AsyncRead`] and [`AsyncWrite`], this caveat also applies to -/// [`poll_readable()`][`Async::poll_readable()`] and -/// [`poll_writable()`][`Async::poll_writable()`]. -/// -/// However, any number of tasks can be concurrently calling other methods like -/// [`readable()`][`Async::readable()`] or [`read_with()`][`Async::read_with()`]. -/// -/// ### Closing -/// -/// Closing the write side of [`Async`] with [`close()`][`futures_lite::AsyncWriteExt::close()`] -/// simply flushes. If you want to shutdown a TCP or Unix socket, use -/// [`Shutdown`][`std::net::Shutdown`]. -/// -/// # Examples -/// -/// Connect to a server and echo incoming messages back to the server: -/// -/// ```no_run -/// use async_io::Async; -/// use futures_lite::io; -/// use std::net::TcpStream; -/// -/// # futures_lite::future::block_on(async { -/// // Connect to a local server. -/// let stream = Async::::connect(([127, 0, 0, 1], 8000)).await?; -/// -/// // Echo all messages from the read side of the stream into the write side. -/// io::copy(&stream, &stream).await?; -/// # std::io::Result::Ok(()) }); -/// ``` -/// -/// You can use either predefined async methods or wrap blocking I/O operations in -/// [`Async::read_with()`], [`Async::read_with_mut()`], [`Async::write_with()`], and -/// [`Async::write_with_mut()`]: -/// -/// ```no_run -/// use async_io::Async; -/// use std::net::TcpListener; -/// -/// # futures_lite::future::block_on(async { -/// let listener = Async::::bind(([127, 0, 0, 1], 0))?; -/// -/// // These two lines are equivalent: -/// let (stream, addr) = listener.accept().await?; -/// let (stream, addr) = listener.read_with(|inner| inner.accept()).await?; -/// # std::io::Result::Ok(()) }); -/// ``` -#[derive(Debug)] -pub struct Async { - /// A source registered in the reactor. - source: Arc, - - /// The inner I/O handle. - io: Option, -} - -impl Unpin for Async {} - -#[cfg(unix)] -impl Async { - /// Creates an async I/O handle. - /// - /// This method will put the handle in non-blocking mode and register it in - /// [epoll]/[kqueue]/[event ports]/[IOCP]. - /// - /// On Unix systems, the handle must implement `AsRawFd`, while on Windows it must implement - /// `AsRawSocket`. - /// - /// [epoll]: https://en.wikipedia.org/wiki/Epoll - /// [kqueue]: https://en.wikipedia.org/wiki/Kqueue - /// [event ports]: https://illumos.org/man/port_create - /// [IOCP]: https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports - /// - /// # Examples - /// - /// ``` - /// use async_io::Async; - /// use std::net::{SocketAddr, TcpListener}; - /// - /// # futures_lite::future::block_on(async { - /// let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))?; - /// let listener = Async::new(listener)?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn new(io: T) -> io::Result> { - // Put the file descriptor in non-blocking mode. - let fd = io.as_fd(); - cfg_if::cfg_if! { - // ioctl(FIONBIO) sets the flag atomically, but we use this only on Linux - // for now, as with the standard library, because it seems to behave - // differently depending on the platform. - // https://github.com/rust-lang/rust/commit/efeb42be2837842d1beb47b51bb693c7474aba3d - // https://github.com/libuv/libuv/blob/e9d91fccfc3e5ff772d5da90e1c4a24061198ca0/src/unix/poll.c#L78-L80 - // https://github.com/tokio-rs/mio/commit/0db49f6d5caf54b12176821363d154384357e70a - if #[cfg(target_os = "linux")] { - rustix::io::ioctl_fionbio(fd, true)?; - } else { - let previous = rustix::fs::fcntl_getfl(fd)?; - let new = previous | rustix::fs::OFlags::NONBLOCK; - if new != previous { - rustix::fs::fcntl_setfl(fd, new)?; - } - } - } - - // SAFETY: It is impossible to drop the I/O source while it is registered through - // this type. - let registration = unsafe { Registration::new(fd) }; - - Ok(Async { - source: Reactor::get().insert_io(registration)?, - io: Some(io), - }) - } -} - -#[cfg(unix)] -impl AsRawFd for Async { - fn as_raw_fd(&self) -> RawFd { - self.get_ref().as_raw_fd() - } -} - -#[cfg(unix)] -impl AsFd for Async { - fn as_fd(&self) -> BorrowedFd<'_> { - self.get_ref().as_fd() - } -} - -#[cfg(unix)] -impl> TryFrom for Async { - type Error = io::Error; - - fn try_from(value: OwnedFd) -> Result { - Async::new(value.into()) - } -} - -#[cfg(unix)] -impl> TryFrom> for OwnedFd { - type Error = io::Error; - - fn try_from(value: Async) -> Result { - value.into_inner().map(Into::into) - } -} - -#[cfg(windows)] -impl Async { - /// Creates an async I/O handle. - /// - /// This method will put the handle in non-blocking mode and register it in - /// [epoll]/[kqueue]/[event ports]/[IOCP]. - /// - /// On Unix systems, the handle must implement `AsRawFd`, while on Windows it must implement - /// `AsRawSocket`. - /// - /// [epoll]: https://en.wikipedia.org/wiki/Epoll - /// [kqueue]: https://en.wikipedia.org/wiki/Kqueue - /// [event ports]: https://illumos.org/man/port_create - /// [IOCP]: https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports - /// - /// # Examples - /// - /// ``` - /// use async_io::Async; - /// use std::net::{SocketAddr, TcpListener}; - /// - /// # futures_lite::future::block_on(async { - /// let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))?; - /// let listener = Async::new(listener)?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn new(io: T) -> io::Result> { - let borrowed = io.as_socket(); - - // Put the socket in non-blocking mode. - // - // Safety: We assume `as_raw_socket()` returns a valid fd. When we can - // depend on Rust >= 1.63, where `AsFd` is stabilized, and when - // `TimerFd` implements it, we can remove this unsafe and simplify this. - rustix::io::ioctl_fionbio(borrowed, true)?; - - // Create the registration. - // - // SAFETY: It is impossible to drop the I/O source while it is registered through - // this type. - let registration = unsafe { Registration::new(borrowed) }; - - Ok(Async { - source: Reactor::get().insert_io(registration)?, - io: Some(io), - }) - } -} - -#[cfg(windows)] -impl AsRawSocket for Async { - fn as_raw_socket(&self) -> RawSocket { - self.get_ref().as_raw_socket() - } -} - -#[cfg(windows)] -impl AsSocket for Async { - fn as_socket(&self) -> BorrowedSocket<'_> { - self.get_ref().as_socket() - } -} - -#[cfg(windows)] -impl> TryFrom for Async { - type Error = io::Error; - - fn try_from(value: OwnedSocket) -> Result { - Async::new(value.into()) - } -} - -#[cfg(windows)] -impl> TryFrom> for OwnedSocket { - type Error = io::Error; - - fn try_from(value: Async) -> Result { - value.into_inner().map(Into::into) - } -} - -impl Async { - /// Gets a reference to the inner I/O handle. - /// - /// # Examples - /// - /// ``` - /// use async_io::Async; - /// use std::net::TcpListener; - /// - /// # futures_lite::future::block_on(async { - /// let listener = Async::::bind(([127, 0, 0, 1], 0))?; - /// let inner = listener.get_ref(); - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn get_ref(&self) -> &T { - self.io.as_ref().unwrap() - } - - /// Gets a mutable reference to the inner I/O handle. - /// - /// # Safety - /// - /// The underlying I/O source must not be dropped using this function. - /// - /// # Examples - /// - /// ``` - /// use async_io::Async; - /// use std::net::TcpListener; - /// - /// # futures_lite::future::block_on(async { - /// let mut listener = Async::::bind(([127, 0, 0, 1], 0))?; - /// let inner = unsafe { listener.get_mut() }; - /// # std::io::Result::Ok(()) }); - /// ``` - pub unsafe fn get_mut(&mut self) -> &mut T { - self.io.as_mut().unwrap() - } - - /// Unwraps the inner I/O handle. - /// - /// This method will **not** put the I/O handle back into blocking mode. - /// - /// # Examples - /// - /// ``` - /// use async_io::Async; - /// use std::net::TcpListener; - /// - /// # futures_lite::future::block_on(async { - /// let listener = Async::::bind(([127, 0, 0, 1], 0))?; - /// let inner = listener.into_inner()?; - /// - /// // Put the listener back into blocking mode. - /// inner.set_nonblocking(false)?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn into_inner(mut self) -> io::Result { - let io = self.io.take().unwrap(); - Reactor::get().remove_io(&self.source)?; - Ok(io) - } - - /// Waits until the I/O handle is readable. - /// - /// This method completes when a read operation on this I/O handle wouldn't block. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::net::TcpListener; - /// - /// # futures_lite::future::block_on(async { - /// let mut listener = Async::::bind(([127, 0, 0, 1], 0))?; - /// - /// // Wait until a client can be accepted. - /// listener.readable().await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn readable(&self) -> Readable<'_, T> { - Source::readable(self) - } - - /// Waits until the I/O handle is readable. - /// - /// This method completes when a read operation on this I/O handle wouldn't block. - pub fn readable_owned(self: Arc) -> ReadableOwned { - Source::readable_owned(self) - } - - /// Waits until the I/O handle is writable. - /// - /// This method completes when a write operation on this I/O handle wouldn't block. - /// - /// # Examples - /// - /// ``` - /// use async_io::Async; - /// use std::net::{TcpStream, ToSocketAddrs}; - /// - /// # futures_lite::future::block_on(async { - /// let addr = "example.com:80".to_socket_addrs()?.next().unwrap(); - /// let stream = Async::::connect(addr).await?; - /// - /// // Wait until the stream is writable. - /// stream.writable().await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn writable(&self) -> Writable<'_, T> { - Source::writable(self) - } - - /// Waits until the I/O handle is writable. - /// - /// This method completes when a write operation on this I/O handle wouldn't block. - pub fn writable_owned(self: Arc) -> WritableOwned { - Source::writable_owned(self) - } - - /// Polls the I/O handle for readability. - /// - /// When this method returns [`Poll::Ready`], that means the OS has delivered an event - /// indicating readability since the last time this task has called the method and received - /// [`Poll::Pending`]. - /// - /// # Caveats - /// - /// Two different tasks should not call this method concurrently. Otherwise, conflicting tasks - /// will just keep waking each other in turn, thus wasting CPU time. - /// - /// Note that the [`AsyncRead`] implementation for [`Async`] also uses this method. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use futures_lite::future; - /// use std::net::TcpListener; - /// - /// # futures_lite::future::block_on(async { - /// let mut listener = Async::::bind(([127, 0, 0, 1], 0))?; - /// - /// // Wait until a client can be accepted. - /// future::poll_fn(|cx| listener.poll_readable(cx)).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn poll_readable(&self, cx: &mut Context<'_>) -> Poll> { - self.source.poll_readable(cx) - } - - /// Polls the I/O handle for writability. - /// - /// When this method returns [`Poll::Ready`], that means the OS has delivered an event - /// indicating writability since the last time this task has called the method and received - /// [`Poll::Pending`]. - /// - /// # Caveats - /// - /// Two different tasks should not call this method concurrently. Otherwise, conflicting tasks - /// will just keep waking each other in turn, thus wasting CPU time. - /// - /// Note that the [`AsyncWrite`] implementation for [`Async`] also uses this method. - /// - /// # Examples - /// - /// ``` - /// use async_io::Async; - /// use futures_lite::future; - /// use std::net::{TcpStream, ToSocketAddrs}; - /// - /// # futures_lite::future::block_on(async { - /// let addr = "example.com:80".to_socket_addrs()?.next().unwrap(); - /// let stream = Async::::connect(addr).await?; - /// - /// // Wait until the stream is writable. - /// future::poll_fn(|cx| stream.poll_writable(cx)).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn poll_writable(&self, cx: &mut Context<'_>) -> Poll> { - self.source.poll_writable(cx) - } - - /// Performs a read operation asynchronously. - /// - /// The I/O handle is registered in the reactor and put in non-blocking mode. This method - /// invokes the `op` closure in a loop until it succeeds or returns an error other than - /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS - /// sends a notification that the I/O handle is readable. - /// - /// The closure receives a shared reference to the I/O handle. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::net::TcpListener; - /// - /// # futures_lite::future::block_on(async { - /// let listener = Async::::bind(([127, 0, 0, 1], 0))?; - /// - /// // Accept a new client asynchronously. - /// let (stream, addr) = listener.read_with(|l| l.accept()).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn read_with(&self, op: impl FnMut(&T) -> io::Result) -> io::Result { - let mut op = op; - loop { - match op(self.get_ref()) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return res, - } - optimistic(self.readable()).await?; - } - } - - /// Performs a read operation asynchronously. - /// - /// The I/O handle is registered in the reactor and put in non-blocking mode. This method - /// invokes the `op` closure in a loop until it succeeds or returns an error other than - /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS - /// sends a notification that the I/O handle is readable. - /// - /// The closure receives a mutable reference to the I/O handle. - /// - /// # Safety - /// - /// In the closure, the underlying I/O source must not be dropped. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::net::TcpListener; - /// - /// # futures_lite::future::block_on(async { - /// let mut listener = Async::::bind(([127, 0, 0, 1], 0))?; - /// - /// // Accept a new client asynchronously. - /// let (stream, addr) = unsafe { listener.read_with_mut(|l| l.accept()).await? }; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async unsafe fn read_with_mut( - &mut self, - op: impl FnMut(&mut T) -> io::Result, - ) -> io::Result { - let mut op = op; - loop { - match op(self.get_mut()) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return res, - } - optimistic(self.readable()).await?; - } - } + #[cfg(all(target_family = "wasm", not(target_os = "wasi")))] + type Item = (); - /// Performs a write operation asynchronously. - /// - /// The I/O handle is registered in the reactor and put in non-blocking mode. This method - /// invokes the `op` closure in a loop until it succeeds or returns an error other than - /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS - /// sends a notification that the I/O handle is writable. - /// - /// The closure receives a shared reference to the I/O handle. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::net::UdpSocket; - /// - /// # futures_lite::future::block_on(async { - /// let socket = Async::::bind(([127, 0, 0, 1], 8000))?; - /// socket.get_ref().connect("127.0.0.1:9000")?; - /// - /// let msg = b"hello"; - /// let len = socket.write_with(|s| s.send(msg)).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn write_with(&self, op: impl FnMut(&T) -> io::Result) -> io::Result { - let mut op = op; - loop { - match op(self.get_ref()) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return res, - } - optimistic(self.writable()).await?; - } - } - - /// Performs a write operation asynchronously. - /// - /// The I/O handle is registered in the reactor and put in non-blocking mode. This method - /// invokes the `op` closure in a loop until it succeeds or returns an error other than - /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS - /// sends a notification that the I/O handle is writable. - /// - /// # Safety - /// - /// The closure receives a mutable reference to the I/O handle. In the closure, the underlying - /// I/O source must not be dropped. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::net::UdpSocket; - /// - /// # futures_lite::future::block_on(async { - /// let mut socket = Async::::bind(([127, 0, 0, 1], 8000))?; - /// socket.get_ref().connect("127.0.0.1:9000")?; - /// - /// let msg = b"hello"; - /// let len = unsafe { socket.write_with_mut(|s| s.send(msg)).await? }; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async unsafe fn write_with_mut( - &mut self, - op: impl FnMut(&mut T) -> io::Result, - ) -> io::Result { - let mut op = op; - loop { - match op(self.get_mut()) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return res, - } - optimistic(self.writable()).await?; - } - } -} - -impl AsRef for Async { - fn as_ref(&self) -> &T { - self.get_ref() - } -} - -impl Drop for Async { - fn drop(&mut self) { - if self.io.is_some() { - // Deregister and ignore errors because destructors should not panic. - Reactor::get().remove_io(&self.source).ok(); - - // Drop the I/O handle to close it. - self.io.take(); - } - } -} - -/// Types whose I/O trait implementations do not drop the underlying I/O source. -/// -/// The resource contained inside of the [`Async`] cannot be invalidated. This invalidation can -/// happen if the inner resource (the [`TcpStream`], [`UnixListener`] or other `T`) is moved out -/// and dropped before the [`Async`]. Because of this, functions that grant mutable access to -/// the inner type are unsafe, as there is no way to guarantee that the source won't be dropped -/// and a dangling handle won't be left behind. -/// -/// Unfortunately this extends to implementations of [`Read`] and [`Write`]. Since methods on those -/// traits take `&mut`, there is no guarantee that the implementor of those traits won't move the -/// source out while the method is being run. -/// -/// This trait is an antidote to this predicament. By implementing this trait, the user pledges -/// that using any I/O traits won't destroy the source. This way, [`Async`] can implement the -/// `async` version of these I/O traits, like [`AsyncRead`] and [`AsyncWrite`]. -/// -/// # Safety -/// -/// Any I/O trait implementations for this type must not drop the underlying I/O source. Traits -/// affected by this trait include [`Read`], [`Write`], [`Seek`] and [`BufRead`]. -/// -/// This trait is implemented by default on top of `libstd` types. In addition, it is implemented -/// for immutable reference types, as it is impossible to invalidate any outstanding references -/// while holding an immutable reference, even with interior mutability. As Rust's current pinning -/// system relies on similar guarantees, I believe that this approach is robust. -/// -/// [`BufRead`]: https://doc.rust-lang.org/std/io/trait.BufRead.html -/// [`Read`]: https://doc.rust-lang.org/std/io/trait.Read.html -/// [`Seek`]: https://doc.rust-lang.org/std/io/trait.Seek.html -/// [`Write`]: https://doc.rust-lang.org/std/io/trait.Write.html -/// -/// [`AsyncRead`]: https://docs.rs/futures-io/latest/futures_io/trait.AsyncRead.html -/// [`AsyncWrite`]: https://docs.rs/futures-io/latest/futures_io/trait.AsyncWrite.html -pub unsafe trait IoSafe {} - -/// Reference types can't be mutated. -/// -/// The worst thing that can happen is that external state is used to change what kind of pointer -/// `as_fd()` returns. For instance: -/// -/// ``` -/// # #[cfg(unix)] { -/// use std::cell::Cell; -/// use std::net::TcpStream; -/// use std::os::unix::io::{AsFd, BorrowedFd}; -/// -/// struct Bar { -/// flag: Cell, -/// a: TcpStream, -/// b: TcpStream -/// } -/// -/// impl AsFd for Bar { -/// fn as_fd(&self) -> BorrowedFd<'_> { -/// if self.flag.replace(!self.flag.get()) { -/// self.a.as_fd() -/// } else { -/// self.b.as_fd() -/// } -/// } -/// } -/// # } -/// ``` -/// -/// We solve this problem by only calling `as_fd()` once to get the original source. Implementations -/// like this are considered buggy (but not unsound) and are thus not really supported by `async-io`. -unsafe impl IoSafe for &T {} - -// Can be implemented on top of libstd types. -unsafe impl IoSafe for std::fs::File {} -unsafe impl IoSafe for std::io::Stderr {} -unsafe impl IoSafe for std::io::Stdin {} -unsafe impl IoSafe for std::io::Stdout {} -unsafe impl IoSafe for std::io::StderrLock<'_> {} -unsafe impl IoSafe for std::io::StdinLock<'_> {} -unsafe impl IoSafe for std::io::StdoutLock<'_> {} -unsafe impl IoSafe for std::net::TcpStream {} - -#[cfg(unix)] -unsafe impl IoSafe for std::os::unix::net::UnixStream {} - -unsafe impl IoSafe for std::io::BufReader {} -unsafe impl IoSafe for std::io::BufWriter {} -unsafe impl IoSafe for std::io::LineWriter {} -unsafe impl IoSafe for &mut T {} -unsafe impl IoSafe for Box {} -unsafe impl IoSafe for std::borrow::Cow<'_, T> {} - -impl AsyncRead for Async { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - loop { - match unsafe { (*self).get_mut() }.read(buf) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return Poll::Ready(res), - } - ready!(self.poll_readable(cx))?; - } - } - - fn poll_read_vectored( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &mut [IoSliceMut<'_>], - ) -> Poll> { - loop { - match unsafe { (*self).get_mut() }.read_vectored(bufs) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return Poll::Ready(res), - } - ready!(self.poll_readable(cx))?; - } - } -} - -// Since this is through a reference, we can't mutate the inner I/O source. -// Therefore this is safe! -impl AsyncRead for &Async -where - for<'a> &'a T: Read, -{ - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - loop { - match (*self).get_ref().read(buf) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return Poll::Ready(res), - } - ready!(self.poll_readable(cx))?; - } - } - - fn poll_read_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &mut [IoSliceMut<'_>], - ) -> Poll> { - loop { - match (*self).get_ref().read_vectored(bufs) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return Poll::Ready(res), - } - ready!(self.poll_readable(cx))?; - } - } -} - -impl AsyncWrite for Async { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - loop { - match unsafe { (*self).get_mut() }.write(buf) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return Poll::Ready(res), - } - ready!(self.poll_writable(cx))?; - } - } - - fn poll_write_vectored( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[IoSlice<'_>], - ) -> Poll> { - loop { - match unsafe { (*self).get_mut() }.write_vectored(bufs) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return Poll::Ready(res), - } - ready!(self.poll_writable(cx))?; - } - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - loop { - match unsafe { (*self).get_mut() }.flush() { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return Poll::Ready(res), - } - ready!(self.poll_writable(cx))?; - } - } - - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.poll_flush(cx) - } -} - -impl AsyncWrite for &Async -where - for<'a> &'a T: Write, -{ - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - loop { - match (*self).get_ref().write(buf) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return Poll::Ready(res), - } - ready!(self.poll_writable(cx))?; - } - } - - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[IoSlice<'_>], - ) -> Poll> { - loop { - match (*self).get_ref().write_vectored(bufs) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return Poll::Ready(res), - } - ready!(self.poll_writable(cx))?; - } - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - loop { - match (*self).get_ref().flush() { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return Poll::Ready(res), - } - ready!(self.poll_writable(cx))?; - } - } - - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.poll_flush(cx) - } -} - -impl Async { - /// Creates a TCP listener bound to the specified address. - /// - /// Binding with port number 0 will request an available port from the OS. - /// - /// # Examples - /// - /// ``` - /// use async_io::Async; - /// use std::net::TcpListener; - /// - /// # futures_lite::future::block_on(async { - /// let listener = Async::::bind(([127, 0, 0, 1], 0))?; - /// println!("Listening on {}", listener.get_ref().local_addr()?); - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn bind>(addr: A) -> io::Result> { - let addr = addr.into(); - Async::new(TcpListener::bind(addr)?) - } - - /// Accepts a new incoming TCP connection. - /// - /// When a connection is established, it will be returned as a TCP stream together with its - /// remote address. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::net::TcpListener; - /// - /// # futures_lite::future::block_on(async { - /// let listener = Async::::bind(([127, 0, 0, 1], 8000))?; - /// let (stream, addr) = listener.accept().await?; - /// println!("Accepted client: {}", addr); - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn accept(&self) -> io::Result<(Async, SocketAddr)> { - let (stream, addr) = self.read_with(|io| io.accept()).await?; - Ok((Async::new(stream)?, addr)) - } - - /// Returns a stream of incoming TCP connections. - /// - /// The stream is infinite, i.e. it never stops with a [`None`]. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use futures_lite::{pin, stream::StreamExt}; - /// use std::net::TcpListener; - /// - /// # futures_lite::future::block_on(async { - /// let listener = Async::::bind(([127, 0, 0, 1], 8000))?; - /// let incoming = listener.incoming(); - /// pin!(incoming); - /// - /// while let Some(stream) = incoming.next().await { - /// let stream = stream?; - /// println!("Accepted client: {}", stream.get_ref().peer_addr()?); - /// } - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn incoming(&self) -> impl Stream>> + Send + '_ { - stream::unfold(self, |listener| async move { - let res = listener.accept().await.map(|(stream, _)| stream); - Some((res, listener)) - }) - } -} - -impl TryFrom for Async { - type Error = io::Error; - - fn try_from(listener: std::net::TcpListener) -> io::Result { - Async::new(listener) - } -} - -impl Async { - /// Creates a TCP connection to the specified address. - /// - /// # Examples - /// - /// ``` - /// use async_io::Async; - /// use std::net::{TcpStream, ToSocketAddrs}; - /// - /// # futures_lite::future::block_on(async { - /// let addr = "example.com:80".to_socket_addrs()?.next().unwrap(); - /// let stream = Async::::connect(addr).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn connect>(addr: A) -> io::Result> { - // Begin async connect. - let addr = addr.into(); - let domain = Domain::for_address(addr); - let socket = connect(addr.into(), domain, Some(Protocol::TCP))?; - let stream = Async::new(TcpStream::from(socket))?; - - // The stream becomes writable when connected. - stream.writable().await?; - - // Check if there was an error while connecting. - match stream.get_ref().take_error()? { - None => Ok(stream), - Some(err) => Err(err), - } - } - - /// Reads data from the stream without removing it from the buffer. - /// - /// Returns the number of bytes read. Successive calls of this method read the same data. - /// - /// # Examples - /// - /// ``` - /// use async_io::Async; - /// use futures_lite::{io::AsyncWriteExt, stream::StreamExt}; - /// use std::net::{TcpStream, ToSocketAddrs}; - /// - /// # futures_lite::future::block_on(async { - /// let addr = "example.com:80".to_socket_addrs()?.next().unwrap(); - /// let mut stream = Async::::connect(addr).await?; - /// - /// stream - /// .write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") - /// .await?; - /// - /// let mut buf = [0u8; 1024]; - /// let len = stream.peek(&mut buf).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn peek(&self, buf: &mut [u8]) -> io::Result { - self.read_with(|io| io.peek(buf)).await - } -} - -impl TryFrom for Async { - type Error = io::Error; - - fn try_from(stream: std::net::TcpStream) -> io::Result { - Async::new(stream) - } -} - -impl Async { - /// Creates a UDP socket bound to the specified address. - /// - /// Binding with port number 0 will request an available port from the OS. - /// - /// # Examples - /// - /// ``` - /// use async_io::Async; - /// use std::net::UdpSocket; - /// - /// # futures_lite::future::block_on(async { - /// let socket = Async::::bind(([127, 0, 0, 1], 0))?; - /// println!("Bound to {}", socket.get_ref().local_addr()?); - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn bind>(addr: A) -> io::Result> { - let addr = addr.into(); - Async::new(UdpSocket::bind(addr)?) - } - - /// Receives a single datagram message. - /// - /// Returns the number of bytes read and the address the message came from. - /// - /// This method must be called with a valid byte slice of sufficient size to hold the message. - /// If the message is too long to fit, excess bytes may get discarded. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::net::UdpSocket; - /// - /// # futures_lite::future::block_on(async { - /// let socket = Async::::bind(([127, 0, 0, 1], 8000))?; - /// - /// let mut buf = [0u8; 1024]; - /// let (len, addr) = socket.recv_from(&mut buf).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { - self.read_with(|io| io.recv_from(buf)).await - } - - /// Receives a single datagram message without removing it from the queue. - /// - /// Returns the number of bytes read and the address the message came from. - /// - /// This method must be called with a valid byte slice of sufficient size to hold the message. - /// If the message is too long to fit, excess bytes may get discarded. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::net::UdpSocket; - /// - /// # futures_lite::future::block_on(async { - /// let socket = Async::::bind(([127, 0, 0, 1], 8000))?; - /// - /// let mut buf = [0u8; 1024]; - /// let (len, addr) = socket.peek_from(&mut buf).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { - self.read_with(|io| io.peek_from(buf)).await - } - - /// Sends data to the specified address. - /// - /// Returns the number of bytes writen. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::net::UdpSocket; - /// - /// # futures_lite::future::block_on(async { - /// let socket = Async::::bind(([127, 0, 0, 1], 0))?; - /// let addr = socket.get_ref().local_addr()?; - /// - /// let msg = b"hello"; - /// let len = socket.send_to(msg, addr).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn send_to>(&self, buf: &[u8], addr: A) -> io::Result { - let addr = addr.into(); - self.write_with(|io| io.send_to(buf, addr)).await - } - - /// Receives a single datagram message from the connected peer. - /// - /// Returns the number of bytes read. - /// - /// This method must be called with a valid byte slice of sufficient size to hold the message. - /// If the message is too long to fit, excess bytes may get discarded. - /// - /// The [`connect`][`UdpSocket::connect()`] method connects this socket to a remote address. - /// This method will fail if the socket is not connected. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::net::UdpSocket; - /// - /// # futures_lite::future::block_on(async { - /// let socket = Async::::bind(([127, 0, 0, 1], 8000))?; - /// socket.get_ref().connect("127.0.0.1:9000")?; - /// - /// let mut buf = [0u8; 1024]; - /// let len = socket.recv(&mut buf).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn recv(&self, buf: &mut [u8]) -> io::Result { - self.read_with(|io| io.recv(buf)).await - } - - /// Receives a single datagram message from the connected peer without removing it from the - /// queue. - /// - /// Returns the number of bytes read and the address the message came from. - /// - /// This method must be called with a valid byte slice of sufficient size to hold the message. - /// If the message is too long to fit, excess bytes may get discarded. - /// - /// The [`connect`][`UdpSocket::connect()`] method connects this socket to a remote address. - /// This method will fail if the socket is not connected. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::net::UdpSocket; - /// - /// # futures_lite::future::block_on(async { - /// let socket = Async::::bind(([127, 0, 0, 1], 8000))?; - /// socket.get_ref().connect("127.0.0.1:9000")?; - /// - /// let mut buf = [0u8; 1024]; - /// let len = socket.peek(&mut buf).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn peek(&self, buf: &mut [u8]) -> io::Result { - self.read_with(|io| io.peek(buf)).await - } - - /// Sends data to the connected peer. - /// - /// Returns the number of bytes written. - /// - /// The [`connect`][`UdpSocket::connect()`] method connects this socket to a remote address. - /// This method will fail if the socket is not connected. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::net::UdpSocket; - /// - /// # futures_lite::future::block_on(async { - /// let socket = Async::::bind(([127, 0, 0, 1], 8000))?; - /// socket.get_ref().connect("127.0.0.1:9000")?; - /// - /// let msg = b"hello"; - /// let len = socket.send(msg).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn send(&self, buf: &[u8]) -> io::Result { - self.write_with(|io| io.send(buf)).await - } -} - -impl TryFrom for Async { - type Error = io::Error; - - fn try_from(socket: std::net::UdpSocket) -> io::Result { - Async::new(socket) - } -} - -#[cfg(unix)] -impl Async { - /// Creates a UDS listener bound to the specified path. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::os::unix::net::UnixListener; - /// - /// # futures_lite::future::block_on(async { - /// let listener = Async::::bind("/tmp/socket")?; - /// println!("Listening on {:?}", listener.get_ref().local_addr()?); - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn bind>(path: P) -> io::Result> { - let path = path.as_ref().to_owned(); - Async::new(UnixListener::bind(path)?) - } - - /// Accepts a new incoming UDS stream connection. - /// - /// When a connection is established, it will be returned as a stream together with its remote - /// address. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::os::unix::net::UnixListener; - /// - /// # futures_lite::future::block_on(async { - /// let listener = Async::::bind("/tmp/socket")?; - /// let (stream, addr) = listener.accept().await?; - /// println!("Accepted client: {:?}", addr); - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn accept(&self) -> io::Result<(Async, UnixSocketAddr)> { - let (stream, addr) = self.read_with(|io| io.accept()).await?; - Ok((Async::new(stream)?, addr)) - } - - /// Returns a stream of incoming UDS connections. - /// - /// The stream is infinite, i.e. it never stops with a [`None`] item. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use futures_lite::{pin, stream::StreamExt}; - /// use std::os::unix::net::UnixListener; - /// - /// # futures_lite::future::block_on(async { - /// let listener = Async::::bind("/tmp/socket")?; - /// let incoming = listener.incoming(); - /// pin!(incoming); - /// - /// while let Some(stream) = incoming.next().await { - /// let stream = stream?; - /// println!("Accepted client: {:?}", stream.get_ref().peer_addr()?); - /// } - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn incoming(&self) -> impl Stream>> + Send + '_ { - stream::unfold(self, |listener| async move { - let res = listener.accept().await.map(|(stream, _)| stream); - Some((res, listener)) - }) - } -} - -#[cfg(unix)] -impl TryFrom for Async { - type Error = io::Error; - - fn try_from(listener: std::os::unix::net::UnixListener) -> io::Result { - Async::new(listener) - } -} - -#[cfg(unix)] -impl Async { - /// Creates a UDS stream connected to the specified path. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::os::unix::net::UnixStream; - /// - /// # futures_lite::future::block_on(async { - /// let stream = Async::::connect("/tmp/socket").await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn connect>(path: P) -> io::Result> { - // Begin async connect. - let socket = connect(SockAddr::unix(path)?, Domain::UNIX, None)?; - let stream = Async::new(UnixStream::from(socket))?; - - // The stream becomes writable when connected. - stream.writable().await?; - - // On Linux, it appears the socket may become writable even when connecting fails, so we - // must do an extra check here and see if the peer address is retrievable. - stream.get_ref().peer_addr()?; - Ok(stream) - } - - /// Creates an unnamed pair of connected UDS stream sockets. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::os::unix::net::UnixStream; - /// - /// # futures_lite::future::block_on(async { - /// let (stream1, stream2) = Async::::pair()?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn pair() -> io::Result<(Async, Async)> { - let (stream1, stream2) = UnixStream::pair()?; - Ok((Async::new(stream1)?, Async::new(stream2)?)) - } -} - -#[cfg(unix)] -impl TryFrom for Async { - type Error = io::Error; - - fn try_from(stream: std::os::unix::net::UnixStream) -> io::Result { - Async::new(stream) - } -} - -#[cfg(unix)] -impl Async { - /// Creates a UDS datagram socket bound to the specified path. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::os::unix::net::UnixDatagram; - /// - /// # futures_lite::future::block_on(async { - /// let socket = Async::::bind("/tmp/socket")?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn bind>(path: P) -> io::Result> { - let path = path.as_ref().to_owned(); - Async::new(UnixDatagram::bind(path)?) - } - - /// Creates a UDS datagram socket not bound to any address. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::os::unix::net::UnixDatagram; - /// - /// # futures_lite::future::block_on(async { - /// let socket = Async::::unbound()?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn unbound() -> io::Result> { - Async::new(UnixDatagram::unbound()?) - } - - /// Creates an unnamed pair of connected Unix datagram sockets. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::os::unix::net::UnixDatagram; - /// - /// # futures_lite::future::block_on(async { - /// let (socket1, socket2) = Async::::pair()?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub fn pair() -> io::Result<(Async, Async)> { - let (socket1, socket2) = UnixDatagram::pair()?; - Ok((Async::new(socket1)?, Async::new(socket2)?)) - } - - /// Receives data from the socket. - /// - /// Returns the number of bytes read and the address the message came from. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::os::unix::net::UnixDatagram; - /// - /// # futures_lite::future::block_on(async { - /// let socket = Async::::bind("/tmp/socket")?; - /// - /// let mut buf = [0u8; 1024]; - /// let (len, addr) = socket.recv_from(&mut buf).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, UnixSocketAddr)> { - self.read_with(|io| io.recv_from(buf)).await - } - - /// Sends data to the specified address. - /// - /// Returns the number of bytes written. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::os::unix::net::UnixDatagram; - /// - /// # futures_lite::future::block_on(async { - /// let socket = Async::::unbound()?; - /// - /// let msg = b"hello"; - /// let addr = "/tmp/socket"; - /// let len = socket.send_to(msg, addr).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn send_to>(&self, buf: &[u8], path: P) -> io::Result { - self.write_with(|io| io.send_to(buf, &path)).await - } - - /// Receives data from the connected peer. - /// - /// Returns the number of bytes read and the address the message came from. - /// - /// The [`connect`][`UnixDatagram::connect()`] method connects this socket to a remote address. - /// This method will fail if the socket is not connected. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::os::unix::net::UnixDatagram; - /// - /// # futures_lite::future::block_on(async { - /// let socket = Async::::bind("/tmp/socket1")?; - /// socket.get_ref().connect("/tmp/socket2")?; - /// - /// let mut buf = [0u8; 1024]; - /// let len = socket.recv(&mut buf).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn recv(&self, buf: &mut [u8]) -> io::Result { - self.read_with(|io| io.recv(buf)).await - } - - /// Sends data to the connected peer. - /// - /// Returns the number of bytes written. - /// - /// The [`connect`][`UnixDatagram::connect()`] method connects this socket to a remote address. - /// This method will fail if the socket is not connected. - /// - /// # Examples - /// - /// ```no_run - /// use async_io::Async; - /// use std::os::unix::net::UnixDatagram; - /// - /// # futures_lite::future::block_on(async { - /// let socket = Async::::bind("/tmp/socket1")?; - /// socket.get_ref().connect("/tmp/socket2")?; - /// - /// let msg = b"hello"; - /// let len = socket.send(msg).await?; - /// # std::io::Result::Ok(()) }); - /// ``` - pub async fn send(&self, buf: &[u8]) -> io::Result { - self.write_with(|io| io.send(buf)).await - } -} - -#[cfg(unix)] -impl TryFrom for Async { - type Error = io::Error; - - fn try_from(socket: std::os::unix::net::UnixDatagram) -> io::Result { - Async::new(socket) - } -} - -/// Polls a future once, waits for a wakeup, and then optimistically assumes the future is ready. -async fn optimistic(fut: impl Future>) -> io::Result<()> { - let mut polled = false; - pin!(fut); - - future::poll_fn(|cx| { - if !polled { - polled = true; - fut.as_mut().poll(cx) - } else { - Poll::Ready(Ok(())) - } - }) - .await -} - -fn connect(addr: SockAddr, domain: Domain, protocol: Option) -> io::Result { - let sock_type = Type::STREAM; - #[cfg(any( - target_os = "android", - target_os = "dragonfly", - target_os = "freebsd", - target_os = "fuchsia", - target_os = "illumos", - target_os = "linux", - target_os = "netbsd", - target_os = "openbsd" - ))] - // If we can, set nonblocking at socket creation for unix - let sock_type = sock_type.nonblocking(); - // This automatically handles cloexec on unix, no_inherit on windows and nosigpipe on macos - let socket = Socket::new(domain, sock_type, protocol)?; - #[cfg(not(any( - target_os = "android", - target_os = "dragonfly", - target_os = "freebsd", - target_os = "fuchsia", - target_os = "illumos", - target_os = "linux", - target_os = "netbsd", - target_os = "openbsd" - )))] - // If the current platform doesn't support nonblocking at creation, enable it after creation - socket.set_nonblocking(true)?; - match socket.connect(&addr) { - Ok(_) => {} - #[cfg(unix)] - Err(err) if err.raw_os_error() == Some(rustix::io::Errno::INPROGRESS.raw_os_error()) => {} - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - Err(err) => return Err(err), + #[inline] + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.0.poll_next(cx) } - Ok(socket) } diff --git a/src/os/unix.rs b/src/os/unix.rs index ffb0832..d43462e 100644 --- a/src/os/unix.rs +++ b/src/os/unix.rs @@ -62,7 +62,7 @@ pub fn reactor_fd() -> Option> { not(polling_test_poll_backend), ))] { use std::os::unix::io::AsFd; - Some(crate::Reactor::get().poller.as_fd()) + Some(crate::reactor::Reactor::get().poller.as_fd()) } else { None } diff --git a/src/timer/native.rs b/src/timer/native.rs new file mode 100644 index 0000000..2dde232 --- /dev/null +++ b/src/timer/native.rs @@ -0,0 +1,192 @@ +//! Timer implementation for non-Web platforms. + +use std::task::{Context, Poll, Waker}; +use std::time::{Duration, Instant}; + +use crate::reactor::Reactor; + +/// A timer for non-Web platforms. +/// +/// self registers a timeout in the global reactor, which in turn sets a timeout in the poll call. +#[derive(Debug)] +pub(super) struct Timer { + /// self timer's ID and last waker that polled it. + /// + /// When self field is set to `None`, self timer is not registered in the reactor. + id_and_waker: Option<(usize, Waker)>, + + /// The next instant at which self timer fires. + /// + /// If self timer is a blank timer, self value is None. If the timer + /// must be set, self value contains the next instant at which the + /// timer must fire. + when: Option, + + /// The period. + period: Duration, +} + +impl Timer { + /// Create a timer that will never fire. + #[inline] + pub(super) fn never() -> Self { + Self { + id_and_waker: None, + when: None, + period: Duration::MAX, + } + } + + /// Create a timer that will fire at the given instant. + #[inline] + pub(super) fn after(duration: Duration) -> Timer { + Instant::now() + .checked_add(duration) + .map_or_else(Timer::never, Timer::at) + } + + /// Create a timer that will fire at the given instant. + #[inline] + pub(super) fn at(instant: Instant) -> Timer { + Timer::interval_at(instant, Duration::MAX) + } + + /// Create a timer that will fire at the given instant. + #[inline] + pub(super) fn interval(period: Duration) -> Timer { + Instant::now() + .checked_add(period) + .map_or_else(Timer::never, |at| Timer::interval_at(at, period)) + } + + /// Create a timer that will fire at self interval at self point. + #[inline] + pub(super) fn interval_at(start: Instant, period: Duration) -> Timer { + Timer { + id_and_waker: None, + when: Some(start), + period, + } + } + + /// Returns `true` if self timer will fire at some point. + #[inline] + pub(super) fn will_fire(&self) -> bool { + self.when.is_some() + } + + /// Set the timer to fire after the given duration. + #[inline] + pub(super) fn set_after(&mut self, duration: Duration) { + match Instant::now().checked_add(duration) { + Some(instant) => self.set_at(instant), + None => { + // Overflow to never going off. + self.clear(); + self.when = None; + } + } + } + + /// Set the timer to fire at the given instant. + #[inline] + pub(super) fn set_at(&mut self, instant: Instant) { + self.clear(); + + // Update the timeout. + self.when = Some(instant); + + if let Some((id, waker)) = self.id_and_waker.as_mut() { + // Re-register the timer with the new timeout. + *id = Reactor::get().insert_timer(instant, waker); + } + } + + /// Set the timer to emit events periodically. + #[inline] + pub(super) fn set_interval(&mut self, period: Duration) { + match Instant::now().checked_add(period) { + Some(instant) => self.set_interval_at(instant, period), + None => { + // Overflow to never going off. + self.clear(); + self.when = None; + } + } + } + + /// Set the timer to emit events periodically starting at a given instant. + #[inline] + pub(super) fn set_interval_at(&mut self, start: Instant, period: Duration) { + self.clear(); + + self.when = Some(start); + self.period = period; + + if let Some((id, waker)) = self.id_and_waker.as_mut() { + // Re-register the timer with the new timeout. + *id = Reactor::get().insert_timer(start, waker); + } + } + + /// Helper function to clear the current timer. + #[inline] + fn clear(&mut self) { + if let (Some(when), Some((id, _))) = (self.when, self.id_and_waker.as_ref()) { + // Deregister the timer from the reactor. + Reactor::get().remove_timer(when, *id); + } + } + + /// Poll for the next timer event. + #[inline] + pub(super) fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll> { + if let Some(ref mut when) = self.when { + // Check if the timer has already fired. + if Instant::now() >= *when { + if let Some((id, _)) = self.id_and_waker.take() { + // Deregister the timer from the reactor. + Reactor::get().remove_timer(*when, id); + } + let result_time = *when; + if let Some(next) = (*when).checked_add(self.period) { + *when = next; + // Register the timer in the reactor. + let id = Reactor::get().insert_timer(next, cx.waker()); + self.id_and_waker = Some((id, cx.waker().clone())); + } else { + self.when = None; + } + return Poll::Ready(Some(result_time)); + } else { + match &self.id_and_waker { + None => { + // Register the timer in the reactor. + let id = Reactor::get().insert_timer(*when, cx.waker()); + self.id_and_waker = Some((id, cx.waker().clone())); + } + Some((id, w)) if !w.will_wake(cx.waker()) => { + // Deregister the timer from the reactor to remove the old waker. + Reactor::get().remove_timer(*when, *id); + + // Register the timer in the reactor with the new waker. + let id = Reactor::get().insert_timer(*when, cx.waker()); + self.id_and_waker = Some((id, cx.waker().clone())); + } + Some(_) => {} + } + } + } + + Poll::Pending + } +} + +impl Drop for Timer { + fn drop(&mut self) { + if let (Some(when), Some((id, _))) = (self.when, self.id_and_waker.take()) { + // Deregister the timer from the reactor. + Reactor::get().remove_timer(when, id); + } + } +} diff --git a/src/timer/web.rs b/src/timer/web.rs new file mode 100644 index 0000000..4fdfc19 --- /dev/null +++ b/src/timer/web.rs @@ -0,0 +1,212 @@ +//! Timers for web targets. +//! +//! These use the `setTimeout` function on the web to handle timing. + +use std::convert::TryInto; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::time::Duration; + +use atomic_waker::AtomicWaker; +use wasm_bindgen::closure::Closure; +use wasm_bindgen::JsCast; + +/// A timer for non-Web platforms. +/// +/// self registers a timeout in the global reactor, which in turn sets a timeout in the poll call. +#[derive(Debug)] +pub(super) struct Timer { + /// The waker to wake when the timer fires. + waker: Arc, + + /// The ongoing timeout or interval. + ongoing_timeout: TimerId, + + /// Keep the closure alive so we don't drop it. + closure: Option>, +} + +#[derive(Debug)] +struct State { + /// The number of times this timer has been woken. + woken: AtomicUsize, + + /// The waker to wake when the timer fires. + waker: AtomicWaker, +} + +#[derive(Debug)] +enum TimerId { + NoTimer, + Timeout(i32), + Interval(i32), +} + +impl Timer { + /// Create a timer that will never fire. + #[inline] + pub(super) fn never() -> Self { + Self { + waker: Arc::new(State { + woken: AtomicUsize::new(0), + waker: AtomicWaker::new(), + }), + ongoing_timeout: TimerId::NoTimer, + closure: None, + } + } + + /// Create a timer that will fire at the given instant. + #[inline] + pub(super) fn after(duration: Duration) -> Timer { + let mut this = Self::never(); + this.set_after(duration); + this + } + + /// Create a timer that will fire at the given instant. + #[inline] + pub(super) fn interval(period: Duration) -> Timer { + let mut this = Self::never(); + this.set_interval(period); + this + } + + /// Returns `true` if self timer will fire at some point. + #[inline] + pub(super) fn will_fire(&self) -> bool { + matches!( + self.ongoing_timeout, + TimerId::Timeout(_) | TimerId::Interval(_) + ) + } + + /// Set the timer to fire after the given duration. + #[inline] + pub(super) fn set_after(&mut self, duration: Duration) { + // Set the timeout. + let id = { + let waker = self.waker.clone(); + let closure: Closure = Closure::wrap(Box::new(move || { + waker.wake(); + })); + + let result = web_sys::window() + .unwrap() + .set_timeout_with_callback_and_timeout_and_arguments_0( + closure.as_ref().unchecked_ref(), + duration.as_millis().try_into().expect("timeout too long"), + ); + + // Make sure we don't drop the closure before it's called. + self.closure = Some(closure); + + match result { + Ok(id) => id, + Err(_) => { + panic!("failed to set timeout") + } + } + }; + + // Set our ID. + self.ongoing_timeout = TimerId::Timeout(id); + } + + /// Set the timer to emit events periodically. + #[inline] + pub(super) fn set_interval(&mut self, period: Duration) { + // Set the timeout. + let id = { + let waker = self.waker.clone(); + let closure: Closure = Closure::wrap(Box::new(move || { + waker.wake(); + })); + + let result = web_sys::window() + .unwrap() + .set_interval_with_callback_and_timeout_and_arguments_0( + closure.as_ref().unchecked_ref(), + period.as_millis().try_into().expect("timeout too long"), + ); + + // Make sure we don't drop the closure before it's called. + self.closure = Some(closure); + + match result { + Ok(id) => id, + Err(_) => { + panic!("failed to set interval") + } + } + }; + + // Set our ID. + self.ongoing_timeout = TimerId::Interval(id); + } + + /// Poll for the next timer event. + #[inline] + pub(super) fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll> { + let mut registered = false; + let mut woken = self.waker.woken.load(Ordering::Acquire); + + loop { + if woken > 0 { + // Try to decrement the number of woken events. + if let Err(new_woken) = self.waker.woken.compare_exchange( + woken, + woken - 1, + Ordering::SeqCst, + Ordering::Acquire, + ) { + woken = new_woken; + continue; + } + + // If we are using a one-shot timer, clear it. + if let TimerId::Timeout(_) = self.ongoing_timeout { + self.clear(); + } + + return Poll::Ready(Some(())); + } + + if !registered { + // Register the waker. + self.waker.waker.register(cx.waker()); + registered = true; + } else { + // We've already registered, so we can just return pending. + return Poll::Pending; + } + } + } + + /// Clear the current timeout. + fn clear(&mut self) { + match self.ongoing_timeout { + TimerId::NoTimer => {} + TimerId::Timeout(id) => { + web_sys::window().unwrap().clear_timeout_with_handle(id); + } + TimerId::Interval(id) => { + web_sys::window().unwrap().clear_interval_with_handle(id); + } + } + } +} + +impl State { + fn wake(&self) { + self.woken.fetch_add(1, Ordering::SeqCst); + self.waker.wake(); + } +} + +impl Drop for Timer { + fn drop(&mut self) { + self.clear(); + } +} diff --git a/tests/async.rs b/tests/async.rs index c856760..bfb5a30 100644 --- a/tests/async.rs +++ b/tests/async.rs @@ -1,3 +1,5 @@ +#![cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] + use std::future::Future; use std::io; use std::net::{Shutdown, TcpListener, TcpStream, UdpSocket}; diff --git a/tests/block_on.rs b/tests/block_on.rs index 70241f0..330ce6d 100644 --- a/tests/block_on.rs +++ b/tests/block_on.rs @@ -1,3 +1,5 @@ +#![cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] + use async_io::block_on; use std::{ future::Future, diff --git a/tests/timer.rs b/tests/timer.rs index cdd90db..f08858e 100644 --- a/tests/timer.rs +++ b/tests/timer.rs @@ -1,12 +1,26 @@ use std::future::Future; +#[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] use std::pin::Pin; +#[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] use std::sync::{Arc, Mutex}; +#[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] use std::thread; + +#[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] use std::time::{Duration, Instant}; +#[cfg(all(target_family = "wasm", not(target_os = "wasi")))] +use web_time::{Duration, Instant}; + +#[cfg(all(target_family = "wasm", not(target_os = "wasi")))] +wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser); use async_io::Timer; -use futures_lite::{future, FutureExt, StreamExt}; +use futures_lite::{FutureExt, StreamExt}; + +#[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] +use futures_lite::future; +#[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] fn spawn( f: impl Future + Send + 'static, ) -> impl Future + Send + 'static { @@ -21,18 +35,60 @@ fn spawn( Box::pin(async move { r.recv().await.unwrap() }) } -#[test] -fn smoke() { - future::block_on(async { +#[cfg(all(target_family = "wasm", not(target_os = "wasi")))] +fn spawn(f: impl Future + 'static) -> impl Future + 'static { + let (s, r) = async_channel::bounded(1); + + #[cfg(all(target_family = "wasm", not(target_os = "wasi")))] + wasm_bindgen_futures::spawn_local(async move { + s.send(f.await).await.ok(); + }); + + Box::pin(async move { r.recv().await.unwrap() }) +} + +#[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] +macro_rules! test { + ( + $(#[$meta:meta])* + async fn $name:ident () $bl:block + ) => { + #[test] + $(#[$meta])* + fn $name() { + futures_lite::future::block_on(async { + $bl + }) + } + }; +} + +#[cfg(all(target_family = "wasm", not(target_os = "wasi")))] +macro_rules! test { + ( + $(#[$meta:meta])* + async fn $name:ident () $bl:block + ) => { + // wasm-bindgen-test handles waiting on the future for us + #[wasm_bindgen_test::wasm_bindgen_test] + $(#[$meta])* + async fn $name() { + console_error_panic_hook::set_once(); + $bl + } + }; +} + +test! { + async fn smoke() { let start = Instant::now(); Timer::after(Duration::from_secs(1)).await; assert!(start.elapsed() >= Duration::from_secs(1)); - }); + } } -#[test] -fn interval() { - future::block_on(async { +test! { + async fn interval() { let period = Duration::from_secs(1); let jitter = Duration::from_millis(500); let start = Instant::now(); @@ -43,12 +99,11 @@ fn interval() { timer.next().await; let elapsed = start.elapsed(); assert!(elapsed >= period * 2 && elapsed - period * 2 < jitter); - }); + } } -#[test] -fn poll_across_tasks() { - future::block_on(async { +test! { + async fn poll_across_tasks() { let start = Instant::now(); let (sender, receiver) = async_channel::bounded(1); @@ -74,9 +129,10 @@ fn poll_across_tasks() { task2.await; assert!(start.elapsed() >= Duration::from_secs(1)); - }); + } } +#[cfg(not(all(target_family = "wasm", not(target_os = "wasi"))))] #[test] fn set() { future::block_on(async {