From daafd608dd3d3e82b4bfaffb8f63d311e94ddd96 Mon Sep 17 00:00:00 2001 From: StellarisW Date: Fri, 9 Aug 2024 11:09:55 +0800 Subject: [PATCH] feat(http): support websocket server (#481) --- Cargo.lock | 102 +++++ Cargo.toml | 1 + volo-http/Cargo.toml | 8 +- volo-http/src/error/server.rs | 68 +++- volo-http/src/server/mod.rs | 6 +- volo-http/src/server/utils/mod.rs | 5 + volo-http/src/server/utils/ws.rs | 631 ++++++++++++++++++++++++++++++ 7 files changed, 817 insertions(+), 4 deletions(-) create mode 100644 volo-http/src/server/utils/ws.rs diff --git a/Cargo.lock b/Cargo.lock index 8912c9c1..524857b7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -295,6 +295,15 @@ version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "bumpalo" version = "3.16.0" @@ -443,6 +452,15 @@ version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" +[[package]] +name = "cpufeatures" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504" +dependencies = [ + "libc", +] + [[package]] name = "crc32fast" version = "1.4.2" @@ -477,6 +495,16 @@ version = "0.8.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + [[package]] name = "dashmap" version = "5.5.3" @@ -516,6 +544,16 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + [[package]] name = "directories" version = "5.0.1" @@ -839,6 +877,16 @@ dependencies = [ "slab", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getrandom" version = "0.2.15" @@ -2722,6 +2770,17 @@ dependencies = [ "unsafe-libyaml", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -3092,6 +3151,18 @@ dependencies = [ "tokio-stream", ] +[[package]] +name = "tokio-tungstenite" +version = "0.23.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6989540ced10490aaf14e6bad2e3d33728a2813310a0c71d1574304c49631cd" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.7.11" @@ -3317,6 +3388,30 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "tungstenite" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e2e2ce1e47ed2994fd43b04c8f618008d4cabdd5ee34027cf14f9d918edd9c8" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http 1.1.0", + "httparse", + "log", + "rand", + "sha1", + "thiserror", + "utf-8", +] + +[[package]] +name = "typenum" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" + [[package]] name = "unicase" version = "2.7.0" @@ -3414,6 +3509,12 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "10477a8879b7fc3bb4fb54fa6adcfd6191de561b13d5413fec9cc0239fd1c882" +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf8parse" version = "0.2.2" @@ -3617,6 +3718,7 @@ dependencies = [ "tokio-native-tls", "tokio-rustls 0.25.0", "tokio-test", + "tokio-tungstenite", "tokio-util", "tracing", "volo", diff --git a/Cargo.toml b/Cargo.toml index 2eebea75..2abc2623 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -133,6 +133,7 @@ webpki-roots = "0.26" tokio-rustls = "0.25" native-tls = "0.2" tokio-native-tls = "0.3" +tokio-tungstenite="0.23.1" [profile.release] opt-level = 3 diff --git a/volo-http/Cargo.toml b/volo-http/Cargo.toml index 65e492af..bab4953d 100644 --- a/volo-http/Cargo.toml +++ b/volo-http/Cargo.toml @@ -59,9 +59,13 @@ tokio = { workspace = true, features = [ tokio-util = { workspace = true, features = ["io"] } tracing.workspace = true +# =====optional===== # server optional matchit = { workspace = true, optional = true } +# protocol optional +tokio-tungstenite = { workspace = true, optional = true } + # tls optional tokio-rustls = { workspace = true, optional = true } tokio-native-tls = { workspace = true, optional = true } @@ -86,11 +90,13 @@ default = [] default_client = ["client", "json"] default_server = ["server", "query", "form", "json"] -full = ["client", "server", "rustls", "cookie", "query", "form", "json", "tls"] +full = ["client", "server", "rustls", "cookie", "query", "form", "json", "tls", "ws"] client = ["hyper/client", "hyper/http1"] # client core server = ["hyper/server", "hyper/http1", "dep:matchit"] # server core +ws = ["dep:tokio-tungstenite"] + tls = ["rustls"] __tls = [] rustls = ["__tls", "dep:tokio-rustls", "volo/rustls"] diff --git a/volo-http/src/error/server.rs b/volo-http/src/error/server.rs index 3be6c308..eacaa0a2 100644 --- a/volo-http/src/error/server.rs +++ b/volo-http/src/error/server.rs @@ -76,7 +76,7 @@ impl Error for GenericRejectionError {} impl GenericRejectionError { /// Convert the [`GenericRejectionError`] to the corresponding [`StatusCode`] - pub fn to_status_code(self) -> StatusCode { + pub fn to_status_code(&self) -> StatusCode { match self { Self::BodyCollectionError => StatusCode::INTERNAL_SERVER_ERROR, Self::InvalidContentType => StatusCode::UNSUPPORTED_MEDIA_TYPE, @@ -99,3 +99,69 @@ pub fn body_collection_error() -> ExtractBodyError { pub fn invalid_content_type() -> ExtractBodyError { ExtractBodyError::Generic(GenericRejectionError::InvalidContentType) } + +/// Rejection used for [`WebSocketUpgrade`](crate::server::utils::WebSocketUpgrade). +#[derive(Debug)] +#[non_exhaustive] +pub enum WebSocketUpgradeRejectionError { + /// The request method must be `GET` + MethodNotGet, + /// The HTTP version is not supported + InvalidHttpVersion, + /// The `Connection` header is invalid + InvalidConnectionHeader, + /// The `Upgrade` header is invalid + InvalidUpgradeHeader, + /// The `Sec-WebSocket-Version` header is invalid + InvalidWebSocketVersionHeader, + /// The `Sec-WebSocket-Key` header is missing + WebSocketKeyHeaderMissing, + /// The connection is not upgradable + ConnectionNotUpgradable, +} + +impl WebSocketUpgradeRejectionError { + /// Convert the [`WebSocketUpgradeRejectionError`] to the corresponding [`StatusCode`] + fn to_status_code(&self) -> StatusCode { + match self { + Self::MethodNotGet => StatusCode::METHOD_NOT_ALLOWED, + Self::InvalidHttpVersion => StatusCode::HTTP_VERSION_NOT_SUPPORTED, + Self::InvalidConnectionHeader => StatusCode::BAD_REQUEST, + Self::InvalidUpgradeHeader => StatusCode::BAD_REQUEST, + Self::InvalidWebSocketVersionHeader => StatusCode::BAD_REQUEST, + Self::WebSocketKeyHeaderMissing => StatusCode::BAD_REQUEST, + Self::ConnectionNotUpgradable => StatusCode::UPGRADE_REQUIRED, + } + } +} + +impl Error for WebSocketUpgradeRejectionError {} + +impl fmt::Display for WebSocketUpgradeRejectionError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::MethodNotGet => write!(f, "Request method must be 'GET'"), + Self::InvalidHttpVersion => { + write!(f, "Http version not support, only support HTTP 1.1 for now") + } + Self::InvalidConnectionHeader => { + write!(f, "Connection header did not include 'upgrade'") + } + Self::InvalidUpgradeHeader => write!(f, "`Upgrade` header did not include 'websocket'"), + Self::InvalidWebSocketVersionHeader => { + write!(f, "`Sec-WebSocket-Version` header did not include '13'") + } + Self::WebSocketKeyHeaderMissing => write!(f, "`Sec-WebSocket-Key` header missing"), + Self::ConnectionNotUpgradable => write!( + f, + "WebSocket request couldn't be upgraded since no upgrade state was present" + ), + } + } +} + +impl IntoResponse for WebSocketUpgradeRejectionError { + fn into_response(self) -> ServerResponse { + self.to_status_code().into_response() + } +} diff --git a/volo-http/src/server/mod.rs b/volo-http/src/server/mod.rs index 7bcd3d7a..56959031 100644 --- a/volo-http/src/server/mod.rs +++ b/volo-http/src/server/mod.rs @@ -443,13 +443,15 @@ async fn serve_conn( let notified = exit_notify.notified(); tokio::pin!(notified); - let mut http_conn = server.serve_connection(TokioIo::new(conn), service); + let mut http_conn = server + .serve_connection(TokioIo::new(conn), service) + .with_upgrades(); tokio::select! { _ = &mut notified => { tracing::trace!("[VOLO] closing a pending connection"); // Graceful shutdown. - hyper::server::conn::http1::Connection::graceful_shutdown( + hyper::server::conn::http1::UpgradeableConnection::graceful_shutdown( Pin::new(&mut http_conn) ); // Continue to poll this connection until shutdown can finish. diff --git a/volo-http/src/server/utils/mod.rs b/volo-http/src/server/utils/mod.rs index d495286b..c965ce17 100644 --- a/volo-http/src/server/utils/mod.rs +++ b/volo-http/src/server/utils/mod.rs @@ -5,3 +5,8 @@ mod serve_dir; pub use file_response::FileResponse; pub use serve_dir::ServeDir; + +#[cfg(feature = "ws")] +pub mod ws; +#[cfg(feature = "ws")] +pub use self::ws::{Config as WebSocketConfig, Message, WebSocket, WebSocketUpgrade}; diff --git a/volo-http/src/server/utils/ws.rs b/volo-http/src/server/utils/ws.rs new file mode 100644 index 00000000..8dbb7652 --- /dev/null +++ b/volo-http/src/server/utils/ws.rs @@ -0,0 +1,631 @@ +//! Handle WebSocket connection +//! +//! This module provides utilities for setting up and handling WebSocket connections, including +//! configuring WebSocket options, setting protocols, and upgrading connections. +//! +//! It uses [`hyper::upgrade::OnUpgrade`] to upgrade the connection. +//! +//! # Example +//! +//! ```rust +//! use std::convert::Infallible; +//! +//! use futures_util::{SinkExt, StreamExt}; +//! use volo_http::{ +//! response::ServerResponse, +//! server::{ +//! route::get, +//! utils::{Message, WebSocket, WebSocketUpgrade}, +//! }, +//! Router, +//! }; +//! +//! async fn handle_socket(mut socket: WebSocket) { +//! while let Some(Ok(msg)) = socket.next().await { +//! match msg { +//! Message::Text(_) => { +//! socket.send(msg).await.unwrap(); +//! } +//! _ => {} +//! } +//! } +//! } +//! +//! async fn ws_handler(ws: WebSocketUpgrade) -> ServerResponse { +//! ws.on_upgrade(handle_socket) +//! } +//! +//! let app: Router = Router::new().route("/ws", get(ws_handler)); +//! ``` + +use std::{borrow::Cow, fmt::Formatter, future::Future}; + +use http::{request::Parts, HeaderMap, HeaderName, HeaderValue}; +use hyper::Error; +use hyper_util::rt::TokioIo; +use tokio_tungstenite::{ + tungstenite::{ + self, + handshake::derive_accept_key, + protocol::{self, WebSocketConfig}, + }, + WebSocketStream, +}; + +use crate::{ + body::Body, context::ServerContext, error::server::WebSocketUpgradeRejectionError, + response::ServerResponse, server::extract::FromContext, +}; + +/// WebSocketStream used In handler Request +pub type WebSocket = WebSocketStream>; +/// alias of [`tungstenite::Message`] +pub type Message = tungstenite::Message; + +/// WebSocket Request headers for establishing a WebSocket connection. +struct Headers { + /// The `Sec-WebSocket-Key` request header value + /// used for compute 'Sec-WebSocket-Accept' response header value + sec_websocket_key: HeaderValue, + /// The `Sec-WebSocket-Protocol` request header value + /// specify [`Callback`] method depend on the protocol + sec_websocket_protocol: Option, +} + +impl std::fmt::Debug for Headers { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Headers") + .field("sec_websocket_key", &self.sec_websocket_protocol) + .field("sec_websocket_protocol", &self.sec_websocket_protocol) + .finish_non_exhaustive() + } +} + +/// WebSocket config +#[derive(Default)] +pub struct Config { + /// WebSocket config for transport (alias of + /// [`WebSocketConfig`](tungstenite::protocol::WebSocketConfig)) e.g. max write buffer size + transport: WebSocketConfig, + /// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response. + /// use [`WebSocketUpgrade::protocols`] to set server supported protocols + protocols: Vec, +} + +impl Config { + /// Create Default Config + pub fn new() -> Self { + Config { + transport: WebSocketConfig::default(), + protocols: Vec::new(), + } + } + + /// Set server supported protocols. + /// + /// This will filter protocols in request header `Sec-WebSocket-Protocol` + /// and will set the first server supported protocol in [`http::header::Sec-WebSocket-Protocol`] + /// in response + /// + /// # Example + /// + /// ```rust + /// use volo_http::server::utils::WebSocketConfig; + /// + /// let config = WebSocketConfig::new().set_protocols(["graphql-ws", "graphql-transport-ws"]); + /// ``` + pub fn set_protocols(mut self, protocols: I) -> Self + where + I: IntoIterator, + I::Item: Into>, + { + self.protocols = protocols + .into_iter() + .map(Into::into) + .map(|protocol| match protocol { + Cow::Owned(s) => s, + Cow::Borrowed(s) => s.to_string(), + }) + .collect(); + self + } + + /// Set transport config + /// + /// e.g. write buffer size + /// + /// ```rust + /// use tokio_tungstenite::tungstenite::protocol::WebSocketConfig as WebSocketTransConfig; + /// use volo_http::server::utils::WebSocketConfig; + /// + /// let config = WebSocketConfig::new().set_transport(WebSocketTransConfig { + /// write_buffer_size: 128 * 1024, + /// ..<_>::default() + /// }); + /// ``` + pub fn set_transport(mut self, config: WebSocketConfig) -> Self { + self.transport = config; + self + } +} + +impl std::fmt::Debug for Config { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Config") + .field("transport", &self.transport) + .field("protocols", &self.protocols) + .finish_non_exhaustive() + } +} + +/// Callback fn that processes [`WebSocket`] +pub trait Callback: Send + 'static { + /// Called when a connection upgrade succeeds + fn call(self, _: WebSocket) -> impl Future + Send; +} + +impl Callback for C +where + Fut: Future + Send + 'static, + C: FnOnce(WebSocket) -> Fut + Send + Copy + 'static, +{ + async fn call(self, websocket: WebSocket) { + self(websocket).await; + } +} + +/// What to do when a connection upgrade fails. +/// +/// See [`WebSocketUpgrade::on_failed_upgrade`] for more details. +pub trait OnFailedUpgrade: Send + 'static { + /// Called when a connection upgrade fails. + fn call(self, error: Error); +} + +impl OnFailedUpgrade for F +where + F: FnOnce(Error) + Send + 'static, +{ + fn call(self, error: Error) { + self(error) + } +} + +/// The default `OnFailedUpgrade` used by `WebSocketUpgrade`. +/// +/// It simply ignores the error. +#[non_exhaustive] +#[derive(Debug)] +pub struct DefaultOnFailedUpgrade; + +impl OnFailedUpgrade for DefaultOnFailedUpgrade { + #[inline] + fn call(self, _error: Error) {} +} + +/// The default `Callback` used by `WebSocketUpgrade`. +/// +/// It simply ignores the socket. +#[derive(Clone)] +pub struct DefaultCallback; +impl Callback for DefaultCallback { + #[inline] + async fn call(self, _: WebSocket) {} +} + +/// Handler request for establishing WebSocket connection +/// +/// # Constrains: +/// +/// The extractor only supports for the request that has the method [`GET`](http::Method::GET) +/// and contains certain header values. +/// +/// See more details in [`WebSocketUpgrade::from_context`] +/// +/// # Usage +/// +/// ```rust +/// use volo_http::{response::ServerResponse, server::utils::WebSocketUpgrade}; +/// +/// fn ws_handler(ws: WebSocketUpgrade) -> ServerResponse { +/// ws.on_upgrade(|socket| async { unimplemented!() }) +/// } +/// ``` +pub struct WebSocketUpgrade { + config: Config, + on_failed_upgrade: F, + on_upgrade: hyper::upgrade::OnUpgrade, + headers: Headers, +} + +impl std::fmt::Debug for WebSocketUpgrade { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("WebSocketUpgrade") + .field("config", &self.config) + .field("headers", &self.headers) + .finish_non_exhaustive() + } +} + +impl WebSocketUpgrade +where + F: OnFailedUpgrade, +{ + /// Set WebSocket config + /// + /// # Example + /// + /// ```rust + /// use tokio_tungstenite::tungstenite::protocol::WebSocketConfig as WebSocketTransConfig; + /// use volo_http::{ + /// response::ServerResponse, + /// server::utils::{WebSocketConfig, WebSocketUpgrade}, + /// }; + /// + /// async fn ws_handler(ws: WebSocketUpgrade) -> ServerResponse { + /// ws.set_config( + /// WebSocketConfig::new() + /// .set_protocols(["graphql-ws", "graphql-transport-ws"]) + /// .set_transport(WebSocketTransConfig { + /// write_buffer_size: 128 * 1024, + /// ..<_>::default() + /// }), + /// ) + /// .on_upgrade(|socket| async {}) + /// } + /// ``` + pub fn set_config(mut self, config: Config) -> Self { + self.config = config; + self + } + + /// Provide a callback to call if upgrading the connection fails. + /// + /// The connection upgrade is performed in a background task. + /// If that fails this callback will be called. + /// + /// By default, any errors will be silently ignored. + /// + /// # Example + /// + /// ```rust + /// use std::collections::HashMap; + /// + /// use volo_http::{ + /// response::ServerResponse, + /// server::utils::{WebSocket, WebSocketConfig, WebSocketUpgrade}, + /// }; + /// + /// async fn ws_handler(ws: WebSocketUpgrade) -> ServerResponse { + /// ws.on_failed_upgrade(|error| unimplemented!()) + /// .on_upgrade(|socket| async {}) + /// } + /// ``` + pub fn on_failed_upgrade(self, callback: F1) -> WebSocketUpgrade + where + F1: OnFailedUpgrade, + { + WebSocketUpgrade { + config: self.config, + on_failed_upgrade: callback, + on_upgrade: self.on_upgrade, + headers: self.headers, + } + } + + /// Finalize upgrading the connection and call the provided callback + /// + /// If request protocol is matched, it will use `callback` to handle the connection stream data + pub fn on_upgrade(self, callback: C) -> ServerResponse + where + Fut: Future + Send + 'static, + C: FnOnce(WebSocket) -> Fut + Send + Sync + 'static, + { + let on_upgrade = self.on_upgrade; + let config = self.config.transport; + let on_failed_upgrade = self.on_failed_upgrade; + + let protocol = self + .headers + .sec_websocket_protocol + .clone() + .as_ref() + .and_then(|p| p.to_str().ok()) + .and_then(|req_protocols| { + self.config.protocols.iter().find(|protocol| { + req_protocols + .split(',') + .any(|req_protocol| req_protocol == *protocol) + }) + }); + + tokio::spawn(async move { + let upgraded = match on_upgrade.await { + Ok(upgraded) => upgraded, + Err(err) => { + on_failed_upgrade.call(err); + return; + } + }; + let upgraded = TokioIo::new(upgraded); + + let socket = + WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config)) + .await; + + callback(socket).await; + }); + + let mut builder = ServerResponse::builder() + .status(http::StatusCode::SWITCHING_PROTOCOLS) + .header(http::header::CONNECTION, "upgrade") + .header(http::header::UPGRADE, "websocket") + .header( + http::header::SEC_WEBSOCKET_ACCEPT, + derive_accept_key(self.headers.sec_websocket_key.as_bytes()), + ); + + if let Some(protocol) = protocol { + builder = builder.header(http::header::SEC_WEBSOCKET_PROTOCOL, protocol); + } + + builder.body(Body::empty()).unwrap() + } +} + +fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool { + let header = if let Some(header) = headers.get(&key) { + header + } else { + return false; + }; + + if let Ok(header) = std::str::from_utf8(header.as_bytes()) { + header.to_ascii_lowercase().contains(value) + } else { + false + } +} + +fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool { + if let Some(header) = headers.get(&key) { + header.as_bytes().eq_ignore_ascii_case(value.as_bytes()) + } else { + false + } +} + +impl FromContext for WebSocketUpgrade { + type Rejection = WebSocketUpgradeRejectionError; + + async fn from_context( + _cx: &mut ServerContext, + parts: &mut Parts, + ) -> Result { + if parts.method != http::Method::GET { + return Err(WebSocketUpgradeRejectionError::MethodNotGet); + } + if parts.version < http::Version::HTTP_11 { + return Err(WebSocketUpgradeRejectionError::InvalidHttpVersion); + } + + if !header_contains(&parts.headers, http::header::CONNECTION, "upgrade") { + return Err(WebSocketUpgradeRejectionError::InvalidConnectionHeader); + } + + if !header_eq(&parts.headers, http::header::UPGRADE, "websocket") { + return Err(WebSocketUpgradeRejectionError::InvalidUpgradeHeader); + } + + if !header_eq(&parts.headers, http::header::SEC_WEBSOCKET_VERSION, "13") { + return Err(WebSocketUpgradeRejectionError::InvalidWebSocketVersionHeader); + } + + let sec_websocket_key = parts + .headers + .get(http::header::SEC_WEBSOCKET_KEY) + .ok_or(WebSocketUpgradeRejectionError::WebSocketKeyHeaderMissing)? + .clone(); + + let on_upgrade = parts + .extensions + .remove::() + .ok_or(WebSocketUpgradeRejectionError::ConnectionNotUpgradable)?; + + let sec_websocket_protocol = parts + .headers + .get(http::header::SEC_WEBSOCKET_PROTOCOL) + .cloned(); + + Ok(Self { + config: Default::default(), + headers: Headers { + sec_websocket_key, + sec_websocket_protocol, + }, + on_failed_upgrade: DefaultOnFailedUpgrade, + on_upgrade, + }) + } +} + +#[cfg(test)] +mod websocket_tests { + use std::net; + + use futures_util::{SinkExt, StreamExt}; + use http::Uri; + use motore::Service; + use tokio::net::TcpStream; + use tokio_tungstenite::{ + tungstenite::{client::IntoClientRequest, ClientRequestBuilder}, + MaybeTlsStream, + }; + use volo::net::Address; + + use super::*; + use crate::{ + server::{ + route::{get, Route}, + test_helpers::empty_cx, + }, + Router, Server, + }; + + async fn run_ws_handler( + addr: Address, + handler: C, + req: R, + ) -> (WebSocketStream>, ServerResponse) + where + R: IntoClientRequest + Unpin, + Fut: Future + Send + 'static, + C: FnOnce(WebSocketUpgrade) -> Fut + Send + Sync + Clone + 'static, + { + let app = Router::new().route("/echo", get(handler)); + + tokio::spawn(async move { + Server::new(app).run(addr).await.unwrap(); + }); + + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + + let (socket, resp) = tokio_tungstenite::connect_async(req).await.unwrap(); + + (socket, resp.map(|resp| resp.unwrap_or_default().into())) + } + + #[tokio::test] + async fn reject_unupgradable_requests() { + let route: Route = Route::new(get( + |ws: Result| { + let rejection = ws.unwrap_err(); + assert!(matches!( + rejection, + WebSocketUpgradeRejectionError::ConnectionNotUpgradable, + )); + std::future::ready(()) + }, + )); + + let req = http::Request::builder() + .version(http::Version::HTTP_11) + .method(http::Method::GET) + .header("upgrade", "websocket") + .header("connection", "Upgrade") + .header("sec-websocket-key", "6D69KGBOr4Re+Nj6zx9aQA==") + .header("sec-websocket-version", "13") + .body(Body::empty()) + .unwrap(); + + let mut cx = empty_cx(); + + let resp = route.call(&mut cx, req).await.unwrap(); + + assert_eq!(resp.status(), http::StatusCode::OK); + } + + #[tokio::test] + async fn reject_non_get_requests() { + let route: Route = Route::new(get( + |ws: Result| { + let rejection = ws.unwrap_err(); + assert!(matches!( + rejection, + WebSocketUpgradeRejectionError::MethodNotGet, + )); + std::future::ready(()) + }, + )); + + let req = http::Request::builder() + .method(http::Method::POST) + .body(Body::empty()) + .unwrap(); + + let mut cx = empty_cx(); + + let resp = route.call(&mut cx, req).await.unwrap(); + + assert_eq!(resp.status(), http::StatusCode::METHOD_NOT_ALLOWED); + } + + #[tokio::test] + async fn set_protocols() { + let addr = Address::Ip(net::SocketAddr::new( + net::IpAddr::V4(net::Ipv4Addr::new(127, 0, 0, 1)), + 25230, + )); + + let builder = ClientRequestBuilder::new( + format!("ws://{}/echo", addr.clone()) + .parse::() + .unwrap(), + ) + .with_sub_protocol("graphql-ws"); + + let (_, resp) = run_ws_handler( + addr.clone(), + |ws: WebSocketUpgrade| async { + ws.set_config(Config::new().set_protocols(["graphql-ws", "test-protocol"])) + .on_upgrade(|_| async {}) + }, + builder, + ) + .await; + + assert_eq!( + resp.headers() + .get(http::header::SEC_WEBSOCKET_PROTOCOL) + .unwrap(), + "graphql-ws" + ); + } + + #[tokio::test] + async fn success_on_upgrade() { + async fn handle_socket(mut socket: WebSocket) { + while let Some(Ok(msg)) = socket.next().await { + match msg { + Message::Text(_) + | Message::Binary(_) + | Message::Close(_) + | Message::Frame(_) => { + if socket.send(msg).await.is_err() { + break; + } + } + Message::Ping(_) | Message::Pong(_) => {} + } + } + } + + let addr = Address::Ip(net::SocketAddr::new( + net::IpAddr::V4(net::Ipv4Addr::new(127, 0, 0, 1)), + 25231, + )); + + let builder = ClientRequestBuilder::new( + format!("ws://{}/echo", addr.clone()) + .parse::() + .unwrap(), + ); + + let (mut ws_stream, _resp) = run_ws_handler( + addr.clone(), + |ws: WebSocketUpgrade| std::future::ready(ws.on_upgrade(handle_socket)), + builder, + ) + .await; + + let input = Message::Text("foobar".to_owned()); + ws_stream.send(input.clone()).await.unwrap(); + let output = ws_stream.next().await.unwrap().unwrap(); + assert_eq!(input, output); + + let input = Message::Ping("foobar".to_owned().into_bytes()); + ws_stream.send(input).await.unwrap(); + let output = ws_stream.next().await.unwrap().unwrap(); + assert_eq!(output, Message::Pong("foobar".to_owned().into_bytes())); + } +}