Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions axdevice_base/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
#![no_std]
#![feature(trait_alias)]
// trait_upcasting has been stabilized in Rust 1.86, but we still need a while to update the minimum
// Rust version of Axvisor.
#![allow(stable_features)]
#![feature(trait_upcasting)]
#![allow(incomplete_features)]
#![feature(generic_const_exprs)]

Expand All @@ -12,7 +16,9 @@

extern crate alloc;

use alloc::{string::String, vec::Vec};
use alloc::{string::String, sync::Arc, vec::Vec};
use core::any::Any;

use axaddrspace::{
GuestPhysAddrRange,
device::{AccessWidth, DeviceAddrRange, PortRange, SysRegAddrRange},
Expand All @@ -39,7 +45,7 @@ pub struct EmulatedDeviceConfig {
}

/// [`BaseDeviceOps`] is the trait that all emulated devices must implement.
pub trait BaseDeviceOps<R: DeviceAddrRange> {
pub trait BaseDeviceOps<R: DeviceAddrRange>: Any {
/// Returns the type of the emulated device.
fn emu_type(&self) -> EmuDeviceType;
/// Returns the address range of the emulated device.
Expand All @@ -50,6 +56,17 @@ pub trait BaseDeviceOps<R: DeviceAddrRange> {
fn handle_write(&self, addr: R::Addr, width: AccessWidth, val: usize) -> AxResult;
}

/// Determines whether the given device is of type `T` and calls the provided function `f` with a
/// reference to the device if it is.
pub fn map_device_of_type<T: BaseDeviceOps<R>, R: DeviceAddrRange, U, F: FnOnce(&T) -> U>(
device: &Arc<dyn BaseDeviceOps<R>>,
f: F,
) -> Option<U> {
let any_arc: Arc<dyn Any> = device.clone();

any_arc.downcast_ref::<T>().map(f)
}

// trait aliases are limited yet: https://github.com/rust-lang/rfcs/pull/3437
/// [`BaseMmioDeviceOps`] is the trait that all emulated MMIO devices must implement.
/// It is a trait alias of [`BaseDeviceOps`] with [`GuestPhysAddrRange`] as the address range.
Expand All @@ -60,3 +77,6 @@ pub trait BaseSysRegDeviceOps = BaseDeviceOps<SysRegAddrRange>;
/// [`BasePortDeviceOps`] is the trait that all emulated port devices must implement.
/// It is a trait alias of [`BaseDeviceOps`] with [`PortRange`] as the address range.
pub trait BasePortDeviceOps = BaseDeviceOps<PortRange>;

#[cfg(test)]
mod test;
75 changes: 75 additions & 0 deletions axdevice_base/src/test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use alloc::vec;
use alloc::{sync::Arc, vec::Vec};
use axaddrspace::{GuestPhysAddr, GuestPhysAddrRange, device::AccessWidth};
use axerrno::AxResult;

use crate::{BaseDeviceOps, EmuDeviceType, map_device_of_type};

const DEVICE_A_TEST_METHOD_ANSWER: usize = 42;

struct DeviceA;

impl BaseDeviceOps<GuestPhysAddrRange> for DeviceA {
fn emu_type(&self) -> EmuDeviceType {
EmuDeviceType::Dummy
}

fn address_range(&self) -> GuestPhysAddrRange {
(0x1000..0x2000).try_into().unwrap()
}

fn handle_read(&self, addr: GuestPhysAddr, _width: AccessWidth) -> AxResult<usize> {
Ok(addr.as_usize())
}

fn handle_write(&self, _addr: GuestPhysAddr, _width: AccessWidth, _val: usize) -> AxResult {
Ok(())
}
}

impl DeviceA {
/// A test method unique to DeviceA.
pub fn test_method(&self) -> usize {
DEVICE_A_TEST_METHOD_ANSWER
}
}

struct DeviceB;

impl BaseDeviceOps<GuestPhysAddrRange> for DeviceB {
fn emu_type(&self) -> EmuDeviceType {
EmuDeviceType::Dummy
}

fn address_range(&self) -> GuestPhysAddrRange {
(0x2000..0x3000).try_into().unwrap()
}

fn handle_read(&self, addr: GuestPhysAddr, _width: AccessWidth) -> AxResult<usize> {
Ok(addr.as_usize())
}

fn handle_write(&self, _addr: GuestPhysAddr, _width: AccessWidth, _val: usize) -> AxResult {
Ok(())
}
}

#[test]
fn test_device_type_test() {
let devices: Vec<Arc<dyn BaseDeviceOps<GuestPhysAddrRange>>> =
vec![Arc::new(DeviceA), Arc::new(DeviceB)];

let mut device_a_found = false;
for device in devices {
assert_eq!(
device.handle_read(0x2000.into(), AccessWidth::Byte),
Ok(0x2000)
);

if let Some(answer) = map_device_of_type(&device, |d: &DeviceA| d.test_method()) {
assert_eq!(answer, DEVICE_A_TEST_METHOD_ANSWER);
device_a_found = true;
}
}
assert!(device_a_found, "DeviceA was not found");
}