From ad7fa6f49a33aac429402270b91b64fbcfedc522 Mon Sep 17 00:00:00 2001 From: Chris Branch Date: Wed, 29 Nov 2023 21:01:04 +0000 Subject: [PATCH] Add "borrowed-fd" feature for sending BorrowedFd on a socket This is a more ergonomic way to use slices of ownership-safe file descriptors instead of coercing them into raw objects. --- Cargo.toml | 3 ++ src/lib.rs | 101 +++++++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 102 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 055ec85..cca2719 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,9 @@ repository = "https://github.com/standard-ai/sendfd" documentation = "https://docs.rs/sendfd" readme = "README.mkd" +[features] +borrowed-fd = [] + [dependencies] libc = "0.2" tokio = { version = "1.0.0", features = [ "net" ], optional = true } diff --git a/src/lib.rs b/src/lib.rs index d49e055..f315bca 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,6 +4,8 @@ extern crate libc; #[cfg(feature = "tokio")] extern crate tokio; +#[cfg(feature = "borrowed-fd")] +use std::os::fd::BorrowedFd; use std::os::unix::io::{AsRawFd, RawFd}; use std::os::unix::net; use std::{alloc, io, mem, ptr}; @@ -16,6 +18,9 @@ pub mod changelog; pub trait SendWithFd { /// Send the bytes and the file descriptors. fn send_with_fd(&self, bytes: &[u8], fds: &[RawFd]) -> io::Result; + /// Send the bytes and the file descriptors. + #[cfg(feature = "borrowed-fd")] + fn send_with_borrowed_fd(&self, bytes: &[u8], fds: &[BorrowedFd<'_>]) -> io::Result; } /// An extension trait that enables receiving associated file descriptors along with the data. @@ -77,7 +82,7 @@ unsafe fn construct_msghdr_for( /// A common implementation of `sendmsg` that sends provided bytes with ancillary file descriptors /// over either a datagram or stream unix socket. -fn send_with_fd(socket: RawFd, bs: &[u8], fds: &[RawFd]) -> io::Result { +fn send_with_fd(socket: RawFd, bs: &[u8], fds: &[F]) -> io::Result { unsafe { let mut iov = libc::iovec { // NB: this casts *const to *mut, and in doing so we trust the OS to be a good citizen @@ -99,7 +104,7 @@ fn send_with_fd(socket: RawFd, bs: &[u8], fds: &[RawFd]) -> io::Result { let cmsg_data = libc::CMSG_DATA(cmsg_header) as *mut RawFd; for (i, fd) in fds.iter().enumerate() { - ptr::write_unaligned(cmsg_data.add(i), *fd); + ptr::write_unaligned(cmsg_data.add(i), fd.as_raw_fd()); } let count = libc::sendmsg(socket, &msghdr as *const _, 0); if count < 0 { @@ -181,6 +186,15 @@ impl SendWithFd for net::UnixStream { fn send_with_fd(&self, bytes: &[u8], fds: &[RawFd]) -> io::Result { send_with_fd(self.as_raw_fd(), bytes, fds) } + + /// Send the bytes and the file descriptors as a stream. + /// + /// Neither is guaranteed to be received by the other end in a single chunk and + /// may arrive entirely independently. + #[cfg(feature = "borrowed-fd")] + fn send_with_borrowed_fd(&self, bytes: &[u8], fds: &[BorrowedFd<'_>]) -> io::Result { + send_with_fd(self.as_raw_fd(), bytes, fds) + } } #[cfg(feature = "tokio")] @@ -193,6 +207,17 @@ impl SendWithFd for tokio::net::UnixStream { fn send_with_fd(&self, bytes: &[u8], fds: &[RawFd]) -> io::Result { self.try_io(Interest::WRITABLE, || send_with_fd(self.as_raw_fd(), bytes, fds)) } + + /// Send the bytes and the file descriptors as a stream. + /// + /// Neither is guaranteed to be received by the other end in a single chunk and + /// may arrive entirely independently. + #[cfg(feature = "borrowed-fd")] + fn send_with_borrowed_fd(&self, bytes: &[u8], fds: &[BorrowedFd<'_>]) -> io::Result { + self.try_io(Interest::WRITABLE, || { + send_with_fd(self.as_raw_fd(), bytes, fds) + }) + } } #[cfg(feature = "tokio")] @@ -206,6 +231,16 @@ impl SendWithFd for tokio::net::unix::WriteHalf<'_> { let unix_stream: &tokio::net::UnixStream = self.as_ref(); unix_stream.send_with_fd(bytes, fds) } + + /// Send the bytes and the file descriptors as a stream. + /// + /// Neither is guaranteed to be received by the other end in a single chunk and + /// may arrive entirely independently. + #[cfg(feature = "borrowed-fd")] + fn send_with_borrowed_fd(&self, bytes: &[u8], fds: &[BorrowedFd<'_>]) -> io::Result { + let unix_stream: &tokio::net::UnixStream = self.as_ref(); + unix_stream.send_with_borrowed_fd(bytes, fds) + } } impl SendWithFd for net::UnixDatagram { @@ -217,6 +252,16 @@ impl SendWithFd for net::UnixDatagram { fn send_with_fd(&self, bytes: &[u8], fds: &[RawFd]) -> io::Result { send_with_fd(self.as_raw_fd(), bytes, fds) } + + /// Send the bytes and the file descriptors as a single packet. + /// + /// It is guaranteed that the bytes and the associated file descriptors will arrive at the same + /// time, however the receiver end may not receive the full message if its buffers are too + /// small. + #[cfg(feature = "borrowed-fd")] + fn send_with_borrowed_fd(&self, bytes: &[u8], fds: &[BorrowedFd<'_>]) -> io::Result { + send_with_fd(self.as_raw_fd(), bytes, fds) + } } #[cfg(feature = "tokio")] @@ -230,6 +275,18 @@ impl SendWithFd for tokio::net::UnixDatagram { fn send_with_fd(&self, bytes: &[u8], fds: &[RawFd]) -> io::Result { self.try_io(Interest::WRITABLE, || send_with_fd(self.as_raw_fd(), bytes, fds)) } + + /// Send the bytes and the file descriptors as a single packet. + /// + /// It is guaranteed that the bytes and the associated file descriptors will arrive at the same + /// time, however the receiver end may not receive the full message if its buffers are too + /// small. + #[cfg(feature = "borrowed-fd")] + fn send_with_borrowed_fd(&self, bytes: &[u8], fds: &[BorrowedFd<'_>]) -> io::Result { + self.try_io(Interest::WRITABLE, || { + send_with_fd(self.as_raw_fd(), bytes, fds) + }) + } } impl RecvWithFd for net::UnixStream { @@ -441,4 +498,44 @@ mod tests { panic!("expected an error when sending a junk file descriptor"); } } + + #[cfg(feature = "borrowed-fd")] + #[test] + fn borrowed_fd() { + use std::os::fd::AsFd; + + let (l, r) = net::UnixStream::pair().expect("create UnixStream pair"); + let sent_bytes = b"hello world!"; + let sent_fds = [l.as_fd(), r.as_fd()]; + assert_eq!( + l.send_with_borrowed_fd(&sent_bytes[..], &sent_fds[..]) + .expect("send should be successful"), + sent_bytes.len() + ); + let mut recv_bytes = [0; 128]; + let mut recv_fds = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; + assert_eq!( + r.recv_with_fd(&mut recv_bytes, &mut recv_fds) + .expect("recv should be successful"), + (sent_bytes.len(), sent_fds.len()) + ); + assert_eq!(recv_bytes[..sent_bytes.len()], sent_bytes[..]); + for (&sent, &recvd) in sent_fds.iter().zip(&recv_fds[..]) { + // Modify the sent resource and check if the received resource has been modified the + // same way. + let expected_value = Some(std::time::Duration::from_secs(42)); + unsafe { + let s = net::UnixStream::from(sent.try_clone_to_owned().unwrap()); + s.set_read_timeout(expected_value) + .expect("set read timeout"); + std::mem::forget(s); + assert_eq!( + net::UnixStream::from_raw_fd(recvd) + .read_timeout() + .expect("get read timeout"), + expected_value + ); + } + } + } }