From 0c1dfca78d8a92a1413e56ccb5d2fd4cb657d80f Mon Sep 17 00:00:00 2001 From: linxiangdong Date: Wed, 15 May 2024 15:50:29 +0800 Subject: [PATCH] add task_spawner support * use single thread task spawner to run guest code in single thread runtime * use multi-thread task spawner to run host code in multi-thread runtime * use mpsc channel to send task's return value back to caller refer to 'docs/multiple-thread-support.md' for more details Signed-off-by: linxiangdong --- docs/multiple-thread-support.md | 288 ++++++++++++++++++ fp-bindgen-support/Cargo.toml | 3 +- fp-bindgen-support/src/wasmer2_host/mod.rs | 1 + .../src/wasmer2_host/task_spawner.rs | 147 +++++++++ .../generators/rust_wasmer2_runtime/mod.rs | 63 +++- .../rust_wasmer2_wasi_runtime/mod.rs | 36 ++- 6 files changed, 523 insertions(+), 15 deletions(-) create mode 100644 docs/multiple-thread-support.md create mode 100644 fp-bindgen-support/src/wasmer2_host/task_spawner.rs diff --git a/docs/multiple-thread-support.md b/docs/multiple-thread-support.md new file mode 100644 index 00000000..4d396c47 --- /dev/null +++ b/docs/multiple-thread-support.md @@ -0,0 +1,288 @@ +# the patch to support multi-thread tokio + +In order to fix the panic in multi-thread tokio env (https://github.com/fiberplane/fp-bindgen/issues/194), I followed @arendjr's advice, and created a TaskSpawner, which will: + +1. Run tokio runtime in single thread +2. handle guest async function in this spawner +3. Handle (at least) the second half of the host async function in this spawner + +## The TaskSpawner + +The goal of this TaskSpawneris to run all async tasks in a dedicated single thread runtime. By referring to Tokio book and some help from ChatGPT, I got this TaskSpawner, which has: + +1. Use `tokio::sync::mpsc` channel to receive tasks +2. Support bidirectional communications: users could get results back, by creating a channel for each request +3. Support multiple result types: use `std::any::Any` to represent result in a generic way +4. Support spawn task from both async and sync context + +Here is the full source code + +```rust +use tokio::runtime::Builder; +use tokio::sync::mpsc; +use std::thread; +use std::pin::Pin; +use std::future::Future; +use std::any::Any; + +type BoxedFuture = Pin> + Send>>; +type Task = (BoxedFuture, mpsc::Sender>); + +#[derive(Clone)] +struct Spawner { + sender: mpsc::Sender, +} + +impl Spawner { + fn new() -> Self { + let (sender, mut receiver) = mpsc::channel::(100); + + let rt = Builder::new_current_thread() + .worker_threads(1) + .enable_all() + .build() + .unwrap(); + + std::thread::spawn(move || { + rt.block_on(async move { + loop { + if let Some((task, result_sender)) = receiver.recv().await { + let task = async move { + let result = task.await; + let _ = result_sender.send(result).await; + }; + tokio::spawn(task); + } + } + }); + }); + + Spawner { sender } + } + + // spawn and return immediatelly, used in async context + fn spawn(&self, task: F) -> mpsc::Receiver> + where + F: Future + Send + 'static, + T: Any + Send + 'static + { + let (result_sender, result_receiver) = mpsc::channel(1); + let task = Box::pin(async move {Box::new(task.await) as Box }); + let _ = self.sender.try_send((task, result_sender)); + result_receiver + } + + // spawn and wait, used in sync context + fn spawn_blocking(&self, task: F) -> mpsc::Receiver> + where + F: Future + Send + 'static, + T: Any + Send + 'static + { + let (result_sender, result_receiver) = mpsc::channel(1); + let task = Box::pin(async move {Box::new(task.await) as Box }); + let _ = self.sender.blocking_send((task, result_sender)); + result_receiver + } + // + // // spawn and wait(in async) until there is capacity, used in sync context + async fn spawn_async(&self, task: F) -> mpsc::Receiver> + where + F: Future + Send + 'static, + T: Any + Send + 'static + { + let (result_sender, result_receiver) = mpsc::channel(1); + let task = Box::pin(async move {Box::new(task.await) as Box }); + let _ = self.sender.send((task, result_sender)).await; + result_receiver + } +} + +fn main() { + let spawner = Spawner::new(); + + let spawner_clone = spawner.clone(); + let sync_code = thread::spawn(move || { + let recv = spawner_clone.spawn_blocking(async { + println!("hello from sync world"); + 3.14 + }); + recv + }); + + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build().unwrap(); + + rt.block_on(async move { + let mut result_receiver = spawner.spawn(async { + println!("{:?} Hello from task", std::thread::current().id()); + tokio::time::sleep(std::time::Duration::from_secs(5)).await; + 42 + }); + // + let mut result_receiver2 = spawner.spawn_async(async { + println!("{:?} Hello from task again", std::thread::current().id()); + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + "foobar".to_owned() + }).await; + + // Get the task result + if let Some(result) = result_receiver2.recv().await { + let result = result.downcast::().unwrap(); + println!("Task result: {}", result); + } else { + println!("Task get error"); + } + + // Get the task result + if let Some(result) = result_receiver.recv().await { + let result = result.downcast::().unwrap(); + println!("Task result: {}", result); + } else { + println!("Task get error"); + } + // Keep the main thread running + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + }); + + let res = sync_code.join().unwrap().blocking_recv(); + let res = res.unwrap().downcast::().unwrap(); + println!("recv from sync block: {}", res); +} +``` + +With this TaskSpawner, I forward both host and guest async functions to this TaskSpawner, and it runs without any panic + +## Guest async function + +### Previous version + + +```rust +pub async fn guest_func(&self, arg1: String) -> Result { + let arg1 = serialize_to_vec(&arg1); + let result = self.guest_func_raw(arg1); + let result = result.await; + let result = result.map(|ref data| deserialize_from_slice(data)); + result +} +``` + +Without specifying runtime, `result.await` in line 4 will run in system default tokio runtime, and we will panic with multithread runtime, and that's why we have to stick with single thread runtime + +### Current version + +```rust +pub async fn guest_func(&self, arg1: String) -> Result { + let this = self.clone(); + let task = async move { + let arg1 = serialize_to_vec(&arg1); + let result = this.guest_func_raw(arg1); + let result = result.await; + let result = result.map(|ref data| deserialize_from_slice::(data)); + result.unwrap() + }; + let mut recv = SPAWNER.spawn_async(task).await; + match recv.recv().await { + Some(result) => Ok(*result.downcast::().unwrap()), + None => Err(InvocationError::UnexpectedReturnType), + } +} +``` + +In line 10, we forward this task to the dedicated single-thread runtime (The SPAWNER) + +## Host async function + +### Previous version + +```rust +pub fn _host_func(env: &RuntimeInstanceData, arg1: FatPtr) -> FatPtr { + let arg1 = import_from_guest::(env, arg1); + let env = env.clone(); + let async_ptr = create_future_value(&env); + let handle = tokio::runtime::Handle::current(); + handle.spawn(async move { + let result = super::host_func(arg1).await; + let result_ptr = export_to_guest(&env, &result); + env.guest_resolve_async_value(async_ptr, result_ptr); + }); + async_ptr +} +``` + +### Current version + +```rust +pub fn _host_func(env: &RuntimeInstanceData, arg1: FatPtr) -> FatPtr { + let arg1 = import_from_guest::(&env_clone, arg1); + let env = env.clone(); + let async_ptr = create_future_value(&env); + let task = async move { + let result = super::host_func(arg1).await; + let result_ptr = export_to_guest(&env, &result); + env.guest_resolve_async_value(async_ptr, result_ptr); + }; + SPAWNER.spawn(task); + async_ptr +} +``` + +The key difference is line 10. We forward this async task to runtime's dedicated single thread runtime +And with this change, at least we get expected values without panic + +### A minor glitch + +Now we still have a small problem: ALL host functions run in the same single thread runtime. In theory, we should make them run in the global multiple thread runtime, and only pass result in this dedicated single thread runtime. By adding another dedicated multiple-thread runtime, or passing the outmost tokio runtime created from main, we can guarantee that host functions run in multi-thread runtime + + +```rust +pub fn _host_func(env: &RuntimeInstanceData, arg1: FatPtr) -> FatPtr { + let arg1 = import_from_guest::(env, arg1); + + let env_clone = env.clone(); + let async_ptr = create_future_value(&env_clone); + let host_task = MT_SPAWNER.spawn_async(async move { + let result = super::host_func(arg1).await; + result + }); + + let env_clone = env.clone(); + let guest_task = async move { + let mut result = host_task.await; + let res = match result.recv().await { + Some(result) => *result.downcast::().unwrap(), + None => "xxx".to_string(), //TODO: fix this later + }; + + let result_ptr = export_to_guest(&env_clone, &res); + env_clone.guest_resolve_async_value(async_ptr, result_ptr); + }; + SG_SPAWNER.spawn(guest_task); + async_ptr +} +``` + +### the extra call in main + +We have to create the __global__ handler in main before calling `fp-bindgen` + +```rust +use fp_bindgen_support::wasmer2_host::task_spawner::GLOBAL_HANDLER; +use fp_bindgen_support::wasmer2_host::task_spawner::GlobalSpawner; + + +#[tokio::main(flavor = "multi_thread")] +async fn main() { + let global_spawner = GlobalSpawner::new(tokio::runtime::Handle::current()); + let _ = GLOBAL_HANDLER.set(global_spawner); + // ... +} + +``` + +### the performance issue + +This is just a rough implementation, mpsc channels and dynamic `std::any::Any` are definitely not the optimal solution, and it indeed shows poor performance in our production environment compared to the single threaded version. + +So if this is the right direction, we need to spend more efforts to make it performant. diff --git a/fp-bindgen-support/Cargo.toml b/fp-bindgen-support/Cargo.toml index fdb966b8..5775a00b 100644 --- a/fp-bindgen-support/Cargo.toml +++ b/fp-bindgen-support/Cargo.toml @@ -19,9 +19,10 @@ once_cell = "1" rmp-serde = "1.0.0" serde = { version = "1.0", features = ["derive"] } serde_bytes = "0.11" -wasmer = { version = "2.1", optional = true } thiserror = { version = "1.0.26", optional = true } +tokio = { version = "1", features = ["rt", "sync"] } tracing = "0.1.37" +wasmer = { version = "2.1", optional = true } [features] default = [] diff --git a/fp-bindgen-support/src/wasmer2_host/mod.rs b/fp-bindgen-support/src/wasmer2_host/mod.rs index 29755f3d..cb3094a5 100644 --- a/fp-bindgen-support/src/wasmer2_host/mod.rs +++ b/fp-bindgen-support/src/wasmer2_host/mod.rs @@ -5,3 +5,4 @@ pub mod errors; pub mod io; pub mod mem; pub mod runtime; +pub mod task_spawner; diff --git a/fp-bindgen-support/src/wasmer2_host/task_spawner.rs b/fp-bindgen-support/src/wasmer2_host/task_spawner.rs new file mode 100644 index 00000000..873548ce --- /dev/null +++ b/fp-bindgen-support/src/wasmer2_host/task_spawner.rs @@ -0,0 +1,147 @@ +use crate::common::mem::FatPtr; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use std::cell::RefCell; +use std::pin::Pin; +use std::future::Future; +use std::task::Waker; +use tokio::task::LocalSet; +use tokio::sync::mpsc::{self, Receiver, Sender}; +use tokio::runtime::Builder; +use wasmer::{LazyInit, Memory, NativeFunc}; +use std::any::Any; +use once_cell::sync::OnceCell; + +type BoxedFuture = Pin> + Send>>; +type Task = (BoxedFuture, mpsc::Sender>); + +#[derive(Clone, Debug)] +pub struct CurrentThreadSpawner { + sender: mpsc::Sender, +} + +#[derive(Clone, Default)] +pub struct GlobalSpawner { + sender: Arc>>, +} + +pub static GLOBAL_HANDLER: OnceCell = OnceCell::new(); + +impl GlobalSpawner { + pub fn new(handle: tokio::runtime::Handle) -> Self { + let (sender, mut receiver) = mpsc::channel::(100); + handle.spawn(async move { + loop { + if let Some((task, result_sender)) = receiver.recv().await { + let task = async move { + let result = task.await; + let _ = result_sender.send(result).await; + }; + tokio::spawn(task); + } + } + }); + + GlobalSpawner { sender: Arc::new(Some(sender)) } + } + + // spawn and return immediatelly, used in async context + pub fn spawn(&self, task: F) -> mpsc::Receiver> + where + F: Future + Send + 'static, + T: Any + Send + 'static + { + let (result_sender, result_receiver) = mpsc::channel(1); + let task = Box::pin(async move {Box::new(task.await) as Box }); + let _ = self.sender.clone().as_ref().clone().unwrap().try_send((task, result_sender)); + result_receiver + } + + // spawn and wait, used in sync context + pub fn spawn_blocking(&self, task: F) -> mpsc::Receiver> + where + F: Future + Send + 'static, + T: Any + Send + 'static + { + let (result_sender, result_receiver) = mpsc::channel(1); + let task = Box::pin(async move {Box::new(task.await) as Box }); + let _ = self.sender.clone().as_ref().clone().unwrap().blocking_send((task, result_sender)); + result_receiver + } + // + // // spawn and wait(in async) until there is capacity, used in sync context + pub async fn spawn_async(&self, task: F) -> mpsc::Receiver> + where + F: Future + Send + 'static, + T: Any + Send + 'static + { + let (result_sender, result_receiver) = mpsc::channel(1); + let task = Box::pin(async move {Box::new(task.await) as Box }); + let _ = self.sender.clone().as_ref().clone().unwrap().send((task, result_sender)).await; + result_receiver + } +} + +impl CurrentThreadSpawner { + pub fn new() -> Self { + let (sender, mut receiver) = mpsc::channel::(100); + + let rt = Builder::new_current_thread() + .worker_threads(1) + .enable_all() + .build() + .unwrap(); + + std::thread::spawn(move || { + rt.block_on(async move { + loop { + if let Some((task, result_sender)) = receiver.recv().await { + let task = async move { + let result = task.await; + let _ = result_sender.send(result).await; + }; + tokio::spawn(task); + } + } + }); + }); + + CurrentThreadSpawner { sender } + } + + // spawn and return immediatelly, used in async context + pub fn spawn(&self, task: F) -> mpsc::Receiver> + where + F: Future + Send + 'static, + T: Any + Send + 'static + { + let (result_sender, result_receiver) = mpsc::channel(1); + let task = Box::pin(async move {Box::new(task.await) as Box }); + let _ = self.sender.try_send((task, result_sender)); + result_receiver + } + + // spawn and wait, used in sync context + pub fn spawn_blocking(&self, task: F) -> mpsc::Receiver> + where + F: Future + Send + 'static, + T: Any + Send + 'static + { + let (result_sender, result_receiver) = mpsc::channel(1); + let task = Box::pin(async move {Box::new(task.await) as Box }); + let _ = self.sender.blocking_send((task, result_sender)); + result_receiver + } + // + // // spawn and wait(in async) until there is capacity, used in sync context + pub async fn spawn_async(&self, task: F) -> mpsc::Receiver> + where + F: Future + Send + 'static, + T: Any + Send + 'static + { + let (result_sender, result_receiver) = mpsc::channel(1); + let task = Box::pin(async move {Box::new(task.await) as Box }); + let _ = self.sender.send((task, result_sender)).await; + result_receiver + } +} diff --git a/fp-bindgen/src/generators/rust_wasmer2_runtime/mod.rs b/fp-bindgen/src/generators/rust_wasmer2_runtime/mod.rs index 3575b54d..48c31fc0 100644 --- a/fp-bindgen/src/generators/rust_wasmer2_runtime/mod.rs +++ b/fp-bindgen/src/generators/rust_wasmer2_runtime/mod.rs @@ -159,7 +159,7 @@ pub(crate) fn generate_export_function_variables<'a>( let (raw_return_wrapper, return_wrapper) = if function.is_async { ( "let result = ModuleRawFuture::new(self.env.clone(), result).await;".to_string(), - "let result = result.await;\nlet result = result.map(|ref data| deserialize_from_slice(data));".to_string(), + format!("let result = result.await;\nlet result = result.map(|ref data| deserialize_from_slice::<{return_type}>(data));").to_string(), ) } else if !function .return_type @@ -216,8 +216,33 @@ fn format_export_function(function: &Function, types: &TypeMap) -> String { return_wrapper, } = generate_export_function_variables(function, types); - format!( - r#"{doc}pub {modifiers}fn {name}(&self{args}) -> Result<{return_type}, InvocationError> {{ + if function.is_async { + format!( + r#"{doc}pub {modifiers}fn {name}(&self{args}) -> Result<{return_type}, InvocationError> {{ + let this = self.clone(); + let task = async move {{ + {serialize_args} + let result = this.{name}_raw({arg_names}); + {return_wrapper}result.unwrap() + }}; + let mut recv = CURRENT_SPAWNER.spawn_async(task).await; + match recv.recv().await {{ + Some(result) => Ok(*result.downcast::<{return_type}>().unwrap()), + None => Err(InvocationError::UnexpectedReturnType), + }} +}} +pub {modifiers}fn {name}_raw(&self{raw_args}) -> Result<{raw_return_type}, InvocationError> {{ + {serialize_raw_args}let function = self.instance + .exports + .get_native_function::<{wasm_args}, {wasm_return_type}>("__fp_gen_{name}") + .map_err(|_| InvocationError::FunctionNotExported("__fp_gen_{name}".to_owned()))?; + let result = function.call({wasm_arg_names})?; + {raw_return_wrapper}Ok(result) +}}"# + ) + } else { + format!( + r#"{doc}pub {modifiers}fn {name}(&self{args}) -> Result<{return_type}, InvocationError> {{ {serialize_args} let result = self.{name}_raw({arg_names}); {return_wrapper}result @@ -231,6 +256,7 @@ pub {modifiers}fn {name}_raw(&self{raw_args}) -> Result<{raw_return_type}, Invoc {raw_return_wrapper}Ok(result) }}"# ) + } } pub(crate) fn format_import_arg(name: &str, ty: &TypeIdent, types: &TypeMap) -> String { @@ -272,15 +298,28 @@ pub(crate) fn format_import_function(function: &Function, types: &TypeMap) -> St .join(", "); let return_wrapper = if function.is_async { + let return_type = match &function.return_type { + None => "()".to_string(), + Some(ty) => format!("{}", ty), + }; format!( - r#"let env = env.clone(); + r#"let env_clone = env.clone(); let async_ptr = create_future_value(&env); - let handle = tokio::runtime::Handle::current(); - handle.spawn(async move {{ - let result = super::{name}({arg_names}).await; - let result_ptr = export_to_guest(&env, &result); - env.guest_resolve_async_value(async_ptr, result_ptr); + let host_task = GLOBAL_HANDLER.get().unwrap().spawn_async(async move {{ + super::{name}({arg_names}).await }}); + + let guest_task = async move {{ + let mut result = host_task.await; + let res = match result.recv().await {{ + Some(result) => *result.downcast::<{return_type}>().unwrap(), + None => panic!("FXIME: host_task return error"), + }}; + + let result_ptr = export_to_guest(&env_clone, &res); + env_clone.guest_resolve_async_value(async_ptr, result_ptr); + }}; + CURRENT_SPAWNER.spawn(guest_task); async_ptr"# ) } else { @@ -350,8 +389,14 @@ use fp_bindgen_support::{{ runtime::RuntimeInstanceData, }}, }}; +use once_cell::sync::Lazy; use std::cell::RefCell; use wasmer::{{imports, Function, ImportObject, Instance, Module, Store, WasmerEnv}}; +use fp_bindgen_support::wasmer2_host::task_spawner::{{CurrentThreadSpawner, GlobalSpawner}}; +use fp_bindgen_support::wasmer2_host::task_spawner::GLOBAL_HANDLER; +static CURRENT_SPAWNER: Lazy = Lazy::new(|| {{ + CurrentThreadSpawner::new() +}}); #[derive(Clone)] pub struct Runtime {{ diff --git a/fp-bindgen/src/generators/rust_wasmer2_wasi_runtime/mod.rs b/fp-bindgen/src/generators/rust_wasmer2_wasi_runtime/mod.rs index 0c21301d..2dbe1b98 100644 --- a/fp-bindgen/src/generators/rust_wasmer2_wasi_runtime/mod.rs +++ b/fp-bindgen/src/generators/rust_wasmer2_wasi_runtime/mod.rs @@ -71,11 +71,36 @@ fn format_export_function(function: &Function, types: &TypeMap) -> String { return_wrapper, } = generate_export_function_variables(function, types); - format!( - r#"{doc}pub {modifiers}fn {name}(&self{args}) -> Result<{return_type}, InvocationError> {{ - {serialize_args} - let result = self.{name}_raw({arg_names}); - {return_wrapper}result + if function.is_async { + format!( + r#"{doc}pub {modifiers}fn {name}(&self{args}) -> Result<{return_type}, InvocationError> {{ + let this = self.clone(); + let task = async move {{ + {serialize_args} + let result = this.{name}_raw({arg_names}); + {return_wrapper}result.unwrap() + }}; + let mut recv = CURRENT_SPAWNER.spawn_async(task).await; + match recv.recv().await {{ + Some(result) => Ok(*result.downcast::<{return_type}>().unwrap()), + None => Err(InvocationError::UnexpectedReturnType), + }} +}} +pub {modifiers}fn {name}_raw(&self{raw_args}) -> Result<{raw_return_type}, InvocationError> {{ + {serialize_raw_args}let function = self.instance + .exports + .get_native_function::<{wasm_args}, {wasm_return_type}>("__fp_gen_{name}") + .map_err(|_| InvocationError::FunctionNotExported("__fp_gen_{name}".to_owned()))?; + let result = function.call({wasm_arg_names})?; + {raw_return_wrapper}Ok(result) +}}"# + ) + } else { + format!( + r#"{doc}pub {modifiers}fn {name}(&self{args}) -> Result<{return_type}, InvocationError> {{ + {serialize_args} + let result = self.{name}_raw({arg_names}); + {return_wrapper}result }} pub {modifiers}fn {name}_raw(&self{raw_args}) -> Result<{raw_return_type}, InvocationError> {{ {serialize_raw_args}let function = self.instance @@ -86,6 +111,7 @@ pub {modifiers}fn {name}_raw(&self{raw_args}) -> Result<{raw_return_type}, Invoc {raw_return_wrapper}Ok(result) }}"# ) + } } fn generate_function_bindings(