Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vhost/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
### Changed
### Deprecated
### Fixed
- [[#338]](https://github.com/rust-vmm/vhost/pull/338) vhost: fix double-locking in Backend to Frontend request handlers

## v0.15.0

Expand Down
68 changes: 32 additions & 36 deletions vhost/src/vhost_user/backend_req.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,20 @@ use std::io;
use std::mem;
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::net::UnixStream;
use std::sync::{Arc, Mutex, MutexGuard};
use std::sync::{Arc, Mutex};

use super::connection::Endpoint;
use super::message::*;
use super::{Error, HandlerResult, Result, VhostUserFrontendReqHandler};

use vm_memory::ByteValued;

impl From<Error> for io::Error {
fn from(e: Error) -> Self {
io::Error::other(e)
}
}

struct BackendInternal {
sock: Endpoint<VhostUserMsgHeader<BackendReq>>,

Expand Down Expand Up @@ -85,13 +91,13 @@ impl BackendInternal {
#[derive(Clone)]
pub struct Backend {
// underlying Unix domain socket for communication
node: Arc<Mutex<BackendInternal>>,
inner: Arc<Mutex<BackendInternal>>,
}

impl Backend {
fn new(ep: Endpoint<VhostUserMsgHeader<BackendReq>>) -> Self {
Backend {
node: Arc::new(Mutex::new(BackendInternal {
inner: Arc::new(Mutex::new(BackendInternal {
sock: ep,
reply_ack_negotiated: false,
shared_object_negotiated: false,
Expand All @@ -101,21 +107,6 @@ impl Backend {
}
}

fn node(&self) -> MutexGuard<'_, BackendInternal> {
self.node.lock().unwrap()
}

fn send_message<T: ByteValued>(
&self,
request: BackendReq,
body: &T,
fds: Option<&[RawFd]>,
) -> io::Result<u64> {
self.node()
.send_message(request, body, fds)
.map_err(|e| io::Error::other(format!("{e}")))
}

/// Create a new instance from a `UnixStream` object.
pub fn from_stream(sock: UnixStream) -> Self {
Self::new(Endpoint::<VhostUserMsgHeader<BackendReq>>::from_stream(
Expand All @@ -129,46 +120,48 @@ impl Backend {
/// the "REPLY_ACK" flag will be set in the message header for every backend to frontend request
/// message.
pub fn set_reply_ack_flag(&self, enable: bool) {
self.node().reply_ack_negotiated = enable;
self.inner.lock().unwrap().reply_ack_negotiated = enable;
}

/// Set the negotiation state of the `VHOST_USER_PROTOCOL_F_SHARED_OBJECT` protocol feature.
///
/// When the `VHOST_USER_PROTOCOL_F_SHARED_OBJECT` protocol feature has been negotiated,
/// the backend is allowed to send "SHARED_OBJECT_*" messages to the frontend.
pub fn set_shared_object_flag(&self, enable: bool) {
self.node().shared_object_negotiated = enable;
self.inner.lock().unwrap().shared_object_negotiated = enable;
}

/// Set the negotiation state of the `VHOST_USER_PROTOCOL_F_SHMEM` protocol feature.
///
/// When the `VHOST_USER_PROTOCOL_F_SHMEM` protocol feature has been negotiated,
/// the backend is allowed to send "SHMEM_{MAP, UNMAP}" messages to the frontend.
pub fn set_shmem_flag(&self, enable: bool) {
self.node().shmem_negotiated = enable;
self.inner.lock().unwrap().shmem_negotiated = enable;
}

/// Mark endpoint as failed with specified error code.
pub fn set_failed(&self, error: i32) {
self.node().error = Some(error);
self.inner.lock().unwrap().error = Some(error);
}
}

impl VhostUserFrontendReqHandler for Backend {
/// Forward vhost-user shared-object add request to the frontend.
fn shared_object_add(&self, uuid: &VhostUserSharedMsg) -> HandlerResult<u64> {
if !self.node().shared_object_negotiated {
let mut guard = self.inner.lock().unwrap();
if !guard.shared_object_negotiated {
return Err(io::Error::other("Shared Object feature not negotiated"));
}
self.send_message(BackendReq::SHARED_OBJECT_ADD, uuid, None)
Ok(guard.send_message(BackendReq::SHARED_OBJECT_ADD, uuid, None)?)
}

/// Forward vhost-user shared-object remove request to the frontend.
fn shared_object_remove(&self, uuid: &VhostUserSharedMsg) -> HandlerResult<u64> {
if !self.node().shared_object_negotiated {
let mut guard = self.inner.lock().unwrap();
if !guard.shared_object_negotiated {
return Err(io::Error::other("Shared Object feature not negotiated"));
}
self.send_message(BackendReq::SHARED_OBJECT_REMOVE, uuid, None)
Ok(guard.send_message(BackendReq::SHARED_OBJECT_REMOVE, uuid, None)?)
}

/// Forward vhost-user shared-object lookup request to the frontend.
Expand All @@ -177,30 +170,33 @@ impl VhostUserFrontendReqHandler for Backend {
uuid: &VhostUserSharedMsg,
fd: &dyn AsRawFd,
) -> HandlerResult<u64> {
if !self.node().shared_object_negotiated {
let mut guard = self.inner.lock().unwrap();
if !guard.shared_object_negotiated {
return Err(io::Error::other("Shared Object feature not negotiated"));
}
self.send_message(
Ok(guard.send_message(
BackendReq::SHARED_OBJECT_LOOKUP,
uuid,
Some(&[fd.as_raw_fd()]),
)
)?)
}

/// Forward vhost-user memory map file request to the frontend.
fn shmem_map(&self, req: &VhostUserMMap, fd: &dyn AsRawFd) -> HandlerResult<u64> {
if !self.node().shmem_negotiated {
let mut guard = self.inner.lock().unwrap();
if !guard.shmem_negotiated {
return Err(io::Error::other("SHMEM feature not negotiated"));
}
self.send_message(BackendReq::SHMEM_MAP, req, Some(&[fd.as_raw_fd()]))
Ok(guard.send_message(BackendReq::SHMEM_MAP, req, Some(&[fd.as_raw_fd()]))?)
}

/// Forward vhost-user memory unmap file request to the frontend.
fn shmem_unmap(&self, req: &VhostUserMMap) -> HandlerResult<u64> {
if !self.node().shmem_negotiated {
let mut guard = self.inner.lock().unwrap();
if !guard.shmem_negotiated {
return Err(io::Error::other("SHMEM feature not negotiated"));
}
self.send_message(BackendReq::SHMEM_UNMAP, req, None)
Ok(guard.send_message(BackendReq::SHMEM_UNMAP, req, None)?)
}
}

Expand All @@ -221,9 +217,9 @@ mod tests {
fn test_backend_req_set_failed() {
let (_, backend) = frontend_backend_pair();

assert!(backend.node().error.is_none());
assert!(backend.inner.lock().unwrap().error.is_none());
backend.set_failed(libc::EAGAIN);
assert_eq!(backend.node().error, Some(libc::EAGAIN));
assert_eq!(backend.inner.lock().unwrap().error, Some(libc::EAGAIN));
}

#[test]
Expand All @@ -237,7 +233,7 @@ mod tests {
backend
.shared_object_remove(&VhostUserSharedMsg::default())
.unwrap_err();
backend.node().error = None;
backend.inner.lock().unwrap().error = None;
}

#[test]
Expand Down