diff --git a/vhost/CHANGELOG.md b/vhost/CHANGELOG.md index 4867f172..f72ae24f 100644 --- a/vhost/CHANGELOG.md +++ b/vhost/CHANGELOG.md @@ -8,6 +8,7 @@ ### Changed ### Deprecated ### Fixed +- [[#338]](https://github.com/rust-vmm/vhost/pull/338) vhost: fix double-locking in Backend to Frontend request handlers ## v0.15.0 diff --git a/vhost/src/vhost_user/backend_req.rs b/vhost/src/vhost_user/backend_req.rs index 026d4e14..3b28bcfc 100644 --- a/vhost/src/vhost_user/backend_req.rs +++ b/vhost/src/vhost_user/backend_req.rs @@ -5,7 +5,7 @@ use std::io; use std::mem; use std::os::unix::io::{AsRawFd, RawFd}; use std::os::unix::net::UnixStream; -use std::sync::{Arc, Mutex, MutexGuard}; +use std::sync::{Arc, Mutex}; use super::connection::Endpoint; use super::message::*; @@ -13,6 +13,12 @@ use super::{Error, HandlerResult, Result, VhostUserFrontendReqHandler}; use vm_memory::ByteValued; +impl From for io::Error { + fn from(e: Error) -> Self { + io::Error::other(e) + } +} + struct BackendInternal { sock: Endpoint>, @@ -85,13 +91,13 @@ impl BackendInternal { #[derive(Clone)] pub struct Backend { // underlying Unix domain socket for communication - node: Arc>, + inner: Arc>, } impl Backend { fn new(ep: Endpoint>) -> Self { Backend { - node: Arc::new(Mutex::new(BackendInternal { + inner: Arc::new(Mutex::new(BackendInternal { sock: ep, reply_ack_negotiated: false, shared_object_negotiated: false, @@ -101,21 +107,6 @@ impl Backend { } } - fn node(&self) -> MutexGuard<'_, BackendInternal> { - self.node.lock().unwrap() - } - - fn send_message( - &self, - request: BackendReq, - body: &T, - fds: Option<&[RawFd]>, - ) -> io::Result { - self.node() - .send_message(request, body, fds) - .map_err(|e| io::Error::other(format!("{e}"))) - } - /// Create a new instance from a `UnixStream` object. pub fn from_stream(sock: UnixStream) -> Self { Self::new(Endpoint::>::from_stream( @@ -129,7 +120,7 @@ impl Backend { /// the "REPLY_ACK" flag will be set in the message header for every backend to frontend request /// message. pub fn set_reply_ack_flag(&self, enable: bool) { - self.node().reply_ack_negotiated = enable; + self.inner.lock().unwrap().reply_ack_negotiated = enable; } /// Set the negotiation state of the `VHOST_USER_PROTOCOL_F_SHARED_OBJECT` protocol feature. @@ -137,7 +128,7 @@ impl Backend { /// When the `VHOST_USER_PROTOCOL_F_SHARED_OBJECT` protocol feature has been negotiated, /// the backend is allowed to send "SHARED_OBJECT_*" messages to the frontend. pub fn set_shared_object_flag(&self, enable: bool) { - self.node().shared_object_negotiated = enable; + self.inner.lock().unwrap().shared_object_negotiated = enable; } /// Set the negotiation state of the `VHOST_USER_PROTOCOL_F_SHMEM` protocol feature. @@ -145,30 +136,32 @@ impl Backend { /// When the `VHOST_USER_PROTOCOL_F_SHMEM` protocol feature has been negotiated, /// the backend is allowed to send "SHMEM_{MAP, UNMAP}" messages to the frontend. pub fn set_shmem_flag(&self, enable: bool) { - self.node().shmem_negotiated = enable; + self.inner.lock().unwrap().shmem_negotiated = enable; } /// Mark endpoint as failed with specified error code. pub fn set_failed(&self, error: i32) { - self.node().error = Some(error); + self.inner.lock().unwrap().error = Some(error); } } impl VhostUserFrontendReqHandler for Backend { /// Forward vhost-user shared-object add request to the frontend. fn shared_object_add(&self, uuid: &VhostUserSharedMsg) -> HandlerResult { - if !self.node().shared_object_negotiated { + let mut guard = self.inner.lock().unwrap(); + if !guard.shared_object_negotiated { return Err(io::Error::other("Shared Object feature not negotiated")); } - self.send_message(BackendReq::SHARED_OBJECT_ADD, uuid, None) + Ok(guard.send_message(BackendReq::SHARED_OBJECT_ADD, uuid, None)?) } /// Forward vhost-user shared-object remove request to the frontend. fn shared_object_remove(&self, uuid: &VhostUserSharedMsg) -> HandlerResult { - if !self.node().shared_object_negotiated { + let mut guard = self.inner.lock().unwrap(); + if !guard.shared_object_negotiated { return Err(io::Error::other("Shared Object feature not negotiated")); } - self.send_message(BackendReq::SHARED_OBJECT_REMOVE, uuid, None) + Ok(guard.send_message(BackendReq::SHARED_OBJECT_REMOVE, uuid, None)?) } /// Forward vhost-user shared-object lookup request to the frontend. @@ -177,30 +170,33 @@ impl VhostUserFrontendReqHandler for Backend { uuid: &VhostUserSharedMsg, fd: &dyn AsRawFd, ) -> HandlerResult { - if !self.node().shared_object_negotiated { + let mut guard = self.inner.lock().unwrap(); + if !guard.shared_object_negotiated { return Err(io::Error::other("Shared Object feature not negotiated")); } - self.send_message( + Ok(guard.send_message( BackendReq::SHARED_OBJECT_LOOKUP, uuid, Some(&[fd.as_raw_fd()]), - ) + )?) } /// Forward vhost-user memory map file request to the frontend. fn shmem_map(&self, req: &VhostUserMMap, fd: &dyn AsRawFd) -> HandlerResult { - if !self.node().shmem_negotiated { + let mut guard = self.inner.lock().unwrap(); + if !guard.shmem_negotiated { return Err(io::Error::other("SHMEM feature not negotiated")); } - self.send_message(BackendReq::SHMEM_MAP, req, Some(&[fd.as_raw_fd()])) + Ok(guard.send_message(BackendReq::SHMEM_MAP, req, Some(&[fd.as_raw_fd()]))?) } /// Forward vhost-user memory unmap file request to the frontend. fn shmem_unmap(&self, req: &VhostUserMMap) -> HandlerResult { - if !self.node().shmem_negotiated { + let mut guard = self.inner.lock().unwrap(); + if !guard.shmem_negotiated { return Err(io::Error::other("SHMEM feature not negotiated")); } - self.send_message(BackendReq::SHMEM_UNMAP, req, None) + Ok(guard.send_message(BackendReq::SHMEM_UNMAP, req, None)?) } } @@ -221,9 +217,9 @@ mod tests { fn test_backend_req_set_failed() { let (_, backend) = frontend_backend_pair(); - assert!(backend.node().error.is_none()); + assert!(backend.inner.lock().unwrap().error.is_none()); backend.set_failed(libc::EAGAIN); - assert_eq!(backend.node().error, Some(libc::EAGAIN)); + assert_eq!(backend.inner.lock().unwrap().error, Some(libc::EAGAIN)); } #[test] @@ -237,7 +233,7 @@ mod tests { backend .shared_object_remove(&VhostUserSharedMsg::default()) .unwrap_err(); - backend.node().error = None; + backend.inner.lock().unwrap().error = None; } #[test]