diff --git a/Cargo.lock b/Cargo.lock index 56a9795fd8..b109ca64f7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -136,6 +136,16 @@ version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8da52d66c7071e2e3fa2a1e5c6d088fec47b593032b254f5e980de8ea54454d6" +[[package]] +name = "async-attributes" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3203e79f4dd9bdda415ed03cf14dae5a2bf775c683a00f94e9cd1faf0f596e5" +dependencies = [ + "quote", + "syn 1.0.96", +] + [[package]] name = "async-channel" version = "1.8.0" @@ -228,6 +238,7 @@ version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "52580991739c5cdb36cde8b2a516371c0a3b70dda36d916cc08b82372916808c" dependencies = [ + "async-attributes", "async-channel", "async-global-executor", "async-io", diff --git a/zellij-server/src/plugins/wasm_bridge.rs b/zellij-server/src/plugins/wasm_bridge.rs index b89cc73bed..7851e870c7 100644 --- a/zellij-server/src/plugins/wasm_bridge.rs +++ b/zellij-server/src/plugins/wasm_bridge.rs @@ -17,7 +17,6 @@ use zellij_utils::async_channel::Sender; use zellij_utils::async_std::task::{self, JoinHandle}; use zellij_utils::consts::ZELLIJ_CACHE_DIR; use zellij_utils::data::{PermissionStatus, PermissionType}; -use zellij_utils::downloader::download::Download; use zellij_utils::downloader::Downloader; use zellij_utils::input::permission::PermissionCache; use zellij_utils::notify_debouncer_full::{notify::RecommendedWatcher, Debouncer, FileIdMap}; @@ -166,22 +165,15 @@ impl WasmBridge { let mut loading_indication = LoadingIndication::new(plugin_name.clone()); if let RunPluginLocation::Remote(url) = &plugin.location { - let download = Download::from(url); - - let hash: String = PortableHash::default() - .hash128(download.url.as_bytes()) + let file_name: String = PortableHash::default() + .hash128(url.as_bytes()) .iter() .map(ToString::to_string) .collect(); - let plugin_directory = ZELLIJ_CACHE_DIR.join(hash); - - // The plugin path is determined by the hash of the plugin URL in the cache directory. - plugin.path = plugin_directory.join(&download.file_name); - - let downloader = Downloader::new(plugin_directory); - match downloader.fetch(&download).await { - Ok(_) => {}, + let downloader = Downloader::new(ZELLIJ_CACHE_DIR.to_path_buf()); + match downloader.download(url, Some(&file_name)).await { + Ok(_) => plugin.path = ZELLIJ_CACHE_DIR.join(&file_name), Err(e) => handle_plugin_loading_failure( &senders, plugin_id, diff --git a/zellij-utils/Cargo.toml b/zellij-utils/Cargo.toml index 68588aa9bf..d7d038b331 100644 --- a/zellij-utils/Cargo.toml +++ b/zellij-utils/Cargo.toml @@ -51,7 +51,7 @@ termwiz = "0.20.0" log4rs = "1.2.0" signal-hook = "0.3" interprocess = "1.2.1" -async-std = { version = "1.3.0", features = ["unstable"] } +async-std = { version = "1.3.0", features = ["unstable", "attributes"] } notify-debouncer-full = "0.1.0" humantime = "2.1.0" futures = "0.3.28" diff --git a/zellij-utils/src/downloader.rs b/zellij-utils/src/downloader.rs new file mode 100644 index 0000000000..aca0c8cde4 --- /dev/null +++ b/zellij-utils/src/downloader.rs @@ -0,0 +1,172 @@ +use async_std::{ + fs, + io::{ReadExt, WriteExt}, + stream::StreamExt, +}; +use std::path::PathBuf; +use surf::Client; +use thiserror::Error; +use url::Url; + +#[derive(Error, Debug)] +pub enum DownloaderError { + #[error("RequestError: {0}")] + Request(surf::Error), + #[error("IoError: {0}")] + Io(#[source] std::io::Error), + #[error("File name cannot be found in URL: {0}")] + NotFoundFileName(String), +} + +#[derive(Debug)] +pub struct Downloader { + client: Client, + location: PathBuf, +} + +impl Default for Downloader { + fn default() -> Self { + Self { + client: surf::client().with(surf::middleware::Redirect::default()), + location: PathBuf::from(""), + } + } +} + +impl Downloader { + pub fn new(location: PathBuf) -> Self { + Self { + client: surf::client().with(surf::middleware::Redirect::default()), + location, + } + } + + pub async fn download( + &self, + url: &str, + file_name: Option<&str>, + ) -> Result<(), DownloaderError> { + let file_name = match file_name { + Some(name) => name.to_string(), + None => self.parse_name(url)?, + }; + + let file_path = self.location.join(file_name.as_str()); + if file_path.exists() { + log::debug!("File already exists: {:?}", file_path); + return Ok(()); + } + + let file_part_path = self.location.join(format!("{}.part", file_name)); + let (mut target, file_part_size) = { + if file_part_path.exists() { + let file_part = fs::OpenOptions::new() + .append(true) + .write(true) + .open(&file_part_path) + .await + .map_err(|e| DownloaderError::Io(e))?; + + let file_part_size = file_part + .metadata() + .await + .map_err(|e| DownloaderError::Io(e))? + .len(); + + log::debug!("Resuming download from {} bytes", file_part_size); + + (file_part, file_part_size) + } else { + let file_part = fs::File::create(&file_part_path) + .await + .map_err(|e| DownloaderError::Io(e))?; + + (file_part, 0) + } + }; + + let res = self + .client + .get(url) + .header("Content-Type", "application/octet-stream") + .header("Range", format!("bytes={}-", file_part_size)) + .await + .map_err(|e| DownloaderError::Request(e))?; + + let mut stream = res.bytes(); + while let Some(byte) = stream.next().await { + let byte = byte.map_err(|e| DownloaderError::Io(e))?; + target + .write(&[byte]) + .await + .map_err(|e| DownloaderError::Io(e))?; + } + + log::debug!("Download complete: {:?}", file_part_path); + + fs::rename(file_part_path, file_path) + .await + .map_err(|e| DownloaderError::Io(e))?; + + Ok(()) + } + + fn parse_name(&self, url: &str) -> Result { + Url::parse(url) + .map_err(|_| DownloaderError::NotFoundFileName(url.to_string()))? + .path_segments() + .ok_or_else(|| DownloaderError::NotFoundFileName(url.to_string()))? + .last() + .ok_or_else(|| DownloaderError::NotFoundFileName(url.to_string())) + .map(|s| s.to_string()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use tempfile::tempdir; + + #[ignore] + #[async_std::test] + async fn test_download_ok() { + let location = tempdir().expect("Failed to create temp directory"); + let location_path = location.path(); + + let downloader = Downloader::new(location_path.to_path_buf()); + let result = downloader + .download( + "https://github.com/imsnif/monocle/releases/download/0.39.0/monocle.wasm", + Some("monocle.wasm"), + ) + .await + .is_ok(); + + assert!(result); + assert!(location_path.join("monocle.wasm").exists()); + + location.close().expect("Failed to close temp directory"); + } + + #[ignore] + #[async_std::test] + async fn test_download_without_file_name() { + let location = tempdir().expect("Failed to create temp directory"); + let location_path = location.path(); + + let downloader = Downloader::new(location_path.to_path_buf()); + let result = downloader + .download( + "https://github.com/imsnif/multitask/releases/download/0.38.2v2/multitask.wasm", + None, + ) + .await + .is_ok(); + + assert!(result); + assert!(location_path.join("multitask.wasm").exists()); + + location.close().expect("Failed to close temp directory"); + } +} diff --git a/zellij-utils/src/downloader/download.rs b/zellij-utils/src/downloader/download.rs deleted file mode 100644 index d665f7e54c..0000000000 --- a/zellij-utils/src/downloader/download.rs +++ /dev/null @@ -1,49 +0,0 @@ -use serde::{Deserialize, Serialize}; -use surf::Url; - -#[derive(Debug, Default, Clone, Deserialize, Serialize, PartialEq, Eq, Hash)] -pub struct Download { - pub url: String, - pub file_name: String, -} - -impl Download { - pub fn from(url: &str) -> Self { - match Url::parse(url) { - Ok(u) => u - .path_segments() - .map_or_else(Download::default, |segments| { - let file_name = segments.last().unwrap_or("").to_string(); - - Download { - url: url.to_string(), - file_name, - } - }), - Err(_) => Download::default(), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_from_download() { - let download = Download::from("https://github.com/example/plugin.wasm"); - assert_eq!(download.url, "https://github.com/example/plugin.wasm"); - assert_eq!(download.file_name, "plugin.wasm"); - } - - #[test] - fn test_empty_download() { - let d1 = Download::from("https://example.com"); - assert_eq!(d1.url, "https://example.com"); - assert_eq!(d1.file_name, ""); - - let d2 = Download::from("github.com"); - assert_eq!(d2.url, ""); - assert_eq!(d2.file_name, ""); - } -} diff --git a/zellij-utils/src/downloader/mod.rs b/zellij-utils/src/downloader/mod.rs deleted file mode 100644 index b0b2771ddb..0000000000 --- a/zellij-utils/src/downloader/mod.rs +++ /dev/null @@ -1,147 +0,0 @@ -pub mod download; - -use async_std::{ - fs::{create_dir_all, File}, - io::{ReadExt, WriteExt}, - stream, task, -}; -use futures::{StreamExt, TryStreamExt}; -use std::path::PathBuf; -use surf::Client; -use thiserror::Error; - -use self::download::Download; - -#[derive(Error, Debug)] -pub enum DownloaderError { - #[error("RequestError: {0}")] - Request(surf::Error), - #[error("StatusError: {0}, StatusCode: {1}")] - Status(String, surf::StatusCode), - #[error("IoError: {0}")] - Io(#[source] std::io::Error), - #[error("IoPathError: {0}, File: {1}")] - IoPath(std::io::Error, PathBuf), -} - -#[derive(Default, Debug)] -pub struct Downloader { - client: Client, - directory: PathBuf, -} - -impl Downloader { - pub fn new(directory: PathBuf) -> Self { - Self { - client: surf::client().with(surf::middleware::Redirect::default()), - directory, - } - } - - pub fn set_directory(&mut self, directory: PathBuf) { - self.directory = directory; - } - - pub fn download(&self, downloads: &[Download]) -> Vec> { - task::block_on(async { - stream::from_iter(downloads) - .map(|download| self.fetch(download)) - .buffer_unordered(4) - .collect::>() - .await - }) - } - - pub async fn fetch(&self, download: &Download) -> Result<(), DownloaderError> { - let mut file_size: usize = 0; - - let file_path = self.directory.join(&download.file_name); - - if file_path.exists() { - file_size = match file_path.metadata() { - Ok(metadata) => metadata.len() as usize, - Err(e) => return Err(DownloaderError::IoPath(e, file_path)), - } - } - - let response = self - .client - .get(&download.url) - .await - .map_err(|e| DownloaderError::Request(e))?; - let status = response.status(); - - if status.is_client_error() || status.is_server_error() { - return Err(DownloaderError::Status( - status.canonical_reason().to_string(), - status, - )); - } - - let length = response.len().unwrap_or(0); - if length > 0 && length == file_size { - return Ok(()); - } - - let mut dest = { - create_dir_all(&self.directory) - .await - .map_err(|e| DownloaderError::IoPath(e, self.directory.clone()))?; - File::create(&file_path) - .await - .map_err(|e| DownloaderError::IoPath(e, file_path))? - }; - - let mut bytes = response.bytes(); - while let Some(byte) = bytes.try_next().await.map_err(DownloaderError::Io)? { - dest.write_all(&[byte]).await.map_err(DownloaderError::Io)?; - } - - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - use tempfile::tempdir; - - #[test] - #[ignore] - fn test_fetch_plugin() { - let dir = tempdir().expect("could not get temp dir"); - let dir_path = dir.path(); - - let downloader = Downloader::new(dir_path.to_path_buf()); - let dl = Download::from( - "https://github.com/imsnif/monocle/releases/download/0.37.2/monocle.wasm", - ); - - let result = task::block_on(downloader.fetch(&dl)); - - assert!(result.is_ok()); - } - - #[test] - #[ignore] - fn test_download_plugins() { - let dir = tempdir().expect("could not get temp dir"); - let dir_path = dir.path(); - - let downloader = Downloader::new(dir_path.to_path_buf()); - let downloads = vec![ - Download::from( - "https://github.com/imsnif/monocle/releases/download/0.37.2/monocle.wasm", - ), - Download::from( - "https://github.com/imsnif/multitask/releases/download/0.38.2/multitask.wasm", - ), - ]; - - let results = downloader.download(&downloads); - for result in results { - assert!(result.is_ok()) - } - } -}