From 7b141ea62fda22abeadad3f83a896958d01897ff Mon Sep 17 00:00:00 2001 From: StellarisW Date: Tue, 6 Aug 2024 23:33:00 +0800 Subject: [PATCH 01/31] feat(http): support websocket server --- Cargo.lock | 102 +++++ Cargo.toml | 1 + volo-http/Cargo.toml | 11 +- volo-http/src/error/server.rs | 57 +++ volo-http/src/server/extract.rs | 1 + volo-http/src/server/mod.rs | 6 +- volo-http/src/server/utils/mod.rs | 4 + volo-http/src/server/utils/ws.rs | 678 ++++++++++++++++++++++++++++++ 8 files changed, 856 insertions(+), 4 deletions(-) create mode 100644 volo-http/src/server/utils/ws.rs diff --git a/Cargo.lock b/Cargo.lock index a8f24dc6..c484751a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -270,6 +270,15 @@ version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "bumpalo" version = "3.16.0" @@ -418,6 +427,15 @@ version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" +[[package]] +name = "cpufeatures" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504" +dependencies = [ + "libc", +] + [[package]] name = "crc32fast" version = "1.4.2" @@ -452,6 +470,16 @@ version = "0.8.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + [[package]] name = "dashmap" version = "5.5.3" @@ -491,6 +519,16 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + [[package]] name = "directories" version = "5.0.1" @@ -809,6 +847,16 @@ dependencies = [ "slab", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getrandom" version = "0.2.15" @@ -2621,6 +2669,17 @@ dependencies = [ "unsafe-libyaml", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -2967,6 +3026,18 @@ dependencies = [ "tokio-stream", ] +[[package]] +name = "tokio-tungstenite" +version = "0.23.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6989540ced10490aaf14e6bad2e3d33728a2813310a0c71d1574304c49631cd" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.7.11" @@ -3192,6 +3263,30 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "tungstenite" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e2e2ce1e47ed2994fd43b04c8f618008d4cabdd5ee34027cf14f9d918edd9c8" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http 1.1.0", + "httparse", + "log", + "rand", + "sha1", + "thiserror", + "utf-8", +] + +[[package]] +name = "typenum" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" + [[package]] name = "unicase" version = "2.7.0" @@ -3289,6 +3384,12 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "10477a8879b7fc3bb4fb54fa6adcfd6191de561b13d5413fec9cc0239fd1c882" +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf8parse" version = "0.2.2" @@ -3492,6 +3593,7 @@ dependencies = [ "tokio-native-tls", "tokio-rustls 0.25.0", "tokio-test", + "tokio-tungstenite", "tokio-util", "tracing", "volo", diff --git a/Cargo.toml b/Cargo.toml index daacb052..da2532ef 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -130,6 +130,7 @@ webpki-roots = "0.26" tokio-rustls = "0.25" native-tls = "0.2" tokio-native-tls = "0.3" +tokio-tungstenite="0.23.1" [profile.release] opt-level = 3 diff --git a/volo-http/Cargo.toml b/volo-http/Cargo.toml index 65e492af..fee48862 100644 --- a/volo-http/Cargo.toml +++ b/volo-http/Cargo.toml @@ -59,9 +59,13 @@ tokio = { workspace = true, features = [ tokio-util = { workspace = true, features = ["io"] } tracing.workspace = true +# =====optional===== # server optional matchit = { workspace = true, optional = true } +# protocol optional +tokio-tungstenite = { workspace = true, optional = true } + # tls optional tokio-rustls = { workspace = true, optional = true } tokio-native-tls = { workspace = true, optional = true } @@ -88,8 +92,11 @@ default_server = ["server", "query", "form", "json"] full = ["client", "server", "rustls", "cookie", "query", "form", "json", "tls"] -client = ["hyper/client", "hyper/http1"] # client core -server = ["hyper/server", "hyper/http1", "dep:matchit"] # server core +client = ["hyper/client", "hyper/http1", "ws"] # client core +server = ["hyper/server", "hyper/http1", "dep:matchit", "ws"] # server core + +protocol = ["ws"] +ws = ["dep:tokio-tungstenite"] tls = ["rustls"] __tls = [] diff --git a/volo-http/src/error/server.rs b/volo-http/src/error/server.rs index 3be6c308..ed0f7450 100644 --- a/volo-http/src/error/server.rs +++ b/volo-http/src/error/server.rs @@ -99,3 +99,60 @@ pub fn body_collection_error() -> ExtractBodyError { pub fn invalid_content_type() -> ExtractBodyError { ExtractBodyError::Generic(GenericRejectionError::InvalidContentType) } + +#[derive(Debug)] +#[non_exhaustive] +pub enum WebSocketUpgradeRejectionError { + MethodNotGet, + InvalidHttpVersion, + InvalidConnectionHeader, + InvalidUpgradeHeader, + InvalidWebSocketVersionHeader, + WebSocketKeyHeaderMissing, + ConnectionNotUpgradable, +} + +impl WebSocketUpgradeRejectionError { + pub fn to_status_code(self) -> StatusCode { + match self { + Self::MethodNotGet => StatusCode::METHOD_NOT_ALLOWED, + Self::InvalidHttpVersion => StatusCode::HTTP_VERSION_NOT_SUPPORTED, + Self::InvalidConnectionHeader => StatusCode::BAD_REQUEST, + Self::InvalidUpgradeHeader => StatusCode::BAD_REQUEST, + Self::InvalidWebSocketVersionHeader => StatusCode::BAD_REQUEST, + Self::WebSocketKeyHeaderMissing => StatusCode::BAD_REQUEST, + Self::ConnectionNotUpgradable => StatusCode::UPGRADE_REQUIRED, + } + } +} + +impl Error for WebSocketUpgradeRejectionError {} + +impl fmt::Display for WebSocketUpgradeRejectionError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::MethodNotGet => write!(f, "Request method must be 'GET'"), + Self::InvalidHttpVersion => { + write!(f, "Http version not support, only support HTTP 1.1 for now") + } + Self::InvalidConnectionHeader => { + write!(f, "Connection header did not include 'upgrade'") + } + Self::InvalidUpgradeHeader => write!(f, "`Upgrade` header did not include 'websocket'"), + Self::InvalidWebSocketVersionHeader => { + write!(f, "`Sec-WebSocket-Version` header did not include '13'") + } + Self::WebSocketKeyHeaderMissing => write!(f, "`Sec-WebSocket-Key` header missing"), + Self::ConnectionNotUpgradable => write!( + f, + "WebSocket request couldn't be upgraded since no upgrade state was present" + ), + } + } +} + +impl IntoResponse for WebSocketUpgradeRejectionError { + fn into_response(self) -> ServerResponse { + self.to_status_code().into_response() + } +} diff --git a/volo-http/src/server/extract.rs b/volo-http/src/server/extract.rs index 35430dad..1e7830ae 100644 --- a/volo-http/src/server/extract.rs +++ b/volo-http/src/server/extract.rs @@ -18,6 +18,7 @@ use hyper::body::Incoming; use mime::Mime; use volo::{context::Context, net::Address}; +pub use super::utils::{Message, WebSocket, WebSocketConfig, WebSocketUpgrade}; use super::IntoResponse; use crate::{ context::ServerContext, diff --git a/volo-http/src/server/mod.rs b/volo-http/src/server/mod.rs index 7bcd3d7a..56959031 100644 --- a/volo-http/src/server/mod.rs +++ b/volo-http/src/server/mod.rs @@ -443,13 +443,15 @@ async fn serve_conn( let notified = exit_notify.notified(); tokio::pin!(notified); - let mut http_conn = server.serve_connection(TokioIo::new(conn), service); + let mut http_conn = server + .serve_connection(TokioIo::new(conn), service) + .with_upgrades(); tokio::select! { _ = &mut notified => { tracing::trace!("[VOLO] closing a pending connection"); // Graceful shutdown. - hyper::server::conn::http1::Connection::graceful_shutdown( + hyper::server::conn::http1::UpgradeableConnection::graceful_shutdown( Pin::new(&mut http_conn) ); // Continue to poll this connection until shutdown can finish. diff --git a/volo-http/src/server/utils/mod.rs b/volo-http/src/server/utils/mod.rs index d495286b..07e268ac 100644 --- a/volo-http/src/server/utils/mod.rs +++ b/volo-http/src/server/utils/mod.rs @@ -5,3 +5,7 @@ mod serve_dir; pub use file_response::FileResponse; pub use serve_dir::ServeDir; + +pub mod ws; + +pub use self::ws::{Config as WebSocketConfig, Message, WebSocket, WebSocketUpgrade}; diff --git a/volo-http/src/server/utils/ws.rs b/volo-http/src/server/utils/ws.rs new file mode 100644 index 00000000..8c725041 --- /dev/null +++ b/volo-http/src/server/utils/ws.rs @@ -0,0 +1,678 @@ +//! Handle WebSocket connections. +//! +//! # Example + +use std::{borrow::Cow, collections::HashMap, fmt::Formatter, future::Future}; + +use http::{header, request::Parts, HeaderMap, HeaderName, HeaderValue, StatusCode}; +use hyper::Error; +use hyper_util::rt::TokioIo; +use tokio_tungstenite::{ + tungstenite::{ + self, + handshake::derive_accept_key, + protocol::{self, WebSocketConfig}, + }, + WebSocketStream, +}; + +use crate::{ + body::Body, context::ServerContext, error::server::WebSocketUpgradeRejectionError, + response::ServerResponse, server::extract::FromContext, +}; + +pub type WebSocket = WebSocketStream>; +pub type Message = tungstenite::Message; + +/// WebSocket headers that will be used for the upgrade request. +struct Headers { + /// The `Sec-WebSocket-Key` request header value + /// used for compute 'Sec-WebSocket-Accept' response header value + sec_websocket_key: HeaderValue, + /// The `Sec-WebSocket-Protocol` request header value + /// specify [`Callback`] method depend on protocol + sec_websocket_protocol: Option, +} + +impl std::fmt::Debug for Headers { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Headers") + .field("sec_websocket_key", &self.sec_websocket_protocol) + .field("sec_websocket_protocol", &self.sec_websocket_protocol) + .finish_non_exhaustive() + } +} + +pub struct Config { + /// WebSocket config for transport (alias of [`tungstenite::protocol::WebSocketConfig`]) + /// e.g. max write buffer size + transport: WebSocketConfig, + /// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response. + /// use [`WebSocketUpgrade::protocols`] to set server supported protocols + protocols: Option>, +} + +impl Config { + pub fn new() -> Self { + Config { + transport: WebSocketConfig::default(), + protocols: None, + } + } + + /// Set server supported protocols + /// This will filter protocols in request header `Sec-WebSocket-Protocol` + /// will set the first server supported protocol in [`http::header::Sec-WebSocket-Protocol`] in + /// response + /// ```rust + /// use volo_http::server::utils::WebSocketConfig; + /// + /// let config = WebSocketConfig::new() + /// .set_protocols([ + /// "graphql-ws", + /// "graphql-transport-ws", + /// ]); + pub fn set_protocols(mut self, protocols: I) -> Self + where + I: IntoIterator, + I::Item: Into>, + { + self.protocols = Some( + protocols + .into_iter() + .map(Into::into) + .map(|protocol| match protocol { + Cow::Owned(s) => HeaderValue::from_str(&s).unwrap(), + Cow::Borrowed(s) => HeaderValue::from_static(s), + }) + .collect(), + ); + self + } + + /// Set transport config + /// e.g. write buffer size + /// ```rust + /// use volo_http::server::utils::WebSocketConfig; + /// use tokio_tungstenite::tungstenite::protocol::{WebSocketConfig as WebSocketTransConfig}; + /// + /// let config = WebSocketConfig::new() + /// .set_transport( + /// WebSocketTransConfig{ + /// write_buffer_size: 128 * 1024, + /// ..<_>::default() + /// } + /// ); + pub fn set_transport(mut self, config: WebSocketConfig) -> Self { + self.transport = config; + self + } +} + +impl Default for Config { + fn default() -> Self { + Config { + transport: Default::default(), + protocols: None, + } + } +} + +impl std::fmt::Debug for Config { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Config") + .field("transport", &self.transport) + .field("protocols", &self.protocols) + .finish_non_exhaustive() + } +} + +/// Callback fn that processes [`WebSocket`] +pub trait Callback: Send + 'static { + fn call(self, _: WebSocket) -> impl std::future::Future + std::marker::Send; +} + +impl Callback for C +where + Fut: Future + Send + 'static, + C: FnOnce(WebSocket) -> Fut + Send + 'static, + C: Copy, +{ + async fn call(self, websocket: WebSocket) { + self(websocket).await; + } +} + +/// What to do when a connection upgrade fails. +/// +/// See [`WebSocketUpgrade::on_failed_upgrade`] for more details. +pub trait OnFailedUpgrade: Send + 'static { + fn call(self, error: Error); +} + +impl OnFailedUpgrade for F +where + F: FnOnce(Error) + Send + 'static, +{ + fn call(self, error: Error) { + self(error) + } +} + +/// The default `OnFailedUpgrade` used by `WebSocketUpgrade`. +/// +/// It simply ignores the error. +#[non_exhaustive] +#[derive(Debug)] +pub struct DefaultOnFailedUpgrade; + +impl OnFailedUpgrade for DefaultOnFailedUpgrade { + #[inline] + fn call(self, _error: Error) {} +} + +#[derive(Copy, Clone)] +pub struct DefaultCallback; +impl Callback for DefaultCallback { + #[inline] + async fn call(self, _: WebSocket) {} +} + +/// Extractor of [`FromContext`] for establishing WebSocket connection +/// +/// Constrains: +/// The extractor only supports for the request that has the method [`Method::GET`] +/// and contains certain header values. +/// See [`WebSocketUpgrade::from_context`] for more details. +pub struct WebSocketUpgrade { + config: Config, + on_protocol: HashMap, + on_failed_upgrade: F, + on_upgrade: hyper::upgrade::OnUpgrade, + headers: Headers, +} + +impl std::fmt::Debug for WebSocketUpgrade { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("WebSocketUpgrade") + .field("config", &self.config) + .field("headers", &self.headers) + .finish_non_exhaustive() + } +} + +impl WebSocketUpgrade +where + C: Callback + Clone, + F: OnFailedUpgrade, +{ + /// Set WebSocket config + /// ```rust + /// use volo_http::{ + /// response::ServerResponse, + /// server::extract::WebSocketConfig, + /// server::extract::WebSocketUpgrade, + /// }; + /// use tokio_tungstenite::tungstenite::protocol::{WebSocketConfig as WebSocketTransConfig}; + /// + /// async fn ws_handler(ws: WebSocketUpgrade) -> ServerResponse{ + /// ws.set_config( + /// WebSocketConfig::new() + /// .set_protocols(["graphql-ws","graphql-transport-ws"]) + /// .set_transport( + /// WebSocketTransConfig{ + /// write_buffer_size: 128 * 1024, + /// ..<_>::default() + /// } + /// ) + /// ) + /// .on_upgrade(|socket| async{} ) + /// .unwrap() + /// } + pub fn set_config(mut self, config: Config) -> Self { + self.config = config; + self + } + + /// Set callback for specific protocol + /// ```rust + /// use std::collections::HashMap; + /// use volo_http::{ + /// response::ServerResponse, + /// server::extract::{ + /// WebSocketConfig, + /// WebSocketUpgrade, + /// WebSocket, + /// } + /// }; + /// + /// async fn ws_handler(ws: WebSocketUpgrade) -> ServerResponse{ + /// ws.set_config( + /// WebSocketConfig::new() + /// .set_protocols(["graphql-ws","graphql-transport-ws"]) + /// ) + /// .on_protocol(HashMap::from([("graphql-ws",|mut socket: WebSocket| async move{})])) + /// .on_upgrade(|socket| async{} ) + /// .unwrap() + /// } + pub fn on_protocol(self, on_protocol: I) -> WebSocketUpgrade + where + I: IntoIterator, + I::Item: Into<(H, C1)>, + H: Into<&'static str>, + C1: Callback, + { + let on_protocol = + HashMap::from_iter(on_protocol.into_iter().map(Into::into).map(|(k, v)| { + let k = HeaderValue::from_str(k.into()).unwrap(); + (k, v) + })); + WebSocketUpgrade { + config: self.config, + on_protocol, + on_failed_upgrade: self.on_failed_upgrade, + on_upgrade: self.on_upgrade, + headers: self.headers, + } + } + + /// Provide a callback to call if upgrading the connection fails. + /// + /// The connection upgrade is performed in a background task. + /// If that fails this callback will be called. + /// + /// By default, any errors will be silently ignored. + /// + /// # Example + /// + /// ```rust + /// use std::collections::HashMap; + /// use volo_http::{ + /// response::ServerResponse, + /// server::extract::{ + /// WebSocketConfig, + /// WebSocketUpgrade, + /// WebSocket, + /// } + /// }; + /// + /// async fn ws_handler(ws: WebSocketUpgrade) -> ServerResponse{ + /// ws.on_failed_upgrade(|error| { + /// unimplemented!() + /// }) + /// .on_upgrade(|socket| async{} ) + /// .unwrap() + /// } + pub fn on_failed_upgrade(self, callback: F1) -> WebSocketUpgrade { + WebSocketUpgrade { + config: self.config, + on_protocol: self.on_protocol, + on_failed_upgrade: callback, + on_upgrade: self.on_upgrade, + headers: self.headers, + } + } + + /// Finalize upgrading the connection and call the provided callback + /// if request protocol is matched, it will use callback set by + /// [`WebSocketUpgrade::on_protocol`], otherwise use `default_callback`. + pub fn on_upgrade(self, default_callback: C1) -> Result + where + Fut: Future + Send + 'static, + C1: FnOnce(WebSocket) -> Fut + Send + 'static, + C1: Send + Sync, + { + let on_upgrade = self.on_upgrade; + let config = self.config.transport; + let on_failed_upgrade = self.on_failed_upgrade; + + let protocol = self + .headers + .sec_websocket_protocol + .clone() + .as_ref() + .and_then(|p| p.to_str().ok()) + .and_then(|req_protocols| { + let binding = self.config.protocols.clone(); + binding.as_ref().map(|protocol| { + protocol + .into_iter() + .find(|protocol| { + req_protocols + .split(',') + .any(|req_protocol| req_protocol == *protocol) + }) + .unwrap() + .clone() + // protocol.clone() + }) + // .into_iter() + // .find(|protocol| { + // req_protocols + // .split(',') + // .any(|req_protocol| req_protocol == *protocol) + // }) + // .unwrap(); + // Some(protocol.clone()) + }); + + let callback = protocol + .clone() + .and_then(|protocol| self.on_protocol.get(&protocol).cloned()); + + tokio::spawn(async move { + let upgraded = match on_upgrade.await { + Ok(upgraded) => upgraded, + Err(err) => { + on_failed_upgrade.call(err); + return; + } + }; + let upgraded = TokioIo::new(upgraded); + + let socket = + WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config)) + .await; + + match callback { + Some(callback) => callback.call(socket).await, + None => default_callback(socket).await, + } + }); + + const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade"); + const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket"); + + let mut builder = ServerResponse::builder() + .status(StatusCode::SWITCHING_PROTOCOLS) + .header(header::CONNECTION, UPGRADE) + .header(header::UPGRADE, WEBSOCKET) + .header( + header::SEC_WEBSOCKET_ACCEPT, + derive_accept_key(self.headers.sec_websocket_key.as_bytes()), + ); + + if let Some(protocol) = protocol { + builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol); + } + + Ok(builder.body(Body::empty()).unwrap()) + } +} + +impl FromContext for WebSocketUpgrade { + type Rejection = WebSocketUpgradeRejectionError; + + async fn from_context( + _cx: &mut ServerContext, + parts: &mut Parts, + ) -> Result { + if parts.method != http::Method::GET { + return Err(WebSocketUpgradeRejectionError::MethodNotGet.into()); + } + if parts.version < http::Version::HTTP_11 { + return Err(WebSocketUpgradeRejectionError::InvalidHttpVersion.into()); + } + + fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool { + let header = if let Some(header) = headers.get(&key) { + header + } else { + return false; + }; + + if let Ok(header) = std::str::from_utf8(header.as_bytes()) { + header.to_ascii_lowercase().contains(value) + } else { + false + } + } + + if !header_contains(&parts.headers, http::header::CONNECTION, "upgrade") { + return Err(WebSocketUpgradeRejectionError::InvalidConnectionHeader.into()); + } + + fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool { + if let Some(header) = headers.get(&key) { + header.as_bytes().eq_ignore_ascii_case(value.as_bytes()) + } else { + false + } + } + + if !header_eq(&parts.headers, http::header::UPGRADE, "websocket") { + return Err(WebSocketUpgradeRejectionError::InvalidUpgradeHeader.into()); + } + + if !header_eq(&parts.headers, http::header::SEC_WEBSOCKET_VERSION, "13") { + return Err(WebSocketUpgradeRejectionError::InvalidWebSocketVersionHeader.into()); + } + + let sec_websocket_key = parts + .headers + .get(header::SEC_WEBSOCKET_KEY) + .ok_or(WebSocketUpgradeRejectionError::WebSocketKeyHeaderMissing)? + .clone(); + + let on_upgrade = parts + .extensions + .remove::() + .ok_or(WebSocketUpgradeRejectionError::ConnectionNotUpgradable)?; + + let sec_websocket_protocol = parts.headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned(); + + Ok(Self { + config: Default::default(), + headers: Headers { + sec_websocket_key, + sec_websocket_protocol, + }, + on_failed_upgrade: DefaultOnFailedUpgrade, + on_upgrade, + on_protocol: HashMap::new(), + }) + } +} + +#[cfg(test)] +mod websocket_tests { + use std::{net, ops::Add}; + + use futures_util::{SinkExt, StreamExt}; + use http::{self, Uri}; + use motore::Service; + use tokio::net::TcpStream; + use tokio_tungstenite::{ + tungstenite::{client::IntoClientRequest, ClientRequestBuilder}, + MaybeTlsStream, + }; + use volo::net::Address; + + use super::*; + use crate::{ + server::{ + route::{get, Route}, + test_helpers::empty_cx, + }, + Router, Server, + }; + + async fn run_ws_handler( + addr: Address, + handler: C, + req: R, + ) -> (WebSocketStream>, ServerResponse) + where + R: IntoClientRequest + Unpin, + Fut: Future + Send + 'static, + C: FnOnce(WebSocketUpgrade) -> Fut + Send + 'static, + C: Send + Sync + Clone, + { + let app = Router::new().route("/echo", get(handler)); + + tokio::spawn(async move { + Server::new(app).run(addr).await.unwrap(); + }); + + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + + let (socket, response) = tokio_tungstenite::connect_async(req).await.unwrap(); + + ( + socket, + response.map(|response| response.unwrap_or_default().into()), + ) + } + + #[tokio::test] + async fn rejects_unupgradable_requests() { + let route: Route = Route::new(get( + |ws: Result| { + let rejection = ws.unwrap_err(); + assert!(matches!( + rejection, + WebSocketUpgradeRejectionError::ConnectionNotUpgradable, + )); + std::future::ready(()) + }, + )); + + let req = http::Request::builder() + .version(http::Version::HTTP_11) + .method(http::Method::GET) + .header("upgrade", "websocket") + .header("connection", "Upgrade") + .header("sec-websocket-key", "6D69KGBOr4Re+Nj6zx9aQA==") + .header("sec-websocket-version", "13") + .body(Body::empty()) + .unwrap(); + + let mut cx = empty_cx(); + + let resp = route.call(&mut cx, req).await.unwrap(); + + assert_eq!(resp.status(), StatusCode::OK); + } + + #[tokio::test] + async fn rejects_non_get_requests() { + let route: Route = Route::new(get( + |ws: Result| { + let rejection = ws.unwrap_err(); + assert!(matches!( + rejection, + WebSocketUpgradeRejectionError::MethodNotGet, + )); + std::future::ready(()) + }, + )); + + let req = http::Request::builder() + .method(http::Method::POST) + .body(Body::empty()) + .unwrap(); + + let mut cx = empty_cx(); + + let resp = route.call(&mut cx, req).await.unwrap(); + + assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); + } + + #[tokio::test] + async fn on_protocol() { + use crate::{ + response::ServerResponse, + server::extract::{WebSocketConfig, WebSocketUpgrade}, + }; + + async fn ws_handler(ws: WebSocketUpgrade) -> ServerResponse { + ws.set_config( + WebSocketConfig::new().set_protocols(["graphql-ws", "graphql-transport-ws"]), + ) + .on_protocol(HashMap::from([( + "graphql-ws", + |mut socket: WebSocket| async move { + while let Some(Ok(msg)) = socket.next().await { + match msg { + Message::Text(text) => { + socket + .send(Message::Text(text.add("-graphql-ws"))) + .await + .unwrap(); + } + _ => {} + } + } + }, + )])) + .on_upgrade(|_| async {}) + .unwrap() + } + + let addr = Address::Ip(net::SocketAddr::new( + net::IpAddr::V4(net::Ipv4Addr::new(127, 0, 0, 1)), + 8000, + )); + + let builder = ClientRequestBuilder::new( + format!("ws://{}/echo", addr.clone()) + .parse::() + .unwrap(), + ) + .with_sub_protocol("graphql-ws"); + let (mut ws_stream, _response) = run_ws_handler(addr, ws_handler, builder).await; + + let input = Message::Text("foobar".to_owned()); + ws_stream.send(input.clone()).await.unwrap(); + let output = ws_stream.next().await.unwrap().unwrap(); + assert_eq!(output, Message::Text("foobar-graphql-ws".to_owned())); + } + + #[tokio::test] + async fn integration_test() { + async fn handle_socket(mut socket: WebSocket) { + while let Some(Ok(msg)) = socket.next().await { + match msg { + Message::Text(_) + | Message::Binary(_) + | Message::Close(_) + | Message::Frame(_) => { + if socket.send(msg).await.is_err() { + break; + } + } + Message::Ping(_) | Message::Pong(_) => {} + } + } + } + + let addr = Address::Ip(net::SocketAddr::new( + net::IpAddr::V4(net::Ipv4Addr::new(127, 0, 0, 1)), + 8001, + )); + + let builder = ClientRequestBuilder::new( + format!("ws://{}/echo", addr.clone()) + .parse::() + .unwrap(), + ); + + let (mut ws_stream, _response) = run_ws_handler( + addr, + |ws: WebSocketUpgrade| std::future::ready(ws.on_upgrade(handle_socket).unwrap()), + builder, + ) + .await; + + let input = Message::Text("foobar".to_owned()); + ws_stream.send(input.clone()).await.unwrap(); + let output = ws_stream.next().await.unwrap().unwrap(); + assert_eq!(input, output); + + let input = Message::Ping("foobar".to_owned().into_bytes()); + ws_stream.send(input).await.unwrap(); + let output = ws_stream.next().await.unwrap().unwrap(); + assert_eq!(output, Message::Pong("foobar".to_owned().into_bytes())); + } +} From 9ad6efd50b16f8e86eac009e3252a290f452e781 Mon Sep 17 00:00:00 2001 From: StellarisW Date: Tue, 6 Aug 2024 23:49:33 +0800 Subject: [PATCH 02/31] feat(http): support websocket server --- volo-http/src/error/server.rs | 9 +++++++++ volo-http/src/server/utils/ws.rs | 11 ++++++++++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/volo-http/src/error/server.rs b/volo-http/src/error/server.rs index ed0f7450..53a858f6 100644 --- a/volo-http/src/error/server.rs +++ b/volo-http/src/error/server.rs @@ -100,19 +100,28 @@ pub fn invalid_content_type() -> ExtractBodyError { ExtractBodyError::Generic(GenericRejectionError::InvalidContentType) } +/// Rejection used for [`WebSocketUpgrade`](crate::server::extract::WebSocketUpgrade). #[derive(Debug)] #[non_exhaustive] pub enum WebSocketUpgradeRejectionError { + /// The request method must be `GET` MethodNotGet, + /// The HTTP version is not supported InvalidHttpVersion, + /// The `Connection` header is invalid InvalidConnectionHeader, + /// The `Upgrade` header is invalid InvalidUpgradeHeader, + /// The `Sec-WebSocket-Version` header is invalid InvalidWebSocketVersionHeader, + /// The `Sec-WebSocket-Key` header is missing WebSocketKeyHeaderMissing, + /// The connection is not upgradable ConnectionNotUpgradable, } impl WebSocketUpgradeRejectionError { + /// Convert the [`WebSocketUpgradeRejectionError`] to the corresponding [`StatusCode`] pub fn to_status_code(self) -> StatusCode { match self { Self::MethodNotGet => StatusCode::METHOD_NOT_ALLOWED, diff --git a/volo-http/src/server/utils/ws.rs b/volo-http/src/server/utils/ws.rs index 8c725041..a69a7354 100644 --- a/volo-http/src/server/utils/ws.rs +++ b/volo-http/src/server/utils/ws.rs @@ -21,7 +21,9 @@ use crate::{ response::ServerResponse, server::extract::FromContext, }; +/// WebSocketStream used In handler Request pub type WebSocket = WebSocketStream>; +/// alias of [`tungstenite::Message`] pub type Message = tungstenite::Message; /// WebSocket headers that will be used for the upgrade request. @@ -43,6 +45,7 @@ impl std::fmt::Debug for Headers { } } +/// WebSocket config pub struct Config { /// WebSocket config for transport (alias of [`tungstenite::protocol::WebSocketConfig`]) /// e.g. max write buffer size @@ -53,6 +56,7 @@ pub struct Config { } impl Config { + /// Create Default Config pub fn new() -> Self { Config { transport: WebSocketConfig::default(), @@ -129,7 +133,8 @@ impl std::fmt::Debug for Config { /// Callback fn that processes [`WebSocket`] pub trait Callback: Send + 'static { - fn call(self, _: WebSocket) -> impl std::future::Future + std::marker::Send; + /// Called when a connection upgrade succeeds + fn call(self, _: WebSocket) -> impl Future + Send; } impl Callback for C @@ -147,6 +152,7 @@ where /// /// See [`WebSocketUpgrade::on_failed_upgrade`] for more details. pub trait OnFailedUpgrade: Send + 'static { + /// Called when a connection upgrade fails. fn call(self, error: Error); } @@ -171,6 +177,9 @@ impl OnFailedUpgrade for DefaultOnFailedUpgrade { fn call(self, _error: Error) {} } +/// The default `Callback` used by `WebSocketUpgrade`. +/// +/// It simply ignores the socket. #[derive(Copy, Clone)] pub struct DefaultCallback; impl Callback for DefaultCallback { From 06e6dade9e61d62e5f8e4635f67f9500b8bb2fe5 Mon Sep 17 00:00:00 2001 From: StellarisW Date: Wed, 7 Aug 2024 11:03:29 +0800 Subject: [PATCH 03/31] feat(http): support websocket server --- volo-http/src/server/utils/ws.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/volo-http/src/server/utils/ws.rs b/volo-http/src/server/utils/ws.rs index a69a7354..8cc53fc0 100644 --- a/volo-http/src/server/utils/ws.rs +++ b/volo-http/src/server/utils/ws.rs @@ -190,7 +190,7 @@ impl Callback for DefaultCallback { /// Extractor of [`FromContext`] for establishing WebSocket connection /// /// Constrains: -/// The extractor only supports for the request that has the method [`Method::GET`] +/// The extractor only supports for the request that has the method [`GET`](http::Method::GET) /// and contains certain header values. /// See [`WebSocketUpgrade::from_context`] for more details. pub struct WebSocketUpgrade { From 5ccc9ee6033ab9961baf978adecf9d3c2ab86375 Mon Sep 17 00:00:00 2001 From: StellarisW Date: Wed, 7 Aug 2024 11:13:20 +0800 Subject: [PATCH 04/31] feat(http): support websocket server --- volo-http/src/server/utils/ws.rs | 22 +++++++--------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/volo-http/src/server/utils/ws.rs b/volo-http/src/server/utils/ws.rs index 8cc53fc0..8cf0b263 100644 --- a/volo-http/src/server/utils/ws.rs +++ b/volo-http/src/server/utils/ws.rs @@ -46,6 +46,7 @@ impl std::fmt::Debug for Headers { } /// WebSocket config +#[derive(Default)] pub struct Config { /// WebSocket config for transport (alias of [`tungstenite::protocol::WebSocketConfig`]) /// e.g. max write buffer size @@ -113,15 +114,6 @@ impl Config { } } -impl Default for Config { - fn default() -> Self { - Config { - transport: Default::default(), - protocols: None, - } - } -} - impl std::fmt::Debug for Config { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("Config") @@ -345,7 +337,7 @@ where let binding = self.config.protocols.clone(); binding.as_ref().map(|protocol| { protocol - .into_iter() + .iter() .find(|protocol| { req_protocols .split(',') @@ -417,10 +409,10 @@ impl FromContext for WebSocketUpgrade { parts: &mut Parts, ) -> Result { if parts.method != http::Method::GET { - return Err(WebSocketUpgradeRejectionError::MethodNotGet.into()); + return Err(WebSocketUpgradeRejectionError::MethodNotGet); } if parts.version < http::Version::HTTP_11 { - return Err(WebSocketUpgradeRejectionError::InvalidHttpVersion.into()); + return Err(WebSocketUpgradeRejectionError::InvalidHttpVersion); } fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool { @@ -438,7 +430,7 @@ impl FromContext for WebSocketUpgrade { } if !header_contains(&parts.headers, http::header::CONNECTION, "upgrade") { - return Err(WebSocketUpgradeRejectionError::InvalidConnectionHeader.into()); + return Err(WebSocketUpgradeRejectionError::InvalidConnectionHeader); } fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool { @@ -450,11 +442,11 @@ impl FromContext for WebSocketUpgrade { } if !header_eq(&parts.headers, http::header::UPGRADE, "websocket") { - return Err(WebSocketUpgradeRejectionError::InvalidUpgradeHeader.into()); + return Err(WebSocketUpgradeRejectionError::InvalidUpgradeHeader); } if !header_eq(&parts.headers, http::header::SEC_WEBSOCKET_VERSION, "13") { - return Err(WebSocketUpgradeRejectionError::InvalidWebSocketVersionHeader.into()); + return Err(WebSocketUpgradeRejectionError::InvalidWebSocketVersionHeader); } let sec_websocket_key = parts From 6da65f5327f9b6be9caef5b5570a98a0f2d30ee2 Mon Sep 17 00:00:00 2001 From: StellarisW Date: Wed, 7 Aug 2024 11:19:00 +0800 Subject: [PATCH 05/31] feat(http): support websocket server --- volo-http/src/server/utils/ws.rs | 9 --------- 1 file changed, 9 deletions(-) diff --git a/volo-http/src/server/utils/ws.rs b/volo-http/src/server/utils/ws.rs index 8cf0b263..a5d8d739 100644 --- a/volo-http/src/server/utils/ws.rs +++ b/volo-http/src/server/utils/ws.rs @@ -345,16 +345,7 @@ where }) .unwrap() .clone() - // protocol.clone() }) - // .into_iter() - // .find(|protocol| { - // req_protocols - // .split(',') - // .any(|req_protocol| req_protocol == *protocol) - // }) - // .unwrap(); - // Some(protocol.clone()) }); let callback = protocol From 712586d39ef34c1c641d071dc06a55e51ddea6db Mon Sep 17 00:00:00 2001 From: StellarisW Date: Wed, 7 Aug 2024 11:27:26 +0800 Subject: [PATCH 06/31] feat(http): support websocket server --- volo-http/src/server/utils/ws.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/volo-http/src/server/utils/ws.rs b/volo-http/src/server/utils/ws.rs index a5d8d739..8dd293d8 100644 --- a/volo-http/src/server/utils/ws.rs +++ b/volo-http/src/server/utils/ws.rs @@ -604,7 +604,7 @@ mod websocket_tests { let addr = Address::Ip(net::SocketAddr::new( net::IpAddr::V4(net::Ipv4Addr::new(127, 0, 0, 1)), - 8000, + 25230, )); let builder = ClientRequestBuilder::new( @@ -621,7 +621,7 @@ mod websocket_tests { assert_eq!(output, Message::Text("foobar-graphql-ws".to_owned())); } - #[tokio::test] + #[cfg(test)] async fn integration_test() { async fn handle_socket(mut socket: WebSocket) { while let Some(Ok(msg)) = socket.next().await { @@ -641,7 +641,7 @@ mod websocket_tests { let addr = Address::Ip(net::SocketAddr::new( net::IpAddr::V4(net::Ipv4Addr::new(127, 0, 0, 1)), - 8001, + 25231, )); let builder = ClientRequestBuilder::new( From 44f12aced29d81678a512ece68113431d84572ff Mon Sep 17 00:00:00 2001 From: StellarisW Date: Wed, 7 Aug 2024 15:42:40 +0800 Subject: [PATCH 07/31] feat(http): support websocket server --- volo-http/Cargo.toml | 5 +- volo-http/src/server/utils/ws.rs | 205 ++++++++++++++++++------------- 2 files changed, 125 insertions(+), 85 deletions(-) diff --git a/volo-http/Cargo.toml b/volo-http/Cargo.toml index fee48862..4210ae4b 100644 --- a/volo-http/Cargo.toml +++ b/volo-http/Cargo.toml @@ -92,10 +92,9 @@ default_server = ["server", "query", "form", "json"] full = ["client", "server", "rustls", "cookie", "query", "form", "json", "tls"] -client = ["hyper/client", "hyper/http1", "ws"] # client core -server = ["hyper/server", "hyper/http1", "dep:matchit", "ws"] # server core +client = ["hyper/client", "hyper/http1"] # client core +server = ["hyper/server", "hyper/http1", "dep:matchit"] # server core -protocol = ["ws"] ws = ["dep:tokio-tungstenite"] tls = ["rustls"] diff --git a/volo-http/src/server/utils/ws.rs b/volo-http/src/server/utils/ws.rs index 8dd293d8..7817cb6e 100644 --- a/volo-http/src/server/utils/ws.rs +++ b/volo-http/src/server/utils/ws.rs @@ -1,10 +1,47 @@ -//! Handle WebSocket connections. +//! Module for handling WebSocket connection +//! +//! +//! This module provides utilities for setting up and handling WebSocket connections, including +//! configuring WebSocket options, setting protocols, and upgrading connections. +//! +//! It uses [`hyper::upgrade::OnUpgrade`] to upgrade the connection. +//! //! //! # Example +//! +//! ```rust +//! use futures_util::{SinkExt, StreamExt}; +//! use volo_http::{ +//! response::ServerResponse, +//! server::{ +//! extract::{Message, WebSocket}, +//! route::get, +//! utils::WebSocketUpgrade, +//! }, +//! Router, +//! }; +//! +//! async fn handle_socket(mut socket: WebSocket) { +//! while let Some(Ok(msg)) = socket.next().await { +//! match msg { +//! Message::Text(text) => { +//! socket.send(msg).await.unwrap(); +//! } +//! _ => {} +//! } +//! } +//! } +//! +//! async fn ws_handler(ws: WebSocketUpgrade) -> ServerResponse { +//! ws.on_upgrade(handle_socket) +//! } +//! +//! let app = Router::new().route("/ws", get(ws_handler)); +//! ``` use std::{borrow::Cow, collections::HashMap, fmt::Formatter, future::Future}; -use http::{header, request::Parts, HeaderMap, HeaderName, HeaderValue, StatusCode}; +use http::{request::Parts, HeaderMap, HeaderName, HeaderValue}; use hyper::Error; use hyper_util::rt::TokioIo; use tokio_tungstenite::{ @@ -69,14 +106,13 @@ impl Config { /// This will filter protocols in request header `Sec-WebSocket-Protocol` /// will set the first server supported protocol in [`http::header::Sec-WebSocket-Protocol`] in /// response + /// + /// /// ```rust /// use volo_http::server::utils::WebSocketConfig; /// - /// let config = WebSocketConfig::new() - /// .set_protocols([ - /// "graphql-ws", - /// "graphql-transport-ws", - /// ]); + /// let config = WebSocketConfig::new().set_protocols(["graphql-ws", "graphql-transport-ws"]); + /// ``` pub fn set_protocols(mut self, protocols: I) -> Self where I: IntoIterator, @@ -97,17 +133,17 @@ impl Config { /// Set transport config /// e.g. write buffer size + /// + /// /// ```rust - /// use volo_http::server::utils::WebSocketConfig; - /// use tokio_tungstenite::tungstenite::protocol::{WebSocketConfig as WebSocketTransConfig}; + /// use tokio_tungstenite::tungstenite::protocol::WebSocketConfig as WebSocketTransConfig; + /// use volo_http::server::utils::WebSocketConfig; /// - /// let config = WebSocketConfig::new() - /// .set_transport( - /// WebSocketTransConfig{ - /// write_buffer_size: 128 * 1024, - /// ..<_>::default() - /// } - /// ); + /// let config = WebSocketConfig::new().set_transport(WebSocketTransConfig { + /// write_buffer_size: 128 * 1024, + /// ..<_>::default() + /// }); + /// ``` pub fn set_transport(mut self, config: WebSocketConfig) -> Self { self.transport = config; self @@ -126,12 +162,12 @@ impl std::fmt::Debug for Config { /// Callback fn that processes [`WebSocket`] pub trait Callback: Send + 'static { /// Called when a connection upgrade succeeds - fn call(self, _: WebSocket) -> impl Future + Send; + fn call(self, _: WebSocket) -> impl Future + Send; } impl Callback for C where - Fut: Future + Send + 'static, + Fut: Future + Send + 'static, C: FnOnce(WebSocket) -> Fut + Send + 'static, C: Copy, { @@ -181,10 +217,20 @@ impl Callback for DefaultCallback { /// Extractor of [`FromContext`] for establishing WebSocket connection /// -/// Constrains: +/// **Constrains**: +/// /// The extractor only supports for the request that has the method [`GET`](http::Method::GET) /// and contains certain header values. -/// See [`WebSocketUpgrade::from_context`] for more details. +/// +/// # Usage +/// +/// ```rust +/// use volo_http::{response::ServerResponse, server::extract::WebSocketUpgrade}; +/// +/// fn ws_handler(ws: WebSocketUpgrade) -> ServerResponse { +/// ws.on_upgrade(|socket| unimplemented!()) +/// } +/// ``` pub struct WebSocketUpgrade { config: Config, on_protocol: HashMap, @@ -228,7 +274,6 @@ where /// ) /// ) /// .on_upgrade(|socket| async{} ) - /// .unwrap() /// } pub fn set_config(mut self, config: Config) -> Self { self.config = config; @@ -254,7 +299,6 @@ where /// ) /// .on_protocol(HashMap::from([("graphql-ws",|mut socket: WebSocket| async move{})])) /// .on_upgrade(|socket| async{} ) - /// .unwrap() /// } pub fn on_protocol(self, on_protocol: I) -> WebSocketUpgrade where @@ -302,7 +346,6 @@ where /// unimplemented!() /// }) /// .on_upgrade(|socket| async{} ) - /// .unwrap() /// } pub fn on_failed_upgrade(self, callback: F1) -> WebSocketUpgrade { WebSocketUpgrade { @@ -317,11 +360,10 @@ where /// Finalize upgrading the connection and call the provided callback /// if request protocol is matched, it will use callback set by /// [`WebSocketUpgrade::on_protocol`], otherwise use `default_callback`. - pub fn on_upgrade(self, default_callback: C1) -> Result + pub fn on_upgrade(self, default_callback: C1) -> ServerResponse where - Fut: Future + Send + 'static, - C1: FnOnce(WebSocket) -> Fut + Send + 'static, - C1: Send + Sync, + Fut: Future + Send + 'static, + C1: FnOnce(WebSocket) -> Fut + Send + Sync + 'static, { let on_upgrade = self.on_upgrade; let config = self.config.transport; @@ -376,19 +418,41 @@ where const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket"); let mut builder = ServerResponse::builder() - .status(StatusCode::SWITCHING_PROTOCOLS) - .header(header::CONNECTION, UPGRADE) - .header(header::UPGRADE, WEBSOCKET) + .status(http::StatusCode::SWITCHING_PROTOCOLS) + .header(http::header::CONNECTION, UPGRADE) + .header(http::header::UPGRADE, WEBSOCKET) .header( - header::SEC_WEBSOCKET_ACCEPT, + http::header::SEC_WEBSOCKET_ACCEPT, derive_accept_key(self.headers.sec_websocket_key.as_bytes()), ); if let Some(protocol) = protocol { - builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol); + builder = builder.header(http::header::SEC_WEBSOCKET_PROTOCOL, protocol); } - Ok(builder.body(Body::empty()).unwrap()) + builder.body(Body::empty()).unwrap() + } +} + +fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool { + let header = if let Some(header) = headers.get(&key) { + header + } else { + return false; + }; + + if let Ok(header) = std::str::from_utf8(header.as_bytes()) { + header.to_ascii_lowercase().contains(value) + } else { + false + } +} + +fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool { + if let Some(header) = headers.get(&key) { + header.as_bytes().eq_ignore_ascii_case(value.as_bytes()) + } else { + false } } @@ -406,32 +470,10 @@ impl FromContext for WebSocketUpgrade { return Err(WebSocketUpgradeRejectionError::InvalidHttpVersion); } - fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool { - let header = if let Some(header) = headers.get(&key) { - header - } else { - return false; - }; - - if let Ok(header) = std::str::from_utf8(header.as_bytes()) { - header.to_ascii_lowercase().contains(value) - } else { - false - } - } - if !header_contains(&parts.headers, http::header::CONNECTION, "upgrade") { return Err(WebSocketUpgradeRejectionError::InvalidConnectionHeader); } - fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool { - if let Some(header) = headers.get(&key) { - header.as_bytes().eq_ignore_ascii_case(value.as_bytes()) - } else { - false - } - } - if !header_eq(&parts.headers, http::header::UPGRADE, "websocket") { return Err(WebSocketUpgradeRejectionError::InvalidUpgradeHeader); } @@ -442,7 +484,7 @@ impl FromContext for WebSocketUpgrade { let sec_websocket_key = parts .headers - .get(header::SEC_WEBSOCKET_KEY) + .get(http::header::SEC_WEBSOCKET_KEY) .ok_or(WebSocketUpgradeRejectionError::WebSocketKeyHeaderMissing)? .clone(); @@ -451,7 +493,7 @@ impl FromContext for WebSocketUpgrade { .remove::() .ok_or(WebSocketUpgradeRejectionError::ConnectionNotUpgradable)?; - let sec_websocket_protocol = parts.headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned(); + let sec_websocket_protocol = parts.headers.get(http::header::SEC_WEBSOCKET_PROTOCOL).cloned(); Ok(Self { config: Default::default(), @@ -471,7 +513,7 @@ mod websocket_tests { use std::{net, ops::Add}; use futures_util::{SinkExt, StreamExt}; - use http::{self, Uri}; + use http::{Uri}; use motore::Service; use tokio::net::TcpStream; use tokio_tungstenite::{ @@ -496,7 +538,7 @@ mod websocket_tests { ) -> (WebSocketStream>, ServerResponse) where R: IntoClientRequest + Unpin, - Fut: Future + Send + 'static, + Fut: Future + Send + 'static, C: FnOnce(WebSocketUpgrade) -> Fut + Send + 'static, C: Send + Sync + Clone, { @@ -543,7 +585,7 @@ mod websocket_tests { let resp = route.call(&mut cx, req).await.unwrap(); - assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(resp.status(), http::StatusCode::OK); } #[tokio::test] @@ -568,7 +610,7 @@ mod websocket_tests { let resp = route.call(&mut cx, req).await.unwrap(); - assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); + assert_eq!(resp.status(), http::StatusCode::METHOD_NOT_ALLOWED); } #[tokio::test] @@ -582,24 +624,23 @@ mod websocket_tests { ws.set_config( WebSocketConfig::new().set_protocols(["graphql-ws", "graphql-transport-ws"]), ) - .on_protocol(HashMap::from([( - "graphql-ws", - |mut socket: WebSocket| async move { - while let Some(Ok(msg)) = socket.next().await { - match msg { - Message::Text(text) => { - socket - .send(Message::Text(text.add("-graphql-ws"))) - .await - .unwrap(); + .on_protocol(HashMap::from([( + "graphql-ws", + |mut socket: WebSocket| async move { + while let Some(Ok(msg)) = socket.next().await { + match msg { + Message::Text(text) => { + socket + .send(Message::Text(text.add("-graphql-ws"))) + .await + .unwrap(); + } + _ => {} } - _ => {} } - } - }, - )])) - .on_upgrade(|_| async {}) - .unwrap() + }, + )])) + .on_upgrade(|_| async {}) } let addr = Address::Ip(net::SocketAddr::new( @@ -612,7 +653,7 @@ mod websocket_tests { .parse::() .unwrap(), ) - .with_sub_protocol("graphql-ws"); + .with_sub_protocol("graphql-ws"); let (mut ws_stream, _response) = run_ws_handler(addr, ws_handler, builder).await; let input = Message::Text("foobar".to_owned()); @@ -621,7 +662,7 @@ mod websocket_tests { assert_eq!(output, Message::Text("foobar-graphql-ws".to_owned())); } - #[cfg(test)] + #[tokio::test] async fn integration_test() { async fn handle_socket(mut socket: WebSocket) { while let Some(Ok(msg)) = socket.next().await { @@ -651,11 +692,11 @@ mod websocket_tests { ); let (mut ws_stream, _response) = run_ws_handler( - addr, - |ws: WebSocketUpgrade| std::future::ready(ws.on_upgrade(handle_socket).unwrap()), + addr.clone(), + |ws: WebSocketUpgrade| std::future::ready(ws.on_upgrade(handle_socket)), builder, ) - .await; + .await; let input = Message::Text("foobar".to_owned()); ws_stream.send(input.clone()).await.unwrap(); From 4e160dc37bfc2c615f4741c3a7cdc0b40a113092 Mon Sep 17 00:00:00 2001 From: StellarisW Date: Wed, 7 Aug 2024 15:50:11 +0800 Subject: [PATCH 08/31] feat(http): support websocket server --- volo-http/src/server/utils/ws.rs | 49 +++++++++++++++++--------------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/volo-http/src/server/utils/ws.rs b/volo-http/src/server/utils/ws.rs index 7817cb6e..eeeef41b 100644 --- a/volo-http/src/server/utils/ws.rs +++ b/volo-http/src/server/utils/ws.rs @@ -162,12 +162,12 @@ impl std::fmt::Debug for Config { /// Callback fn that processes [`WebSocket`] pub trait Callback: Send + 'static { /// Called when a connection upgrade succeeds - fn call(self, _: WebSocket) -> impl Future + Send; + fn call(self, _: WebSocket) -> impl Future + Send; } impl Callback for C where - Fut: Future + Send + 'static, + Fut: Future + Send + 'static, C: FnOnce(WebSocket) -> Fut + Send + 'static, C: Copy, { @@ -362,7 +362,7 @@ where /// [`WebSocketUpgrade::on_protocol`], otherwise use `default_callback`. pub fn on_upgrade(self, default_callback: C1) -> ServerResponse where - Fut: Future + Send + 'static, + Fut: Future + Send + 'static, C1: FnOnce(WebSocket) -> Fut + Send + Sync + 'static, { let on_upgrade = self.on_upgrade; @@ -493,7 +493,10 @@ impl FromContext for WebSocketUpgrade { .remove::() .ok_or(WebSocketUpgradeRejectionError::ConnectionNotUpgradable)?; - let sec_websocket_protocol = parts.headers.get(http::header::SEC_WEBSOCKET_PROTOCOL).cloned(); + let sec_websocket_protocol = parts + .headers + .get(http::header::SEC_WEBSOCKET_PROTOCOL) + .cloned(); Ok(Self { config: Default::default(), @@ -513,7 +516,7 @@ mod websocket_tests { use std::{net, ops::Add}; use futures_util::{SinkExt, StreamExt}; - use http::{Uri}; + use http::Uri; use motore::Service; use tokio::net::TcpStream; use tokio_tungstenite::{ @@ -538,7 +541,7 @@ mod websocket_tests { ) -> (WebSocketStream>, ServerResponse) where R: IntoClientRequest + Unpin, - Fut: Future + Send + 'static, + Fut: Future + Send + 'static, C: FnOnce(WebSocketUpgrade) -> Fut + Send + 'static, C: Send + Sync + Clone, { @@ -624,23 +627,23 @@ mod websocket_tests { ws.set_config( WebSocketConfig::new().set_protocols(["graphql-ws", "graphql-transport-ws"]), ) - .on_protocol(HashMap::from([( - "graphql-ws", - |mut socket: WebSocket| async move { - while let Some(Ok(msg)) = socket.next().await { - match msg { - Message::Text(text) => { - socket - .send(Message::Text(text.add("-graphql-ws"))) - .await - .unwrap(); - } - _ => {} + .on_protocol(HashMap::from([( + "graphql-ws", + |mut socket: WebSocket| async move { + while let Some(Ok(msg)) = socket.next().await { + match msg { + Message::Text(text) => { + socket + .send(Message::Text(text.add("-graphql-ws"))) + .await + .unwrap(); } + _ => {} } - }, - )])) - .on_upgrade(|_| async {}) + } + }, + )])) + .on_upgrade(|_| async {}) } let addr = Address::Ip(net::SocketAddr::new( @@ -653,7 +656,7 @@ mod websocket_tests { .parse::() .unwrap(), ) - .with_sub_protocol("graphql-ws"); + .with_sub_protocol("graphql-ws"); let (mut ws_stream, _response) = run_ws_handler(addr, ws_handler, builder).await; let input = Message::Text("foobar".to_owned()); @@ -696,7 +699,7 @@ mod websocket_tests { |ws: WebSocketUpgrade| std::future::ready(ws.on_upgrade(handle_socket)), builder, ) - .await; + .await; let input = Message::Text("foobar".to_owned()); ws_stream.send(input.clone()).await.unwrap(); From 618d088bc69feb6c07cdb64a9fc1af96760389da Mon Sep 17 00:00:00 2001 From: StellarisW Date: Wed, 7 Aug 2024 15:52:57 +0800 Subject: [PATCH 09/31] feat(http): support websocket server --- volo-http/src/server/utils/ws.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/volo-http/src/server/utils/ws.rs b/volo-http/src/server/utils/ws.rs index eeeef41b..2476801f 100644 --- a/volo-http/src/server/utils/ws.rs +++ b/volo-http/src/server/utils/ws.rs @@ -168,8 +168,7 @@ pub trait Callback: Send + 'static { impl Callback for C where Fut: Future + Send + 'static, - C: FnOnce(WebSocket) -> Fut + Send + 'static, - C: Copy, + C: FnOnce(WebSocket) -> Fut + Send + Copy + 'static, { async fn call(self, websocket: WebSocket) { self(websocket).await; @@ -208,7 +207,7 @@ impl OnFailedUpgrade for DefaultOnFailedUpgrade { /// The default `Callback` used by `WebSocketUpgrade`. /// /// It simply ignores the socket. -#[derive(Copy, Clone)] +#[derive(Clone)] pub struct DefaultCallback; impl Callback for DefaultCallback { #[inline] From aaee04e82624f52a028c02d9be74d2175926651b Mon Sep 17 00:00:00 2001 From: StellarisW Date: Wed, 7 Aug 2024 15:55:30 +0800 Subject: [PATCH 10/31] feat(http): support websocket server --- volo-http/src/server/extract.rs | 1 + volo-http/src/server/utils/mod.rs | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/volo-http/src/server/extract.rs b/volo-http/src/server/extract.rs index 1e7830ae..350ba059 100644 --- a/volo-http/src/server/extract.rs +++ b/volo-http/src/server/extract.rs @@ -18,6 +18,7 @@ use hyper::body::Incoming; use mime::Mime; use volo::{context::Context, net::Address}; +#[cfg(feature = "ws")] pub use super::utils::{Message, WebSocket, WebSocketConfig, WebSocketUpgrade}; use super::IntoResponse; use crate::{ diff --git a/volo-http/src/server/utils/mod.rs b/volo-http/src/server/utils/mod.rs index 07e268ac..4eadb2d8 100644 --- a/volo-http/src/server/utils/mod.rs +++ b/volo-http/src/server/utils/mod.rs @@ -6,6 +6,8 @@ mod serve_dir; pub use file_response::FileResponse; pub use serve_dir::ServeDir; -pub mod ws; +#[cfg(feature = "ws")] +pub mod ws; +#[cfg(feature = "ws")] pub use self::ws::{Config as WebSocketConfig, Message, WebSocket, WebSocketUpgrade}; From e8dc55805db733e6903a206b96900a68918dc220 Mon Sep 17 00:00:00 2001 From: StellarisW Date: Wed, 7 Aug 2024 16:01:12 +0800 Subject: [PATCH 11/31] feat(http): support websocket server --- volo-http/src/server/utils/mod.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/volo-http/src/server/utils/mod.rs b/volo-http/src/server/utils/mod.rs index 4eadb2d8..c965ce17 100644 --- a/volo-http/src/server/utils/mod.rs +++ b/volo-http/src/server/utils/mod.rs @@ -6,7 +6,6 @@ mod serve_dir; pub use file_response::FileResponse; pub use serve_dir::ServeDir; - #[cfg(feature = "ws")] pub mod ws; #[cfg(feature = "ws")] From f6ab717b4212c205d36619389e9ee5bae39ed33b Mon Sep 17 00:00:00 2001 From: StellarisW Date: Wed, 7 Aug 2024 16:05:05 +0800 Subject: [PATCH 12/31] feat(http): support websocket server --- volo-http/src/server/utils/ws.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/volo-http/src/server/utils/ws.rs b/volo-http/src/server/utils/ws.rs index 2476801f..e2a5bd58 100644 --- a/volo-http/src/server/utils/ws.rs +++ b/volo-http/src/server/utils/ws.rs @@ -541,8 +541,7 @@ mod websocket_tests { where R: IntoClientRequest + Unpin, Fut: Future + Send + 'static, - C: FnOnce(WebSocketUpgrade) -> Fut + Send + 'static, - C: Send + Sync + Clone, + C: FnOnce(WebSocketUpgrade) -> Fut + Send + Sync + Clone + 'static, { let app = Router::new().route("/echo", get(handler)); From 152086b8c1ae50445e213319a1b0db9085da226c Mon Sep 17 00:00:00 2001 From: StellarisW Date: Wed, 7 Aug 2024 16:37:37 +0800 Subject: [PATCH 13/31] feat(http): support websocket server --- volo-http/Cargo.toml | 2 +- volo-http/src/server/extract.rs | 2 -- volo-http/src/server/utils/ws.rs | 32 ++++++++++++++++---------------- 3 files changed, 17 insertions(+), 19 deletions(-) diff --git a/volo-http/Cargo.toml b/volo-http/Cargo.toml index 4210ae4b..bab4953d 100644 --- a/volo-http/Cargo.toml +++ b/volo-http/Cargo.toml @@ -90,7 +90,7 @@ default = [] default_client = ["client", "json"] default_server = ["server", "query", "form", "json"] -full = ["client", "server", "rustls", "cookie", "query", "form", "json", "tls"] +full = ["client", "server", "rustls", "cookie", "query", "form", "json", "tls", "ws"] client = ["hyper/client", "hyper/http1"] # client core server = ["hyper/server", "hyper/http1", "dep:matchit"] # server core diff --git a/volo-http/src/server/extract.rs b/volo-http/src/server/extract.rs index 350ba059..35430dad 100644 --- a/volo-http/src/server/extract.rs +++ b/volo-http/src/server/extract.rs @@ -18,8 +18,6 @@ use hyper::body::Incoming; use mime::Mime; use volo::{context::Context, net::Address}; -#[cfg(feature = "ws")] -pub use super::utils::{Message, WebSocket, WebSocketConfig, WebSocketUpgrade}; use super::IntoResponse; use crate::{ context::ServerContext, diff --git a/volo-http/src/server/utils/ws.rs b/volo-http/src/server/utils/ws.rs index e2a5bd58..129ba57c 100644 --- a/volo-http/src/server/utils/ws.rs +++ b/volo-http/src/server/utils/ws.rs @@ -1,12 +1,10 @@ //! Module for handling WebSocket connection //! -//! //! This module provides utilities for setting up and handling WebSocket connections, including //! configuring WebSocket options, setting protocols, and upgrading connections. //! //! It uses [`hyper::upgrade::OnUpgrade`] to upgrade the connection. //! -//! //! # Example //! //! ```rust @@ -14,7 +12,7 @@ //! use volo_http::{ //! response::ServerResponse, //! server::{ -//! extract::{Message, WebSocket}, +//! utils::{Message, WebSocket}, //! route::get, //! utils::WebSocketUpgrade, //! }, @@ -63,13 +61,13 @@ pub type WebSocket = WebSocketStream>; /// alias of [`tungstenite::Message`] pub type Message = tungstenite::Message; -/// WebSocket headers that will be used for the upgrade request. +/// WebSocket Request headers for establishing a WebSocket connection. struct Headers { /// The `Sec-WebSocket-Key` request header value /// used for compute 'Sec-WebSocket-Accept' response header value sec_websocket_key: HeaderValue, /// The `Sec-WebSocket-Protocol` request header value - /// specify [`Callback`] method depend on protocol + /// specify [`Callback`] method depend on the protocol sec_websocket_protocol: Option, } @@ -85,7 +83,7 @@ impl std::fmt::Debug for Headers { /// WebSocket config #[derive(Default)] pub struct Config { - /// WebSocket config for transport (alias of [`tungstenite::protocol::WebSocketConfig`]) + /// WebSocket config for transport (alias of [`WebSocketConfig`](tungstenite::protocol::WebSocketConfig)) /// e.g. max write buffer size transport: WebSocketConfig, /// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response. @@ -102,12 +100,12 @@ impl Config { } } - /// Set server supported protocols + /// Set server supported protocols. + /// /// This will filter protocols in request header `Sec-WebSocket-Protocol` - /// will set the first server supported protocol in [`http::header::Sec-WebSocket-Protocol`] in + /// and will set the first server supported protocol in [`http::header::Sec-WebSocket-Protocol`] in /// response /// - /// /// ```rust /// use volo_http::server::utils::WebSocketConfig; /// @@ -160,7 +158,7 @@ impl std::fmt::Debug for Config { } /// Callback fn that processes [`WebSocket`] -pub trait Callback: Send + 'static { +trait Callback: Send + 'static { /// Called when a connection upgrade succeeds fn call(self, _: WebSocket) -> impl Future + Send; } @@ -224,7 +222,7 @@ impl Callback for DefaultCallback { /// # Usage /// /// ```rust -/// use volo_http::{response::ServerResponse, server::extract::WebSocketUpgrade}; +/// use volo_http::{response::ServerResponse, server::utils::WebSocketUpgrade}; /// /// fn ws_handler(ws: WebSocketUpgrade) -> ServerResponse { /// ws.on_upgrade(|socket| unimplemented!()) @@ -256,8 +254,10 @@ where /// ```rust /// use volo_http::{ /// response::ServerResponse, - /// server::extract::WebSocketConfig, - /// server::extract::WebSocketUpgrade, + /// server::utils::{ + /// WebSocketConfig, + /// WebSocketUpgrade, + /// } /// }; /// use tokio_tungstenite::tungstenite::protocol::{WebSocketConfig as WebSocketTransConfig}; /// @@ -284,7 +284,7 @@ where /// use std::collections::HashMap; /// use volo_http::{ /// response::ServerResponse, - /// server::extract::{ + /// server::utils::{ /// WebSocketConfig, /// WebSocketUpgrade, /// WebSocket, @@ -333,7 +333,7 @@ where /// use std::collections::HashMap; /// use volo_http::{ /// response::ServerResponse, - /// server::extract::{ + /// server::utils::{ /// WebSocketConfig, /// WebSocketUpgrade, /// WebSocket, @@ -618,7 +618,7 @@ mod websocket_tests { async fn on_protocol() { use crate::{ response::ServerResponse, - server::extract::{WebSocketConfig, WebSocketUpgrade}, + server::utils::{WebSocketConfig, WebSocketUpgrade}, }; async fn ws_handler(ws: WebSocketUpgrade) -> ServerResponse { From b1aab0fdbebca740eadbb0d73511f2382bf43055 Mon Sep 17 00:00:00 2001 From: StellarisW Date: Wed, 7 Aug 2024 16:38:25 +0800 Subject: [PATCH 14/31] feat(http): support websocket server --- volo-http/src/server/utils/ws.rs | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/volo-http/src/server/utils/ws.rs b/volo-http/src/server/utils/ws.rs index 129ba57c..bbe87cb9 100644 --- a/volo-http/src/server/utils/ws.rs +++ b/volo-http/src/server/utils/ws.rs @@ -12,9 +12,8 @@ //! use volo_http::{ //! response::ServerResponse, //! server::{ -//! utils::{Message, WebSocket}, //! route::get, -//! utils::WebSocketUpgrade, +//! utils::{Message, WebSocket, WebSocketUpgrade}, //! }, //! Router, //! }; @@ -83,8 +82,8 @@ impl std::fmt::Debug for Headers { /// WebSocket config #[derive(Default)] pub struct Config { - /// WebSocket config for transport (alias of [`WebSocketConfig`](tungstenite::protocol::WebSocketConfig)) - /// e.g. max write buffer size + /// WebSocket config for transport (alias of + /// [`WebSocketConfig`](tungstenite::protocol::WebSocketConfig)) e.g. max write buffer size transport: WebSocketConfig, /// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response. /// use [`WebSocketUpgrade::protocols`] to set server supported protocols @@ -103,8 +102,8 @@ impl Config { /// Set server supported protocols. /// /// This will filter protocols in request header `Sec-WebSocket-Protocol` - /// and will set the first server supported protocol in [`http::header::Sec-WebSocket-Protocol`] in - /// response + /// and will set the first server supported protocol in [`http::header::Sec-WebSocket-Protocol`] + /// in response /// /// ```rust /// use volo_http::server::utils::WebSocketConfig; From 62cf986fb38603b5c44a3f7c3f4c170e493b9a99 Mon Sep 17 00:00:00 2001 From: StellarisW Date: Wed, 7 Aug 2024 16:52:37 +0800 Subject: [PATCH 15/31] feat(http): support websocket server --- volo-http/src/server/utils/ws.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/volo-http/src/server/utils/ws.rs b/volo-http/src/server/utils/ws.rs index bbe87cb9..0f8ceb4e 100644 --- a/volo-http/src/server/utils/ws.rs +++ b/volo-http/src/server/utils/ws.rs @@ -157,7 +157,7 @@ impl std::fmt::Debug for Config { } /// Callback fn that processes [`WebSocket`] -trait Callback: Send + 'static { +pub trait Callback: Send + 'static { /// Called when a connection upgrade succeeds fn call(self, _: WebSocket) -> impl Future + Send; } @@ -295,7 +295,7 @@ where /// WebSocketConfig::new() /// .set_protocols(["graphql-ws","graphql-transport-ws"]) /// ) - /// .on_protocol(HashMap::from([("graphql-ws",|mut socket: WebSocket| async move{})])) + /// .on_protocol([("graphql-ws",|mut socket: WebSocket| async move{})]) /// .on_upgrade(|socket| async{} ) /// } pub fn on_protocol(self, on_protocol: I) -> WebSocketUpgrade @@ -624,7 +624,7 @@ mod websocket_tests { ws.set_config( WebSocketConfig::new().set_protocols(["graphql-ws", "graphql-transport-ws"]), ) - .on_protocol(HashMap::from([( + .on_protocol([( "graphql-ws", |mut socket: WebSocket| async move { while let Some(Ok(msg)) = socket.next().await { @@ -639,7 +639,7 @@ mod websocket_tests { } } }, - )])) + )]) .on_upgrade(|_| async {}) } From 3800ea5735bc94cf0839659973e1ffffa86a8304 Mon Sep 17 00:00:00 2001 From: StellarisW Date: Wed, 7 Aug 2024 17:03:54 +0800 Subject: [PATCH 16/31] feat(http): support websocket server --- volo-http/src/server/utils/ws.rs | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/volo-http/src/server/utils/ws.rs b/volo-http/src/server/utils/ws.rs index 0f8ceb4e..43fdc060 100644 --- a/volo-http/src/server/utils/ws.rs +++ b/volo-http/src/server/utils/ws.rs @@ -624,22 +624,19 @@ mod websocket_tests { ws.set_config( WebSocketConfig::new().set_protocols(["graphql-ws", "graphql-transport-ws"]), ) - .on_protocol([( - "graphql-ws", - |mut socket: WebSocket| async move { - while let Some(Ok(msg)) = socket.next().await { - match msg { - Message::Text(text) => { - socket - .send(Message::Text(text.add("-graphql-ws"))) - .await - .unwrap(); - } - _ => {} + .on_protocol([("graphql-ws", |mut socket: WebSocket| async move { + while let Some(Ok(msg)) = socket.next().await { + match msg { + Message::Text(text) => { + socket + .send(Message::Text(text.add("-graphql-ws"))) + .await + .unwrap(); } + _ => {} } - }, - )]) + } + })]) .on_upgrade(|_| async {}) } From 78d36fdb5fe624801a906464b5b7ccb05acbcf2d Mon Sep 17 00:00:00 2001 From: StellarisW Date: Wed, 7 Aug 2024 17:04:18 +0800 Subject: [PATCH 17/31] feat(http): support websocket server --- volo-http/src/error/server.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/volo-http/src/error/server.rs b/volo-http/src/error/server.rs index 53a858f6..22d1cd11 100644 --- a/volo-http/src/error/server.rs +++ b/volo-http/src/error/server.rs @@ -100,7 +100,7 @@ pub fn invalid_content_type() -> ExtractBodyError { ExtractBodyError::Generic(GenericRejectionError::InvalidContentType) } -/// Rejection used for [`WebSocketUpgrade`](crate::server::extract::WebSocketUpgrade). +/// Rejection used for [`WebSocketUpgrade`](crate::server::utils::WebSocketUpgrade). #[derive(Debug)] #[non_exhaustive] pub enum WebSocketUpgradeRejectionError { From 9599b6366f4d753369932001d1c3d00148a91820 Mon Sep 17 00:00:00 2001 From: StellarisW Date: Thu, 8 Aug 2024 10:57:08 +0800 Subject: [PATCH 18/31] feat(http): support websocket server --- volo-http/src/server/utils/ws.rs | 165 ++++++------------------------- 1 file changed, 32 insertions(+), 133 deletions(-) diff --git a/volo-http/src/server/utils/ws.rs b/volo-http/src/server/utils/ws.rs index 43fdc060..37f1d9de 100644 --- a/volo-http/src/server/utils/ws.rs +++ b/volo-http/src/server/utils/ws.rs @@ -36,7 +36,7 @@ //! let app = Router::new().route("/ws", get(ws_handler)); //! ``` -use std::{borrow::Cow, collections::HashMap, fmt::Formatter, future::Future}; +use std::{borrow::Cow, fmt::Formatter, future::Future}; use http::{request::Parts, HeaderMap, HeaderName, HeaderValue}; use hyper::Error; @@ -87,7 +87,7 @@ pub struct Config { transport: WebSocketConfig, /// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response. /// use [`WebSocketUpgrade::protocols`] to set server supported protocols - protocols: Option>, + protocols: Vec, } impl Config { @@ -95,7 +95,7 @@ impl Config { pub fn new() -> Self { Config { transport: WebSocketConfig::default(), - protocols: None, + protocols: Vec::new(), } } @@ -115,7 +115,7 @@ impl Config { I: IntoIterator, I::Item: Into>, { - self.protocols = Some( + self.protocols = protocols .into_iter() .map(Into::into) @@ -123,8 +123,7 @@ impl Config { Cow::Owned(s) => HeaderValue::from_str(&s).unwrap(), Cow::Borrowed(s) => HeaderValue::from_static(s), }) - .collect(), - ); + .collect(); self } @@ -159,12 +158,12 @@ impl std::fmt::Debug for Config { /// Callback fn that processes [`WebSocket`] pub trait Callback: Send + 'static { /// Called when a connection upgrade succeeds - fn call(self, _: WebSocket) -> impl Future + Send; + fn call(self, _: WebSocket) -> impl Future + Send; } impl Callback for C where - Fut: Future + Send + 'static, + Fut: Future + Send + 'static, C: FnOnce(WebSocket) -> Fut + Send + Copy + 'static, { async fn call(self, websocket: WebSocket) { @@ -227,15 +226,14 @@ impl Callback for DefaultCallback { /// ws.on_upgrade(|socket| unimplemented!()) /// } /// ``` -pub struct WebSocketUpgrade { +pub struct WebSocketUpgrade { config: Config, - on_protocol: HashMap, on_failed_upgrade: F, on_upgrade: hyper::upgrade::OnUpgrade, headers: Headers, } -impl std::fmt::Debug for WebSocketUpgrade { +impl std::fmt::Debug for WebSocketUpgrade { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("WebSocketUpgrade") .field("config", &self.config) @@ -244,9 +242,8 @@ impl std::fmt::Debug for WebSocketUpgrade { } } -impl WebSocketUpgrade +impl WebSocketUpgrade where - C: Callback + Clone, F: OnFailedUpgrade, { /// Set WebSocket config @@ -278,47 +275,6 @@ where self } - /// Set callback for specific protocol - /// ```rust - /// use std::collections::HashMap; - /// use volo_http::{ - /// response::ServerResponse, - /// server::utils::{ - /// WebSocketConfig, - /// WebSocketUpgrade, - /// WebSocket, - /// } - /// }; - /// - /// async fn ws_handler(ws: WebSocketUpgrade) -> ServerResponse{ - /// ws.set_config( - /// WebSocketConfig::new() - /// .set_protocols(["graphql-ws","graphql-transport-ws"]) - /// ) - /// .on_protocol([("graphql-ws",|mut socket: WebSocket| async move{})]) - /// .on_upgrade(|socket| async{} ) - /// } - pub fn on_protocol(self, on_protocol: I) -> WebSocketUpgrade - where - I: IntoIterator, - I::Item: Into<(H, C1)>, - H: Into<&'static str>, - C1: Callback, - { - let on_protocol = - HashMap::from_iter(on_protocol.into_iter().map(Into::into).map(|(k, v)| { - let k = HeaderValue::from_str(k.into()).unwrap(); - (k, v) - })); - WebSocketUpgrade { - config: self.config, - on_protocol, - on_failed_upgrade: self.on_failed_upgrade, - on_upgrade: self.on_upgrade, - headers: self.headers, - } - } - /// Provide a callback to call if upgrading the connection fails. /// /// The connection upgrade is performed in a background task. @@ -345,10 +301,12 @@ where /// }) /// .on_upgrade(|socket| async{} ) /// } - pub fn on_failed_upgrade(self, callback: F1) -> WebSocketUpgrade { + pub fn on_failed_upgrade(self, callback: F1) -> WebSocketUpgrade + where + F1: OnFailedUpgrade, + { WebSocketUpgrade { config: self.config, - on_protocol: self.on_protocol, on_failed_upgrade: callback, on_upgrade: self.on_upgrade, headers: self.headers, @@ -358,10 +316,10 @@ where /// Finalize upgrading the connection and call the provided callback /// if request protocol is matched, it will use callback set by /// [`WebSocketUpgrade::on_protocol`], otherwise use `default_callback`. - pub fn on_upgrade(self, default_callback: C1) -> ServerResponse + pub fn on_upgrade(self, callback: C) -> ServerResponse where - Fut: Future + Send + 'static, - C1: FnOnce(WebSocket) -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, + C: FnOnce(WebSocket) -> Fut + Send + Sync + 'static, { let on_upgrade = self.on_upgrade; let config = self.config.transport; @@ -374,24 +332,15 @@ where .as_ref() .and_then(|p| p.to_str().ok()) .and_then(|req_protocols| { - let binding = self.config.protocols.clone(); - binding.as_ref().map(|protocol| { - protocol - .iter() - .find(|protocol| { - req_protocols - .split(',') - .any(|req_protocol| req_protocol == *protocol) - }) - .unwrap() - .clone() - }) + self.config.protocols + .iter() + .find(|protocol| { + req_protocols + .split(',') + .any(|req_protocol| req_protocol == *protocol) + }) }); - let callback = protocol - .clone() - .and_then(|protocol| self.on_protocol.get(&protocol).cloned()); - tokio::spawn(async move { let upgraded = match on_upgrade.await { Ok(upgraded) => upgraded, @@ -406,10 +355,7 @@ where WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config)) .await; - match callback { - Some(callback) => callback.call(socket).await, - None => default_callback(socket).await, - } + callback(socket).await; }); const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade"); @@ -454,7 +400,7 @@ fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool } } -impl FromContext for WebSocketUpgrade { +impl FromContext for WebSocketUpgrade { type Rejection = WebSocketUpgradeRejectionError; async fn from_context( @@ -504,14 +450,13 @@ impl FromContext for WebSocketUpgrade { }, on_failed_upgrade: DefaultOnFailedUpgrade, on_upgrade, - on_protocol: HashMap::new(), }) } } #[cfg(test)] mod websocket_tests { - use std::{net, ops::Add}; + use std::{net}; use futures_util::{SinkExt, StreamExt}; use http::Uri; @@ -539,7 +484,7 @@ mod websocket_tests { ) -> (WebSocketStream>, ServerResponse) where R: IntoClientRequest + Unpin, - Fut: Future + Send + 'static, + Fut: Future + Send + 'static, C: FnOnce(WebSocketUpgrade) -> Fut + Send + Sync + Clone + 'static, { let app = Router::new().route("/echo", get(handler)); @@ -559,7 +504,7 @@ mod websocket_tests { } #[tokio::test] - async fn rejects_unupgradable_requests() { + async fn reject_unupgradable_requests() { let route: Route = Route::new(get( |ws: Result| { let rejection = ws.unwrap_err(); @@ -589,7 +534,7 @@ mod websocket_tests { } #[tokio::test] - async fn rejects_non_get_requests() { + async fn reject_non_get_requests() { let route: Route = Route::new(get( |ws: Result| { let rejection = ws.unwrap_err(); @@ -614,53 +559,7 @@ mod websocket_tests { } #[tokio::test] - async fn on_protocol() { - use crate::{ - response::ServerResponse, - server::utils::{WebSocketConfig, WebSocketUpgrade}, - }; - - async fn ws_handler(ws: WebSocketUpgrade) -> ServerResponse { - ws.set_config( - WebSocketConfig::new().set_protocols(["graphql-ws", "graphql-transport-ws"]), - ) - .on_protocol([("graphql-ws", |mut socket: WebSocket| async move { - while let Some(Ok(msg)) = socket.next().await { - match msg { - Message::Text(text) => { - socket - .send(Message::Text(text.add("-graphql-ws"))) - .await - .unwrap(); - } - _ => {} - } - } - })]) - .on_upgrade(|_| async {}) - } - - let addr = Address::Ip(net::SocketAddr::new( - net::IpAddr::V4(net::Ipv4Addr::new(127, 0, 0, 1)), - 25230, - )); - - let builder = ClientRequestBuilder::new( - format!("ws://{}/echo", addr.clone()) - .parse::() - .unwrap(), - ) - .with_sub_protocol("graphql-ws"); - let (mut ws_stream, _response) = run_ws_handler(addr, ws_handler, builder).await; - - let input = Message::Text("foobar".to_owned()); - ws_stream.send(input.clone()).await.unwrap(); - let output = ws_stream.next().await.unwrap().unwrap(); - assert_eq!(output, Message::Text("foobar-graphql-ws".to_owned())); - } - - #[tokio::test] - async fn integration_test() { + async fn success_on_upgrade() { async fn handle_socket(mut socket: WebSocket) { while let Some(Ok(msg)) = socket.next().await { match msg { @@ -693,7 +592,7 @@ mod websocket_tests { |ws: WebSocketUpgrade| std::future::ready(ws.on_upgrade(handle_socket)), builder, ) - .await; + .await; let input = Message::Text("foobar".to_owned()); ws_stream.send(input.clone()).await.unwrap(); From 0adfbd6f01ddb16ac21c86aae18bfa8c5ad775d8 Mon Sep 17 00:00:00 2001 From: StellarisW Date: Thu, 8 Aug 2024 10:58:36 +0800 Subject: [PATCH 19/31] feat(http): support websocket server --- volo-http/src/server/utils/ws.rs | 44 +++++++++++++++----------------- 1 file changed, 20 insertions(+), 24 deletions(-) diff --git a/volo-http/src/server/utils/ws.rs b/volo-http/src/server/utils/ws.rs index 37f1d9de..5b29610e 100644 --- a/volo-http/src/server/utils/ws.rs +++ b/volo-http/src/server/utils/ws.rs @@ -115,15 +115,14 @@ impl Config { I: IntoIterator, I::Item: Into>, { - self.protocols = - protocols - .into_iter() - .map(Into::into) - .map(|protocol| match protocol { - Cow::Owned(s) => HeaderValue::from_str(&s).unwrap(), - Cow::Borrowed(s) => HeaderValue::from_static(s), - }) - .collect(); + self.protocols = protocols + .into_iter() + .map(Into::into) + .map(|protocol| match protocol { + Cow::Owned(s) => HeaderValue::from_str(&s).unwrap(), + Cow::Borrowed(s) => HeaderValue::from_static(s), + }) + .collect(); self } @@ -158,12 +157,12 @@ impl std::fmt::Debug for Config { /// Callback fn that processes [`WebSocket`] pub trait Callback: Send + 'static { /// Called when a connection upgrade succeeds - fn call(self, _: WebSocket) -> impl Future + Send; + fn call(self, _: WebSocket) -> impl Future + Send; } impl Callback for C where - Fut: Future + Send + 'static, + Fut: Future + Send + 'static, C: FnOnce(WebSocket) -> Fut + Send + Copy + 'static, { async fn call(self, websocket: WebSocket) { @@ -314,11 +313,10 @@ where } /// Finalize upgrading the connection and call the provided callback - /// if request protocol is matched, it will use callback set by - /// [`WebSocketUpgrade::on_protocol`], otherwise use `default_callback`. + /// if request protocol is matched, it will use `callback` to handle the connection stream data pub fn on_upgrade(self, callback: C) -> ServerResponse where - Fut: Future + Send + 'static, + Fut: Future + Send + 'static, C: FnOnce(WebSocket) -> Fut + Send + Sync + 'static, { let on_upgrade = self.on_upgrade; @@ -332,13 +330,11 @@ where .as_ref() .and_then(|p| p.to_str().ok()) .and_then(|req_protocols| { - self.config.protocols - .iter() - .find(|protocol| { - req_protocols - .split(',') - .any(|req_protocol| req_protocol == *protocol) - }) + self.config.protocols.iter().find(|protocol| { + req_protocols + .split(',') + .any(|req_protocol| req_protocol == *protocol) + }) }); tokio::spawn(async move { @@ -456,7 +452,7 @@ impl FromContext for WebSocketUpgrade { #[cfg(test)] mod websocket_tests { - use std::{net}; + use std::net; use futures_util::{SinkExt, StreamExt}; use http::Uri; @@ -484,7 +480,7 @@ mod websocket_tests { ) -> (WebSocketStream>, ServerResponse) where R: IntoClientRequest + Unpin, - Fut: Future + Send + 'static, + Fut: Future + Send + 'static, C: FnOnce(WebSocketUpgrade) -> Fut + Send + Sync + Clone + 'static, { let app = Router::new().route("/echo", get(handler)); @@ -592,7 +588,7 @@ mod websocket_tests { |ws: WebSocketUpgrade| std::future::ready(ws.on_upgrade(handle_socket)), builder, ) - .await; + .await; let input = Message::Text("foobar".to_owned()); ws_stream.send(input.clone()).await.unwrap(); From 7672b4ee1862bd644c8536ce8eee5b21b5bbd1a3 Mon Sep 17 00:00:00 2001 From: StellarisW Date: Thu, 8 Aug 2024 14:40:23 +0800 Subject: [PATCH 20/31] feat(http): support websocket server --- volo-http/src/error/server.rs | 2 +- volo-http/src/server/utils/ws.rs | 16 ++++++++++------ 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/volo-http/src/error/server.rs b/volo-http/src/error/server.rs index 22d1cd11..42e18b82 100644 --- a/volo-http/src/error/server.rs +++ b/volo-http/src/error/server.rs @@ -122,7 +122,7 @@ pub enum WebSocketUpgradeRejectionError { impl WebSocketUpgradeRejectionError { /// Convert the [`WebSocketUpgradeRejectionError`] to the corresponding [`StatusCode`] - pub fn to_status_code(self) -> StatusCode { + fn to_status_code(self) -> StatusCode { match self { Self::MethodNotGet => StatusCode::METHOD_NOT_ALLOWED, Self::InvalidHttpVersion => StatusCode::HTTP_VERSION_NOT_SUPPORTED, diff --git a/volo-http/src/server/utils/ws.rs b/volo-http/src/server/utils/ws.rs index 5b29610e..417a245e 100644 --- a/volo-http/src/server/utils/ws.rs +++ b/volo-http/src/server/utils/ws.rs @@ -1,4 +1,4 @@ -//! Module for handling WebSocket connection +//! Handle WebSocket connection //! //! This module provides utilities for setting up and handling WebSocket connections, including //! configuring WebSocket options, setting protocols, and upgrading connections. @@ -8,6 +8,8 @@ //! # Example //! //! ```rust +//! use std::convert::Infallible; +//! //! use futures_util::{SinkExt, StreamExt}; //! use volo_http::{ //! response::ServerResponse, @@ -21,7 +23,7 @@ //! async fn handle_socket(mut socket: WebSocket) { //! while let Some(Ok(msg)) = socket.next().await { //! match msg { -//! Message::Text(text) => { +//! Message::Text(_) => { //! socket.send(msg).await.unwrap(); //! } //! _ => {} @@ -33,7 +35,7 @@ //! ws.on_upgrade(handle_socket) //! } //! -//! let app = Router::new().route("/ws", get(ws_handler)); +//! let app: Router = Router::new().route("/ws", get(ws_handler)); //! ``` use std::{borrow::Cow, fmt::Formatter, future::Future}; @@ -127,8 +129,8 @@ impl Config { } /// Set transport config - /// e.g. write buffer size /// + /// e.g. write buffer size /// /// ```rust /// use tokio_tungstenite::tungstenite::protocol::WebSocketConfig as WebSocketTransConfig; @@ -213,16 +215,18 @@ impl Callback for DefaultCallback { /// /// **Constrains**: /// -/// The extractor only supports for the request that has the method [`GET`](http::Method::GET) +/// The extractor only supports for the request that has the method [`GET`](http::method::GET) /// and contains certain header values. /// +/// See more details in [`WebSocketUpgrade::from_context`] +/// /// # Usage /// /// ```rust /// use volo_http::{response::ServerResponse, server::utils::WebSocketUpgrade}; /// /// fn ws_handler(ws: WebSocketUpgrade) -> ServerResponse { -/// ws.on_upgrade(|socket| unimplemented!()) +/// ws.on_upgrade(|socket| async { unimplemented!() }) /// } /// ``` pub struct WebSocketUpgrade { From cc17a1da48458e64dbe7c593ca49b8ac5b23902c Mon Sep 17 00:00:00 2001 From: StellarisW Date: Thu, 8 Aug 2024 14:44:59 +0800 Subject: [PATCH 21/31] feat(http): support websocket server --- volo-http/src/server/utils/ws.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/volo-http/src/server/utils/ws.rs b/volo-http/src/server/utils/ws.rs index 417a245e..c759f29a 100644 --- a/volo-http/src/server/utils/ws.rs +++ b/volo-http/src/server/utils/ws.rs @@ -215,7 +215,7 @@ impl Callback for DefaultCallback { /// /// **Constrains**: /// -/// The extractor only supports for the request that has the method [`GET`](http::method::GET) +/// The extractor only supports for the request that has the method [`GET`](http::Method::GET) /// and contains certain header values. /// /// See more details in [`WebSocketUpgrade::from_context`] From 4bd0e1645608f66459f7e24cfb70a98a3af7a4ff Mon Sep 17 00:00:00 2001 From: StellarisW Date: Thu, 8 Aug 2024 15:14:20 +0800 Subject: [PATCH 22/31] feat(http): support websocket server --- volo-http/src/server/utils/ws.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/volo-http/src/server/utils/ws.rs b/volo-http/src/server/utils/ws.rs index c759f29a..ca0b0333 100644 --- a/volo-http/src/server/utils/ws.rs +++ b/volo-http/src/server/utils/ws.rs @@ -250,6 +250,7 @@ where F: OnFailedUpgrade, { /// Set WebSocket config + /// /// ```rust /// use volo_http::{ /// response::ServerResponse, @@ -273,6 +274,7 @@ where /// ) /// .on_upgrade(|socket| async{} ) /// } + /// ``` pub fn set_config(mut self, config: Config) -> Self { self.config = config; self @@ -317,7 +319,8 @@ where } /// Finalize upgrading the connection and call the provided callback - /// if request protocol is matched, it will use `callback` to handle the connection stream data + /// + /// If request protocol is matched, it will use `callback` to handle the connection stream data pub fn on_upgrade(self, callback: C) -> ServerResponse where Fut: Future + Send + 'static, From 05b2e8282537999ed2fd3c69f52db109f889cd23 Mon Sep 17 00:00:00 2001 From: StellarisW Date: Thu, 8 Aug 2024 15:18:04 +0800 Subject: [PATCH 23/31] feat(http): support websocket server --- volo-http/src/server/utils/ws.rs | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/volo-http/src/server/utils/ws.rs b/volo-http/src/server/utils/ws.rs index ca0b0333..f2ed8794 100644 --- a/volo-http/src/server/utils/ws.rs +++ b/volo-http/src/server/utils/ws.rs @@ -252,27 +252,22 @@ where /// Set WebSocket config /// /// ```rust + /// use tokio_tungstenite::tungstenite::protocol::WebSocketConfig as WebSocketTransConfig; /// use volo_http::{ /// response::ServerResponse, - /// server::utils::{ - /// WebSocketConfig, - /// WebSocketUpgrade, - /// } + /// server::utils::{WebSocketConfig, WebSocketUpgrade}, /// }; - /// use tokio_tungstenite::tungstenite::protocol::{WebSocketConfig as WebSocketTransConfig}; /// - /// async fn ws_handler(ws: WebSocketUpgrade) -> ServerResponse{ + /// async fn ws_handler(ws: WebSocketUpgrade) -> ServerResponse { /// ws.set_config( /// WebSocketConfig::new() - /// .set_protocols(["graphql-ws","graphql-transport-ws"]) - /// .set_transport( - /// WebSocketTransConfig{ - /// write_buffer_size: 128 * 1024, - /// ..<_>::default() - /// } - /// ) - /// ) - /// .on_upgrade(|socket| async{} ) + /// .set_protocols(["graphql-ws", "graphql-transport-ws"]) + /// .set_transport(WebSocketTransConfig { + /// write_buffer_size: 128 * 1024, + /// ..<_>::default() + /// }), + /// ) + /// .on_upgrade(|socket| async {}) /// } /// ``` pub fn set_config(mut self, config: Config) -> Self { From d8fa4347ca2e7dadd51dfb5cef62a12a4bddcad6 Mon Sep 17 00:00:00 2001 From: StellarisW Date: Thu, 8 Aug 2024 15:45:10 +0800 Subject: [PATCH 24/31] feat(http): support websocket server --- volo-http/src/server/utils/ws.rs | 48 ++++++++++++++++++++++++++------ 1 file changed, 39 insertions(+), 9 deletions(-) diff --git a/volo-http/src/server/utils/ws.rs b/volo-http/src/server/utils/ws.rs index f2ed8794..a5e6d74e 100644 --- a/volo-http/src/server/utils/ws.rs +++ b/volo-http/src/server/utils/ws.rs @@ -89,7 +89,7 @@ pub struct Config { transport: WebSocketConfig, /// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response. /// use [`WebSocketUpgrade::protocols`] to set server supported protocols - protocols: Vec, + protocols: Vec, } impl Config { @@ -121,8 +121,8 @@ impl Config { .into_iter() .map(Into::into) .map(|protocol| match protocol { - Cow::Owned(s) => HeaderValue::from_str(&s).unwrap(), - Cow::Borrowed(s) => HeaderValue::from_static(s), + Cow::Owned(s) => s, + Cow::Borrowed(s) => s.to_string(), }) .collect(); self @@ -356,7 +356,9 @@ where callback(socket).await; }); + #[allow(clippy::declare_interior_mutable_const)] const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade"); + #[allow(clippy::declare_interior_mutable_const)] const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket"); let mut builder = ServerResponse::builder() @@ -493,12 +495,9 @@ mod websocket_tests { tokio::time::sleep(std::time::Duration::from_secs(1)).await; - let (socket, response) = tokio_tungstenite::connect_async(req).await.unwrap(); + let (socket, resp) = tokio_tungstenite::connect_async(req).await.unwrap(); - ( - socket, - response.map(|response| response.unwrap_or_default().into()), - ) + (socket, resp.map(|resp| resp.unwrap_or_default().into())) } #[tokio::test] @@ -556,6 +555,37 @@ mod websocket_tests { assert_eq!(resp.status(), http::StatusCode::METHOD_NOT_ALLOWED); } + #[tokio::test] + async fn set_protocols() { + let addr = Address::Ip(net::SocketAddr::new( + net::IpAddr::V4(net::Ipv4Addr::new(127, 0, 0, 1)), + 25230, + )); + + let builder = ClientRequestBuilder::new( + format!("ws://{}/echo", addr.clone()) + .parse::() + .unwrap(), + ) + .with_sub_protocol("graphql-ws"); + + let (_, resp) = run_ws_handler( + addr.clone(), + |ws: WebSocketUpgrade| async { + ws.set_config(Config::new().set_protocols(["graphql-ws", "test-protocol"])) + .on_upgrade(|_| async {}) + }, + builder, + ) + .await; + + assert!(resp + .headers() + .get(http::header::SEC_WEBSOCKET_PROTOCOL) + .unwrap() + .eq("graphql-ws")); + } + #[tokio::test] async fn success_on_upgrade() { async fn handle_socket(mut socket: WebSocket) { @@ -585,7 +615,7 @@ mod websocket_tests { .unwrap(), ); - let (mut ws_stream, _response) = run_ws_handler( + let (mut ws_stream, _resp) = run_ws_handler( addr.clone(), |ws: WebSocketUpgrade| std::future::ready(ws.on_upgrade(handle_socket)), builder, From 105c74102df4792622d22eb6ea3ea609de39cae8 Mon Sep 17 00:00:00 2001 From: StellarisW Date: Thu, 8 Aug 2024 15:46:16 +0800 Subject: [PATCH 25/31] feat(http): support websocket server --- volo-http/src/server/utils/ws.rs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/volo-http/src/server/utils/ws.rs b/volo-http/src/server/utils/ws.rs index a5e6d74e..24893e82 100644 --- a/volo-http/src/server/utils/ws.rs +++ b/volo-http/src/server/utils/ws.rs @@ -579,11 +579,12 @@ mod websocket_tests { ) .await; - assert!(resp - .headers() - .get(http::header::SEC_WEBSOCKET_PROTOCOL) - .unwrap() - .eq("graphql-ws")); + assert_eq!( + resp.headers() + .get(http::header::SEC_WEBSOCKET_PROTOCOL) + .unwrap(), + "graphql-ws" + ); } #[tokio::test] From f047d681ee23a6e92f4aac006a45f4d63daeea27 Mon Sep 17 00:00:00 2001 From: StellarisW Date: Thu, 8 Aug 2024 15:53:13 +0800 Subject: [PATCH 26/31] feat(http): support websocket server --- volo-http/src/server/utils/ws.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/volo-http/src/server/utils/ws.rs b/volo-http/src/server/utils/ws.rs index 24893e82..00d5d59e 100644 --- a/volo-http/src/server/utils/ws.rs +++ b/volo-http/src/server/utils/ws.rs @@ -211,7 +211,7 @@ impl Callback for DefaultCallback { async fn call(self, _: WebSocket) {} } -/// Extractor of [`FromContext`] for establishing WebSocket connection +/// Handler request for establishing WebSocket connection /// /// **Constrains**: /// From ffde3a4cd008a392f975bbaa1de06a5e0135f97d Mon Sep 17 00:00:00 2001 From: StellarisW Date: Thu, 8 Aug 2024 17:04:22 +0800 Subject: [PATCH 27/31] feat(http): support websocket server --- volo-http/src/server/utils/ws.rs | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/volo-http/src/server/utils/ws.rs b/volo-http/src/server/utils/ws.rs index 00d5d59e..f88d0dd8 100644 --- a/volo-http/src/server/utils/ws.rs +++ b/volo-http/src/server/utils/ws.rs @@ -35,7 +35,7 @@ //! ws.on_upgrade(handle_socket) //! } //! -//! let app: Router = Router::new().route("/ws", get(ws_handler)); +//! let app: Router = Router::new().route("/ws", get(ws_handler)); //! ``` use std::{borrow::Cow, fmt::Formatter, future::Future}; @@ -213,7 +213,7 @@ impl Callback for DefaultCallback { /// Handler request for establishing WebSocket connection /// -/// **Constrains**: +/// # Constrains: /// /// The extractor only supports for the request that has the method [`GET`](http::Method::GET) /// and contains certain header values. @@ -356,15 +356,10 @@ where callback(socket).await; }); - #[allow(clippy::declare_interior_mutable_const)] - const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade"); - #[allow(clippy::declare_interior_mutable_const)] - const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket"); - let mut builder = ServerResponse::builder() .status(http::StatusCode::SWITCHING_PROTOCOLS) - .header(http::header::CONNECTION, UPGRADE) - .header(http::header::UPGRADE, WEBSOCKET) + .header(http::header::CONNECTION, "upgrade") + .header(http::header::UPGRADE, "websocket") .header( http::header::SEC_WEBSOCKET_ACCEPT, derive_accept_key(self.headers.sec_websocket_key.as_bytes()), From 294275468742d35d2abf9cef0e91ea5f0776e358 Mon Sep 17 00:00:00 2001 From: StellarisW Date: Thu, 8 Aug 2024 17:13:29 +0800 Subject: [PATCH 28/31] feat(http): support websocket server --- volo-http/src/error/server.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/volo-http/src/error/server.rs b/volo-http/src/error/server.rs index 42e18b82..eacaa0a2 100644 --- a/volo-http/src/error/server.rs +++ b/volo-http/src/error/server.rs @@ -76,7 +76,7 @@ impl Error for GenericRejectionError {} impl GenericRejectionError { /// Convert the [`GenericRejectionError`] to the corresponding [`StatusCode`] - pub fn to_status_code(self) -> StatusCode { + pub fn to_status_code(&self) -> StatusCode { match self { Self::BodyCollectionError => StatusCode::INTERNAL_SERVER_ERROR, Self::InvalidContentType => StatusCode::UNSUPPORTED_MEDIA_TYPE, @@ -122,7 +122,7 @@ pub enum WebSocketUpgradeRejectionError { impl WebSocketUpgradeRejectionError { /// Convert the [`WebSocketUpgradeRejectionError`] to the corresponding [`StatusCode`] - fn to_status_code(self) -> StatusCode { + fn to_status_code(&self) -> StatusCode { match self { Self::MethodNotGet => StatusCode::METHOD_NOT_ALLOWED, Self::InvalidHttpVersion => StatusCode::HTTP_VERSION_NOT_SUPPORTED, From 63a55e88d31816eab1edfa637b354cef243a0227 Mon Sep 17 00:00:00 2001 From: StellarisW Date: Thu, 8 Aug 2024 17:50:38 +0800 Subject: [PATCH 29/31] feat(http): support websocket server --- volo-http/src/server/utils/ws.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/volo-http/src/server/utils/ws.rs b/volo-http/src/server/utils/ws.rs index f88d0dd8..8301999f 100644 --- a/volo-http/src/server/utils/ws.rs +++ b/volo-http/src/server/utils/ws.rs @@ -107,6 +107,8 @@ impl Config { /// and will set the first server supported protocol in [`http::header::Sec-WebSocket-Protocol`] /// in response /// + /// # Example + /// /// ```rust /// use volo_http::server::utils::WebSocketConfig; /// @@ -295,12 +297,13 @@ where /// } /// }; /// - /// async fn ws_handler(ws: WebSocketUpgrade) -> ServerResponse{ + /// async fn ws_handler(ws: WebSocketUpgrade) -> ServerResponse { /// ws.on_failed_upgrade(|error| { /// unimplemented!() /// }) /// .on_upgrade(|socket| async{} ) /// } + /// ``` pub fn on_failed_upgrade(self, callback: F1) -> WebSocketUpgrade where F1: OnFailedUpgrade, From 15ae4c7e7a96562d0b1842586d2328ca19ad5a86 Mon Sep 17 00:00:00 2001 From: StellarisW Date: Thu, 8 Aug 2024 17:51:35 +0800 Subject: [PATCH 30/31] feat(http): support websocket server --- volo-http/src/server/utils/ws.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/volo-http/src/server/utils/ws.rs b/volo-http/src/server/utils/ws.rs index 8301999f..12f26b12 100644 --- a/volo-http/src/server/utils/ws.rs +++ b/volo-http/src/server/utils/ws.rs @@ -253,6 +253,8 @@ where { /// Set WebSocket config /// + /// # Example + /// /// ```rust /// use tokio_tungstenite::tungstenite::protocol::WebSocketConfig as WebSocketTransConfig; /// use volo_http::{ From 865439a7d0ac4ef3c40d86c90ef2683546bc1644 Mon Sep 17 00:00:00 2001 From: StellarisW Date: Thu, 8 Aug 2024 17:56:39 +0800 Subject: [PATCH 31/31] feat(http): support websocket server --- volo-http/src/server/utils/ws.rs | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/volo-http/src/server/utils/ws.rs b/volo-http/src/server/utils/ws.rs index 12f26b12..8dbb7652 100644 --- a/volo-http/src/server/utils/ws.rs +++ b/volo-http/src/server/utils/ws.rs @@ -290,20 +290,15 @@ where /// /// ```rust /// use std::collections::HashMap; + /// /// use volo_http::{ /// response::ServerResponse, - /// server::utils::{ - /// WebSocketConfig, - /// WebSocketUpgrade, - /// WebSocket, - /// } + /// server::utils::{WebSocket, WebSocketConfig, WebSocketUpgrade}, /// }; /// /// async fn ws_handler(ws: WebSocketUpgrade) -> ServerResponse { - /// ws.on_failed_upgrade(|error| { - /// unimplemented!() - /// }) - /// .on_upgrade(|socket| async{} ) + /// ws.on_failed_upgrade(|error| unimplemented!()) + /// .on_upgrade(|socket| async {}) /// } /// ``` pub fn on_failed_upgrade(self, callback: F1) -> WebSocketUpgrade