diff --git a/axdevice_base/src/lib.rs b/axdevice_base/src/lib.rs index 731b2a3..4603496 100644 --- a/axdevice_base/src/lib.rs +++ b/axdevice_base/src/lib.rs @@ -1,5 +1,9 @@ #![no_std] #![feature(trait_alias)] +// trait_upcasting has been stabilized in Rust 1.86, but we still need a while to update the minimum +// Rust version of Axvisor. +#![allow(stable_features)] +#![feature(trait_upcasting)] #![allow(incomplete_features)] #![feature(generic_const_exprs)] @@ -12,7 +16,9 @@ extern crate alloc; -use alloc::{string::String, vec::Vec}; +use alloc::{string::String, sync::Arc, vec::Vec}; +use core::any::Any; + use axaddrspace::{ GuestPhysAddrRange, device::{AccessWidth, DeviceAddrRange, PortRange, SysRegAddrRange}, @@ -39,7 +45,7 @@ pub struct EmulatedDeviceConfig { } /// [`BaseDeviceOps`] is the trait that all emulated devices must implement. -pub trait BaseDeviceOps { +pub trait BaseDeviceOps: Any { /// Returns the type of the emulated device. fn emu_type(&self) -> EmuDeviceType; /// Returns the address range of the emulated device. @@ -50,6 +56,17 @@ pub trait BaseDeviceOps { fn handle_write(&self, addr: R::Addr, width: AccessWidth, val: usize) -> AxResult; } +/// Determines whether the given device is of type `T` and calls the provided function `f` with a +/// reference to the device if it is. +pub fn map_device_of_type, R: DeviceAddrRange, U, F: FnOnce(&T) -> U>( + device: &Arc>, + f: F, +) -> Option { + let any_arc: Arc = device.clone(); + + any_arc.downcast_ref::().map(f) +} + // trait aliases are limited yet: https://github.com/rust-lang/rfcs/pull/3437 /// [`BaseMmioDeviceOps`] is the trait that all emulated MMIO devices must implement. /// It is a trait alias of [`BaseDeviceOps`] with [`GuestPhysAddrRange`] as the address range. @@ -60,3 +77,6 @@ pub trait BaseSysRegDeviceOps = BaseDeviceOps; /// [`BasePortDeviceOps`] is the trait that all emulated port devices must implement. /// It is a trait alias of [`BaseDeviceOps`] with [`PortRange`] as the address range. pub trait BasePortDeviceOps = BaseDeviceOps; + +#[cfg(test)] +mod test; diff --git a/axdevice_base/src/test.rs b/axdevice_base/src/test.rs new file mode 100644 index 0000000..49fb2bf --- /dev/null +++ b/axdevice_base/src/test.rs @@ -0,0 +1,75 @@ +use alloc::vec; +use alloc::{sync::Arc, vec::Vec}; +use axaddrspace::{GuestPhysAddr, GuestPhysAddrRange, device::AccessWidth}; +use axerrno::AxResult; + +use crate::{BaseDeviceOps, EmuDeviceType, map_device_of_type}; + +const DEVICE_A_TEST_METHOD_ANSWER: usize = 42; + +struct DeviceA; + +impl BaseDeviceOps for DeviceA { + fn emu_type(&self) -> EmuDeviceType { + EmuDeviceType::Dummy + } + + fn address_range(&self) -> GuestPhysAddrRange { + (0x1000..0x2000).try_into().unwrap() + } + + fn handle_read(&self, addr: GuestPhysAddr, _width: AccessWidth) -> AxResult { + Ok(addr.as_usize()) + } + + fn handle_write(&self, _addr: GuestPhysAddr, _width: AccessWidth, _val: usize) -> AxResult { + Ok(()) + } +} + +impl DeviceA { + /// A test method unique to DeviceA. + pub fn test_method(&self) -> usize { + DEVICE_A_TEST_METHOD_ANSWER + } +} + +struct DeviceB; + +impl BaseDeviceOps for DeviceB { + fn emu_type(&self) -> EmuDeviceType { + EmuDeviceType::Dummy + } + + fn address_range(&self) -> GuestPhysAddrRange { + (0x2000..0x3000).try_into().unwrap() + } + + fn handle_read(&self, addr: GuestPhysAddr, _width: AccessWidth) -> AxResult { + Ok(addr.as_usize()) + } + + fn handle_write(&self, _addr: GuestPhysAddr, _width: AccessWidth, _val: usize) -> AxResult { + Ok(()) + } +} + +#[test] +fn test_device_type_test() { + let devices: Vec>> = + vec![Arc::new(DeviceA), Arc::new(DeviceB)]; + + let mut device_a_found = false; + for device in devices { + assert_eq!( + device.handle_read(0x2000.into(), AccessWidth::Byte), + Ok(0x2000) + ); + + if let Some(answer) = map_device_of_type(&device, |d: &DeviceA| d.test_method()) { + assert_eq!(answer, DEVICE_A_TEST_METHOD_ANSWER); + device_a_found = true; + } + } + assert!(device_a_found, "DeviceA was not found"); +}