Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mmap/sbrk/brk implementation and integration into glibc/wasmtime #74

Merged
merged 11 commits into from
Jan 5, 2025
270 changes: 215 additions & 55 deletions src/RawPOSIX/src/interface/mem.rs

Large diffs are not rendered by default.

32 changes: 18 additions & 14 deletions src/RawPOSIX/src/safeposix/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ pub fn lind_syscall_api(
let mut fildes = arg5 as i32;
let off = arg6 as i64;

interface::mmap_handler(cageid, addr, len, prot, flags, fildes, off)
interface::mmap_handler(cageid, addr, len, prot, flags, fildes, off) as i32
}

PREAD_SYSCALL => {
Expand Down Expand Up @@ -1204,33 +1204,32 @@ pub fn lind_syscall_api(
cage.waitpid_syscall(pid, status, options)
}


SBRK_SYSCALL => {
let brk = arg1 as i32;

interface::sbrk_handler(cageid, brk) as i32
}

BRK_SYSCALL => {
let brk = arg1 as u32;

interface::sbrk_handler(cageid, brk)
interface::brk_handler(cageid, brk)
}

_ => -1, // Return -1 for unknown syscalls
};
ret
}

// initilize the vmmap, invoked by wasmtime
pub fn lind_cage_vmmap_init(cageid: u64) {
let cage = interface::cagetable_getref(cageid);
let mut vmmap = cage.vmmap.write();
vmmap.add_entry(VmmapEntry::new(0, 0x30, PROT_WRITE | PROT_READ, 0 /* not sure about this field */, (MAP_PRIVATE | MAP_ANONYMOUS) as i32, false, 0, 0, cageid, MemoryBackingType::Anonymous));
// BUG: currently need to insert an entry at the end to indicate the end of memory space. This should be fixed soon so that
// no dummy entries are required to be inserted
vmmap.add_entry(VmmapEntry::new(1 << 18, 1, PROT_NONE, 0 /* not sure about this field */, (MAP_PRIVATE | MAP_ANONYMOUS) as i32, false, 0, 0, cageid, MemoryBackingType::Anonymous));
ret
}

// set the wasm linear memory base address to vmmap
pub fn set_base_address(cageid: u64, base_address: i64) {
pub fn init_vmmap_helper(cageid: u64, base_address: usize, program_break: Option<u32>) {
let cage = interface::cagetable_getref(cageid);
let mut vmmap = cage.vmmap.write();
vmmap.set_base_address(base_address);
if program_break.is_some() {
vmmap.set_program_break(program_break.unwrap());
}
}

// clone the cage memory. Invoked by wasmtime after cage is forked
Expand All @@ -1241,6 +1240,11 @@ pub fn fork_vmmap_helper(parent_cageid: u64, child_cageid: u64) {
let child_vmmap = child_cage.vmmap.read();

interface::fork_vmmap(&parent_vmmap, &child_vmmap);

// update program break for child
drop(child_vmmap);
let mut child_vmmap = child_cage.vmmap.write();
child_vmmap.set_program_break(parent_vmmap.program_break);
}

#[no_mangle]
Expand Down
12 changes: 6 additions & 6 deletions src/RawPOSIX/src/safeposix/syscalls/fs_calls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ impl Cage {
flags: i32,
virtual_fd: i32,
off: i64
) -> i64 {
) -> usize {
if virtual_fd != -1 {
match fdtables::translate_virtual_fd(self.cageid, virtual_fd as u64) {
Ok(kernel_fd) => {
Expand All @@ -789,13 +789,13 @@ impl Cage {

// Check if mmap failed and return the appropriate error if so
if ret == -1 {
return syscall_error(Errno::EINVAL, "mmap", "mmap failed with invalid flags") as i64;
return syscall_error(Errno::EINVAL, "mmap", "mmap failed with invalid flags") as usize;
}

ret
ret as usize
},
Err(_e) => {
return syscall_error(Errno::EBADF, "mmap", "Bad File Descriptor") as i64;
return syscall_error(Errno::EBADF, "mmap", "Bad File Descriptor") as usize;
}
}
} else {
Expand All @@ -805,10 +805,10 @@ impl Cage {
};
// Check if mmap failed and return the appropriate error if so
if ret == -1 {
return syscall_error(Errno::EINVAL, "mmap", "mmap failed with invalid flags") as i64;
return syscall_error(Errno::EINVAL, "mmap", "mmap failed with invalid flags") as usize;
}

ret
ret as usize
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/RawPOSIX/src/safeposix/syscalls/sys_calls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ impl Cage {
sigset: newsigset,
main_threadid: interface::RustAtomicU64::new(0),
interval_timer: interface::IntervalTimer::new(child_cageid),
vmmap: interface::RustLock::new(Vmmap::new()), // Initialize empty virtual memory map for new process
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we sure that this deep copies on clone?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually cannot be 100% sure about this. clone method should be designed to deep-copying itself by convention from what I've read, but if the developer (in our case, developer of nodit library or its sub-dependent) is breaking the convention and doing some hacky things here, it would probably be hard for us to know. The documentation for the NoditMap only says it returns a copy of the value without mentioning shallow or deep.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is something worth just adding a test for @ChinmayShringi

vmmap: interface::RustLock::new(new_vmmap), // Initialize empty virtual memory map for new process
zombies: interface::RustLock::new(vec![]),
child_num: interface::RustAtomicU64::new(0),
};
Expand Down
143 changes: 66 additions & 77 deletions src/RawPOSIX/src/safeposix/vmmap.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
use crate::constants::{
PROT_NONE, PROT_READ, PROT_WRITE, PROT_EXEC,
MAP_SHARED, MAP_PRIVATE, MAP_FIXED, MAP_ANONYMOUS,
MAP_FAILED
MAP_ANONYMOUS, MAP_FAILED, MAP_FIXED, MAP_PRIVATE, MAP_SHARED, PAGESHIFT, PROT_EXEC, PROT_NONE, PROT_READ, PROT_WRITE
};
use std::io;
use nodit::NoditMap;
use nodit::{interval::ie, Interval};
use crate::fdtables;
use crate::safeposix::cage::syscall_error;
use crate::safeposix::cage::Errno;

const DEFAULT_VMMAP_SIZE: u32 = 1 << (32 - PAGESHIFT);

/// Used to identify whether the vmmap entry is backed anonymously,
/// by an fd, or by a shared memory segment
///
Expand Down Expand Up @@ -243,8 +244,11 @@ pub struct Vmmap {
pub entries: NoditMap<u32, Interval<u32>, VmmapEntry>, // Keyed by `page_num`
pub cached_entry: Option<VmmapEntry>, // TODO: is this still needed?
// Use Option for safety
pub base_address: Option<i64>, // wasm base address. None means uninitialized yet
pub base_address: Option<usize>, // wasm base address. None means uninitialized yet

pub start_address: u32, // start address of valid vmmap address range
pub end_address: u32, // end address of valid vmmap address range
pub program_break: u32, // program break (i.e. heap bottom) of the memory
}

#[allow(dead_code)]
Expand All @@ -255,7 +259,10 @@ impl Vmmap {
Vmmap {
entries: NoditMap::new(),
cached_entry: None,
base_address: None
base_address: None,
start_address: 0,
end_address: DEFAULT_VMMAP_SIZE,
program_break: 0,
}
}

Expand Down Expand Up @@ -287,20 +294,28 @@ impl Vmmap {
///
/// Arguments:
/// - base_address: The base address to set
pub fn set_base_address(&mut self, base_address: i64) {
pub fn set_base_address(&mut self, base_address: usize) {
// Store the provided base address
self.base_address = Some(base_address);
}

/// Sets the program break for the memory
///
/// Arguments:
/// - program_break: The program break to set
pub fn set_program_break(&mut self, program_break: u32) {
self.program_break = program_break;
}

/// Converts a user address to a system address
///
/// Arguments:
/// - address: User space address to convert
///
/// Returns the corresponding system address
pub fn user_to_sys(&self, address: i32) -> i64 {
pub fn user_to_sys(&self, address: u32) -> usize {
// Add base address to user address to get system address
address as i64 + self.base_address.unwrap()
address as usize + self.base_address.unwrap()
}

/// Converts a system address to a user address
Expand All @@ -309,9 +324,9 @@ impl Vmmap {
/// - address: System address to convert
///
/// Returns the corresponding user space address
pub fn sys_to_user(&self, address: i64) -> i32 {
pub fn sys_to_user(&self, address: usize) -> u32 {
// Subtract base address from system address to get user address
(address as i64 - self.base_address.unwrap()) as i32
(address as usize - self.base_address.unwrap()) as u32
}

// Visits each entry in the vmmap, applying a visitor function to each entry
Expand Down Expand Up @@ -814,24 +829,17 @@ impl VmmapOps for Vmmap {
/// - Some(Interval) containing the found space
/// - None if no suitable space found
fn find_space(&self, npages: u32) -> Option<Interval<u32>> {
let start = self.first_entry();
let end = self.last_entry();
let start = self.start_address;
let end = self.end_address;

if start == None || end == None {
return None;
} else {
let start_unwrapped = start.unwrap().0.start();
let end_unwrapped = end.unwrap().0.end();

let desired_space = npages + 1; // TODO: check if this is correct
let desired_space = npages + 1; // TODO: check if this is correct

for gap in self
.entries
.gaps_trimmed(ie(start_unwrapped, end_unwrapped))
{
if gap.end() - gap.start() >= desired_space {
return Some(gap);
}
for gap in self
.entries
.gaps_trimmed(ie(start, end))
{
if gap.end() - gap.start() >= desired_space {
return Some(gap);
}
}

Expand All @@ -852,19 +860,13 @@ impl VmmapOps for Vmmap {
/// - None if no suitable space found
fn find_space_above_hint(&self, npages: u32, hint: u32) -> Option<Interval<u32>> {
let start = hint;
let end = self.last_entry();
let end = self.end_address;

if end == None {
return None;
} else {
let end_unwrapped = end.unwrap().0.end();

let desired_space = npages + 1; // TODO: check if this is correct
let desired_space = npages + 1; // TODO: check if this is correct

for gap in self.entries.gaps_trimmed(ie(start, end_unwrapped)) {
if gap.end() - gap.start() >= desired_space {
return Some(gap);
}
for gap in self.entries.gaps_trimmed(ie(start, end)) {
if gap.end() - gap.start() >= desired_space {
return Some(gap);
}
}

Expand All @@ -888,31 +890,24 @@ impl VmmapOps for Vmmap {
/// - Rounds page numbers up to alignment boundaries
/// - Handles alignment constraints for start and end addresses
fn find_map_space(&self, num_pages: u32, pages_per_map: u32) -> Option<Interval<u32>> {
let start = self.first_entry();
let end = self.last_entry();
let start = self.start_address;
let end = self.end_address;

if start == None || end == None {
return None;
} else {
let start_unwrapped = start.unwrap().0.start();
let end_unwrapped = end.unwrap().0.end();

let rounded_num_pages =
self.round_page_num_up_to_map_multiple(num_pages, pages_per_map);
let rounded_num_pages =
self.round_page_num_up_to_map_multiple(num_pages, pages_per_map);

for gap in self
.entries
.gaps_trimmed(ie(start_unwrapped, end_unwrapped))
{
let aligned_start_page =
self.trunc_page_num_down_to_map_multiple(gap.start(), pages_per_map);
let aligned_end_page =
self.round_page_num_up_to_map_multiple(gap.end(), pages_per_map);

let gap_size = aligned_end_page - aligned_start_page;
if gap_size >= rounded_num_pages {
return Some(ie(aligned_end_page - rounded_num_pages, aligned_end_page));
}
for gap in self
.entries
.gaps_trimmed(ie(start, end))
{
let aligned_start_page =
self.trunc_page_num_down_to_map_multiple(gap.start(), pages_per_map);
let aligned_end_page =
self.round_page_num_up_to_map_multiple(gap.end(), pages_per_map);

let gap_size = aligned_end_page - aligned_start_page;
if gap_size >= rounded_num_pages {
return Some(ie(aligned_end_page - rounded_num_pages, aligned_end_page));
}
}

Expand Down Expand Up @@ -944,26 +939,20 @@ impl VmmapOps for Vmmap {
hint: u32,
) -> Option<Interval<u32>> {
let start = hint;
let end = self.last_entry();
let end = self.end_address;

if end == None {
return None;
} else {
let end_unwrapped = end.unwrap().0.end();

let rounded_num_pages =
self.round_page_num_up_to_map_multiple(num_pages, pages_per_map);
let rounded_num_pages =
self.round_page_num_up_to_map_multiple(num_pages, pages_per_map);

for gap in self.entries.gaps_trimmed(ie(start, end_unwrapped)) {
let aligned_start_page =
self.trunc_page_num_down_to_map_multiple(gap.start(), pages_per_map);
let aligned_end_page =
self.round_page_num_up_to_map_multiple(gap.end(), pages_per_map);
for gap in self.entries.gaps_trimmed(ie(start, end)) {
let aligned_start_page =
self.trunc_page_num_down_to_map_multiple(gap.start(), pages_per_map);
let aligned_end_page =
self.round_page_num_up_to_map_multiple(gap.end(), pages_per_map);

let gap_size = aligned_end_page - aligned_start_page;
if gap_size >= rounded_num_pages {
return Some(ie(aligned_end_page - rounded_num_pages, aligned_end_page));
}
let gap_size = aligned_end_page - aligned_start_page;
if gap_size >= rounded_num_pages {
return Some(ie(aligned_end_page - rounded_num_pages, aligned_end_page));
}
}

Expand Down
30 changes: 19 additions & 11 deletions src/glibc/lind_syscall/lind_syscall.c
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,23 @@ int __imported_wasi_snapshot_preview1_lind_syscall(unsigned int callnumber, unsi
// handled here instead
int lind_syscall (unsigned int callnumber, unsigned long long callname, unsigned long long arg1, unsigned long long arg2, unsigned long long arg3, unsigned long long arg4, unsigned long long arg5, unsigned long long arg6)
{
int ret = __imported_wasi_snapshot_preview1_lind_syscall(callnumber, callname, arg1, arg2, arg3, arg4, arg5, arg6);
// handle the errno
if(ret < 0)
{
errno = -ret;
}
else
{
errno = 0;
}
return ret;
int ret = __imported_wasi_snapshot_preview1_lind_syscall(callnumber, callname, arg1, arg2, arg3, arg4, arg5, arg6);
// handle the errno
// in rawposix, we use -errno as the return value to indicate the error
// but this may cause some issues for mmap syscall, because mmap syscall
// is returning an 32-bit address, which may overflow the int type (i32)
// luckily we can handle this easily because the return value of mmap is always
// multiple of pages (typically 4096) even when overflow, therefore we can distinguish
// the errno and mmap result by simply checking if the return value is
// within the valid errno range
if(ret < 0 && ret > -256)
{
errno = -ret;
return -1;
}
else
{
errno = 0;
}
return ret;
}
Loading