diff --git a/vm/devices/net/mana_driver/src/gdma_driver.rs b/vm/devices/net/mana_driver/src/gdma_driver.rs index 4038454766..20e64d6e53 100644 --- a/vm/devices/net/mana_driver/src/gdma_driver.rs +++ b/vm/devices/net/mana_driver/src/gdma_driver.rs @@ -9,6 +9,10 @@ use crate::queues::Eq; use crate::queues::Wq; use crate::resources::Resource; use crate::resources::ResourceArena; +use crate::save_restore::DoorbellSavedState; +use crate::save_restore::GdmaDriverSavedState; +use crate::save_restore::InterruptSavedState; +use crate::save_restore::SavedMemoryState; use anyhow::Context; use futures::FutureExt; use gdma_defs::Cqe; @@ -118,6 +122,13 @@ impl Doorbell for Bar0 { safe_intrinsics::store_fence(); self.mem.write_u64(offset as usize, value); } + + fn save(&self, doorbell_id: Option) -> DoorbellSavedState { + DoorbellSavedState { + doorbell_id: doorbell_id.unwrap(), + page_count: self.page_count(), + } + } } #[derive(Inspect)] @@ -148,6 +159,8 @@ pub struct GdmaDriver { hwc_warning_time_in_ms: u32, hwc_timeout_in_ms: u32, hwc_failure: bool, + db_id: u32, + saving: bool, } const EQ_PAGE: usize = 0; @@ -163,9 +176,17 @@ const RWQE_SIZE: u32 = 32; impl Drop for GdmaDriver { fn drop(&mut self) { + tracing::debug!(?self.saving, ?self.hwc_failure, "dropping gdma driver"); + + // Don't destroy anything if we're saving its state for restoration. + if self.saving { + return; + } + if self.hwc_failure { return; } + let data = self .bar0 .mem @@ -230,7 +251,12 @@ impl GdmaDriver { self.bar0.clone() as _ } - pub async fn new(driver: &impl Driver, mut device: T, num_vps: u32) -> anyhow::Result { + pub async fn new( + driver: &impl Driver, + mut device: T, + num_vps: u32, + dma_buffer: Option, + ) -> anyhow::Result { let bar0_mapping = device.map_bar(0)?; let bar0_len = bar0_mapping.len(); if bar0_len < size_of::() { @@ -280,11 +306,14 @@ impl GdmaDriver { ); } - let dma_client = device.dma_client(); - - let dma_buffer = dma_client - .allocate_dma_buffer(NUM_PAGES * PAGE_SIZE) - .context("failed to allocate DMA buffer")?; + let dma_buffer = if let Some(dma_buffer) = dma_buffer { + dma_buffer + } else { + let dma_client = device.dma_client(); + dma_client + .allocate_dma_buffer(NUM_PAGES * PAGE_SIZE) + .context("failed to allocate DMA buffer")? + }; let pages = dma_buffer.pfns(); @@ -475,6 +504,8 @@ impl GdmaDriver { hwc_warning_time_in_ms: HWC_WARNING_TIME_IN_MS, hwc_timeout_in_ms: HWC_TIMEOUT_DEFAULT_IN_MS, hwc_failure: false, + saving: false, + db_id, }; this.push_rqe(); @@ -499,6 +530,182 @@ impl GdmaDriver { Ok(this) } + #[allow(dead_code)] + pub async fn save(mut self) -> anyhow::Result { + self.saving = true; + + let doorbell = self.bar0.save(Some(self.db_id as u64)); + + let mut interrupt_config = Vec::new(); + for (index, interrupt) in self.interrupts.iter().enumerate() { + if interrupt.is_some() { + interrupt_config.push(InterruptSavedState { + msix_index: index as u32, + cpu: index as u32, + }); + } + } + + Ok(GdmaDriverSavedState { + mem: SavedMemoryState { + base_pfn: self.dma_buffer.pfns()[0], + len: self.dma_buffer.len(), + }, + eq: self.eq.save(), + cq: self.cq.save(), + rq: self.rq.save(), + sq: self.sq.save(), + db_id: doorbell.doorbell_id, + gpa_mkey: self.gpa_mkey, + pdid: self._pdid, + cq_armed: self.cq_armed, + eq_armed: self.eq_armed, + hwc_subscribed: self.hwc_subscribed, + eq_id_msix: self.eq_id_msix.clone(), + hwc_activity_id: self.hwc_activity_id, + num_msix: self.num_msix, + min_queue_avail: self.min_queue_avail, + link_toggle: self.link_toggle.clone(), + interrupt_config, + }) + } + + #[allow(dead_code)] + pub async fn restore( + saved_state: GdmaDriverSavedState, + mut device: T, + dma_buffer: MemoryBlock, + ) -> anyhow::Result { + tracing::info!("restoring gdma driver"); + + let bar0_mapping = device.map_bar(0)?; + let bar0_len = bar0_mapping.len(); + if bar0_len < size_of::() { + anyhow::bail!("bar0 ({} bytes) too small for reg map", bar0_mapping.len()); + } + + let mut map = RegMap::new_zeroed(); + for i in 0..size_of_val(&map) / 4 { + let v = bar0_mapping.read_u32(i * 4); + // Unmapped device memory will return -1 on reads, so check the first 32 + // bits for this condition to get a clear error message early. + if i == 0 && v == !0 { + anyhow::bail!("bar0 read returned -1, device is not present"); + } + map.as_mut_bytes()[i * 4..(i + 1) * 4].copy_from_slice(&v.to_ne_bytes()); + } + + tracing::debug!(?map, "register map on restore"); + + // Log on unknown major version numbers. This is not necessarily an + // error, so continue. + if map.major_version_number != 0 && map.major_version_number != 1 { + tracing::warn!( + major = map.major_version_number, + minor = map.minor_version_number, + micro = map.micro_version_number, + "unrecognized major version" + ); + } + + if map.vf_gdma_sriov_shared_sz != 32 { + anyhow::bail!( + "unexpected shared memory size: {}", + map.vf_gdma_sriov_shared_sz + ); + } + + if (bar0_len as u64).saturating_sub(map.vf_gdma_sriov_shared_reg_start) + < map.vf_gdma_sriov_shared_sz as u64 + { + anyhow::bail!( + "bar0 ({} bytes) too small for shared memory at {}", + bar0_mapping.len(), + map.vf_gdma_sriov_shared_reg_start + ); + } + + let doorbell_shift = map.vf_db_page_sz.trailing_zeros(); + let bar0 = Arc::new(Bar0 { + mem: bar0_mapping, + map, + doorbell_shift, + }); + + let eq = Eq::restore( + dma_buffer.subblock(0, PAGE_SIZE), + saved_state.eq, + DoorbellPage::new(bar0.clone(), saved_state.db_id as u32)?, + )?; + + let db_id = saved_state.db_id; + let cq = Cq::restore( + dma_buffer.subblock(CQ_PAGE * PAGE_SIZE, PAGE_SIZE), + saved_state.cq, + DoorbellPage::new(bar0.clone(), saved_state.db_id as u32)?, + )?; + + let rq = Wq::restore_rq( + dma_buffer.subblock(RQ_PAGE * PAGE_SIZE, PAGE_SIZE), + saved_state.rq, + DoorbellPage::new(bar0.clone(), saved_state.db_id as u32)?, + )?; + + let sq = Wq::restore_sq( + dma_buffer.subblock(SQ_PAGE * PAGE_SIZE, PAGE_SIZE), + saved_state.sq, + DoorbellPage::new(bar0.clone(), saved_state.db_id as u32)?, + )?; + + let mut interrupts = vec![None; saved_state.num_msix as usize]; + for int_state in &saved_state.interrupt_config { + let interrupt = device.map_interrupt(int_state.msix_index, int_state.cpu)?; + + interrupts[int_state.msix_index as usize] = Some(interrupt); + } + + let mut this = Self { + device: Some(device), + bar0, + dma_buffer, + interrupts, + eq, + cq, + rq, + sq, + test_events: 0, + eq_armed: saved_state.eq_armed, + cq_armed: saved_state.cq_armed, + gpa_mkey: saved_state.gpa_mkey, + _pdid: saved_state.pdid, + eq_id_msix: saved_state.eq_id_msix, + num_msix: saved_state.num_msix, + min_queue_avail: saved_state.min_queue_avail, + hwc_activity_id: saved_state.hwc_activity_id, + link_toggle: saved_state.link_toggle, + hwc_subscribed: saved_state.hwc_subscribed, + hwc_warning_time_in_ms: HWC_WARNING_TIME_IN_MS, + hwc_timeout_in_ms: HWC_TIMEOUT_DEFAULT_IN_MS, + hwc_failure: false, + saving: false, + db_id: db_id as u32, + }; + + if saved_state.hwc_subscribed { + this.hwc_subscribe(); + } + + if saved_state.eq_armed { + this.eq.arm(); + } + + if saved_state.cq_armed { + this.cq.arm(); + } + + Ok(this) + } + async fn report_hwc_timeout( &mut self, last_cmd_failed: bool, diff --git a/vm/devices/net/mana_driver/src/lib.rs b/vm/devices/net/mana_driver/src/lib.rs index a510b4b65e..2572de3a1a 100644 --- a/vm/devices/net/mana_driver/src/lib.rs +++ b/vm/devices/net/mana_driver/src/lib.rs @@ -10,5 +10,6 @@ mod gdma_driver; pub mod mana; pub mod queues; mod resources; +pub mod save_restore; #[cfg(test)] mod tests; diff --git a/vm/devices/net/mana_driver/src/mana.rs b/vm/devices/net/mana_driver/src/mana.rs index 8bfa10cd12..3af920e270 100644 --- a/vm/devices/net/mana_driver/src/mana.rs +++ b/vm/devices/net/mana_driver/src/mana.rs @@ -77,7 +77,7 @@ impl ManaDevice { num_vps: u32, max_queues_per_vport: u16, ) -> anyhow::Result { - let mut gdma = GdmaDriver::new(driver, device, num_vps).await?; + let mut gdma = GdmaDriver::new(driver, device, num_vps, None).await?; gdma.test_eq().await?; gdma.verify_vf_driver_version().await?; diff --git a/vm/devices/net/mana_driver/src/queues.rs b/vm/devices/net/mana_driver/src/queues.rs index f0c696ab32..272ba16915 100644 --- a/vm/devices/net/mana_driver/src/queues.rs +++ b/vm/devices/net/mana_driver/src/queues.rs @@ -3,6 +3,9 @@ //! Types to access work, completion, and event queues. +use crate::save_restore::CqEqSavedState; +use crate::save_restore::DoorbellSavedState; +use crate::save_restore::WqSavedState; use gdma_defs::CLIENT_OOB_8; use gdma_defs::CLIENT_OOB_24; use gdma_defs::CLIENT_OOB_32; @@ -37,6 +40,8 @@ pub trait Doorbell: Send + Sync { fn page_count(&self) -> u32; /// Write a doorbell value at page `page`, offset `address`. fn write(&self, page: u32, address: u32, value: u64); + /// Save the doorbell state. + fn save(&self, doorbell_id: Option) -> DoorbellSavedState; } struct NullDoorbell; @@ -47,6 +52,13 @@ impl Doorbell for NullDoorbell { } fn write(&self, _page: u32, _address: u32, _value: u64) {} + + fn save(&self, _doorbell_id: Option) -> DoorbellSavedState { + DoorbellSavedState { + doorbell_id: 0, + page_count: 0, + } + } } /// A single GDMA doorbell page. @@ -114,6 +126,25 @@ impl CqEq { pub fn new_cq(mem: MemoryBlock, doorbell: DoorbellPage, id: u32) -> Self { Self::new(GdmaQueueType::GDMA_CQ, DB_CQ, mem, doorbell, id) } + + /// Restores an existing completion queue. + pub fn restore( + mem: MemoryBlock, + state: CqEqSavedState, + doorbell: DoorbellPage, + ) -> anyhow::Result { + Ok(Self { + doorbell, + doorbell_addr: state.doorbell_addr, + queue_type: GdmaQueueType::GDMA_CQ, + mem, + id: state.id, + next: state.next, + size: state.size, + shift: state.shift, + _phantom: PhantomData, + }) + } } impl CqEq { @@ -121,6 +152,25 @@ impl CqEq { pub fn new_eq(mem: MemoryBlock, doorbell: DoorbellPage, id: u32) -> Self { Self::new(GdmaQueueType::GDMA_EQ, DB_EQ, mem, doorbell, id) } + + /// Restores an existing event queue. + pub fn restore( + mem: MemoryBlock, + state: CqEqSavedState, + doorbell: DoorbellPage, + ) -> anyhow::Result { + Ok(Self { + doorbell, + doorbell_addr: state.doorbell_addr, + queue_type: GdmaQueueType::GDMA_EQ, + mem, + id: state.id, + next: state.next, + size: state.size, + shift: state.shift, + _phantom: PhantomData, + }) + } } impl CqEq { @@ -147,6 +197,21 @@ impl CqEq { } } + /// Save the state of the queue for restoration after servicing. + pub fn save(&self) -> CqEqSavedState { + CqEqSavedState { + doorbell: DoorbellSavedState { + doorbell_id: self.doorbell.doorbell_id as u64, + page_count: self.doorbell.doorbell.page_count(), + }, + doorbell_addr: self.doorbell_addr, + id: self.id, + next: self.next, + size: self.size, + shift: self.shift, + } + } + /// Updates the queue ID. pub(crate) fn set_id(&mut self, id: u32) { self.id = id; @@ -284,6 +349,59 @@ impl Wq { } } + /// Save the state of the Wq for restoration after servicing + pub fn save(&self) -> WqSavedState { + WqSavedState { + doorbell: DoorbellSavedState { + doorbell_id: self.doorbell.doorbell_id as u64, + page_count: self.doorbell.doorbell.page_count(), + }, + doorbell_addr: self.doorbell_addr, + id: self.id, + head: self.head, + tail: self.tail, + mask: self.mask, + } + } + + /// Restores an existing receive work queue. + pub fn restore_rq( + mem: MemoryBlock, + state: WqSavedState, + doorbell: DoorbellPage, + ) -> anyhow::Result { + Ok(Self { + doorbell, + doorbell_addr: state.doorbell_addr, + queue_type: GdmaQueueType::GDMA_RQ, + mem, + id: state.id, + head: state.head, + tail: state.tail, + mask: state.mask, + uncommitted_count: 0, + }) + } + + /// Restores an existing send work queue. + pub fn restore_sq( + mem: MemoryBlock, + state: WqSavedState, + doorbell: DoorbellPage, + ) -> anyhow::Result { + Ok(Self { + doorbell, + doorbell_addr: state.doorbell_addr, + queue_type: GdmaQueueType::GDMA_SQ, + mem, + id: state.id, + head: state.head, + tail: state.tail, + mask: state.mask, + uncommitted_count: 0, + }) + } + /// Returns the queue ID. pub fn id(&self) -> u32 { self.id diff --git a/vm/devices/net/mana_driver/src/save_restore.rs b/vm/devices/net/mana_driver/src/save_restore.rs new file mode 100644 index 0000000000..2246706fc8 --- /dev/null +++ b/vm/devices/net/mana_driver/src/save_restore.rs @@ -0,0 +1,179 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Types to save and restore the state of a MANA device. + +use mesh::payload::Protobuf; +use std::collections::HashMap; + +/// Top level saved state for the GDMA driver's saved state +#[derive(Protobuf, Clone, Debug)] +#[mesh(package = "mana_driver")] +pub struct GdmaDriverSavedState { + /// Memory to be restored by a DMA client + #[mesh(1)] + pub mem: SavedMemoryState, + + /// EQ to be restored + #[mesh(2)] + pub eq: CqEqSavedState, + + /// CQ to be restored + #[mesh(3)] + pub cq: CqEqSavedState, + + /// RQ to be restored + #[mesh(4)] + pub rq: WqSavedState, + + /// SQ to be restored + #[mesh(5)] + pub sq: WqSavedState, + + /// Doorbell id + #[mesh(6)] + pub db_id: u64, + + /// Guest physical address memory key + #[mesh(7)] + pub gpa_mkey: u32, + + /// Protection domain id + #[mesh(8)] + pub pdid: u32, + + /// Whether the driver is subscribed to hwc + #[mesh(9)] + pub hwc_subscribed: bool, + + /// Whether the eq is armed or not + #[mesh(10)] + pub eq_armed: bool, + + /// Whether the cq is armed or not + #[mesh(11)] + pub cq_armed: bool, + + /// Event queue id to msix mapping + #[mesh(12)] + pub eq_id_msix: HashMap, + + /// The id of the hwc activity + #[mesh(13)] + pub hwc_activity_id: u32, + + /// How many msix vectors are available + #[mesh(14)] + pub num_msix: u32, + + /// Minimum number of queues available + #[mesh(15)] + pub min_queue_avail: u32, + + /// Saved interrupts for restoration + #[mesh(16)] + pub interrupt_config: Vec, + + /// Link status by vport index + #[mesh(17)] + pub link_toggle: Vec<(u32, bool)>, +} + +/// Saved state of an interrupt for restoration during servicing +#[derive(Protobuf, Clone, Debug)] +#[mesh(package = "mana_driver")] +pub struct InterruptSavedState { + /// The index in the msix table for this interrupt + #[mesh(1)] + pub msix_index: u32, + + /// Which CPU this interrupt is assigned to + #[mesh(2)] + pub cpu: u32, +} + +/// The saved state of a completion queue or event queue for restoration +/// during servicing +#[derive(Clone, Protobuf, Debug)] +#[mesh(package = "mana_driver")] +pub struct CqEqSavedState { + /// The doorbell state of the queue, which is how the device is notified + #[mesh(1)] + pub doorbell: DoorbellSavedState, + + /// The address of the doorbell register + #[mesh(2)] + pub doorbell_addr: u32, + + /// The id of the queue + #[mesh(3)] + pub id: u32, + + /// The index of the next entry in the queue + #[mesh(4)] + pub next: u32, + + /// The total size of the queue + #[mesh(5)] + pub size: u32, + + /// The bit shift value for the queue + #[mesh(6)] + pub shift: u32, +} + +/// Saved state of a doorbell for restoration during servicing +#[derive(Clone, Protobuf, Debug)] +#[mesh(package = "mana_driver")] +pub struct DoorbellSavedState { + /// The doorbell's id + #[mesh(1)] + pub doorbell_id: u64, + + /// The number of pages allocated for the doorbell + #[mesh(2)] + pub page_count: u32, +} + +/// Saved state of a work queue for restoration during servicing +#[derive(Debug, Protobuf, Clone)] +#[mesh(package = "mana_driver")] +pub struct WqSavedState { + /// The doorbell state of the queue, which is how the device is notified + #[mesh(1)] + pub doorbell: DoorbellSavedState, + + /// The address of the doorbell + #[mesh(2)] + pub doorbell_addr: u32, + + /// The id of the queue + #[mesh(3)] + pub id: u32, + + /// The head of the queue + #[mesh(4)] + pub head: u32, + + /// The tail of the queue + #[mesh(5)] + pub tail: u32, + + /// The bitmask for wrapping queue indices + #[mesh(6)] + pub mask: u32, +} + +/// Saved state for a memory region used by the driver +/// to be restored by a DMA client during servicing +#[derive(Debug, Protobuf, Clone)] +#[mesh(package = "mana_driver")] +pub struct SavedMemoryState { + /// The base page frame number of the memory region + #[mesh(1)] + pub base_pfn: u64, + + /// How long the memory region is + #[mesh(2)] + pub len: usize, +} diff --git a/vm/devices/net/mana_driver/src/tests.rs b/vm/devices/net/mana_driver/src/tests.rs index 77b2d38c84..019cd9ffa0 100644 --- a/vm/devices/net/mana_driver/src/tests.rs +++ b/vm/devices/net/mana_driver/src/tests.rs @@ -42,8 +42,12 @@ async fn test_gdma(driver: DefaultDriver) { ); let dma_client = mem.dma_client(); let device = EmulatedDevice::new(device, msi_set, dma_client); + let dma_client = device.dma_client(); + let buffer = dma_client.allocate_dma_buffer(6 * PAGE_SIZE).unwrap(); - let mut gdma = GdmaDriver::new(&driver, device, 1).await.unwrap(); + let mut gdma = GdmaDriver::new(&driver, device, 1, Some(buffer)) + .await + .unwrap(); gdma.test_eq().await.unwrap(); gdma.verify_vf_driver_version().await.unwrap(); let dev_id = gdma @@ -159,3 +163,43 @@ async fn test_gdma(driver: DefaultDriver) { .unwrap(); arena.destroy(&mut gdma).await; } + +#[async_test] +async fn test_gdma_save_restore(driver: DefaultDriver) { + let mem = DeviceTestMemory::new(128, false, "test_gdma"); + let mut msi_set = MsiInterruptSet::new(); + let device = gdma::GdmaDevice::new( + &VmTaskDriverSource::new(SingleDriverBackend::new(driver.clone())), + mem.guest_memory(), + &mut msi_set, + vec![VportConfig { + mac_address: [1, 2, 3, 4, 5, 6].into(), + endpoint: Box::new(NullEndpoint::new()), + }], + &mut ExternallyManagedMmioIntercepts, + ); + let dma_client = mem.dma_client(); + + let device = EmulatedDevice::new(device, msi_set, dma_client); + let cloned_device = device.clone(); + + let dma_client = device.dma_client(); + let gdma_buffer = dma_client.allocate_dma_buffer(6 * PAGE_SIZE).unwrap(); + + let saved_state = { + let mut gdma = GdmaDriver::new(&driver, device, 1, Some(gdma_buffer.clone())) + .await + .unwrap(); + + gdma.test_eq().await.unwrap(); + gdma.verify_vf_driver_version().await.unwrap(); + gdma.save().await.unwrap() + }; + + let mut new_gdma = GdmaDriver::restore(saved_state, cloned_device, gdma_buffer) + .await + .unwrap(); + + // Validate that the new driver still works after restoration. + new_gdma.test_eq().await.unwrap(); +} diff --git a/vm/devices/user_driver_emulated_mock/src/lib.rs b/vm/devices/user_driver_emulated_mock/src/lib.rs index fb68810cd5..0f04b9f2b5 100644 --- a/vm/devices/user_driver_emulated_mock/src/lib.rs +++ b/vm/devices/user_driver_emulated_mock/src/lib.rs @@ -39,7 +39,7 @@ use user_driver::memory::PAGE_SIZE64; /// allowing the user to control device behaviour to a certain extent. Can be used with devices such as the `NvmeController` pub struct EmulatedDevice { device: Arc>, - controller: MsiController, + controller: Arc, dma_client: Arc, bar0_len: usize, } @@ -77,6 +77,17 @@ impl MsiInterruptTarget for MsiController { } } +impl Clone for EmulatedDevice { + fn clone(&self) -> Self { + Self { + device: self.device.clone(), + controller: self.controller.clone(), + dma_client: self.dma_client.clone(), + bar0_len: self.bar0_len, + } + } +} + impl EmulatedDevice { /// Creates a new emulated device, wrapping `device` of type T, using the provided MSI Interrupt Set. Dma_client should point to memory /// shared with the device. @@ -84,6 +95,7 @@ impl EmulatedDevice { // Connect an interrupt controller. let controller = MsiController::new(msi_set.len()); msi_set.connect(&controller); + let controller = Arc::new(controller); let bars = device.probe_bar_masks(); let bar0_len = !(bars[0] & !0xf) as usize + 1;