diff --git a/src/cache.rs b/src/cache.rs index 7e9455e..ddb193f 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -182,7 +182,7 @@ impl LocalCache { pub fn mark_resident(&self, index: usize) -> Result<()> { let mut meta = self.meta.lock().unwrap(); meta.resident.set(index, true); - meta.persist(&self.meta_path, false) + Ok(()) } pub fn mark_resident_many(&self, indices: &[usize]) -> Result<()> { @@ -193,14 +193,14 @@ impl LocalCache { for index in indices { meta.resident.set(*index, true); } - meta.persist(&self.meta_path, false) + Ok(()) } pub fn mark_dirty(&self, index: usize) -> Result<()> { let mut meta = self.meta.lock().unwrap(); meta.resident.set(index, true); meta.dirty.set(index, true); - meta.persist(&self.meta_path, true) + Ok(()) } pub fn clear_dirty_all(&self) -> Result<()> { diff --git a/src/export.rs b/src/export.rs index 1ae1b79..3de29e1 100644 --- a/src/export.rs +++ b/src/export.rs @@ -1774,6 +1774,7 @@ mod tests { .await .unwrap(); export.write(0, b"abcd", false).await.unwrap(); + export.flush().await.unwrap(); let journal_path = dir.path().join("cache").join("snapshot.journal.json"); let staging_path = dir.path().join("cache").join("pending.delta"); diff --git a/src/nbd.rs b/src/nbd.rs index 3ff9f51..e184d25 100644 --- a/src/nbd.rs +++ b/src/nbd.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::{Semaphore, mpsc}; use crate::error::{Error, Result}; use crate::export::Export; @@ -46,6 +47,7 @@ const NBD_CMD_FLAG_FUA: u16 = 1 << 0; const MIN_BLOCK_SIZE: u32 = 1; const PREFERRED_BLOCK_SIZE: u32 = 4096; const MAX_BLOCK_SIZE: u32 = 32 * 1024 * 1024; +const MAX_INFLIGHT_REQUESTS: usize = 128; pub async fn serve_nbd(addr: std::net::SocketAddr, export: Arc) -> Result<()> { let listener = TcpListener::bind(addr).await?; @@ -81,7 +83,7 @@ async fn handle_client(mut stream: TcpStream, export: Arc) -> Result<()> if !negotiate_options(&mut stream, export.clone(), client_flags).await? { return Ok(()); } - transmission_phase(&mut stream, export).await + transmission_phase(stream, export).await } async fn handle_client_manager(mut stream: TcpStream, manager: Arc) -> Result<()> { @@ -89,7 +91,7 @@ async fn handle_client_manager(mut stream: TcpStream, manager: Arc Result { @@ -226,12 +228,23 @@ async fn negotiate_options_manager( } } -async fn transmission_phase(stream: &mut TcpStream, export: Arc) -> Result<()> { +async fn transmission_phase(stream: TcpStream, export: Arc) -> Result<()> { + let (mut reader, mut writer) = stream.into_split(); + let (reply_tx, mut reply_rx) = mpsc::channel::>(MAX_INFLIGHT_REQUESTS); + let permits = Arc::new(Semaphore::new(MAX_INFLIGHT_REQUESTS)); + + let writer_task = tokio::spawn(async move { + while let Some(reply) = reply_rx.recv().await { + writer.write_all(&reply).await?; + } + Ok::<(), Error>(()) + }); + loop { let mut header = [0_u8; 28]; - if let Err(error) = stream.read_exact(&mut header).await { + if let Err(error) = reader.read_exact(&mut header).await { if error.kind() == std::io::ErrorKind::UnexpectedEof { - return Ok(()); + break; } return Err(Error::Io(error)); } @@ -251,28 +264,81 @@ async fn transmission_phase(stream: &mut TcpStream, export: Arc) -> Resu match command { NBD_CMD_READ => { - let data = export.read(offset, len).await?; - send_reply(stream, handle, 0, Some(&data)).await?; + let permit = + permits.clone().acquire_owned().await.map_err(|_| { + Error::InvalidRequest("request semaphore closed".to_string()) + })?; + let export = export.clone(); + let reply_tx = reply_tx.clone(); + tokio::spawn(async move { + let _permit = permit; + let reply = match export.read(offset, len).await { + Ok(data) => encode_reply(handle, 0, Some(&data)), + Err(error) => { + tracing::warn!(offset, len, handle, "read request failed: {error}"); + encode_reply(handle, libc::EIO as u32, None) + } + }; + let _ = reply_tx.send(reply).await; + }); } NBD_CMD_WRITE => { let mut payload = vec![0_u8; len as usize]; - stream.read_exact(&mut payload).await?; - export - .write(offset, &payload, flags & NBD_CMD_FLAG_FUA != 0) - .await?; - send_reply(stream, handle, 0, None).await?; + reader.read_exact(&mut payload).await?; + let permit = + permits.clone().acquire_owned().await.map_err(|_| { + Error::InvalidRequest("request semaphore closed".to_string()) + })?; + let export = export.clone(); + let reply_tx = reply_tx.clone(); + let fua = flags & NBD_CMD_FLAG_FUA != 0; + tokio::spawn(async move { + let _permit = permit; + let reply = match export.write(offset, &payload, fua).await { + Ok(()) => encode_reply(handle, 0, None), + Err(error) => { + tracing::warn!(offset, len, handle, "write request failed: {error}"); + encode_reply(handle, libc::EIO as u32, None) + } + }; + let _ = reply_tx.send(reply).await; + }); } NBD_CMD_FLUSH => { - export.flush().await?; - send_reply(stream, handle, 0, None).await?; + let permit = + permits.clone().acquire_owned().await.map_err(|_| { + Error::InvalidRequest("request semaphore closed".to_string()) + })?; + let export = export.clone(); + let reply_tx = reply_tx.clone(); + tokio::spawn(async move { + let _permit = permit; + let reply = match export.flush().await { + Ok(()) => encode_reply(handle, 0, None), + Err(error) => { + tracing::warn!(handle, "flush request failed: {error}"); + encode_reply(handle, libc::EIO as u32, None) + } + }; + let _ = reply_tx.send(reply).await; + }); } - NBD_CMD_DISC => return Ok(()), + NBD_CMD_DISC => break, other => { - send_reply(stream, handle, libc::EINVAL as u32, None).await?; - return Err(Error::UnsupportedCommand(other)); + reply_tx + .send(encode_reply(handle, libc::EINVAL as u32, None)) + .await + .map_err(|_| Error::InvalidRequest("reply channel closed".to_string()))?; + tracing::warn!("unsupported NBD command {other}"); } } } + + drop(reply_tx); + writer_task + .await + .map_err(|error| Error::Io(std::io::Error::other(error.to_string())))??; + Ok(()) } fn validate_export_name(payload: &[u8], expected: &str) -> Result<()> { @@ -344,19 +410,16 @@ async fn send_option_reply( Ok(()) } -async fn send_reply( - stream: &mut TcpStream, - handle: u64, - error: u32, - payload: Option<&[u8]>, -) -> Result<()> { - stream.write_all(&NBD_REPLY_MAGIC.to_be_bytes()).await?; - stream.write_all(&error.to_be_bytes()).await?; - stream.write_all(&handle.to_be_bytes()).await?; +fn encode_reply(handle: u64, error: u32, payload: Option<&[u8]>) -> Vec { + let payload_len = payload.map_or(0, <[u8]>::len); + let mut reply = Vec::with_capacity(16 + payload_len); + reply.extend_from_slice(&NBD_REPLY_MAGIC.to_be_bytes()); + reply.extend_from_slice(&error.to_be_bytes()); + reply.extend_from_slice(&handle.to_be_bytes()); if let Some(payload) = payload { - stream.write_all(payload).await?; + reply.extend_from_slice(payload); } - Ok(()) + reply } fn export_flags() -> u16 { @@ -439,7 +502,7 @@ mod tests { use super::{ InfoRequest, NBD_INFO_BLOCK_SIZE, NBD_INFO_DESCRIPTION, NBD_INFO_EXPORT, NBD_INFO_NAME, - parse_info_request, requested_infos, + NBD_REPLY_MAGIC, encode_reply, parse_info_request, requested_infos, }; #[test] @@ -469,4 +532,14 @@ mod tests { ] ); } + + #[test] + fn encode_reply_formats_header_and_payload() { + let payload = [1_u8, 2, 3, 4]; + let reply = encode_reply(7, 5, Some(&payload)); + assert_eq!(&reply[0..4], &NBD_REPLY_MAGIC.to_be_bytes()); + assert_eq!(&reply[4..8], &5_u32.to_be_bytes()); + assert_eq!(&reply[8..16], &7_u64.to_be_bytes()); + assert_eq!(&reply[16..], &payload); + } } diff --git a/tests/intergration/run_test.sh b/tests/intergration/run_test.sh index e6b2b15..c094f7a 100755 --- a/tests/intergration/run_test.sh +++ b/tests/intergration/run_test.sh @@ -167,6 +167,8 @@ run_fio_job() { } modprobe nbd max_part=16 +# cleanup any leftover state from previous runs to ensure a clean slate for the integration test +cleanup start_server echo "Creating ${NBD_SERVER_VM01}"