diff --git a/.gitignore b/.gitignore index ef3bcf0..547e323 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ test-results/ playwright-report/ blob-report/ playwright/.cache/ +.DS_Store diff --git a/Cargo.lock b/Cargo.lock index db5b4ec..85df94a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "addr2line" @@ -17,15 +17,6 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" -[[package]] -name = "atomic-polyfill" -version = "1.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cf2bce30dfe09ef0bfaef228b9d414faaf7e563035494d7fe092dba54b300f4" -dependencies = [ - "critical-section", -] - [[package]] name = "autocfg" version = "1.4.0" @@ -66,28 +57,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] -name = "cobs" +name = "crunchy" version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67ba02a97a2bd10f4b59b25c7973101c79642302776489e030cd13cdab09ed15" - -[[package]] -name = "critical-section" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "790eea4361631c5e7d22598ecd5723ff611904e3344ce8720784c93e3d83d40b" - -[[package]] -name = "embedded-io" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef1a6892d9eef45c8fa6b9e0086428a2cca8491aca8f787c534a3d6d0bcb3ced" - -[[package]] -name = "embedded-io" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edd0f118536f44f5ccd48bcb8b111bdc3de888b58c74639dfb034a357d0f206d" +checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929" [[package]] name = "futures" @@ -198,26 +171,13 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] -name = "hash32" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0c35f58762feb77d74ebe43bdbc3210f09be9fe6742234d573bacc26ed92b67" -dependencies = [ - "byteorder", -] - -[[package]] -name = "heapless" -version = "0.7.17" +name = "half" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cdc6457c0eb62c71aac4bc17216026d8410337c4126773b9c5daba343f17964f" +checksum = "459196ed295495a68f7d7fe1d84f6c4b7ff0e21fe3017b2f283c6fac3ad803c9" dependencies = [ - "atomic-polyfill", - "hash32", - "rustc_version", - "serde", - "spin", - "stable_deref_trait", + "cfg-if", + "crunchy", ] [[package]] @@ -235,16 +195,6 @@ version = "0.2.164" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f" -[[package]] -name = "lock_api" -version = "0.4.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" -dependencies = [ - "autocfg", - "scopeguard", -] - [[package]] name = "log" version = "0.4.22" @@ -294,15 +244,13 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] -name = "postcard" -version = "1.0.10" +name = "pot" +version = "3.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f7f0a8d620d71c457dd1d47df76bb18960378da56af4527aaa10f515eee732e" +checksum = "bf741fa415952eb20f27fbc210dc85f31cc7cdc80aa3ce81d5e27d28a6f45dc2" dependencies = [ - "cobs", - "embedded-io 0.4.0", - "embedded-io 0.6.1", - "heapless", + "byteorder", + "half", "serde", ] @@ -369,27 +317,6 @@ version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" -[[package]] -name = "rustc_version" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" -dependencies = [ - "semver", -] - -[[package]] -name = "scopeguard" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" - -[[package]] -name = "semver" -version = "1.0.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" - [[package]] name = "send_wrapper" version = "0.6.0" @@ -445,21 +372,6 @@ dependencies = [ "autocfg", ] -[[package]] -name = "spin" -version = "0.9.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" -dependencies = [ - "lock_api", -] - -[[package]] -name = "stable_deref_trait" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" - [[package]] name = "syn" version = "2.0.89" @@ -586,7 +498,8 @@ version = "0.1.2" dependencies = [ "futures", "js-sys", - "postcard", + "log", + "pot", "send_wrapper", "serde", "serde-wasm-bindgen", @@ -620,6 +533,7 @@ dependencies = [ name = "wasmworker-proc-macro" version = "0.1.1" dependencies = [ + "log", "quote", "syn", ] diff --git a/Cargo.toml b/Cargo.toml index ac746c4..42502a9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,7 @@ readme = "README.md" keywords = ["webworker", "parallelism", "wasm"] [workspace.dependencies] -wasmworker = { version = "0.1", path = ".", features = ["serde"]} +wasmworker = { version = "0.1", path = ".", features = ["serde"] } wasmworker-proc-macro = { version = "0.1", path = "proc-macro" } [package] @@ -30,7 +30,7 @@ keywords.workspace = true [dependencies] futures = "0.3" js-sys = { version = "0.3" } -postcard = { version = "1.0", features = ["alloc"] } +pot = "3.0" send_wrapper = "0.6" serde = { version = "1.0", features = ["derive"] } serde_bytes = "0.11" @@ -39,6 +39,7 @@ thiserror = "2.0" tokio = { version = "1.4", features = ["sync"] } wasm-bindgen = "0.2" wasm-bindgen-futures = "0.4" +log = "*" [dependencies.web-sys] features = [ @@ -52,6 +53,9 @@ features = [ "BlobPropertyBag", "Url", "Navigator", + "MessagePort", + "MessageChannel", + "Response", ] version = "0.3" diff --git a/demo/src/lib.rs b/demo/src/lib.rs index 921f71c..755ffc1 100644 --- a/demo/src/lib.rs +++ b/demo/src/lib.rs @@ -23,7 +23,7 @@ async fn worker() -> &'static WebWorker { WORKER .get_or_init(move || async { SendWrapper::new( - WebWorker::with_path(None, None) + WebWorker::with_path(None, None, None) .await .expect_throw("Couldn't instantiate WebWorker"), ) diff --git a/proc-macro/Cargo.toml b/proc-macro/Cargo.toml index de6a22f..b8939ee 100644 --- a/proc-macro/Cargo.toml +++ b/proc-macro/Cargo.toml @@ -17,3 +17,4 @@ proc-macro = true [dependencies] syn = { version = "2.0", features = ["full", "extra-traits"] } quote = "1.0" +log = "*" diff --git a/proc-macro/src/lib.rs b/proc-macro/src/lib.rs index 78e53bc..577d1ab 100644 --- a/proc-macro/src/lib.rs +++ b/proc-macro/src/lib.rs @@ -4,14 +4,36 @@ use proc_macro::TokenStream; use quote::{format_ident, quote}; -use syn::{parse_macro_input, ItemFn}; +use syn::{parse_macro_input, token::Async, ItemFn}; /// A procedural macro that exports a function for use with a webworker. #[proc_macro_attribute] pub fn webworker_fn(_attr: TokenStream, item: TokenStream) -> TokenStream { - let input = parse_macro_input!(item as ItemFn); + let mut input = parse_macro_input!(item as ItemFn); let fn_name = &input.sig.ident; + // make the input function always async + input.sig.asyncness = Some(Async::default()); + + // the core function should have 2 inputs always + let output = if input.sig.inputs.len() == 1 { + let func_vis = &input.vis; // like pub + let func_block = &input.block; // { some statement or expression here } + let func_name = &input.sig.ident; // function nameinput + let func_generics = &input.sig.generics; + let func_inputs = &input.sig.inputs; + let func_output = &input.sig.output; + quote! { + #func_vis async fn #func_name #func_generics(#func_inputs, channel: Option) #func_output { + #func_block + } + } + } else { + quote! { + #input + } + }; + // Generate a module with the wrapper function let wrapper_fn_name = format_ident!("__webworker_{}", fn_name); let mod_code = quote! { @@ -19,9 +41,11 @@ pub fn webworker_fn(_attr: TokenStream, item: TokenStream) -> TokenStream { pub const __WEBWORKER: () = (); const _: () = { #[wasm_bindgen::prelude::wasm_bindgen] - pub fn #wrapper_fn_name(arg: Box<[u8]>) -> Box<[u8]> { + pub async fn #wrapper_fn_name(arg: Box<[u8]>, port: wasm_bindgen::JsValue) -> Box<[u8]> { + use wasm_bindgen::JsCast; let arg = wasmworker::convert::from_bytes(&arg); - let res = super::#fn_name(arg); + let channel = port.dyn_into::().ok().map(wasmworker::Channel::from); + let res = super::#fn_name(arg, channel).await; let res = wasmworker::convert::to_bytes(&res); res } @@ -31,7 +55,7 @@ pub fn webworker_fn(_attr: TokenStream, item: TokenStream) -> TokenStream { // Combine everything into the final output let expanded = quote! { - #input + #output #mod_code }; diff --git a/src/channel.rs b/src/channel.rs new file mode 100644 index 0000000..064b284 --- /dev/null +++ b/src/channel.rs @@ -0,0 +1,95 @@ +use std::{cell::RefCell, rc::Rc}; + +use serde::{de::DeserializeOwned, Serialize}; +use tokio::sync::mpsc; +use wasm_bindgen::prelude::*; +use web_sys::{MessageChannel, MessageEvent, MessagePort}; + +use crate::{ + convert::{from_bytes, to_bytes}, + error::InitError, +}; + +/// An internal type for the callback. +type Callback = dyn FnMut(MessageEvent); + +#[derive(Clone)] +pub struct Channel { + /// The message queue to await / incoming messages + messages: Rc>>, + /// The internal message port to send and receive data + port: MessagePort, + // The callback handle for the messages + //_callback: Closure, +} + +impl Channel { + /// Create two Channels to communicate between the WebWorker and the main application + /// The first channel is supposed to be used by the main application, the second one for the WebWorker. + /// When a message is send to one channel, it can be read from the second one and vice versa. + pub fn new() -> Result<(Self, MessagePort), InitError> { + // the message channel which creates two ports, which can be transfered to a WebWorker + let channel = MessageChannel::new().map_err(|e| InitError::ChannelCreation(e))?; + Ok((Self::from(channel.port1()), channel.port2())) + } + + /// Handle messages received by the port and forwards them into the message stream + fn on_message_callback( + sender: mpsc::UnboundedSender, + ) -> Closure { + Closure::new(move |event: MessageEvent| { + let _ = sender.send(event.data()); + }) + } + + /// Receives the next value for this receiver. + /// This method returns None if the channel has been closed and there are no remaining messages in the channel’s buffer. This indicates that no further values can ever be received from this Receiver. The channel is closed when all senders have been dropped, or when [close] is called. + /// If there are no messages in the channel’s buffer, but the channel has not yet been closed, this method will sleep until a message is sent or the channel is closed. + pub async fn recv(&self) -> Option { + let bytes = self.recv_bytes().await?; + Some(from_bytes(&bytes)) + } + + /// Receives the next value for this receiver. + pub async fn recv_bytes(&self) -> Option> { + let mut messages = self.messages.borrow_mut(); + let value = messages.recv().await?; + let array = js_sys::Uint8Array::new(&value); + Some(array.to_vec().into_boxed_slice()) + } + + /// Send a value to the receiver + pub fn send(&self, msg: &T) { + let bytes = to_bytes(msg); + self.send_bytes(&bytes); + } + + /// Send byte values to the receiver + pub fn send_bytes(&self, bytes: &Box<[u8]>) { + let array = js_sys::Uint8Array::new_with_length(bytes.len() as u32); + array.copy_from(&bytes); + self.port + .post_message(&array) + .expect("Channel is already closed"); + } +} + +impl From for Channel { + /// Create a new Channel from a MessagePort + fn from(port: MessagePort) -> Self { + // the internal message sender / receiver to have the messages easy available in rust + let (sender, receiver) = mpsc::unbounded_channel(); + + // handle messages which have been received by the other port + let callback_handle = Self::on_message_callback(sender); + port.set_onmessage(Some(callback_handle.as_ref().unchecked_ref())); + callback_handle.forget(); + + // Return the channel + Self { + messages: Rc::new(RefCell::new(receiver)), + port, + //_callback: callback_handle, + } + } +} diff --git a/src/convert.rs b/src/convert.rs index 43d7cb3..2dd1823 100644 --- a/src/convert.rs +++ b/src/convert.rs @@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize}; /// It is used internally to prepare values before sending them to a worker /// or back to the main thread via `postMessage`. pub fn to_bytes(value: &T) -> Box<[u8]> { - postcard::to_allocvec(value) + pot::to_vec(value) .expect("WebWorker serialization failed") .into() } @@ -13,5 +13,5 @@ pub fn to_bytes(value: &T) -> Box<[u8]> { /// It is used internally to prepare values after receiving them from a worker /// or the main thread via `postMessage`. pub fn from_bytes<'de, T: Deserialize<'de>>(bytes: &'de [u8]) -> T { - postcard::from_bytes(bytes).expect("WebWorker deserialization failed") + pot::from_slice(bytes).expect("WebWorker deserialization failed") } diff --git a/src/error.rs b/src/error.rs index ac845a9..8f3784d 100644 --- a/src/error.rs +++ b/src/error.rs @@ -20,4 +20,7 @@ pub enum InitError { /// an invalid path. The path should point to the glue file generated by wasm-bindgen. #[error("WebWorker module loading error: {0:?}")] WebWorkerModuleLoading(String), + /// This error covers errors during the `new MessageChannel()` command. + #[error("Channel creation error: {0:?}")] + ChannelCreation(JsValue), } diff --git a/src/func.rs b/src/func.rs index f2a03f1..4752141 100644 --- a/src/func.rs +++ b/src/func.rs @@ -1,3 +1,7 @@ +use futures::future::LocalBoxFuture; + +use crate::Channel; + /// This struct describes the function to be called by the worker. /// It also ensures type safety, when constructed using the [`crate::webworker!`] macro. pub struct WebWorkerFn { @@ -5,7 +9,7 @@ pub struct WebWorkerFn { /// The worker will automatically add the `__webworker_` prefix. pub(crate) name: &'static str, /// The original function, which can be used as a fallback. - pub(crate) func: fn(T) -> R, + pub(crate) func: fn(T, Option) -> LocalBoxFuture<'static, R>, } impl Clone for WebWorkerFn { @@ -22,7 +26,10 @@ impl WebWorkerFn { /// has the right type or is exposed to the worker. /// /// Instead use the [`crate::webworker!`] macro to create an instance of this type. - pub fn new_unchecked(func_name: &'static str, f: fn(T) -> R) -> Self { + pub fn new_unchecked( + func_name: &'static str, + f: fn(T, Option) -> LocalBoxFuture<'static, R>, + ) -> Self { Self { name: func_name, func: f, @@ -47,6 +54,6 @@ impl WebWorkerFn { macro_rules! webworker { ($name:ident) => {{ let _ = $name::__WEBWORKER; - $crate::func::WebWorkerFn::new_unchecked(stringify!($name), $name) + $crate::func::WebWorkerFn::new_unchecked(stringify!($name), |a, b| Box::pin($name(a, b))) }}; } diff --git a/src/global.rs b/src/global.rs index a381f30..28e6cc6 100644 --- a/src/global.rs +++ b/src/global.rs @@ -40,6 +40,23 @@ pub async fn init_worker_pool(options: WorkerPoolOptions) { .await; } +/// JavaScript-accessible function to initialize an optimized worker pool globally. +/// This creates a worker pool that precompiles and shares WASM across all workers +/// for optimal bandwidth usage. +/// +/// ```js +/// import init, { initOptimizedWorkerPool } from "./pkg/webapp.js"; +/// +/// await init(); +/// await initOptimizedWorkerPool(); +/// ``` +#[wasm_bindgen(js_name = initOptimizedWorkerPool)] +pub async fn init_optimized_worker_pool() { + let mut options = WorkerPoolOptions::default(); + options.precompile_wasm = Some(true); + init_worker_pool(options).await; +} + /// This function accesses the default worker pool. /// If [`init_worker_pool`] has not been manually called, /// this function will initialize the worker pool prior to returning it. @@ -49,7 +66,7 @@ pub async fn worker_pool() -> &'static WebWorkerPool { WORKER_POOL .get_or_init(|| async { SendWrapper::new( - WebWorkerPool::new() + WebWorkerPool::with_options(WorkerPoolOptions::default()) .await .expect_throw("Couldn't instantiate worker pool"), ) diff --git a/src/iter_ext/mod.rs b/src/iter_ext/mod.rs index 51cd1f7..5f24a85 100644 --- a/src/iter_ext/mod.rs +++ b/src/iter_ext/mod.rs @@ -32,7 +32,7 @@ where R: Serialize + for<'de> Deserialize<'de>, { let pool = worker_pool().await; - join_all(self.map(|arg| pool.run_internal(func, arg))).await + join_all(self.map(|arg| pool.run_internal(func, arg, None))).await } /// The `try_par_map` function will attempt to parallelize a map operation on the default @@ -59,7 +59,8 @@ where if has_worker_pool() { self.par_map(func).await } else { - self.map(|item| (func.func)(item.into())).collect() + self.map(|item| futures::executor::block_on((func.func)(item.into(), None))) + .collect() } } } diff --git a/src/lib.rs b/src/lib.rs index 107cf6a..d96a61d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,16 @@ #![doc = include_str!("../README.md")] #![allow(clippy::borrowed_box)] -pub use global::{has_worker_pool, init_worker_pool, worker_pool}; +pub use channel::Channel; +pub use global::{has_worker_pool, init_optimized_worker_pool, init_worker_pool, worker_pool}; + pub use pool::WorkerPoolOptions; +pub use web_sys::MessagePort; pub use webworker::WebWorker; +// Re-export WebWorkerPool from pool module +pub use pool::WebWorkerPool; + +mod channel; pub mod convert; pub mod error; pub mod func; diff --git a/src/pool/mod.rs b/src/pool/mod.rs index a447043..2848b7f 100644 --- a/src/pool/mod.rs +++ b/src/pool/mod.rs @@ -5,7 +5,9 @@ use js_sys::wasm_bindgen::{prelude::wasm_bindgen, UnwrapThrowExt}; use scheduler::Scheduler; pub use scheduler::Strategy; use serde::{Deserialize, Serialize}; -use web_sys::window; + +use wasm_bindgen_futures::JsFuture; +use web_sys::{window, MessagePort}; use crate::{error::InitError, func::WebWorkerFn, WebWorker}; @@ -25,10 +27,16 @@ pub struct WorkerPoolOptions { /// [`crate::WebWorker::with_path`] lists more details on when this path /// should be manually configured. pub path: Option, + pub path_bg: Option, /// The strategy to be used by the worker pool. pub strategy: Option, /// The number of workers that will be spawned. This defaults to `navigator.hardwareConcurrency`. pub num_workers: Option, + /// Whether to precompile and share the WASM module across workers for bandwidth optimization. + /// This reduces the number of WASM fetches from N (one per worker) to 1 (shared across all workers). + pub precompile_wasm: Option, + /// Pre-compiled WASM module to share across workers. Internal use only. + pub(crate) wasm_module: Option, } #[wasm_bindgen] @@ -46,6 +54,10 @@ impl WorkerPoolOptions { self.path.as_deref() } + fn path_bg(&self) -> Option<&str> { + self.path_bg.as_deref() + } + /// Returns the configured strategy or the default strategy. fn strategy(&self) -> Strategy { self.strategy.unwrap_or_default() @@ -82,6 +94,8 @@ pub struct WebWorkerPool { workers: Vec, /// The internal scheduler that is used to distribute the tasks. scheduler: Scheduler, + /// Pre-compiled WASM module shared across workers + wasm_module: Option, } impl WebWorkerPool { @@ -128,10 +142,23 @@ impl WebWorkerPool { /// Initializes a worker pool with the given [`WorkerPoolOptions`]. /// This async function might return an [`InitError`] if one of the workers /// cannot be initialized, as described in [`WebWorker::new`]. - pub async fn with_options(options: WorkerPoolOptions) -> Result { + pub async fn with_options(mut options: WorkerPoolOptions) -> Result { + // Pre-compile WASM module if explicitly requested or not already provided + let wasm_module = + if options.wasm_module.is_none() && options.precompile_wasm.unwrap_or(false) { + Some(Self::precompile_wasm(&options).await?) + } else { + options.wasm_module.take() + }; + let worker_inits = (0..options.num_workers()).map(|_| { // Do not impose a task limit. - WebWorker::with_path(options.path(), None) + WebWorker::with_path_and_module( + options.path(), + options.path_bg(), + None, + wasm_module.clone(), + ) }); let workers = join_all(worker_inits).await; let workers = workers.into_iter().collect::, _>>()?; @@ -139,6 +166,7 @@ impl WebWorkerPool { Ok(Self { workers, scheduler: Scheduler::new(options.strategy()), + wasm_module, }) } @@ -159,7 +187,21 @@ impl WebWorkerPool { T: Serialize + for<'de> Deserialize<'de>, R: Serialize + for<'de> Deserialize<'de>, { - self.run_internal(func, arg).await + self.run_internal(func, arg, None).await + } + + #[cfg(feature = "serde")] + pub async fn run_with_channel( + &self, + func: WebWorkerFn, + arg: &T, + port: MessagePort, + ) -> R + where + T: Serialize + for<'de> Deserialize<'de>, + R: Serialize + for<'de> Deserialize<'de>, + { + self.run_internal(func, arg, Some(port)).await } /// This function can outsource a task on a [`WebWorkerPool`] which has `Box<[u8]>` both as input and output. @@ -178,12 +220,17 @@ impl WebWorkerPool { func: WebWorkerFn, Box<[u8]>>, arg: &Box<[u8]>, ) -> Box<[u8]> { - self.run_internal(func, arg).await + self.run_internal(func, arg, None).await } /// Determines the worker to run the task on using the scheduler /// and runs the task. - pub(crate) async fn run_internal(&self, func: WebWorkerFn, arg: A) -> R + pub(crate) async fn run_internal( + &self, + func: WebWorkerFn, + arg: A, + port: Option, + ) -> R where A: Borrow, T: Serialize + for<'de> Deserialize<'de>, @@ -191,7 +238,7 @@ impl WebWorkerPool { { let worker_id = self.scheduler.schedule(self); self.workers[worker_id] - .run_internal(func, arg.borrow()) + .run_internal(func, arg.borrow(), port) .await } @@ -204,4 +251,82 @@ impl WebWorkerPool { pub fn num_workers(&self) -> usize { self.workers.len() } + + /// Create a worker pool with a pre-compiled WASM module for optimal bandwidth usage. + /// This method pre-compiles the WASM module once and shares it across all workers, + /// reducing bandwidth usage compared to each worker loading the WASM independently. + pub async fn with_precompiled_wasm() -> Result { + let mut options = WorkerPoolOptions::default(); + options.precompile_wasm = Some(true); + Self::with_options(options).await + } + + /// Pre-compile the WASM module for sharing across workers. + /// + /// This function fetches and compiles the WASM module once, which can then be + /// shared across all workers to reduce bandwidth usage. + /// + /// Path resolution: + /// - If `path_bg` is provided, it should be the full URL to the WASM file + /// - If `path` is provided, assumes standard wasm-bindgen naming (_bg.wasm suffix) + /// - Otherwise, infers path from the current module location + async fn precompile_wasm( + options: &WorkerPoolOptions, + ) -> Result { + use wasm_bindgen::JsCast; + + // Get the WASM path - if path_bg is provided, use it directly since it should be the WASM URL + let wasm_path = if let Some(bg_path) = options.path_bg() { + // path_bg should already be the WASM URL (e.g., "http://localhost:8080/webapp_bg.wasm") + bg_path.to_string() + } else if let Some(js_path) = options.path() { + // Convert main JS path to WASM path (typically add _bg.wasm) + if js_path.ends_with(".js") { + js_path.replace(".js", "_bg.wasm") + } else { + format!("{}_bg.wasm", js_path) + } + } else { + // Use default path inference from the main JS module + let js_path = crate::webworker::js::main_js().as_string().unwrap_throw(); + if js_path.ends_with(".js") { + js_path.replace(".js", "_bg.wasm") + } else { + format!("{}_bg.wasm", js_path) + } + }; + + // Fetch the WASM file + use wasm_bindgen::UnwrapThrowExt; + let window = web_sys::window().unwrap_throw(); + let resp_value = JsFuture::from(window.fetch_with_str(&wasm_path)) + .await + .map_err(|e| { + InitError::WebWorkerModuleLoading(format!( + "Failed to fetch WASM from '{}': {:?}. Check that path_bg points to the correct WASM file URL.", + wasm_path, e + )) + })?; + let resp: web_sys::Response = resp_value.unchecked_into(); + + let array_buffer = JsFuture::from(resp.array_buffer().unwrap_throw()) + .await + .map_err(|e| { + InitError::WebWorkerModuleLoading(format!( + "Failed to read WASM bytes from '{}': {:?}", + wasm_path, e + )) + })?; + + // Compile the WASM module + let compile_promise = js_sys::WebAssembly::compile(&array_buffer); + let module_value = JsFuture::from(compile_promise).await.map_err(|e| { + InitError::WebWorkerModuleLoading(format!( + "Failed to compile WASM from '{}': {:?}. This usually means the file is not a valid WASM binary or the URL returned an error page.", + wasm_path, e + )) + })?; + + Ok(module_value.into()) + } } diff --git a/src/webworker/js.rs b/src/webworker/js.rs index b3a6cc6..b46767d 100644 --- a/src/webworker/js.rs +++ b/src/webworker/js.rs @@ -18,7 +18,7 @@ console.debug('Initializing worker'); return; } - await mod.default(); + await mod.default({{wasm_bg}}); self.postMessage({ success: true }); console.debug('Worker started'); @@ -34,7 +34,7 @@ console.debug('Initializing worker'); return; } - const worker_result = await fn(arg); + const worker_result = await fn(arg, event.ports[0]); // Send response back to be handled by callback in main thread. console.debug('Send worker result'); @@ -54,3 +54,59 @@ pub(crate) fn main_js() -> JsString { URL.with(Clone::clone) } + +/// The initialization code for workers that receive a pre-compiled WASM module +pub(crate) const WORKER_JS_WITH_PRECOMPILED: &str = r#" +console.debug('Initializing worker with pre-compiled WASM'); + +let wasmModule = null; +let mod = null; +let initHandler = null; + +// Listen for the pre-compiled WASM module +initHandler = async function(event) { + const data = event.data; + + if (data.type === 'wasm_module') { + console.debug('Received pre-compiled WASM module'); + wasmModule = data.module; + + // Now initialize with the pre-compiled module + try { + mod = await import('{{wasm}}'); + await mod.default({ module_or_path: wasmModule }); + self.postMessage({ success: true }); + console.debug('Worker started with pre-compiled WASM'); + } catch (e) { + console.error('Unable to initialize with pre-compiled WASM', e); + self.postMessage({ success: false, message: e.toString() }); + return; + } + + // Remove this listener and add the task handler + self.removeEventListener('message', initHandler); + + // Add the main message handler for tasks + self.addEventListener('message', async event => { + console.debug('Received worker event'); + const { id, func_name, arg } = event.data; + + const webworker_func_name = `__webworker_${func_name}`; + const fn = mod[webworker_func_name]; + if (!fn) { + console.error(`Function '${func_name}' is not exported.`); + self.postMessage({ id: id, response: null }); + return; + } + + const worker_result = await fn(arg, event.ports[0]); + + // Send response back to be handled by callback in main thread. + console.debug('Send worker result'); + self.postMessage({ id: id, response: worker_result }); + }); + } +}; + +self.addEventListener('message', initHandler); +"#; diff --git a/src/webworker/mod.rs b/src/webworker/mod.rs index 5986c30..83f8705 100644 --- a/src/webworker/mod.rs +++ b/src/webworker/mod.rs @@ -1,5 +1,5 @@ pub use worker::*; mod com; -mod js; +pub mod js; mod worker; diff --git a/src/webworker/worker.rs b/src/webworker/worker.rs index 289406c..3973464 100644 --- a/src/webworker/worker.rs +++ b/src/webworker/worker.rs @@ -10,7 +10,9 @@ use js_sys::Array; use serde::{Deserialize, Serialize}; use tokio::sync::{oneshot, Semaphore}; use wasm_bindgen::{prelude::Closure, JsCast, JsValue, UnwrapThrowExt}; -use web_sys::{Blob, BlobPropertyBag, MessageEvent, Url, Worker, WorkerOptions, WorkerType}; +use web_sys::{ + Blob, BlobPropertyBag, MessageEvent, MessagePort, Url, Worker, WorkerOptions, WorkerType, +}; use crate::{ convert::{from_bytes, to_bytes}, @@ -51,7 +53,11 @@ impl WebWorker { /// This function takes the [`WORKER_JS`] and creates the corresponding /// worker blob after inserting the given path. /// If no `wasm_path` is provided, the [`main_js()`] path is used. - fn worker_blob(wasm_path: Option<&str>) -> String { + fn worker_blob( + wasm_path: Option<&str>, + wasm_bg_path: Option<&str>, + has_precompiled_module: bool, + ) -> String { let blob_options = BlobPropertyBag::new(); blob_options.set_type("application/javascript"); @@ -62,9 +68,22 @@ impl WebWorker { wasm_path_owned.as_ref().unwrap_throw() }); + let wasm_bg_path = match wasm_bg_path { + Some(path) => format!("{{module_or_path: '{path}'}}"), + None => "undefined".to_string(), + }; + + let worker_js = if has_precompiled_module { + super::js::WORKER_JS_WITH_PRECOMPILED + } else { + super::js::WORKER_JS + }; + let code = Array::new(); code.push(&JsValue::from_str( - &WORKER_JS.replace("{{wasm}}", wasm_path), + &worker_js + .replace("{{wasm}}", wasm_path) + .replace("{{wasm_bg}}", &wasm_bg_path), )); Url::create_object_url_with_blob( @@ -78,7 +97,7 @@ impl WebWorker { /// This can fail with an [`InitError`], for example, if the automatically inferred /// path for the wasm-bindgen glue is wrong. pub async fn new(task_limit: Option) -> Result { - Self::with_path(None, task_limit).await + Self::with_path(None, None, task_limit).await } /// Create a new [`WebWorker`] with an optional limit on the number of tasks queued. @@ -90,15 +109,45 @@ impl WebWorker { /// If a wrong path is given, a [`InitError`] will be returned. pub async fn with_path( main_js: Option<&str>, + main_bg_js: Option<&str>, task_limit: Option, + ) -> Result { + Self::with_path_and_module(main_js, main_bg_js, task_limit, None).await + } + + /// Create a new [`WebWorker`] with an optional limit on the number of tasks queued + /// and an optional pre-compiled WASM module. + pub async fn with_path_and_module( + main_js: Option<&str>, + main_bg_js: Option<&str>, + task_limit: Option, + wasm_module: Option, ) -> Result { // Create worker let worker_options = WorkerOptions::new(); worker_options.set_type(WorkerType::Module); - let script_url = WebWorker::worker_blob(main_js); + let script_url = WebWorker::worker_blob(main_js, main_bg_js, wasm_module.is_some()); + let worker = Worker::new_with_options(&script_url, &worker_options) .map_err(InitError::WebWorkerCreation)?; + // Send pre-compiled WASM module if provided + if let Some(module) = wasm_module { + let init_msg = js_sys::Object::new(); + js_sys::Reflect::set( + &init_msg, + &JsValue::from_str("type"), + &JsValue::from_str("wasm_module"), + ) + .expect_throw("Could not set type"); + js_sys::Reflect::set(&init_msg, &JsValue::from_str("module"), &module) + .expect_throw("Could not set module"); + + worker + .post_message(&init_msg) + .expect_throw("Could not send WASM module to worker"); + } + // Wait until worker is initialized. let (tx, rx) = oneshot::channel(); let handler = Closure::once(move |event: MessageEvent| { @@ -171,7 +220,21 @@ impl WebWorker { T: Serialize + for<'de> Deserialize<'de>, R: Serialize + for<'de> Deserialize<'de>, { - self.run_internal(func, arg).await + self.run_internal(func, arg, None).await + } + + #[cfg(feature = "serde")] + pub async fn run_with_channel( + &self, + func: WebWorkerFn, + arg: &T, + port: MessagePort, + ) -> R + where + T: Serialize + for<'de> Deserialize<'de>, + R: Serialize + for<'de> Deserialize<'de>, + { + self.run_internal(func, arg, Some(port)).await } /// This function differs from [`WebWorker::run`] by returning early if the given task limit is reached. @@ -192,7 +255,7 @@ impl WebWorker { T: Serialize + for<'de> Deserialize<'de>, R: Serialize + for<'de> Deserialize<'de>, { - self.try_run_internal(func, arg).await + self.try_run_internal(func, arg, None).await } /// This function can outsource a task on a [`WebWorker`] which has `Box<[u8]>` both as input and output. @@ -214,7 +277,7 @@ impl WebWorker { func: WebWorkerFn, Box<[u8]>>, arg: &Box<[u8]>, ) -> Box<[u8]> { - self.run_internal(func, arg).await + self.run_internal(func, arg, None).await } /// This function differs from [`WebWorker::run_bytes`] by returning early if the given task limit is reached. @@ -236,7 +299,7 @@ impl WebWorker { func: WebWorkerFn, Box<[u8]>>, arg: &Box<[u8]>, ) -> Result, Full> { - self.try_run_internal(func, arg).await + self.try_run_internal(func, arg, None).await } /// Internal function to schedule a task to the worker. @@ -245,6 +308,7 @@ impl WebWorker { &self, func: WebWorkerFn, arg: &T, + port: Option, ) -> Result where T: Serialize + for<'de> Deserialize<'de>, @@ -261,11 +325,16 @@ impl WebWorker { }; // Convert arg and result. - Ok(self.force_run(func.name, arg).await) + Ok(self.force_run(func.name, arg, port).await) } /// Internal function to schedule a task to the worker. - pub(crate) async fn run_internal(&self, func: WebWorkerFn, arg: &T) -> R + pub(crate) async fn run_internal( + &self, + func: WebWorkerFn, + arg: &T, + port: Option, + ) -> R where T: Serialize + for<'de> Deserialize<'de>, R: Serialize + for<'de> Deserialize<'de>, @@ -278,13 +347,18 @@ impl WebWorker { }; // Convert arg and result. - self.force_run(func.name, arg).await + self.force_run(func.name, arg, port).await } /// This function handles the communication with the worker /// after the task limit has been checked. /// It also handles (de)serialization. - async fn force_run(&self, func_name: &'static str, arg: &T) -> R + async fn force_run( + &self, + func_name: &'static str, + arg: &T, + port: Option, + ) -> R where T: Serialize + for<'de> Deserialize<'de>, R: Serialize + for<'de> Deserialize<'de>, @@ -301,11 +375,26 @@ impl WebWorker { let (sender, receiver) = oneshot::channel(); self.open_tasks.borrow_mut().insert(id, sender); - self.worker - .post_message( - &serde_wasm_bindgen::to_value(&request).expect_throw("Could not serialize request"), - ) - .expect_throw("WebWorker gone"); + // send the task to the webworker, either with a port or without one + if let Some(port) = port { + let transfer = Array::new(); + transfer.push(&port); + + self.worker + .post_message_with_transfer( + &serde_wasm_bindgen::to_value(&request) + .expect_throw("Could not serialize request"), + &transfer, + ) + .expect_throw("WebWorker gone"); + } else { + self.worker + .post_message( + &serde_wasm_bindgen::to_value(&request) + .expect_throw("Could not serialize request"), + ) + .expect_throw("WebWorker gone"); + } // Handle result. let res = receiver diff --git a/test/src/raw.rs b/test/src/raw.rs index 0b420f3..893169f 100644 --- a/test/src/raw.rs +++ b/test/src/raw.rs @@ -11,7 +11,7 @@ pub fn sort(mut v: Box<[u8]>) -> Box<[u8]> { } pub(crate) async fn can_handle_invalid_paths() { - let worker = WebWorker::with_path(Some("something"), None).await; + let worker = WebWorker::with_path(Some("something"), None, None).await; if !matches!(worker, Err(InitError::WebWorkerModuleLoading(_))) { throw_str("Should have failed initialization with wrong path"); }