Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ impl LocalCache {
pub fn mark_resident(&self, index: usize) -> Result<()> {
let mut meta = self.meta.lock().unwrap();
meta.resident.set(index, true);
meta.persist(&self.meta_path, false)
Ok(())
}

pub fn mark_resident_many(&self, indices: &[usize]) -> Result<()> {
Expand All @@ -193,14 +193,14 @@ impl LocalCache {
for index in indices {
meta.resident.set(*index, true);
}
meta.persist(&self.meta_path, false)
Ok(())
}

pub fn mark_dirty(&self, index: usize) -> Result<()> {
let mut meta = self.meta.lock().unwrap();
meta.resident.set(index, true);
meta.dirty.set(index, true);
meta.persist(&self.meta_path, true)
Ok(())
}

pub fn clear_dirty_all(&self) -> Result<()> {
Expand Down
1 change: 1 addition & 0 deletions src/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1774,6 +1774,7 @@ mod tests {
.await
.unwrap();
export.write(0, b"abcd", false).await.unwrap();
export.flush().await.unwrap();

let journal_path = dir.path().join("cache").join("snapshot.journal.json");
let staging_path = dir.path().join("cache").join("pending.delta");
Expand Down
131 changes: 102 additions & 29 deletions src/nbd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::sync::Arc;

use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{Semaphore, mpsc};

use crate::error::{Error, Result};
use crate::export::Export;
Expand Down Expand Up @@ -46,6 +47,7 @@ const NBD_CMD_FLAG_FUA: u16 = 1 << 0;
const MIN_BLOCK_SIZE: u32 = 1;
const PREFERRED_BLOCK_SIZE: u32 = 4096;
const MAX_BLOCK_SIZE: u32 = 32 * 1024 * 1024;
const MAX_INFLIGHT_REQUESTS: usize = 128;

pub async fn serve_nbd(addr: std::net::SocketAddr, export: Arc<Export>) -> Result<()> {
let listener = TcpListener::bind(addr).await?;
Expand Down Expand Up @@ -81,15 +83,15 @@ async fn handle_client(mut stream: TcpStream, export: Arc<Export>) -> Result<()>
if !negotiate_options(&mut stream, export.clone(), client_flags).await? {
return Ok(());
}
transmission_phase(&mut stream, export).await
transmission_phase(stream, export).await
}

async fn handle_client_manager(mut stream: TcpStream, manager: Arc<ExportManager>) -> Result<()> {
let client_flags = send_handshake(&mut stream).await?;
let Some(export) = negotiate_options_manager(&mut stream, manager, client_flags).await? else {
return Ok(());
};
transmission_phase(&mut stream, export).await
transmission_phase(stream, export).await
}

async fn send_handshake(stream: &mut TcpStream) -> Result<u32> {
Expand Down Expand Up @@ -226,12 +228,23 @@ async fn negotiate_options_manager(
}
}

async fn transmission_phase(stream: &mut TcpStream, export: Arc<Export>) -> Result<()> {
async fn transmission_phase(stream: TcpStream, export: Arc<Export>) -> Result<()> {
let (mut reader, mut writer) = stream.into_split();
let (reply_tx, mut reply_rx) = mpsc::channel::<Vec<u8>>(MAX_INFLIGHT_REQUESTS);
let permits = Arc::new(Semaphore::new(MAX_INFLIGHT_REQUESTS));

let writer_task = tokio::spawn(async move {
while let Some(reply) = reply_rx.recv().await {
writer.write_all(&reply).await?;
}
Ok::<(), Error>(())
});

loop {
let mut header = [0_u8; 28];
if let Err(error) = stream.read_exact(&mut header).await {
if let Err(error) = reader.read_exact(&mut header).await {
if error.kind() == std::io::ErrorKind::UnexpectedEof {
return Ok(());
break;
}
return Err(Error::Io(error));
}
Expand All @@ -251,28 +264,81 @@ async fn transmission_phase(stream: &mut TcpStream, export: Arc<Export>) -> Resu

match command {
NBD_CMD_READ => {
let data = export.read(offset, len).await?;
send_reply(stream, handle, 0, Some(&data)).await?;
let permit =
permits.clone().acquire_owned().await.map_err(|_| {
Error::InvalidRequest("request semaphore closed".to_string())
})?;
let export = export.clone();
let reply_tx = reply_tx.clone();
tokio::spawn(async move {
let _permit = permit;
let reply = match export.read(offset, len).await {
Ok(data) => encode_reply(handle, 0, Some(&data)),
Err(error) => {
tracing::warn!(offset, len, handle, "read request failed: {error}");
encode_reply(handle, libc::EIO as u32, None)
}
};
let _ = reply_tx.send(reply).await;
});
}
NBD_CMD_WRITE => {
let mut payload = vec![0_u8; len as usize];
stream.read_exact(&mut payload).await?;
export
.write(offset, &payload, flags & NBD_CMD_FLAG_FUA != 0)
.await?;
send_reply(stream, handle, 0, None).await?;
reader.read_exact(&mut payload).await?;
let permit =
permits.clone().acquire_owned().await.map_err(|_| {
Error::InvalidRequest("request semaphore closed".to_string())
})?;
let export = export.clone();
let reply_tx = reply_tx.clone();
let fua = flags & NBD_CMD_FLAG_FUA != 0;
tokio::spawn(async move {
let _permit = permit;
let reply = match export.write(offset, &payload, fua).await {
Ok(()) => encode_reply(handle, 0, None),
Err(error) => {
tracing::warn!(offset, len, handle, "write request failed: {error}");
encode_reply(handle, libc::EIO as u32, None)
}
};
let _ = reply_tx.send(reply).await;
});
}
NBD_CMD_FLUSH => {
export.flush().await?;
send_reply(stream, handle, 0, None).await?;
let permit =
permits.clone().acquire_owned().await.map_err(|_| {
Error::InvalidRequest("request semaphore closed".to_string())
})?;
let export = export.clone();
let reply_tx = reply_tx.clone();
tokio::spawn(async move {
let _permit = permit;
let reply = match export.flush().await {
Ok(()) => encode_reply(handle, 0, None),
Err(error) => {
tracing::warn!(handle, "flush request failed: {error}");
encode_reply(handle, libc::EIO as u32, None)
}
};
let _ = reply_tx.send(reply).await;
});
}
NBD_CMD_DISC => return Ok(()),
NBD_CMD_DISC => break,
other => {
send_reply(stream, handle, libc::EINVAL as u32, None).await?;
return Err(Error::UnsupportedCommand(other));
reply_tx
.send(encode_reply(handle, libc::EINVAL as u32, None))
.await
.map_err(|_| Error::InvalidRequest("reply channel closed".to_string()))?;
tracing::warn!("unsupported NBD command {other}");
}
}
}

drop(reply_tx);
writer_task
.await
.map_err(|error| Error::Io(std::io::Error::other(error.to_string())))??;
Ok(())
}

fn validate_export_name(payload: &[u8], expected: &str) -> Result<()> {
Expand Down Expand Up @@ -344,19 +410,16 @@ async fn send_option_reply(
Ok(())
}

async fn send_reply(
stream: &mut TcpStream,
handle: u64,
error: u32,
payload: Option<&[u8]>,
) -> Result<()> {
stream.write_all(&NBD_REPLY_MAGIC.to_be_bytes()).await?;
stream.write_all(&error.to_be_bytes()).await?;
stream.write_all(&handle.to_be_bytes()).await?;
fn encode_reply(handle: u64, error: u32, payload: Option<&[u8]>) -> Vec<u8> {
let payload_len = payload.map_or(0, <[u8]>::len);
let mut reply = Vec::with_capacity(16 + payload_len);
reply.extend_from_slice(&NBD_REPLY_MAGIC.to_be_bytes());
reply.extend_from_slice(&error.to_be_bytes());
reply.extend_from_slice(&handle.to_be_bytes());
if let Some(payload) = payload {
stream.write_all(payload).await?;
reply.extend_from_slice(payload);
}
Ok(())
reply
}

fn export_flags() -> u16 {
Expand Down Expand Up @@ -439,7 +502,7 @@ mod tests {

use super::{
InfoRequest, NBD_INFO_BLOCK_SIZE, NBD_INFO_DESCRIPTION, NBD_INFO_EXPORT, NBD_INFO_NAME,
parse_info_request, requested_infos,
NBD_REPLY_MAGIC, encode_reply, parse_info_request, requested_infos,
};

#[test]
Expand Down Expand Up @@ -469,4 +532,14 @@ mod tests {
]
);
}

#[test]
fn encode_reply_formats_header_and_payload() {
let payload = [1_u8, 2, 3, 4];
let reply = encode_reply(7, 5, Some(&payload));
assert_eq!(&reply[0..4], &NBD_REPLY_MAGIC.to_be_bytes());
assert_eq!(&reply[4..8], &5_u32.to_be_bytes());
assert_eq!(&reply[8..16], &7_u64.to_be_bytes());
assert_eq!(&reply[16..], &payload);
}
}
2 changes: 2 additions & 0 deletions tests/intergration/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ run_fio_job() {
}

modprobe nbd max_part=16
# cleanup any leftover state from previous runs to ensure a clean slate for the integration test
cleanup
start_server

echo "Creating ${NBD_SERVER_VM01}"
Expand Down
Loading