diff --git a/Cargo.lock b/Cargo.lock index a630721..c983b60 100755 --- a/Cargo.lock +++ b/Cargo.lock @@ -66,6 +66,12 @@ version = "4.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b75356056920673b02621b35afd0f7dda9306d03c79a30f5c56c44cf256e3de" +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + [[package]] name = "autocfg" version = "1.3.0" @@ -360,19 +366,21 @@ checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" [[package]] name = "hyper" -version = "1.5.2" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "256fb8d4bd6413123cc9d91832d78325c48ff41677595be797d90f42969beae0" +checksum = "1744436df46f0bde35af3eda22aeaba453aada65d8f1c171cd8a5f59030bd69f" dependencies = [ + "atomic-waker", "bytes", "futures-channel", - "futures-util", + "futures-core", "http", "http-body", "httparse", "httpdate", "itoa", "pin-project-lite", + "pin-utils", "smallvec", "tokio", "want", @@ -404,18 +412,20 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.10" +version = "0.1.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df2dcfbe0677734ab2f3ffa7fa7bfd4706bfdc1ef393f2ee30184aed67e631b4" +checksum = "3c6995591a8f1380fcb4ba966a252a4b29188d51d2b89e3a252f5305be65aea8" dependencies = [ "bytes", "futures-channel", + "futures-core", "futures-util", "http", "http-body", "hyper", + "libc", "pin-project-lite", - "socket2", + "socket2 0.6.1", "tokio", "tower-service", "tracing", @@ -429,9 +439,9 @@ checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" [[package]] name = "libc" -version = "0.2.169" +version = "0.2.177" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" +checksum = "2874a2af47a2325c2001a6e6fad9b16a53b802102b528163885171cf92b15976" [[package]] name = "linux-raw-sys" @@ -611,6 +621,16 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "socket2" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17129e116933cf371d018bb80ae557e889637989d8638274fb25622827b03881" +dependencies = [ + "libc", + "windows-sys 0.60.2", +] + [[package]] name = "syn" version = "2.0.71" @@ -633,7 +653,7 @@ dependencies = [ "libc", "mio", "pin-project-lite", - "socket2", + "socket2 0.5.7", "tokio-macros", "windows-sys 0.52.0", ] @@ -733,13 +753,19 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + [[package]] name = "windows-sys" version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets", + "windows-targets 0.52.6", ] [[package]] @@ -748,7 +774,16 @@ version = "0.59.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" dependencies = [ - "windows-targets", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" +dependencies = [ + "windows-targets 0.53.5", ] [[package]] @@ -757,14 +792,31 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_gnullvm", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm 0.52.6", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + +[[package]] +name = "windows-targets" +version = "0.53.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" +dependencies = [ + "windows-link", + "windows_aarch64_gnullvm 0.53.1", + "windows_aarch64_msvc 0.53.1", + "windows_i686_gnu 0.53.1", + "windows_i686_gnullvm 0.53.1", + "windows_i686_msvc 0.53.1", + "windows_x86_64_gnu 0.53.1", + "windows_x86_64_gnullvm 0.53.1", + "windows_x86_64_msvc 0.53.1", ] [[package]] @@ -773,44 +825,92 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" + [[package]] name = "windows_aarch64_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" + [[package]] name = "windows_i686_gnu" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" +[[package]] +name = "windows_i686_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" + [[package]] name = "windows_i686_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" +[[package]] +name = "windows_i686_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" + [[package]] name = "windows_i686_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" +[[package]] +name = "windows_i686_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" + [[package]] name = "windows_x86_64_gnu" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" +[[package]] +name = "windows_x86_64_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" + [[package]] name = "windows_x86_64_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" + [[package]] name = "windows_x86_64_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" diff --git a/Cargo.toml b/Cargo.toml index a507b2d..527d079 100755 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,7 @@ version = "0.6.0" edition = "2021" license = "MIT" description = "A Hyper client library allowing access to Unix, VSock and Firecracker sockets" -rust-version = "1.63" +rust-version = "1.74" repository = "https://github.com/rust-firecracker/hyper-client-sockets" readme = "README.md" keywords = ["hyper", "client", "sockets"] @@ -19,7 +19,7 @@ hyper = { version = "1.5.2", default-features = false } # hyper-util support tower-service = { version = "0.3.3", optional = true } -hyper-util = { version = "0.1.10", optional = true, default-features = false } +hyper-util = { version = "0.1.17", optional = true, default-features = false } hex = { version = "0.4.3", optional = true } http = { version = "1.2.0", optional = true } # tokio backend @@ -31,7 +31,7 @@ smol-hyper = { version = "0.1.1", optional = true, default-features = false, fea "async-io", ] } futures-lite = { version = "2.6.0", optional = true } -# vsock sockets +# VSOCK sockets vsock = { version = "0.5.1", optional = true } [dev-dependencies] @@ -41,7 +41,7 @@ tokio = { version = "1.43.0", features = ["macros", "fs"] } async-executor = "1.13.1" # hyper utils hyper = { version = "1.5.2", features = ["server", "http1"] } -hyper-util = { version = "0.1.10", features = [ +hyper-util = { version = "0.1.17", features = [ "client", "client-legacy", "http1", diff --git a/README.md b/README.md index 40f7e0f..182ed4c 100755 --- a/README.md +++ b/README.md @@ -1,12 +1,12 @@ ## Hyper Client Sockets -Before hyper v1, hyperlocal was the most convenient solution to use Unix sockets for both client and server. With hyper v1, server socket support is no longer needed (just use `UnixListener` or `VsockListener` instead of `TcpListener`), yet hyperlocal still has it and hasn't received a release since several years. +Before hyper v1, hyperlocal was the most convenient solution to use Unix Domain sockets for both client and server. With hyper v1, server socket support is no longer needed (just use `UnixListener` or `VsockListener` instead of `TcpListener`), yet hyperlocal still has it and hasn't received a release since several years. This library provides hyper v1 client support for: - Unix (`AF_UNIX`) sockets (`HyperUnixStream` implementing hyper traits) - VSock (`AF_VSOCK`) sockets (`HyperVsockStream` implementing hyper traits) -- Firecracker Unix sockets that need `CONNECT` commands in order to establish a tunnel (`HyperFirecrackerStream` implementing hyper traits) +- Firecracker Unix Domain sockets that need `CONNECT` commands in order to establish a tunnel (`HyperFirecrackerStream` implementing hyper traits) Additionally, the library supports different async I/O backends: diff --git a/src/async_io.rs b/src/async_io.rs deleted file mode 100644 index feb22aa..0000000 --- a/src/async_io.rs +++ /dev/null @@ -1,235 +0,0 @@ -#[cfg(any(feature = "unix", feature = "firecracker"))] -use std::path::Path; -#[cfg(feature = "vsock")] -use std::{ - io::{Read, Write}, - os::fd::{AsRawFd, FromRawFd, IntoRawFd}, - pin::Pin, - task::{Context, Poll}, -}; - -#[cfg(any(feature = "unix", feature = "vsock", feature = "firecracker"))] -use async_io::Async; -#[cfg(feature = "firecracker")] -use futures_lite::{io::BufReader, AsyncBufReadExt, AsyncWriteExt, StreamExt}; -#[cfg(any(feature = "unix", feature = "firecracker"))] -use smol_hyper::rt::FuturesIo; -#[cfg(feature = "vsock")] -use vsock::VsockAddr; - -use crate::Backend; - -/// [Backend] for hyper-client-sockets that is implemented via the async-io crate's reactor. -#[derive(Debug, Clone)] -pub struct AsyncIoBackend; - -impl Backend for AsyncIoBackend { - #[cfg(feature = "unix")] - #[cfg_attr(docsrs, doc(cfg(feature = "unix")))] - type UnixIo = FuturesIo>; - - #[cfg(feature = "vsock")] - #[cfg_attr(docsrs, doc(cfg(feature = "vsock")))] - type VsockIo = AsyncVsockIo; - - #[cfg(feature = "firecracker")] - #[cfg_attr(docsrs, doc(cfg(feature = "firecracker")))] - type FirecrackerIo = FuturesIo>; - - #[cfg(feature = "unix")] - #[cfg_attr(docsrs, doc(cfg(feature = "unix")))] - async fn connect_to_unix_socket(socket_path: &Path) -> Result { - Ok(FuturesIo::new( - Async::::connect(socket_path).await?, - )) - } - - #[cfg(feature = "vsock")] - #[cfg_attr(docsrs, doc(cfg(feature = "vsock")))] - async fn connect_to_vsock_socket(addr: vsock::VsockAddr) -> Result { - Ok(AsyncVsockIo::connect(addr).await?) - } - - #[cfg(feature = "firecracker")] - #[cfg_attr(docsrs, doc(cfg(feature = "firecracker")))] - async fn connect_to_firecracker_socket( - host_socket_path: &Path, - guest_port: u32, - ) -> Result { - let mut stream = Async::::connect(host_socket_path).await?; - stream.write_all(format!("CONNECT {guest_port}\n").as_bytes()).await?; - - let mut lines = BufReader::new(&mut stream).lines(); - match lines.next().await { - Some(Ok(line)) => { - if !line.starts_with("OK") { - return Err(std::io::Error::new( - std::io::ErrorKind::ConnectionRefused, - "Firecracker refused to establish a tunnel to the given guest port", - )); - } - } - _ => { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidInput, - "Could not read Firecracker response", - )) - } - }; - - Ok(FuturesIo::new(stream)) - } -} - -#[cfg(feature = "vsock")] -#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))] -pub struct AsyncVsockIo(Async); - -#[cfg(feature = "vsock")] -#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))] -impl AsyncVsockIo { - async fn connect(addr: VsockAddr) -> Result { - let socket = unsafe { libc::socket(libc::AF_VSOCK, libc::SOCK_STREAM, 0) }; - if socket < 0 { - return Err(std::io::Error::last_os_error()); - } - - if unsafe { libc::fcntl(socket, libc::F_SETFL, libc::O_NONBLOCK | libc::O_CLOEXEC) } < 0 { - let _ = unsafe { libc::close(socket) }; - return Err(std::io::Error::last_os_error()); - } - - if unsafe { - libc::connect( - socket, - &addr as *const _ as *const libc::sockaddr, - size_of::() as libc::socklen_t, - ) - } < 0 - { - let err = std::io::Error::last_os_error(); - if let Some(os_err) = err.raw_os_error() { - if os_err != libc::EINPROGRESS { - let _ = unsafe { libc::close(socket) }; - return Err(err); - } - } - } - - let async_fd = Async::new(unsafe { std::fs::File::from_raw_fd(socket) })?; - - loop { - let connection_check = async_fd.write_with(|fd| { - let mut sock_err: libc::c_int = 0; - let mut sock_err_len: libc::socklen_t = size_of::() as libc::socklen_t; - let err = unsafe { - libc::getsockopt( - fd.as_raw_fd(), - libc::SOL_SOCKET, - libc::SO_ERROR, - &mut sock_err as *mut _ as *mut libc::c_void, - &mut sock_err_len as *mut libc::socklen_t, - ) - }; - - if err < 0 { - return Err(std::io::Error::last_os_error()); - } - - if sock_err == 0 { - Ok(()) - } else { - Err(std::io::Error::from_raw_os_error(sock_err)) - } - }); - - match connection_check.await { - Ok(_) => { - return Ok(AsyncVsockIo(Async::new(unsafe { - std::fs::File::from_raw_fd(async_fd.into_inner()?.into_raw_fd()) - })?)) - } - Err(err) - if err.kind() == std::io::ErrorKind::WouldBlock - || err.kind() == std::io::ErrorKind::Interrupted => - { - continue - } - Err(err) => return Err(err), - } - } - } -} - -#[cfg(feature = "vsock")] -#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))] -impl hyper::rt::Write for AsyncVsockIo { - fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { - loop { - match self.0.poll_writable(cx) { - Poll::Ready(Ok(_)) => {} - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), - Poll::Pending => return Poll::Pending, - }; - - match self.0.get_ref().write(buf) { - Ok(amount) => return Poll::Ready(Ok(amount)), - Err(ref err) - if err.kind() == std::io::ErrorKind::Interrupted - || err.kind() == std::io::ErrorKind::WouldBlock => - { - continue - } - Err(err) => return Poll::Ready(Err(err)), - } - } - } - - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } -} - -#[cfg(feature = "vsock")] -#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))] -impl hyper::rt::Read for AsyncVsockIo { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - mut buf: hyper::rt::ReadBufCursor<'_>, - ) -> Poll> { - let b; - unsafe { - b = &mut *(buf.as_mut() as *mut [std::mem::MaybeUninit] as *mut [u8]); - }; - - loop { - match self.0.poll_readable(cx) { - Poll::Ready(Ok(_)) => {} - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), - Poll::Pending => return Poll::Pending, - }; - - match self.0.get_ref().read(b) { - Ok(amount) => { - unsafe { - buf.advance(amount); - } - - return Poll::Ready(Ok(())); - } - Err(ref err) - if err.kind() == std::io::ErrorKind::Interrupted - || err.kind() == std::io::ErrorKind::WouldBlock => - { - continue - } - Err(err) => return Poll::Ready(Err(err)), - } - } - } -} diff --git a/src/async_io/firecracker.rs b/src/async_io/firecracker.rs new file mode 100644 index 0000000..4c4b227 --- /dev/null +++ b/src/async_io/firecracker.rs @@ -0,0 +1,63 @@ +use std::{ + io::Result, + ops::{Deref, DerefMut}, + os::unix::net::UnixStream, + path::Path, +}; + +use async_io::Async; +use futures_lite::{io::BufReader, AsyncBufReadExt as _, AsyncWriteExt as _, StreamExt as _}; +use smol_hyper::rt::FuturesIo; + +use crate::utils::{ + firecracker::{format_request, parse_connection_response}, + hyper_io_by_deref, hyper_util_connection_default, +}; + +pub type AsyncFirecrackerIoInner = FuturesIo>; + +#[derive(Debug)] +pub struct AsyncFirecrackerIo(pub AsyncFirecrackerIoInner); + +impl AsyncFirecrackerIo { + pub(super) async fn connect

(host_socket_path: P, guest_port: u32) -> Result + where + P: AsRef, + { + let mut stream = Async::::connect(host_socket_path).await?; + stream.write_all(format_request(guest_port).as_bytes()).await?; + let response = BufReader::new(&mut stream).lines().next().await.transpose(); + parse_connection_response(stream, response) + .map(FuturesIo::new) + .map(Self) + } +} + +impl Deref for AsyncFirecrackerIo { + type Target = AsyncFirecrackerIoInner; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for AsyncFirecrackerIo { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl From for AsyncFirecrackerIo { + fn from(inner: AsyncFirecrackerIoInner) -> Self { + Self(inner) + } +} + +impl From for AsyncFirecrackerIoInner { + fn from(AsyncFirecrackerIo(inner): AsyncFirecrackerIo) -> Self { + inner + } +} + +hyper_io_by_deref!(AsyncFirecrackerIo); +hyper_util_connection_default!(AsyncFirecrackerIo); diff --git a/src/async_io/mod.rs b/src/async_io/mod.rs new file mode 100644 index 0000000..ab49ea3 --- /dev/null +++ b/src/async_io/mod.rs @@ -0,0 +1,70 @@ +#[cfg(feature = "firecracker")] +#[cfg_attr(docsrs, doc(cfg(feature = "firecracker")))] +pub mod firecracker; + +#[cfg(feature = "unix")] +#[cfg_attr(docsrs, doc(cfg(feature = "unix")))] +pub mod unix; + +#[cfg(feature = "vsock")] +#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))] +pub mod vsock; + +#[cfg(any(feature = "firecracker", feature = "unix", feature = "vsock"))] +use std::io::Result; + +#[cfg(any(feature = "firecracker", feature = "unix"))] +use std::path::Path; + +#[cfg(feature = "vsock")] +use ::vsock::VsockAddr; + +#[cfg(feature = "firecracker")] +#[cfg_attr(docsrs, doc(cfg(feature = "firecracker")))] +pub use self::firecracker::{AsyncFirecrackerIo, AsyncFirecrackerIoInner}; + +#[cfg(feature = "unix")] +#[cfg_attr(docsrs, doc(cfg(feature = "unix")))] +pub use self::unix::{AsyncUnixIo, AsyncUnixIoInner}; + +#[cfg(feature = "vsock")] +#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))] +pub use self::vsock::{AsyncVsockIo, AsyncVsockIoInner}; + +use crate::Backend; + +/// [`Backend`] for hyper-client-sockets that is implemented via the async-io crate's reactor. +#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub struct AsyncIoBackend; + +impl Backend for AsyncIoBackend { + #[cfg(feature = "firecracker")] + #[cfg_attr(docsrs, doc(cfg(feature = "firecracker")))] + type FirecrackerIo = AsyncFirecrackerIo; + + #[cfg(feature = "unix")] + #[cfg_attr(docsrs, doc(cfg(feature = "unix")))] + type UnixIo = AsyncUnixIo; + + #[cfg(feature = "vsock")] + #[cfg_attr(docsrs, doc(cfg(feature = "vsock")))] + type VsockIo = AsyncVsockIo; + + #[cfg(feature = "firecracker")] + #[cfg_attr(docsrs, doc(cfg(feature = "firecracker")))] + async fn connect_to_firecracker_socket(host_socket_path: &Path, guest_port: u32) -> Result { + Self::FirecrackerIo::connect(host_socket_path, guest_port).await + } + + #[cfg(feature = "unix")] + #[cfg_attr(docsrs, doc(cfg(feature = "unix")))] + async fn connect_to_unix_socket(socket_path: &Path) -> Result { + Self::UnixIo::connect(socket_path).await + } + + #[cfg(feature = "vsock")] + #[cfg_attr(docsrs, doc(cfg(feature = "vsock")))] + async fn connect_to_vsock_socket(addr: VsockAddr) -> Result { + Self::VsockIo::connect(addr).await + } +} diff --git a/src/async_io/unix.rs b/src/async_io/unix.rs new file mode 100644 index 0000000..a21d813 --- /dev/null +++ b/src/async_io/unix.rs @@ -0,0 +1,57 @@ +use std::{ + io::Result, + ops::{Deref, DerefMut}, + os::unix::net::UnixStream, + path::Path, +}; + +use async_io::Async; +use smol_hyper::rt::FuturesIo; + +use crate::utils::{hyper_io_by_deref, hyper_util_connection_default}; + +pub type AsyncUnixIoInner = FuturesIo>; + +#[derive(Debug)] +pub struct AsyncUnixIo(pub AsyncUnixIoInner); + +impl AsyncUnixIo { + pub(super) async fn connect

(socket_path: P) -> Result + where + P: AsRef, + { + Async::::connect(socket_path) + .await + .map(FuturesIo::new) + .map(Self) + } +} + +impl Deref for AsyncUnixIo { + type Target = AsyncUnixIoInner; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for AsyncUnixIo { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl From for AsyncUnixIo { + fn from(inner: AsyncUnixIoInner) -> Self { + Self(inner) + } +} + +impl From for AsyncUnixIoInner { + fn from(AsyncUnixIo(inner): AsyncUnixIo) -> Self { + inner + } +} + +hyper_io_by_deref!(AsyncUnixIo); +hyper_util_connection_default!(AsyncUnixIo); diff --git a/src/async_io/vsock.rs b/src/async_io/vsock.rs new file mode 100644 index 0000000..af37613 --- /dev/null +++ b/src/async_io/vsock.rs @@ -0,0 +1,133 @@ +use std::{ + fs::File, + io::{ErrorKind, Read as _, Result, Write as _}, + mem::MaybeUninit, + ops::{Deref, DerefMut}, + os::fd::{AsRawFd as _, FromRawFd as _, IntoRawFd as _}, + pin::Pin, + task::{Context, Poll}, +}; + +use async_io::Async; +use hyper::rt::{Read, ReadBufCursor, Write}; +use vsock::VsockAddr; + +use crate::utils::{ + hyper_util_connection_default, + vsock::{check_connection, raw_connect, try_advance_cursor, try_poll_write}, +}; + +pub type AsyncVsockIoInner = Async; + +#[derive(Debug)] +pub struct AsyncVsockIo(pub AsyncVsockIoInner); + +impl AsyncVsockIo { + pub(super) async fn connect(addr: VsockAddr) -> Result { + let socket = raw_connect(addr)?; + let async_fd = Async::new(unsafe { File::from_raw_fd(socket) })?; + + loop { + let connection_check = async_fd.write_with(|fd| check_connection(fd.as_raw_fd())); + + break match connection_check.await { + Ok(_) => { + let raw_fd = async_fd.into_inner()?.into_raw_fd(); + let inner = unsafe { File::from_raw_fd(raw_fd) }; + Async::new(inner).map(Self) + } + Err(err) => match err.kind() { + ErrorKind::Interrupted | ErrorKind::WouldBlock => continue, + _ => Err(err), + }, + }; + } + } + + fn try_poll_read( + self: Pin<&mut Self>, + context: &mut Context<'_>, + cursor: &mut ReadBufCursor<'_>, + ) -> Option>> { + match self.0.poll_readable(context) { + Poll::Ready(Ok(_)) => { + // TODO: Once https://github.com/rust-lang/rust/issues/63569 is stable, use `assume_init_mut`: + let buffer = unsafe { &mut *(cursor.as_mut() as *mut [MaybeUninit] as *mut [u8]) }; + let amount = self.0.get_ref().read(buffer); + try_advance_cursor(cursor, amount) + } + other => Some(other), + } + } + + fn try_poll_write(self: Pin<&mut Self>, context: &mut Context<'_>, buffer: &[u8]) -> Option>> { + match self.0.poll_writable(context) { + Poll::Ready(Ok(_)) => try_poll_write(self.0.get_ref().write(buffer)), + other => Some(other.map_ok(|_| 0)), + } + } +} + +impl Deref for AsyncVsockIo { + type Target = AsyncVsockIoInner; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for AsyncVsockIo { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl From for AsyncVsockIo { + fn from(inner: AsyncVsockIoInner) -> Self { + Self(inner) + } +} + +impl From for AsyncVsockIoInner { + fn from(AsyncVsockIo(inner): AsyncVsockIo) -> Self { + inner + } +} + +impl Read for AsyncVsockIo { + #[inline(always)] + fn poll_read( + mut self: Pin<&mut Self>, + context: &mut Context<'_>, + mut cursor: ReadBufCursor<'_>, + ) -> Poll> { + loop { + if let Some(poll_result) = self.as_mut().try_poll_read(context, &mut cursor) { + break poll_result; + } + } + } +} + +impl Write for AsyncVsockIo { + #[inline(always)] + fn poll_write(mut self: Pin<&mut Self>, context: &mut Context<'_>, buffer: &[u8]) -> Poll> { + loop { + if let Some(poll_result) = self.as_mut().try_poll_write(context, buffer) { + break poll_result; + } + } + } + + #[inline(always)] + fn poll_flush(self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + #[inline(always)] + fn poll_shutdown(self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +hyper_util_connection_default!(AsyncVsockIo); diff --git a/src/connector.rs b/src/connector.rs deleted file mode 100644 index f9f5b72..0000000 --- a/src/connector.rs +++ /dev/null @@ -1,197 +0,0 @@ -use std::{ - pin::Pin, - task::{Context, Poll}, -}; - -use hyper_util::client::legacy::connect::{Connected, Connection}; - -/// This is an internal wrapper over an IO type that implements [hyper::rt::Write] and -/// [hyper::rt::Read] that also implements [Connection] to achieve compatibility with hyper-util. -pub struct ConnectableIo(IO); - -impl hyper::rt::Write for ConnectableIo { - #[inline(always)] - fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { - Pin::new(&mut self.get_mut().0).poll_write(cx, buf) - } - - #[inline(always)] - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.get_mut().0).poll_flush(cx) - } - - #[inline(always)] - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.get_mut().0).poll_shutdown(cx) - } -} - -impl hyper::rt::Read for ConnectableIo { - #[inline(always)] - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: hyper::rt::ReadBufCursor<'_>, - ) -> Poll> { - Pin::new(&mut self.get_mut().0).poll_read(cx, buf) - } -} - -impl Connection for ConnectableIo { - fn connected(&self) -> Connected { - Connected::new() - } -} - -#[cfg(feature = "unix")] -#[cfg_attr(docsrs, doc(cfg(feature = "unix")))] -mod unix { - use std::{future::Future, marker::PhantomData, pin::Pin, task::Poll}; - - use http::Uri; - - use crate::{uri::UnixUri, Backend}; - - use super::ConnectableIo; - - /// A hyper-util connector that accepts hex-encoded Unix URIs and uses them to connect - /// to Unix sockets via the given [Backend]. - #[derive(Debug, Clone)] - pub struct UnixConnector { - marker: PhantomData, - } - - impl UnixConnector { - pub fn new() -> Self { - Self { marker: PhantomData } - } - } - - impl tower_service::Service for UnixConnector { - type Response = ConnectableIo; - - type Error = std::io::Error; - - type Future = Pin> + Send + 'static>>; - - #[inline(always)] - fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - #[inline(always)] - fn call(&mut self, uri: Uri) -> Self::Future { - Box::pin(async move { - let socket_path = uri.parse_unix()?; - let io = B::connect_to_unix_socket(&socket_path).await?; - Ok(ConnectableIo(io)) - }) - } - } -} - -#[cfg(feature = "unix")] -#[cfg_attr(docsrs, doc(cfg(feature = "unix")))] -pub use unix::UnixConnector; - -#[cfg(feature = "vsock")] -#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))] -mod vsock { - use std::{future::Future, marker::PhantomData, pin::Pin, task::Poll}; - - use http::Uri; - - use crate::{uri::VsockUri, Backend}; - - use super::ConnectableIo; - - /// A hyper-util connector that accepts hex-encoded virtio-vsock URIs and uses them to connect - /// to virtio-vsock sockets via the given [Backend]. - #[derive(Debug, Clone)] - pub struct VsockConnector { - marker: PhantomData, - } - - impl VsockConnector { - pub fn new() -> Self { - Self { marker: PhantomData } - } - } - - impl tower_service::Service for VsockConnector { - type Response = ConnectableIo; - - type Error = std::io::Error; - - type Future = Pin> + Send + 'static>>; - - #[inline(always)] - fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - #[inline(always)] - fn call(&mut self, uri: Uri) -> Self::Future { - Box::pin(async move { - let addr = uri.parse_vsock()?; - let io = B::connect_to_vsock_socket(addr).await?; - Ok(ConnectableIo(io)) - }) - } - } -} - -#[cfg(feature = "vsock")] -#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))] -pub use vsock::VsockConnector; - -#[cfg(feature = "firecracker")] -#[cfg_attr(docsrs, doc(cfg(feature = "firecracker")))] -mod firecracker { - use std::{future::Future, marker::PhantomData, pin::Pin, task::Poll}; - - use http::Uri; - - use crate::{uri::FirecrackerUri, Backend}; - - use super::ConnectableIo; - - /// A hyper-util connector that accepts hex-encoded Firecracker URIs and uses them to connect - /// to Firecracker sockets via the given [Backend]. - #[derive(Debug, Clone)] - pub struct FirecrackerConnector { - marker: PhantomData, - } - - impl FirecrackerConnector { - pub fn new() -> Self { - Self { marker: PhantomData } - } - } - - impl tower_service::Service for FirecrackerConnector { - type Response = ConnectableIo; - - type Error = std::io::Error; - - type Future = Pin> + Send + 'static>>; - - #[inline(always)] - fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - #[inline(always)] - fn call(&mut self, uri: Uri) -> Self::Future { - Box::pin(async move { - let (host_socket_path, guest_port) = uri.parse_firecracker()?; - let io = B::connect_to_firecracker_socket(&host_socket_path, guest_port).await?; - Ok(ConnectableIo(io)) - }) - } - } -} - -#[cfg(feature = "firecracker")] -#[cfg_attr(docsrs, doc(cfg(feature = "firecracker")))] -pub use firecracker::FirecrackerConnector; diff --git a/src/connector/firecracker.rs b/src/connector/firecracker.rs new file mode 100644 index 0000000..85c2fc9 --- /dev/null +++ b/src/connector/firecracker.rs @@ -0,0 +1,46 @@ +use std::{ + future::Future, + io::Error, + marker::PhantomData, + pin::Pin, + task::{Context, Poll}, +}; + +use http::Uri; +use tower_service::Service; + +use super::ConnectableIo; +use crate::{uri::FirecrackerUri, Backend}; + +/// A hyper-util connector that accepts hex-encoded Firecracker URIs and uses them to connect +/// to Firecracker sockets via the given [Backend]. +#[derive(Debug, Default, Clone)] +pub struct FirecrackerConnector { + marker: PhantomData, +} + +impl FirecrackerConnector { + pub const fn new() -> Self { + Self { marker: PhantomData } + } +} + +impl Service for FirecrackerConnector { + type Response = ConnectableIo; + type Error = Error; + type Future = Pin> + Send + 'static>>; + + #[inline(always)] + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + #[inline(always)] + fn call(&mut self, uri: Uri) -> Self::Future { + Box::pin(async move { + let (host_socket_path, guest_port) = uri.parse_firecracker()?; + let io = B::connect_to_firecracker_socket(&host_socket_path, guest_port).await?; + Ok(ConnectableIo(io)) + }) + } +} diff --git a/src/connector/mod.rs b/src/connector/mod.rs new file mode 100644 index 0000000..9f6fcb1 --- /dev/null +++ b/src/connector/mod.rs @@ -0,0 +1,83 @@ +#[cfg(feature = "firecracker")] +#[cfg_attr(docsrs, doc(cfg(feature = "firecracker")))] +pub mod firecracker; + +#[cfg(feature = "unix")] +#[cfg_attr(docsrs, doc(cfg(feature = "unix")))] +pub mod unix; + +#[cfg(feature = "vsock")] +#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))] +pub mod vsock; + +use std::{ + io::{IoSlice, Result}, + pin::Pin, + task::{Context, Poll}, +}; + +use hyper::rt::{Read, ReadBufCursor, Write}; +use hyper_util::client::legacy::connect::{Connected, Connection}; + +#[cfg(feature = "firecracker")] +#[cfg_attr(docsrs, doc(cfg(feature = "firecracker")))] +pub use self::firecracker::FirecrackerConnector; + +#[cfg(feature = "unix")] +#[cfg_attr(docsrs, doc(cfg(feature = "unix")))] +pub use self::unix::UnixConnector; + +#[cfg(feature = "vsock")] +#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))] +pub use self::vsock::VsockConnector; + +/// This is an internal wrapper over an IO type that implements [`Write`] and +/// [`Read`] that also implements [`Connection`] to achieve compatibility with hyper-util. +#[derive(Debug)] +pub struct ConnectableIo(pub IO); + +impl From for ConnectableIo { + fn from(inner: IO) -> Self { + Self(inner) + } +} + +impl Connection for ConnectableIo { + fn connected(&self) -> Connected { + Connected::new() + } +} + +impl Read for ConnectableIo { + #[inline(always)] + fn poll_read(self: Pin<&mut Self>, ctx: &mut Context<'_>, buf: ReadBufCursor<'_>) -> Poll> { + Pin::new(&mut self.get_mut().0).poll_read(ctx, buf) + } +} + +impl Write for ConnectableIo { + #[inline(always)] + fn poll_write(self: Pin<&mut Self>, ctx: &mut Context<'_>, buf: &[u8]) -> Poll> { + Pin::new(&mut self.get_mut().0).poll_write(ctx, buf) + } + + #[inline(always)] + fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.get_mut().0).poll_flush(ctx) + } + + #[inline(always)] + fn poll_shutdown(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.get_mut().0).poll_shutdown(ctx) + } + + #[inline(always)] + fn is_write_vectored(&self) -> bool { + self.0.is_write_vectored() + } + + #[inline(always)] + fn poll_write_vectored(self: Pin<&mut Self>, ctx: &mut Context<'_>, bufs: &[IoSlice<'_>]) -> Poll> { + Pin::new(&mut self.get_mut().0).poll_write_vectored(ctx, bufs) + } +} diff --git a/src/connector/unix.rs b/src/connector/unix.rs new file mode 100644 index 0000000..89407b4 --- /dev/null +++ b/src/connector/unix.rs @@ -0,0 +1,46 @@ +use std::{ + future::Future, + io::Error, + marker::PhantomData, + pin::Pin, + task::{Context, Poll}, +}; + +use http::Uri; +use tower_service::Service; + +use super::ConnectableIo; +use crate::{uri::UnixUri, Backend}; + +/// A hyper-util connector that accepts hex-encoded Unix URIs and +/// uses them to connect to Unix Domain sockets via the given [`Backend`]. +#[derive(Debug, Default, Clone)] +pub struct UnixConnector { + marker: PhantomData, +} + +impl UnixConnector { + pub const fn new() -> Self { + Self { marker: PhantomData } + } +} + +impl Service for UnixConnector { + type Response = ConnectableIo; + type Error = Error; + type Future = Pin> + Send + 'static>>; + + #[inline(always)] + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + #[inline(always)] + fn call(&mut self, uri: Uri) -> Self::Future { + Box::pin(async move { + let socket_path = uri.parse_unix()?; + let io = B::connect_to_unix_socket(&socket_path).await?; + Ok(ConnectableIo(io)) + }) + } +} diff --git a/src/connector/vsock.rs b/src/connector/vsock.rs new file mode 100644 index 0000000..1c8ffbb --- /dev/null +++ b/src/connector/vsock.rs @@ -0,0 +1,41 @@ +use std::{future::Future, marker::PhantomData, pin::Pin, task::Poll}; + +use http::Uri; + +use super::ConnectableIo; +use crate::{uri::VsockUri, Backend}; + +/// A hyper-util connector that accepts hex-encoded VSOCK URIs and +/// uses them to connect to VSOCK sockets via the given [`Backend`]. +#[derive(Debug, Default, Clone)] +pub struct VsockConnector { + marker: PhantomData, +} + +impl VsockConnector { + pub const fn new() -> Self { + Self { marker: PhantomData } + } +} + +impl tower_service::Service for VsockConnector { + type Response = ConnectableIo; + + type Error = std::io::Error; + + type Future = Pin> + Send + 'static>>; + + #[inline(always)] + fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + #[inline(always)] + fn call(&mut self, uri: Uri) -> Self::Future { + Box::pin(async move { + let addr = uri.parse_vsock()?; + let io = B::connect_to_vsock_socket(addr).await?; + Ok(ConnectableIo(io)) + }) + } +} diff --git a/src/lib.rs b/src/lib.rs index 031b0d8..c705ae0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,62 +1,68 @@ #![cfg_attr(docsrs, feature(doc_cfg))] -#[cfg(any(feature = "unix", feature = "vsock", feature = "firecracker"))] -use std::future::Future; +pub mod utils; -#[cfg(any(feature = "unix", feature = "firecracker"))] +#[cfg(any(feature = "firecracker", feature = "unix"))] use std::path::Path; -#[cfg(feature = "tokio-backend")] -#[cfg_attr(docsrs, doc(cfg(feature = "tokio-backend")))] -pub mod tokio; +#[cfg(feature = "vsock")] +use vsock::VsockAddr; + +#[cfg(any(feature = "firecracker", feature = "unix", feature = "vsock"))] +use std::{future::Future, io::Result}; + +#[cfg(any(feature = "firecracker", feature = "unix", feature = "vsock"))] +use hyper::rt::{Read, Write}; #[cfg(feature = "async-io-backend")] #[cfg_attr(docsrs, doc(cfg(feature = "async-io-backend")))] pub mod async_io; +#[cfg(feature = "tokio-backend")] +#[cfg_attr(docsrs, doc(cfg(feature = "tokio-backend")))] +pub mod tokio; + #[cfg(feature = "hyper-util")] #[cfg_attr(docsrs, doc(cfg(feature = "hyper-util")))] -pub mod uri; +pub mod connector; #[cfg(feature = "hyper-util")] #[cfg_attr(docsrs, doc(cfg(feature = "hyper-util")))] -pub mod connector; +pub mod uri; -/// A [Backend] is a runtime- and reactor-agnostic way to use hyper client-side with various types of sockets. +/// A [`Backend`] is a runtime- and reactor-agnostic way to use hyper client-side with various types of sockets. pub trait Backend: Clone { - /// An IO object representing a connected Unix socket. - #[cfg(feature = "unix")] - #[cfg_attr(docsrs, doc(cfg(feature = "unix")))] - type UnixIo: hyper::rt::Read + hyper::rt::Write + Send + Unpin; - - /// An IO object representing a connected virtio-vsock socket. - #[cfg(feature = "vsock")] - #[cfg_attr(docsrs, doc(cfg(feature = "vsock")))] - type VsockIo: hyper::rt::Read + hyper::rt::Write + Send + Unpin; - - /// An IO object representing a connected Firecracker socket (a specialized Unix socket). + /// An IO object representing a connected Firecracker socket (a specialized Unix Domain socket). #[cfg(feature = "firecracker")] #[cfg_attr(docsrs, doc(cfg(feature = "firecracker")))] - type FirecrackerIo: hyper::rt::Read + hyper::rt::Write + Send + Unpin; + type FirecrackerIo: Read + Write + Send + Unpin; - /// Connect to a Unix socket at the given [Path]. + /// An IO object representing a connected Unix Domain socket. #[cfg(feature = "unix")] #[cfg_attr(docsrs, doc(cfg(feature = "unix")))] - fn connect_to_unix_socket(socket_path: &Path) -> impl Future> + Send; + type UnixIo: Read + Write + Send + Unpin; - /// Connect to a virtio-vsock socket at the given vsock address. + /// An IO object representing a connected VSOCK socket. #[cfg(feature = "vsock")] #[cfg_attr(docsrs, doc(cfg(feature = "vsock")))] - fn connect_to_vsock_socket( - addr: vsock::VsockAddr, - ) -> impl Future> + Send; + type VsockIo: Read + Write + Send + Unpin; - /// Connect to a Firecracker socket at the given [Path], establishing a tunnel to the given - /// guest vsock port. + /// Connect to a Firecracker socket at the given [`Path`], + /// establishing a tunnel to the given guest VSOCK port. #[cfg(feature = "firecracker")] #[cfg_attr(docsrs, doc(cfg(feature = "firecracker")))] fn connect_to_firecracker_socket( host_socket_path: &Path, guest_port: u32, - ) -> impl Future> + Send; + ) -> impl Future> + Send; + + /// Connect to a Unix Domain socket at the given [`Path`]. + #[cfg(feature = "unix")] + #[cfg_attr(docsrs, doc(cfg(feature = "unix")))] + fn connect_to_unix_socket(socket_path: &Path) -> impl Future> + Send; + + /// Connect to a VSOCK socket at the given address. + #[cfg(feature = "vsock")] + #[cfg_attr(docsrs, doc(cfg(feature = "vsock")))] + fn connect_to_vsock_socket(addr: VsockAddr) -> impl Future> + Send; } diff --git a/src/tokio.rs b/src/tokio.rs deleted file mode 100644 index 50e5183..0000000 --- a/src/tokio.rs +++ /dev/null @@ -1,229 +0,0 @@ -#[cfg(feature = "vsock")] -use std::{ - io::{Read, Write}, - os::fd::{AsRawFd, FromRawFd, IntoRawFd, OwnedFd}, - task::Poll, -}; -#[cfg(feature = "vsock")] -use std::{pin::Pin, task::Context}; - -#[cfg(any(feature = "unix", feature = "firecracker"))] -use hyper_util::rt::TokioIo; -#[cfg(any(feature = "unix", feature = "firecracker"))] -use std::path::Path; -#[cfg(any(feature = "unix", feature = "firecracker"))] -use tokio::net::UnixStream; - -#[cfg(feature = "firecracker")] -use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; - -#[cfg(feature = "vsock")] -use tokio::io::unix::AsyncFd; - -use crate::Backend; - -/// [Backend] for hyper-client-sockets that is implemented via the Tokio reactor. -#[derive(Debug, Clone)] -pub struct TokioBackend; - -impl Backend for TokioBackend { - #[cfg(feature = "unix")] - #[cfg_attr(docsrs, doc(cfg(feature = "unix")))] - type UnixIo = TokioIo; - - #[cfg(feature = "vsock")] - #[cfg_attr(docsrs, doc(cfg(feature = "vsock")))] - type VsockIo = TokioVsockIo; - - #[cfg(feature = "firecracker")] - #[cfg_attr(docsrs, doc(cfg(feature = "firecracker")))] - type FirecrackerIo = TokioIo; - - #[cfg(feature = "unix")] - #[cfg_attr(docsrs, doc(cfg(feature = "unix")))] - async fn connect_to_unix_socket(socket_path: &Path) -> Result { - Ok(TokioIo::new(UnixStream::connect(socket_path).await?)) - } - - #[cfg(feature = "vsock")] - #[cfg_attr(docsrs, doc(cfg(feature = "vsock")))] - async fn connect_to_vsock_socket(addr: vsock::VsockAddr) -> Result { - TokioVsockIo::connect(addr).await - } - - #[cfg(feature = "firecracker")] - #[cfg_attr(docsrs, doc(cfg(feature = "firecracker")))] - async fn connect_to_firecracker_socket( - host_socket_path: &Path, - guest_port: u32, - ) -> Result { - let mut stream = UnixStream::connect(host_socket_path).await?; - stream.write_all(format!("CONNECT {guest_port}\n").as_bytes()).await?; - - let mut lines = BufReader::new(&mut stream).lines(); - match lines.next_line().await { - Ok(Some(line)) => { - if !line.starts_with("OK") { - return Err(std::io::Error::new( - std::io::ErrorKind::ConnectionRefused, - "Firecracker refused to establish a tunnel to the given guest port", - )); - } - } - _ => { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidInput, - "Could not read Firecracker response", - )) - } - }; - - Ok(TokioIo::new(stream)) - } -} - -/// IO object representing an active vsock connection controlled via a Tokio [AsyncFd]. -/// This is internally a reimplementation of a relevant part of the tokio-vsock crate. -#[cfg(feature = "vsock")] -#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))] -pub struct TokioVsockIo(AsyncFd); - -#[cfg(feature = "vsock")] -#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))] -impl TokioVsockIo { - async fn connect(addr: vsock::VsockAddr) -> Result { - let socket = unsafe { libc::socket(libc::AF_VSOCK, libc::SOCK_STREAM, 0) }; - if socket < 0 { - return Err(std::io::Error::last_os_error()); - } - - if unsafe { libc::fcntl(socket, libc::F_SETFL, libc::O_NONBLOCK | libc::O_CLOEXEC) } < 0 { - let _ = unsafe { libc::close(socket) }; - return Err(std::io::Error::last_os_error()); - } - - if unsafe { - libc::connect( - socket, - &addr as *const _ as *const libc::sockaddr, - size_of::() as libc::socklen_t, - ) - } < 0 - { - let err = std::io::Error::last_os_error(); - if let Some(os_err) = err.raw_os_error() { - if os_err != libc::EINPROGRESS { - let _ = unsafe { libc::close(socket) }; - return Err(err); - } - } - } - - let async_fd = AsyncFd::new(unsafe { OwnedFd::from_raw_fd(socket) })?; - - loop { - let mut guard = async_fd.writable().await?; - - let connection_check = guard.try_io(|fd| { - let mut sock_err: libc::c_int = 0; - let mut sock_err_len: libc::socklen_t = size_of::() as libc::socklen_t; - let err = unsafe { - libc::getsockopt( - fd.as_raw_fd(), - libc::SOL_SOCKET, - libc::SO_ERROR, - &mut sock_err as *mut _ as *mut libc::c_void, - &mut sock_err_len as *mut libc::socklen_t, - ) - }; - - if err < 0 { - return Err(std::io::Error::last_os_error()); - } - - if sock_err == 0 { - Ok(()) - } else { - Err(std::io::Error::from_raw_os_error(sock_err)) - } - }); - - match connection_check { - Ok(Ok(_)) => { - return Ok(TokioVsockIo(AsyncFd::new(unsafe { - vsock::VsockStream::from_raw_fd(async_fd.into_inner().into_raw_fd()) - })?)) - } - Ok(Err(err)) => return Err(err), - Err(_would_block) => continue, - } - } - } -} - -#[cfg(feature = "vsock")] -#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))] -impl hyper::rt::Write for TokioVsockIo { - fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { - loop { - let mut guard = match self.0.poll_write_ready(cx) { - Poll::Ready(Ok(guard)) => guard, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), - Poll::Pending => return Poll::Pending, - }; - - match guard.try_io(|inner| inner.get_ref().write(buf)) { - Ok(Ok(amount)) => return Ok(amount).into(), - Ok(Err(ref err)) if err.kind() == std::io::ErrorKind::Interrupted => continue, - Ok(Err(err)) => return Err(err).into(), - Err(_would_block) => continue, - } - } - } - - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } -} - -#[cfg(feature = "vsock")] -#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))] -impl hyper::rt::Read for TokioVsockIo { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - mut buf: hyper::rt::ReadBufCursor<'_>, - ) -> Poll> { - let b; - unsafe { - b = &mut *(buf.as_mut() as *mut [std::mem::MaybeUninit] as *mut [u8]); - }; - - loop { - let mut guard = match self.0.poll_read_ready(cx) { - Poll::Ready(Ok(guard)) => guard, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), - Poll::Pending => return Poll::Pending, - }; - - match guard.try_io(|inner| inner.get_ref().read(b)) { - Ok(Ok(amount)) => { - unsafe { - buf.advance(amount); - } - - return Ok(()).into(); - } - Ok(Err(ref err)) if err.kind() == std::io::ErrorKind::Interrupted => continue, - Ok(Err(err)) => return Err(err).into(), - Err(_would_block) => { - continue; - } - } - } - } -} diff --git a/src/tokio/firecracker.rs b/src/tokio/firecracker.rs new file mode 100644 index 0000000..33624ed --- /dev/null +++ b/src/tokio/firecracker.rs @@ -0,0 +1,63 @@ +use std::{ + io::Result, + ops::{Deref, DerefMut}, + path::Path, +}; + +use hyper_util::rt::TokioIo; +use tokio::{ + io::{AsyncBufReadExt as _, AsyncWriteExt as _, BufReader}, + net::UnixStream, +}; + +use crate::utils::{ + firecracker::{format_request, parse_connection_response}, + hyper_io_by_deref, hyper_util_connection_by_deref, +}; + +pub type TokioFirecrackerIoInner = TokioIo; + +#[derive(Debug)] +pub struct TokioFirecrackerIo(pub TokioIo); + +impl TokioFirecrackerIo { + pub(super) async fn connect

(host_socket_path: P, guest_port: u32) -> Result + where + P: AsRef, + { + let mut stream = UnixStream::connect(host_socket_path).await?; + stream.write_all(format_request(guest_port).as_bytes()).await?; + let response = BufReader::new(&mut stream).lines().next_line().await; + + parse_connection_response(stream, response).map(TokioIo::new).map(Self) + } +} + +impl Deref for TokioFirecrackerIo { + type Target = TokioFirecrackerIoInner; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for TokioFirecrackerIo { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl From for TokioFirecrackerIo { + fn from(inner: TokioFirecrackerIoInner) -> Self { + Self(inner) + } +} + +impl From for TokioFirecrackerIoInner { + fn from(TokioFirecrackerIo(inner): TokioFirecrackerIo) -> Self { + inner + } +} + +hyper_io_by_deref!(TokioFirecrackerIo); +hyper_util_connection_by_deref!(TokioFirecrackerIo); diff --git a/src/tokio/mod.rs b/src/tokio/mod.rs new file mode 100644 index 0000000..025bb16 --- /dev/null +++ b/src/tokio/mod.rs @@ -0,0 +1,70 @@ +#[cfg(feature = "firecracker")] +#[cfg_attr(docsrs, doc(cfg(feature = "firecracker")))] +pub mod firecracker; + +#[cfg(feature = "unix")] +#[cfg_attr(docsrs, doc(cfg(feature = "unix")))] +pub mod unix; + +#[cfg(feature = "vsock")] +#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))] +pub mod vsock; + +#[cfg(any(feature = "firecracker", feature = "unix", feature = "vsock"))] +use std::io::Result; + +#[cfg(any(feature = "firecracker", feature = "unix"))] +use std::path::Path; + +#[cfg(feature = "vsock")] +use ::vsock::VsockAddr; + +#[cfg(feature = "firecracker")] +#[cfg_attr(docsrs, doc(cfg(feature = "firecracker")))] +pub use self::firecracker::{TokioFirecrackerIo, TokioFirecrackerIoInner}; + +#[cfg(feature = "unix")] +#[cfg_attr(docsrs, doc(cfg(feature = "unix")))] +pub use self::unix::{TokioUnixIo, TokioUnixIoInner}; + +#[cfg(feature = "vsock")] +#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))] +pub use self::vsock::{TokioVsockIo, TokioVsockIoInner}; + +use crate::Backend; + +/// [`Backend`] for hyper-client-sockets that is implemented via the Tokio reactor. +#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub struct TokioBackend; + +impl Backend for TokioBackend { + #[cfg(feature = "firecracker")] + #[cfg_attr(docsrs, doc(cfg(feature = "firecracker")))] + type FirecrackerIo = TokioFirecrackerIo; + + #[cfg(feature = "unix")] + #[cfg_attr(docsrs, doc(cfg(feature = "unix")))] + type UnixIo = TokioUnixIo; + + #[cfg(feature = "vsock")] + #[cfg_attr(docsrs, doc(cfg(feature = "vsock")))] + type VsockIo = TokioVsockIo; + + #[cfg(feature = "firecracker")] + #[cfg_attr(docsrs, doc(cfg(feature = "firecracker")))] + async fn connect_to_firecracker_socket(host_socket_path: &Path, guest_port: u32) -> Result { + Self::FirecrackerIo::connect(host_socket_path, guest_port).await + } + + #[cfg(feature = "unix")] + #[cfg_attr(docsrs, doc(cfg(feature = "unix")))] + async fn connect_to_unix_socket(socket_path: &Path) -> Result { + Self::UnixIo::connect(socket_path).await + } + + #[cfg(feature = "vsock")] + #[cfg_attr(docsrs, doc(cfg(feature = "vsock")))] + async fn connect_to_vsock_socket(addr: VsockAddr) -> Result { + Self::VsockIo::connect(addr).await + } +} diff --git a/src/tokio/unix.rs b/src/tokio/unix.rs new file mode 100644 index 0000000..6c93e47 --- /dev/null +++ b/src/tokio/unix.rs @@ -0,0 +1,53 @@ +use std::{ + io::Result, + ops::{Deref, DerefMut}, + path::Path, +}; + +use hyper_util::rt::TokioIo; +use tokio::net::UnixStream; + +use crate::utils::{hyper_io_by_deref, hyper_util_connection_by_deref}; + +pub type TokioUnixIoInner = TokioIo; + +#[derive(Debug)] +pub struct TokioUnixIo(pub TokioUnixIoInner); + +impl TokioUnixIo { + pub(super) async fn connect

(socket_path: P) -> Result + where + P: AsRef, + { + UnixStream::connect(socket_path).await.map(TokioIo::new).map(Self) + } +} + +impl Deref for TokioUnixIo { + type Target = TokioUnixIoInner; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for TokioUnixIo { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl From for TokioUnixIo { + fn from(inner: TokioUnixIoInner) -> Self { + Self(inner) + } +} + +impl From for TokioUnixIoInner { + fn from(TokioUnixIo(inner): TokioUnixIo) -> Self { + inner + } +} + +hyper_io_by_deref!(TokioUnixIo); +hyper_util_connection_by_deref!(TokioUnixIo); diff --git a/src/tokio/vsock.rs b/src/tokio/vsock.rs new file mode 100644 index 0000000..f895c28 --- /dev/null +++ b/src/tokio/vsock.rs @@ -0,0 +1,135 @@ +use std::{ + io::{Read as _, Result, Write as _}, + mem::MaybeUninit, + ops::{Deref, DerefMut}, + os::fd::{AsRawFd as _, FromRawFd as _, IntoRawFd as _, OwnedFd}, + pin::Pin, + task::{Context, Poll}, +}; + +use hyper::rt::{Read, ReadBufCursor, Write}; +use tokio::io::unix::AsyncFd; +use vsock::{VsockAddr, VsockStream}; + +use crate::utils::{ + hyper_util_connection_default, + vsock::{check_connection, raw_connect, try_advance_cursor, try_poll_write}, +}; + +pub type TokioVsockIoInner = AsyncFd; + +/// IO object representing an active VSOCK connection controlled via a Tokio [`AsyncFd`]. +/// This is internally a reimplementation of a relevant part of the tokio-vsock crate. +#[derive(Debug)] +pub struct TokioVsockIo(pub TokioVsockIoInner); + +impl TokioVsockIo { + pub(super) async fn connect(addr: VsockAddr) -> Result { + let socket = raw_connect(addr)?; + let async_fd = AsyncFd::new(unsafe { OwnedFd::from_raw_fd(socket) })?; + + loop { + let connection_check = { + let mut guard = async_fd.writable().await?; + guard.try_io(|fd| check_connection(fd.as_raw_fd())) + }; + + break match connection_check { + Ok(Ok(_)) => { + let raw_fd = async_fd.into_inner().into_raw_fd(); + let inner = unsafe { VsockStream::from_raw_fd(raw_fd) }; + AsyncFd::new(inner).map(Self) + } + Ok(Err(err)) => Err(err), + Err(_would_block) => continue, + }; + } + } + + fn try_poll_read( + self: Pin<&mut Self>, + context: &mut Context<'_>, + cursor: &mut ReadBufCursor<'_>, + ) -> Option>> { + match self.0.poll_read_ready(context) { + Poll::Ready(Ok(mut guard)) => { + // TODO: Once https://github.com/rust-lang/rust/issues/63569 is stable, use `assume_init_mut`: + let buffer = unsafe { &mut *(cursor.as_mut() as *mut [MaybeUninit] as *mut [u8]) }; + let amount = guard.try_io(|inner| inner.get_ref().read(buffer)).ok()?; + try_advance_cursor(cursor, amount) + } + other => Some(other.map_ok(|_| ())), + } + } + + fn try_poll_write(self: Pin<&mut Self>, context: &mut Context<'_>, buffer: &[u8]) -> Option>> { + match self.0.poll_write_ready(context) { + Poll::Ready(Ok(mut guard)) => try_poll_write(guard.try_io(|inner| inner.get_ref().write(buffer)).ok()?), + other => Some(other.map_ok(|_| 0)), + } + } +} + +impl Deref for TokioVsockIo { + type Target = TokioVsockIoInner; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for TokioVsockIo { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl From for TokioVsockIo { + fn from(inner: TokioVsockIoInner) -> Self { + Self(inner) + } +} + +impl From for TokioVsockIoInner { + fn from(TokioVsockIo(inner): TokioVsockIo) -> Self { + inner + } +} + +impl Read for TokioVsockIo { + #[inline(always)] + fn poll_read( + mut self: Pin<&mut Self>, + context: &mut Context<'_>, + mut cursor: ReadBufCursor<'_>, + ) -> Poll> { + loop { + if let Some(poll_result) = self.as_mut().try_poll_read(context, &mut cursor) { + break poll_result; + } + } + } +} + +impl Write for TokioVsockIo { + #[inline(always)] + fn poll_write(mut self: Pin<&mut Self>, context: &mut Context<'_>, buffer: &[u8]) -> Poll> { + loop { + if let Some(poll_result) = self.as_mut().try_poll_write(context, buffer) { + break poll_result; + } + } + } + + #[inline(always)] + fn poll_flush(self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + #[inline(always)] + fn poll_shutdown(self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +hyper_util_connection_default!(TokioVsockIo); diff --git a/src/uri.rs b/src/uri.rs deleted file mode 100644 index a6a00f9..0000000 --- a/src/uri.rs +++ /dev/null @@ -1,216 +0,0 @@ -#[cfg(any(feature = "unix", feature = "vsock", feature = "firecracker"))] -use hex::FromHex; -#[cfg(any(feature = "unix", feature = "vsock", feature = "firecracker"))] -use http::{uri::InvalidUri, Uri}; -#[cfg(any(feature = "unix", feature = "firecracker"))] -use std::path::{Path, PathBuf}; - -/// An extension trait for a URI that allows constructing a hex-encoded Unix socket URI. -#[cfg(feature = "unix")] -#[cfg_attr(docsrs, doc(cfg(feature = "unix")))] -pub trait UnixUri { - /// Create a new Unix URI with the given socket path and in-socket URI. - fn unix(socket_path: impl AsRef, url: impl AsRef) -> Result; - - /// Try to deconstruct this Unix URI's socket path. - fn parse_unix(&self) -> Result; -} - -#[cfg(feature = "unix")] -#[cfg_attr(docsrs, doc(cfg(feature = "unix")))] -impl UnixUri for Uri { - fn unix(socket_path: impl AsRef, url: impl AsRef) -> Result { - let host = hex::encode(socket_path.as_ref().to_string_lossy().to_string()); - let uri_str = format!("unix://{host}/{}", url.as_ref().trim_start_matches('/')); - let uri = uri_str.parse::()?; - Ok(uri) - } - - fn parse_unix(&self) -> Result { - if self.scheme_str() != Some("unix") { - return Err(io_input_err("URI scheme on a Unix socket must be unix://")); - } - - match self.host() { - Some(host) => { - let bytes = Vec::from_hex(host).map_err(|_| io_input_err("URI host must be hex"))?; - Ok(PathBuf::from(String::from_utf8_lossy(&bytes).into_owned())) - } - None => Err(io_input_err("URI host must be present")), - } - } -} - -/// An extension trait for hyper URI that allows constructing a hex-encoded virtio-vsock socket URI. -#[cfg(feature = "vsock")] -#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))] -pub trait VsockUri { - /// Create a new vsock URI with the given vsock address and in-socket URL - fn vsock(addr: vsock::VsockAddr, url: impl AsRef) -> Result; - - /// Deconstruct this vsock URI into its address. - fn parse_vsock(&self) -> Result; -} - -#[cfg(feature = "vsock")] -#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))] -impl VsockUri for Uri { - fn vsock(addr: vsock::VsockAddr, url: impl AsRef) -> Result { - let host = hex::encode(format!("{}.{}", addr.cid(), addr.port())); - let uri_str = format!("vsock://{host}/{}", url.as_ref().trim_start_matches('/')); - let uri = uri_str.parse::()?; - Ok(uri) - } - - fn parse_vsock(&self) -> Result { - if self.scheme_str() != Some("vsock") { - return Err(io_input_err("URI scheme on a vsock socket must be vsock://")); - } - - match self.host() { - Some(host) => { - let full_str = Vec::from_hex(host) - .map_err(|_| io_input_err("URI host must be hex")) - .map(|bytes| String::from_utf8_lossy(&bytes).into_owned())?; - let splits = full_str - .split_once('.') - .ok_or_else(|| io_input_err("URI host could not be split at . into 2 slices (CID, then port)"))?; - let cid: u32 = splits - .0 - .parse() - .map_err(|_| io_input_err("First split of URI (CID) can't be parsed"))?; - let port: u32 = splits - .1 - .parse() - .map_err(|_| io_input_err("Second split of URI (port) can't be parsed"))?; - - Ok(vsock::VsockAddr::new(cid, port)) - } - None => Err(io_input_err("URI host must be present")), - } - } -} - -/// An extension trait for hyper URI that allows constructing a hex-encoded Firecracker socket URI. -#[cfg(feature = "firecracker")] -#[cfg_attr(docsrs, doc(cfg(feature = "firecracker")))] -pub trait FirecrackerUri { - /// Create a new Firecracker URI with the given host socket path, guest port and in-socket URL - fn firecracker( - host_socket_path: impl AsRef, - guest_port: u32, - url: impl AsRef, - ) -> Result; - - /// Deconstruct this Firecracker URI into its host socket path and guest port - fn parse_firecracker(&self) -> Result<(PathBuf, u32), std::io::Error>; -} - -#[cfg(feature = "firecracker")] -#[cfg_attr(docsrs, doc(cfg(feature = "firecracker")))] -impl FirecrackerUri for Uri { - fn firecracker( - host_socket_path: impl AsRef, - guest_port: u32, - url: impl AsRef, - ) -> Result { - let host = hex::encode(format!( - "{}:{guest_port}", - host_socket_path.as_ref().to_string_lossy().to_string() - )); - let uri_str = format!("fc://{host}/{}", url.as_ref().trim_start_matches('/')); - let uri = uri_str.parse::()?; - Ok(uri) - } - - fn parse_firecracker(&self) -> Result<(PathBuf, u32), std::io::Error> { - if self.scheme_str() != Some("fc") { - return Err(io_input_err("URI scheme on a Firecracker socket must be fc://")); - } - - let host = self.host().ok_or_else(|| io_input_err("URI host must be present"))?; - let hex_decoded = Vec::from_hex(host).map_err(|_| io_input_err("URI host must be hex"))?; - let full_str = String::from_utf8_lossy(&hex_decoded).into_owned(); - let splits = full_str - .split_once(':') - .ok_or_else(|| io_input_err("URI host could not be split in halves with a ."))?; - let host_socket_path = PathBuf::try_from(splits.0) - .map_err(|_| io_input_err("URI socket path could not be converted to a path"))?; - let guest_port: u32 = splits - .1 - .parse() - .map_err(|_| io_input_err("URI guest port could not converted to u32"))?; - - Ok((host_socket_path, guest_port)) - } -} - -#[cfg(any(feature = "unix", feature = "vsock", feature = "firecracker"))] -#[inline(always)] -fn io_input_err(detail: &str) -> std::io::Error { - std::io::Error::new(std::io::ErrorKind::InvalidInput, detail) -} - -#[cfg(test)] -mod tests { - use std::path::PathBuf; - - use hyper::Uri; - use vsock::VsockAddr; - - use crate::uri::{FirecrackerUri, UnixUri, VsockUri}; - - #[test] - fn unix_uri_should_be_constructed_correctly() { - let uri_str = format!("unix://{}/route", hex::encode("/tmp/socket.sock")); - assert_eq!( - Uri::unix("/tmp/socket.sock", "/route").unwrap(), - uri_str.parse::().unwrap() - ); - } - - #[test] - fn unix_uri_should_be_deconstructed_correctly() { - let uri = format!("unix://{}/route", hex::encode("/tmp/socket.sock")); - assert_eq!( - uri.parse::().unwrap().parse_unix().unwrap(), - PathBuf::from("/tmp/socket.sock") - ); - } - - #[test] - fn vsock_uri_should_be_constructed_correctly() { - let uri = format!("vsock://{}/route", hex::encode("10.20")); - assert_eq!( - uri.parse::().unwrap(), - Uri::vsock(VsockAddr::new(10, 20), "/route").unwrap() - ); - } - - #[test] - fn vsock_uri_should_be_deconstructed_correctly() { - let uri = format!("vsock://{}/route", hex::encode("10.20")) - .parse::() - .unwrap(); - assert_eq!(uri.parse_vsock().unwrap(), VsockAddr::new(10, 20)); - } - - #[test] - fn firecracker_uri_should_be_constructed_correctly() { - let uri_str = format!("fc://{}/route", hex::encode("/tmp/socket.sock:1000")); - assert_eq!( - Uri::firecracker("/tmp/socket.sock", 1000, "/route").unwrap(), - uri_str.parse::().unwrap() - ); - } - - #[test] - fn firecracker_uri_should_be_deconstructed_correctly() { - let uri = format!("fc://{}/route", hex::encode("/tmp/socket.sock:1000")) - .parse::() - .unwrap(); - let (socket_path, port) = uri.parse_firecracker().unwrap(); - assert_eq!(socket_path, PathBuf::from("/tmp/socket.sock")); - assert_eq!(port, 1000); - } -} diff --git a/src/uri/firecracker/mod.rs b/src/uri/firecracker/mod.rs new file mode 100644 index 0000000..3ebaa19 --- /dev/null +++ b/src/uri/firecracker/mod.rs @@ -0,0 +1,131 @@ +#[cfg(test)] +mod tests; + +use std::{ + ffi::OsString, + io::Result as IoResult, + os::unix::{ffi::OsStringExt as _, net::SocketAddr as UnixSocketAddr}, + path::{Path, PathBuf}, +}; + +#[cfg(target_os = "android")] +use std::os::android::net::SocketAddrExt as _; + +#[cfg(target_os = "linux")] +use std::os::linux::net::SocketAddrExt as _; + +use hex::{encode, FromHex}; +use http::uri::{InvalidUri, Uri}; + +use super::io_input_err; + +/// An extension trait for hyper URI that allows constructing a hex-encoded Firecracker socket URI. +pub trait FirecrackerUri { + /// Create a new Firecracker URI with the given host socket path, guest port and in-socket URL. + fn firecracker, S: AsRef>( + host_socket_path: P, + guest_port: u32, + url: S, + ) -> Result; + + /// Create a new Firecracker URI with the given host socket address, guest port and in-socket URL. + fn firecracker_addr>( + host_socket_addr: &UnixSocketAddr, + guest_port: u32, + url: S, + ) -> Result; + + /// Deconstruct this Firecracker URI into its host socket path and guest port. + fn parse_firecracker(&self) -> IoResult<(PathBuf, u32)>; + + /// Deconstruct this Firecracker URI into its host socket address and guest port. + fn parse_firecracker_addr(&self) -> IoResult<(UnixSocketAddr, u32)> { + let (path, port) = self.parse_firecracker()?; + + #[cfg(any(target_os = "android", target_os = "linux"))] + let from_abstract_name = |err| { + let octets = path.as_os_str().as_encoded_bytes(); + match octets.split_first() { + Some((0, name)) => UnixSocketAddr::from_abstract_name(name), + _ => Err(err), + } + }; + + #[cfg(not(any(target_os = "android", target_os = "linux")))] + let from_abstract_name = |err| Err(err); + + UnixSocketAddr::from_pathname(&path) + .or_else(from_abstract_name) + .map(|address| (address, port)) + } +} + +fn from_octets(prefix: &str, octets: &[u8], guest_port: u32, url: S) -> Result +where + S: AsRef, +{ + let host = encode(octets); + let guest_port = encode(guest_port.to_string()); + let authority = format!("{prefix}{host}{:02x}{guest_port}", b':'); + let path_and_query = url.as_ref().trim_start_matches('/'); + let uri_str = format!("fc://{authority}/{path_and_query}"); + uri_str.parse() +} + +impl FirecrackerUri for Uri { + fn firecracker(host_socket_path: P, guest_port: u32, url: S) -> Result + where + P: AsRef, + S: AsRef, + { + let octets = host_socket_path.as_ref().as_os_str().as_encoded_bytes(); + from_octets("", octets, guest_port, url) + } + + fn firecracker_addr(host_socket_addr: &UnixSocketAddr, guest_port: u32, url: S) -> Result + where + S: AsRef, + { + let (prefix, octets) = match host_socket_addr.as_pathname() { + Some(host_socket_path) => ("", host_socket_path.as_os_str().as_encoded_bytes()), + + None => { + #[cfg(any(target_os = "android", target_os = "linux"))] + let octets = host_socket_addr.as_abstract_name().unwrap_or_default(); + + #[cfg(not(any(target_os = "android", target_os = "linux")))] + let octets = &[]; + + // Unnamed Unix Domain sockets are encoded as `00`: + ("00", octets) + } + }; + from_octets(prefix, octets, guest_port, url) + } + + fn parse_firecracker(&self) -> IoResult<(PathBuf, u32)> { + if self.scheme_str() == Some("fc") { + let host_hex = self.host().ok_or_else(|| io_input_err("URI host must be present"))?; + let mut host_octets = + Vec::from_hex(host_hex).map_err(|_| io_input_err("URI host must be hexadecimal encoded"))?; + + let colon_pos = host_octets + .iter() + .rposition(|octet| *octet == b':') + .ok_or_else(|| io_input_err("URI host does not encode port"))?; + + let guest_port = String::from_utf8(host_octets.split_off(colon_pos)) + .map_err(|_| io_input_err("URI guest port is not valid UTF8"))? + .split_at(1) + .1 + .parse::() + .map_err(|_| io_input_err("URI guest port could not be parsed"))?; + + let host_socket_path = OsString::from_vec(host_octets).into(); + + Ok((host_socket_path, guest_port)) + } else { + Err(io_input_err("URI scheme on a Firecracker socket must be fc://")) + } + } +} diff --git a/src/uri/firecracker/tests.rs b/src/uri/firecracker/tests.rs new file mode 100644 index 0000000..a696857 --- /dev/null +++ b/src/uri/firecracker/tests.rs @@ -0,0 +1,172 @@ +use std::{fmt::Debug, os::unix::net::SocketAddr as UnixSocketAddr, path::PathBuf}; + +#[cfg(target_os = "android")] +use std::os::android::net::SocketAddrExt as _; + +#[cfg(target_os = "linux")] +use std::os::linux::net::SocketAddrExt as _; + +use super::FirecrackerUri as _; +use hyper::Uri; + +fn assert_debug_eq(expected: E, value: T) +where + E: Debug, + T: Debug, +{ + assert_eq!(format!("{expected:?}"), format!("{value:?}")) +} + +#[cfg(any(target_os = "android", target_os = "linux"))] +#[test] +fn decode_abstract_name_address() { + // TODO: Randomise: + let abstract_name = "abstract"; + let port = 1000; + let path_and_query = "/route"; + + let uri = { + let formatted = format!("{abstract_name}:{port}"); + let uri_str = format!("fc://00{}{path_and_query}", hex::encode(formatted)); + uri_str.parse::().unwrap() + }; + + let expected = { + let address = UnixSocketAddr::from_abstract_name(abstract_name).unwrap(); + (address, port) + }; + + assert_debug_eq(expected, uri.parse_firecracker_addr().unwrap()); +} + +#[test] +fn decode_abstract_name_path() { + // TODO: Randomise: + let abstract_name = "\0abstract"; + let port = 1000; + let path_and_query = "/route"; + + let uri = { + let formatted = format!("{abstract_name}:{port}"); + let uri_str = format!("fc://{}{path_and_query}", hex::encode(formatted)); + uri_str.parse::().unwrap() + }; + + let expected = { + let path = PathBuf::from(abstract_name); + (path, port) + }; + + assert_eq!(expected, uri.parse_firecracker().unwrap()); +} + +#[test] +fn decode_pathname_address() { + // TODO: Randomise: + let path = "/tmp/socket.sock"; + let port = 1000; + let path_and_query = "/route"; + + let uri = { + let formatted = format!("{path}:{port}"); + let uri_str = format!("fc://{}{path_and_query}", hex::encode(formatted)); + uri_str.parse::().unwrap() + }; + + let expected = { + let address = UnixSocketAddr::from_pathname(path).unwrap(); + (address, port) + }; + + assert_debug_eq(expected, uri.parse_firecracker_addr().unwrap()); +} + +#[test] +fn decode_pathname_path() { + // TODO: Randomise: + let path = "/tmp/socket.sock"; + let port = 1000; + let path_and_query = "/route"; + + let uri = { + let formatted = format!("{path}:{port}"); + let uri_str = format!("fc://{}{path_and_query}", hex::encode(formatted)); + uri_str.parse::().unwrap() + }; + + let expected = { + let path = PathBuf::from(path); + (path, port) + }; + + assert_eq!(expected, uri.parse_firecracker().unwrap()); +} + +#[cfg(any(target_os = "android", target_os = "linux"))] +#[test] +fn encode_abstract_name_as_address() { + // TODO: Randomise: + let abstract_name = "abstract"; + let port = 1000; + let path_and_query = "/route"; + + let address = UnixSocketAddr::from_abstract_name(abstract_name).unwrap(); + + let expected = { + let formatted = format!("{abstract_name}:{port}"); + let uri_str = format!("fc://00{}{path_and_query}", hex::encode(formatted)); + uri_str.parse::().unwrap() + }; + + assert_eq!(expected, Uri::firecracker_addr(&address, port, path_and_query).unwrap()); +} + +#[test] +fn encode_abstract_name_as_path() { + // TODO: Randomise: + let abstract_name = "\0abstract"; + let port = 1000; + let path_and_query = "/route"; + + let expected = { + let formatted = format!("{abstract_name}:{port}"); + let uri_str = format!("fc://{}{path_and_query}", hex::encode(formatted)); + uri_str.parse::().unwrap() + }; + + assert_eq!(expected, Uri::firecracker(abstract_name, port, path_and_query).unwrap()); +} + +#[test] +fn encode_pathname_as_address() { + // TODO: Randomise: + let path = "/tmp/socket.sock"; + let port = 1000; + let path_and_query = "/route"; + + let address = UnixSocketAddr::from_pathname(path).unwrap(); + + let expected = { + let formatted = format!("{path}:{port}"); + let uri_str = format!("fc://{}{path_and_query}", hex::encode(formatted)); + uri_str.parse::().unwrap() + }; + + assert_eq!(expected, Uri::firecracker_addr(&address, port, path_and_query).unwrap()); +} + +#[test] +fn encode_pathname_as_path() { + // TODO: Randomise: + let path = "/tmp/socket.sock"; + let port = 1000; + let path_and_query = "/route"; + + let expected = { + let formatted = format!("{path}:{port}"); + let uri_str = format!("fc://{}{path_and_query}", hex::encode(formatted)); + uri_str.parse::().unwrap() + }; + + assert_eq!(expected, Uri::firecracker(path, port, path_and_query).unwrap()); +} diff --git a/src/uri/mod.rs b/src/uri/mod.rs new file mode 100644 index 0000000..93bce0c --- /dev/null +++ b/src/uri/mod.rs @@ -0,0 +1,32 @@ +#[cfg(feature = "firecracker")] +#[cfg_attr(docsrs, doc(cfg(feature = "firecracker")))] +pub mod firecracker; + +#[cfg(feature = "unix")] +#[cfg_attr(docsrs, doc(cfg(feature = "unix")))] +pub mod unix; + +#[cfg(feature = "vsock")] +#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))] +pub mod vsock; + +#[cfg(any(feature = "firecracker", feature = "unix", feature = "vsock"))] +use std::io::{Error, ErrorKind}; + +#[cfg(feature = "firecracker")] +#[cfg_attr(docsrs, doc(cfg(feature = "firecracker")))] +pub use self::firecracker::FirecrackerUri; + +#[cfg(feature = "unix")] +#[cfg_attr(docsrs, doc(cfg(feature = "unix")))] +pub use self::unix::UnixUri; + +#[cfg(feature = "vsock")] +#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))] +pub use self::vsock::VsockUri; + +#[cfg(any(feature = "firecracker", feature = "unix", feature = "vsock"))] +#[inline(always)] +fn io_input_err(detail: &str) -> Error { + Error::new(ErrorKind::InvalidInput, detail) +} diff --git a/src/uri/unix/mod.rs b/src/uri/unix/mod.rs new file mode 100644 index 0000000..e502cc0 --- /dev/null +++ b/src/uri/unix/mod.rs @@ -0,0 +1,108 @@ +#[cfg(test)] +mod tests; + +use std::{ + ffi::OsString, + io::Result as IoResult, + os::unix::{ffi::OsStringExt as _, net::SocketAddr as UnixSocketAddr}, + path::{Path, PathBuf}, +}; + +#[cfg(target_os = "android")] +use std::os::android::net::SocketAddrExt as _; + +#[cfg(target_os = "linux")] +use std::os::linux::net::SocketAddrExt as _; + +use hex::{encode, FromHex}; +use http::uri::{InvalidUri, Uri}; + +use super::io_input_err; + +/// An extension trait for a URI that allows constructing a hex-encoded Unix Domain socket URI. +pub trait UnixUri { + /// Create a new Unix URI with the given socket path and in-socket URI. + fn unix, S: AsRef>(socket_path: P, url: S) -> Result; + + /// Create a new Unix URI with the given socket address and in-socket URI. + fn unix_addr>(socket_addr: &UnixSocketAddr, url: S) -> Result; + + /// Try to deconstruct this Unix URI's socket path. + fn parse_unix(&self) -> IoResult; + + /// Try to deconstruct this Unix URI's socket address. + fn parse_unix_addr(&self) -> IoResult { + let path = self.parse_unix()?; + + #[cfg(any(target_os = "android", target_os = "linux"))] + let from_abstract_name = |err| { + let octets = path.as_os_str().as_encoded_bytes(); + match octets.split_first() { + Some((0, name)) => UnixSocketAddr::from_abstract_name(name), + _ => Err(err), + } + }; + + #[cfg(not(any(target_os = "android", target_os = "linux")))] + let from_abstract_name = |err| Err(err); + + UnixSocketAddr::from_pathname(&path).or_else(from_abstract_name) + } +} + +fn from_octets(prefix: &str, octets: &[u8], url: S) -> Result +where + S: AsRef, +{ + let authority = encode(octets); + let path_and_query = url.as_ref().trim_start_matches('/'); + let uri_str = format!("unix://{prefix}{authority}/{path_and_query}"); + uri_str.parse() +} + +impl UnixUri for Uri { + fn unix(socket_path: P, url: S) -> Result + where + P: AsRef, + S: AsRef, + { + let octets = socket_path.as_ref().as_os_str().as_encoded_bytes(); + from_octets("", octets, url) + } + + fn unix_addr(socket_addr: &UnixSocketAddr, url: S) -> Result + where + S: AsRef, + { + let (prefix, octets) = match socket_addr.as_pathname() { + Some(socket_path) => ("", socket_path.as_os_str().as_encoded_bytes()), + + None => { + #[cfg(any(target_os = "android", target_os = "linux"))] + let octets = socket_addr.as_abstract_name().unwrap_or_default(); + + #[cfg(not(any(target_os = "android", target_os = "linux")))] + let octets = &[]; + + // Unnamed Unix Domain sockets are encoded as `00`: + ("00", octets) + } + }; + from_octets(prefix, octets, url) + } + + fn parse_unix(&self) -> IoResult { + if self.scheme_str() == Some("unix") { + match self.host() { + Some(host) => { + let octets = + Vec::from_hex(host).map_err(|_| io_input_err("URI host must be hexadecimal encoded"))?; + Ok(OsString::from_vec(octets).into()) + } + None => Err(io_input_err("URI host must be present")), + } + } else { + Err(io_input_err("URI scheme on a Unix Domain socket must be unix://")) + } + } +} diff --git a/src/uri/unix/tests.rs b/src/uri/unix/tests.rs new file mode 100644 index 0000000..3412f41 --- /dev/null +++ b/src/uri/unix/tests.rs @@ -0,0 +1,144 @@ +use std::{fmt::Debug, os::unix::net::SocketAddr as UnixSocketAddr, path::PathBuf}; + +#[cfg(target_os = "android")] +use std::os::android::net::SocketAddrExt as _; + +#[cfg(target_os = "linux")] +use std::os::linux::net::SocketAddrExt as _; + +use super::UnixUri as _; +use hyper::Uri; + +fn assert_debug_eq(expected: E, value: T) +where + E: Debug, + T: Debug, +{ + assert_eq!(format!("{expected:?}"), format!("{value:?}")) +} + +#[cfg(any(target_os = "android", target_os = "linux"))] +#[test] +fn decode_abstract_name_address() { + // TODO: Randomise: + let abstract_name = "abstract"; + let path_and_query = "/route"; + + let uri = { + let uri_str = format!("unix://00{}{path_and_query}", hex::encode(abstract_name)); + uri_str.parse::().unwrap() + }; + + let expected = UnixSocketAddr::from_abstract_name(abstract_name).unwrap(); + + assert_debug_eq(expected, uri.parse_unix_addr().unwrap()); +} + +#[test] +fn decode_abstract_name_path() { + // TODO: Randomise: + let abstract_name = "\0abstract"; + let path_and_query = "/route"; + + let uri = { + let uri_str = format!("unix://{}{path_and_query}", hex::encode(abstract_name)); + uri_str.parse::().unwrap() + }; + + let expected = PathBuf::from(abstract_name); + + assert_eq!(expected, uri.parse_unix().unwrap()); +} + +#[test] +fn decode_pathname_address() { + // TODO: Randomise: + let path = "/tmp/socket.sock"; + let path_and_query = "/route"; + + let uri = { + let uri_str = format!("unix://{}{path_and_query}", hex::encode(path)); + uri_str.parse::().unwrap() + }; + + let expected = UnixSocketAddr::from_pathname(path).unwrap(); + + assert_debug_eq(expected, uri.parse_unix_addr().unwrap()); +} + +#[test] +fn decode_pathname_path() { + // TODO: Randomise: + let path = "/tmp/socket.sock"; + let path_and_query = "/route"; + + let uri = { + let uri_str = format!("unix://{}{path_and_query}", hex::encode(path)); + uri_str.parse::().unwrap() + }; + + let expected = PathBuf::from(path); + + assert_eq!(expected, uri.parse_unix().unwrap()); +} + +#[cfg(any(target_os = "android", target_os = "linux"))] +#[test] +fn encode_abstract_name_as_address() { + // TODO: Randomise: + let abstract_name = "abstract"; + let path_and_query = "/route"; + + let address = UnixSocketAddr::from_abstract_name(abstract_name).unwrap(); + + let expected = { + let uri_str = format!("unix://00{}{path_and_query}", hex::encode(abstract_name)); + uri_str.parse::().unwrap() + }; + + assert_eq!(expected, Uri::unix_addr(&address, path_and_query).unwrap()); +} + +#[test] +fn encode_abstract_name_as_path() { + // TODO: Randomise: + let abstract_name = "\0abstract"; + let path_and_query = "/route"; + + let expected = { + let uri_str = format!("unix://{}{path_and_query}", hex::encode(abstract_name)); + uri_str.parse::().unwrap() + }; + + assert_eq!(expected, Uri::unix(abstract_name, path_and_query).unwrap()); +} + +#[test] +fn encode_pathname_as_address() { + // TODO: Randomise: + let path = "/tmp/socket.sock"; + let path_and_query = "/route"; + + let address = UnixSocketAddr::from_pathname(path).unwrap(); + + let expected = { + let uri_str = format!("unix://{}{path_and_query}", hex::encode(path)); + uri_str.parse::().unwrap() + }; + + assert_eq!(expected, Uri::unix_addr(&address, path_and_query).unwrap()); +} + +#[test] +fn encode_pathname_as_path() { + // TODO: Randomise: + let path = "/tmp/socket.sock"; + let path_and_query = "/route"; + + let expected = { + let uri_str = format!("unix://{}{path_and_query}", hex::encode(path)); + uri_str.parse::().unwrap() + }; + + assert_eq!(expected, Uri::unix(path, path_and_query).unwrap()); +} diff --git a/src/uri/vsock/mod.rs b/src/uri/vsock/mod.rs new file mode 100644 index 0000000..b7f587b --- /dev/null +++ b/src/uri/vsock/mod.rs @@ -0,0 +1,59 @@ +#[cfg(test)] +mod tests; + +use std::io::Result as IoResult; + +use hex::{encode, FromHex}; +use http::uri::{InvalidUri, Uri}; +use vsock::VsockAddr; + +use super::io_input_err; + +/// An extension trait for hyper URI that allows constructing a hex-encoded VSOCK socket URI. +pub trait VsockUri { + /// Create a new VSOCK URI with the given address and in-socket URL. + fn vsock>(addr: VsockAddr, url: S) -> Result; + + /// Deconstruct this VSOCK URI into its address. + fn parse_vsock(&self) -> IoResult; +} + +impl VsockUri for Uri { + fn vsock(addr: VsockAddr, url: S) -> Result + where + S: AsRef, + { + let authority = encode(format!("{}.{}", addr.cid(), addr.port())); + let path_and_query = url.as_ref().trim_start_matches('/'); + let uri_str = format!("vsock://{authority}/{path_and_query}"); + uri_str.parse() + } + + fn parse_vsock(&self) -> IoResult { + if self.scheme_str() == Some("vsock") { + match self.host() { + Some(host) => { + let full_str = Vec::from_hex(host) + .map_err(|_| io_input_err("URI host must be hex")) + .map(|bytes| String::from_utf8_lossy(&bytes).into_owned())?; + let splits = full_str.split_once('.').ok_or_else(|| { + io_input_err("URI host could not be split at . into 2 slices (CID, then port)") + })?; + let cid = splits + .0 + .parse::() + .map_err(|_| io_input_err("First split of URI (CID) can't be parsed"))?; + let port = splits + .1 + .parse::() + .map_err(|_| io_input_err("Second split of URI (port) can't be parsed"))?; + + Ok(VsockAddr::new(cid, port)) + } + None => Err(io_input_err("URI host must be present")), + } + } else { + Err(io_input_err("URI scheme on a VSOCK socket must be vsock://")) + } + } +} diff --git a/src/uri/vsock/tests.rs b/src/uri/vsock/tests.rs new file mode 100644 index 0000000..96d4d05 --- /dev/null +++ b/src/uri/vsock/tests.rs @@ -0,0 +1,40 @@ +use hyper::Uri; +use vsock::VsockAddr; + +use super::VsockUri as _; + +#[test] +fn decode() { + // TODO: Randomise: + let cid = 10; + let port = 20; + let path_and_query = "/route"; + + let uri = { + let formatted = format!("{cid}.{port}"); + let uri_str = format!("vsock://{}{path_and_query}", hex::encode(formatted)); + uri_str.parse::().unwrap() + }; + + let expected = VsockAddr::new(cid, port); + + assert_eq!(expected, uri.parse_vsock().unwrap()); +} + +#[test] +fn encode() { + // TODO: Randomise: + let cid = 10; + let port = 20; + let path_and_query = "/route"; + + let address = VsockAddr::new(cid, port); + + let expected = { + let formatted = format!("{cid}.{port}"); + let uri_str = format!("vsock://{}{path_and_query}", hex::encode(formatted)); + uri_str.parse::().unwrap() + }; + + assert_eq!(expected, Uri::vsock(address, path_and_query).unwrap()); +} diff --git a/src/utils/firecracker.rs b/src/utils/firecracker.rs new file mode 100644 index 0000000..ecb6646 --- /dev/null +++ b/src/utils/firecracker.rs @@ -0,0 +1,24 @@ +use std::io::{Error, ErrorKind, Result}; + +pub fn format_request(guest_port: u32) -> String { + format!("CONNECT {guest_port}\n") +} + +pub fn parse_connection_response(stream: S, response: Result>) -> Result { + match response { + Ok(Some(line)) => { + if line.starts_with("OK") { + Ok(stream) + } else { + Err(Error::new( + ErrorKind::ConnectionRefused, + "Firecracker refused to establish a tunnel to the given guest port", + )) + } + } + _ => Err(Error::new( + ErrorKind::InvalidInput, + "Could not read Firecracker response", + )), + } +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs new file mode 100644 index 0000000..382b366 --- /dev/null +++ b/src/utils/mod.rs @@ -0,0 +1,99 @@ +#[cfg(feature = "firecracker")] +#[cfg_attr(docsrs, doc(cfg(feature = "firecracker")))] +pub mod firecracker; + +#[cfg(feature = "vsock")] +#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))] +pub mod vsock; + +#[allow(unused)] +macro_rules! hyper_io_by_deref { + ($ty:ty) => { + const _: () = { + use std::{ + io::{IoSlice, Result}, + pin::Pin, + task::{Context, Poll}, + }; + + use hyper::rt::{Read, ReadBufCursor, Write}; + + impl Read for $ty { + #[inline(always)] + fn poll_read( + self: Pin<&mut Self>, + context: &mut Context<'_>, + cursor: ReadBufCursor<'_>, + ) -> Poll> { + Pin::new(&mut self.get_mut().0).poll_read(context, cursor) + } + } + + impl Write for $ty { + #[inline(always)] + fn poll_write(self: Pin<&mut Self>, context: &mut Context<'_>, buffer: &[u8]) -> Poll> { + Pin::new(&mut self.get_mut().0).poll_write(context, buffer) + } + + #[inline(always)] + fn poll_flush(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.get_mut().0).poll_flush(context) + } + + #[inline(always)] + fn poll_shutdown(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.get_mut().0).poll_shutdown(context) + } + + #[inline(always)] + fn is_write_vectored(&self) -> bool { + self.0.is_write_vectored() + } + + #[inline(always)] + fn poll_write_vectored( + self: Pin<&mut Self>, + context: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + Pin::new(&mut self.get_mut().0).poll_write_vectored(context, bufs) + } + } + }; + }; +} + +#[allow(unused)] +macro_rules! hyper_util_connection_by_deref { + ($ty:ty) => { + const _: () = { + use std::ops::Deref as _; + + use hyper_util::client::legacy::connect::{Connected, Connection}; + + impl Connection for $ty { + fn connected(&self) -> Connected { + self.deref().connected() + } + } + }; + }; +} + +#[allow(unused)] +macro_rules! hyper_util_connection_default { + ($ty:ty) => { + const _: () = { + use hyper_util::client::legacy::connect::{Connected, Connection}; + + impl Connection for $ty { + fn connected(&self) -> Connected { + Connected::new() + } + } + }; + }; +} + +#[allow(unused)] +pub(crate) use {hyper_io_by_deref, hyper_util_connection_by_deref, hyper_util_connection_default}; diff --git a/src/utils/vsock.rs b/src/utils/vsock.rs new file mode 100644 index 0000000..c377f35 --- /dev/null +++ b/src/utils/vsock.rs @@ -0,0 +1,78 @@ +use std::{ + io::{Error, ErrorKind, Result}, + os::fd::RawFd, + task::Poll, +}; + +use hyper::rt::ReadBufCursor; +use vsock::VsockAddr; + +pub fn check_connection(raw_fd: RawFd) -> Result<()> { + let mut sock_err: libc::c_int = 0; + let mut sock_err_len: libc::socklen_t = size_of::() as libc::socklen_t; + let err = unsafe { + libc::getsockopt( + raw_fd, + libc::SOL_SOCKET, + libc::SO_ERROR, + &mut sock_err as *mut _ as *mut libc::c_void, + &mut sock_err_len as *mut libc::socklen_t, + ) + }; + + if err < 0 { + Err(Error::last_os_error()) + } else if sock_err != 0 { + Err(Error::from_raw_os_error(sock_err)) + } else { + Ok(()) + } +} + +pub fn raw_connect(addr: VsockAddr) -> Result { + let socket = unsafe { libc::socket(libc::AF_VSOCK, libc::SOCK_STREAM, 0) }; + if socket < 0 { + Err(Error::last_os_error()) + } else if unsafe { libc::fcntl(socket, libc::F_SETFL, libc::O_NONBLOCK | libc::O_CLOEXEC) } < 0 { + let _ = unsafe { libc::close(socket) }; + Err(Error::last_os_error()) + } else { + let addr = &addr as *const _ as *const libc::sockaddr; + let addrlen = size_of::() as libc::socklen_t; + if unsafe { libc::connect(socket, addr, addrlen) } >= 0 { + Ok(socket) + } else { + let err = Error::last_os_error(); + match err.raw_os_error() { + Some(libc::EINPROGRESS) => Ok(socket), + Some(_os_err) => { + let _ = unsafe { libc::close(socket) }; + Err(err) + } + None => unreachable!(), + } + } + } +} + +pub fn try_advance_cursor(cursor: &mut ReadBufCursor<'_>, amount: Result) -> Option>> { + match amount { + Ok(amount) => { + unsafe { + cursor.advance(amount); + } + Some(Poll::Ready(Ok(()))) + } + Err(err) => match err.kind() { + ErrorKind::Interrupted | ErrorKind::WouldBlock => None, + _ => Some(Poll::Ready(Err(err))), + }, + } +} + +pub fn try_poll_write(amount: Result) -> Option>> { + match amount { + Err(err) if matches!(err.kind(), ErrorKind::Interrupted | ErrorKind::WouldBlock) => None, + other => Some(Poll::Ready(other)), + } +} diff --git a/tests/async_io.rs b/tests/async_io.rs index b3db5c1..e5ba296 100644 --- a/tests/async_io.rs +++ b/tests/async_io.rs @@ -2,7 +2,7 @@ use std::{future::Future, sync::Arc}; use async_executor::Executor; use bytes::Bytes; -use common::{check_response, serve_firecracker, serve_unix, serve_vsock}; +use common::{check_response, serve_firecracker, serve_unix, serve_unix_abstract, serve_vsock}; use http::{Request, Uri}; use http_body_util::Full; use hyper::client::conn::http1::handshake; @@ -18,7 +18,7 @@ use smol_hyper::rt::SmolExecutor; mod common; #[test] -fn async_io_unix_raw_connectivity() { +fn async_io_unix_raw_connectivity_with_pathname() { run(|executor| async move { let socket_path = serve_unix(); let io = AsyncIoBackend::connect_to_unix_socket(&socket_path).await.unwrap(); @@ -32,6 +32,21 @@ fn async_io_unix_raw_connectivity() { }); } +#[test] +fn async_io_unix_raw_connectivity_with_abstract_name() { + run(|executor| async move { + let socket_path = serve_unix_abstract(); + let io = AsyncIoBackend::connect_to_unix_socket(&socket_path).await.unwrap(); + let (mut send_request, conn) = handshake::<_, Full>(io).await.unwrap(); + executor.spawn(conn).detach(); + let response = send_request + .send_request(Request::new(Full::new(Bytes::new()))) + .await + .unwrap(); + check_response(response).await; + }); +} + #[test] fn async_io_unix_pooled_connectivity() { run(|executor| async move { diff --git a/tests/common.rs b/tests/common.rs index f91dd1d..d791a3e 100644 --- a/tests/common.rs +++ b/tests/common.rs @@ -37,6 +37,29 @@ pub fn serve_unix() -> PathBuf { socket_path } +#[allow(unused)] +pub fn serve_unix_abstract() -> PathBuf { + let socket_path = PathBuf::from("\0test").with_extension(Uuid::new_v4().to_string()); + + let cloned_socket_path = socket_path.clone(); + in_tokio_thread(async move { + let listener = UnixListener::bind(cloned_socket_path).unwrap(); + + loop { + let (stream, _) = listener.accept().await.unwrap(); + tokio::spawn(async move { + http1::Builder::new() + .serve_connection(TokioIo::new(stream), service_fn(responder)) + .await + .unwrap(); + }); + } + }); + + std::thread::sleep(Duration::from_millis(1)); + socket_path +} + #[allow(unused)] pub fn serve_vsock() -> VsockAddr { let port = fastrand::u32(15000..=65536); diff --git a/tests/tokio.rs b/tests/tokio.rs index 8d2f670..e8f2b6c 100644 --- a/tests/tokio.rs +++ b/tests/tokio.rs @@ -1,5 +1,5 @@ use bytes::Bytes; -use common::{check_response, serve_firecracker, serve_unix, serve_vsock}; +use common::{check_response, serve_firecracker, serve_unix, serve_unix_abstract, serve_vsock}; use http::{Request, Uri}; use http_body_util::Full; use hyper::client::conn::http1::handshake; @@ -14,7 +14,7 @@ use hyper_util::{client::legacy::Client, rt::TokioExecutor}; mod common; #[tokio::test] -async fn tokio_unix_raw_connectivity() { +async fn tokio_unix_raw_connectivity_with_pathname() { let socket_path = serve_unix(); let io = TokioBackend::connect_to_unix_socket(&socket_path).await.unwrap(); let (mut send_request, conn) = handshake::<_, Full>(io).await.unwrap(); @@ -26,6 +26,19 @@ async fn tokio_unix_raw_connectivity() { check_response(response).await; } +#[tokio::test] +async fn tokio_unix_raw_connectivity_with_abstract_name() { + let socket_path = serve_unix_abstract(); + let io = TokioBackend::connect_to_unix_socket(&socket_path).await.unwrap(); + let (mut send_request, conn) = handshake::<_, Full>(io).await.unwrap(); + tokio::spawn(conn); + let response = send_request + .send_request(Request::new(Full::new(Bytes::new()))) + .await + .unwrap(); + check_response(response).await; +} + #[tokio::test] async fn tokio_unix_pooled_connectivity() { let socket_path = serve_unix();