diff --git a/Cargo.lock b/Cargo.lock index 7206802f..3222bf84 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3164,9 +3164,9 @@ dependencies = [ [[package]] name = "tokio-tungstenite" -version = "0.23.1" +version = "0.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6989540ced10490aaf14e6bad2e3d33728a2813310a0c71d1574304c49631cd" +checksum = "edc5f74e248dc973e0dbb7b74c7e0d6fcc301c694ff50049504004ef4d0cdcd9" dependencies = [ "futures-util", "log", @@ -3409,9 +3409,9 @@ checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" [[package]] name = "tungstenite" -version = "0.23.0" +version = "0.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e2e2ce1e47ed2994fd43b04c8f618008d4cabdd5ee34027cf14f9d918edd9c8" +checksum = "18e5b8366ee7a95b16d32197d0b2604b43a0be89dc5fac9f8e96ccafbaedda8a" dependencies = [ "byteorder", "bytes", @@ -3737,6 +3737,7 @@ dependencies = [ "tokio-tungstenite", "tokio-util", "tracing", + "tungstenite", "volo", ] diff --git a/Cargo.toml b/Cargo.toml index 284634a1..6886bb77 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -133,7 +133,9 @@ webpki-roots = "0.26" tokio-rustls = "0.25" native-tls = "0.2" tokio-native-tls = "0.3" -tokio-tungstenite = "0.23" + +tungstenite = "0.24" +tokio-tungstenite = "0.24" [profile.release] opt-level = 3 diff --git a/volo-http/Cargo.toml b/volo-http/Cargo.toml index dfaec1f1..825262ab 100644 --- a/volo-http/Cargo.toml +++ b/volo-http/Cargo.toml @@ -60,10 +60,12 @@ tokio-util = { workspace = true, features = ["io"] } tracing.workspace = true # =====optional===== + # server optional matchit = { workspace = true, optional = true } # protocol optional +tungstenite = { workspace = true, optional = true } tokio-tungstenite = { workspace = true, optional = true } # tls optional @@ -95,7 +97,7 @@ full = ["client", "server", "rustls", "cookie", "query", "form", "json", "tls", client = ["hyper/client", "hyper/http1"] # client core server = ["hyper/server", "hyper/http1", "dep:matchit"] # server core -ws = ["dep:tokio-tungstenite"] +ws = ["dep:tungstenite", "dep:tokio-tungstenite"] tls = ["rustls"] __tls = [] diff --git a/volo-http/src/error/server.rs b/volo-http/src/error/server.rs index d99a4bbb..d517f3ce 100644 --- a/volo-http/src/error/server.rs +++ b/volo-http/src/error/server.rs @@ -99,69 +99,3 @@ pub fn body_collection_error() -> ExtractBodyError { pub fn invalid_content_type() -> ExtractBodyError { ExtractBodyError::Generic(GenericRejectionError::InvalidContentType) } - -/// Rejection used for [`WebSocketUpgrade`](crate::server::utils::ws::WebSocketUpgrade). -#[derive(Debug)] -#[non_exhaustive] -pub enum WebSocketUpgradeRejectionError { - /// The request method must be `GET` - MethodNotGet, - /// The HTTP version is not supported - InvalidHttpVersion, - /// The `Connection` header is invalid - InvalidConnectionHeader, - /// The `Upgrade` header is invalid - InvalidUpgradeHeader, - /// The `Sec-WebSocket-Version` header is invalid - InvalidWebSocketVersionHeader, - /// The `Sec-WebSocket-Key` header is missing - WebSocketKeyHeaderMissing, - /// The connection is not upgradable - ConnectionNotUpgradable, -} - -impl WebSocketUpgradeRejectionError { - /// Convert the [`WebSocketUpgradeRejectionError`] to the corresponding [`StatusCode`] - fn to_status_code(&self) -> StatusCode { - match self { - Self::MethodNotGet => StatusCode::METHOD_NOT_ALLOWED, - Self::InvalidHttpVersion => StatusCode::HTTP_VERSION_NOT_SUPPORTED, - Self::InvalidConnectionHeader => StatusCode::BAD_REQUEST, - Self::InvalidUpgradeHeader => StatusCode::BAD_REQUEST, - Self::InvalidWebSocketVersionHeader => StatusCode::BAD_REQUEST, - Self::WebSocketKeyHeaderMissing => StatusCode::BAD_REQUEST, - Self::ConnectionNotUpgradable => StatusCode::UPGRADE_REQUIRED, - } - } -} - -impl Error for WebSocketUpgradeRejectionError {} - -impl fmt::Display for WebSocketUpgradeRejectionError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::MethodNotGet => write!(f, "Request method must be 'GET'"), - Self::InvalidHttpVersion => { - write!(f, "Http version not support, only support HTTP 1.1 for now") - } - Self::InvalidConnectionHeader => { - write!(f, "Connection header did not include 'upgrade'") - } - Self::InvalidUpgradeHeader => write!(f, "`Upgrade` header did not include 'websocket'"), - Self::InvalidWebSocketVersionHeader => { - write!(f, "`Sec-WebSocket-Version` header did not include '13'") - } - Self::WebSocketKeyHeaderMissing => write!(f, "`Sec-WebSocket-Key` header missing"), - Self::ConnectionNotUpgradable => write!( - f, - "WebSocket request couldn't be upgraded since no upgrade state was present" - ), - } - } -} - -impl IntoResponse for WebSocketUpgradeRejectionError { - fn into_response(self) -> ServerResponse { - self.to_status_code().into_response() - } -} diff --git a/volo-http/src/server/utils/mod.rs b/volo-http/src/server/utils/mod.rs index 6ad1f88d..a1a456d1 100644 --- a/volo-http/src/server/utils/mod.rs +++ b/volo-http/src/server/utils/mod.rs @@ -7,5 +7,3 @@ pub use file_response::FileResponse; pub use serve_dir::ServeDir; #[cfg(feature = "ws")] pub mod ws; -#[cfg(feature = "ws")] -pub use self::ws::{Config as WebSocketConfig, Message, WebSocket, WebSocketUpgrade}; diff --git a/volo-http/src/server/utils/ws.rs b/volo-http/src/server/utils/ws.rs index 8dbb7652..f4f584fa 100644 --- a/volo-http/src/server/utils/ws.rs +++ b/volo-http/src/server/utils/ws.rs @@ -1,23 +1,20 @@ -//! Handle WebSocket connection +//! WebSocket implementation for server. //! //! This module provides utilities for setting up and handling WebSocket connections, including -//! configuring WebSocket options, setting protocols, and upgrading connections. -//! -//! It uses [`hyper::upgrade::OnUpgrade`] to upgrade the connection. +//! configuring WebSocket options, setting protocols and upgrading connections. //! //! # Example //! -//! ```rust +//! ``` //! use std::convert::Infallible; //! -//! use futures_util::{SinkExt, StreamExt}; +//! use futures_util::{sink::SinkExt, stream::StreamExt}; //! use volo_http::{ //! response::ServerResponse, //! server::{ -//! route::get, -//! utils::{Message, WebSocket, WebSocketUpgrade}, +//! route::{get, Router}, +//! utils::ws::{Message, WebSocket, WebSocketUpgrade}, //! }, -//! Router, //! }; //! //! async fn handle_socket(mut socket: WebSocket) { @@ -37,542 +34,702 @@ //! //! let app: Router = Router::new().route("/ws", get(ws_handler)); //! ``` +//! +//! See [`WebSocketUpgrade`] and [`WebSocket`] for more details. -use std::{borrow::Cow, fmt::Formatter, future::Future}; +use std::{ + borrow::Cow, + fmt, + future::Future, + ops::{Deref, DerefMut}, +}; -use http::{request::Parts, HeaderMap, HeaderName, HeaderValue}; -use hyper::Error; +use ahash::AHashSet; +use http::{ + header, + header::{HeaderMap, HeaderName, HeaderValue}, + method::Method, + request::Parts, + status::StatusCode, + version::Version, +}; use hyper_util::rt::TokioIo; -use tokio_tungstenite::{ - tungstenite::{ - self, - handshake::derive_accept_key, - protocol::{self, WebSocketConfig}, - }, - WebSocketStream, +use tokio_tungstenite::WebSocketStream; +pub use tungstenite::Message; +use tungstenite::{ + handshake::derive_accept_key, + protocol::{self, WebSocketConfig}, }; use crate::{ - body::Body, context::ServerContext, error::server::WebSocketUpgradeRejectionError, - response::ServerResponse, server::extract::FromContext, + body::Body, + context::ServerContext, + response::ServerResponse, + server::{extract::FromContext, IntoResponse}, }; -/// WebSocketStream used In handler Request -pub type WebSocket = WebSocketStream>; -/// alias of [`tungstenite::Message`] -pub type Message = tungstenite::Message; - -/// WebSocket Request headers for establishing a WebSocket connection. -struct Headers { - /// The `Sec-WebSocket-Key` request header value - /// used for compute 'Sec-WebSocket-Accept' response header value +const HEADERVALUE_UPGRADE: HeaderValue = HeaderValue::from_static("upgrade"); +const HEADERVALUE_WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket"); + +/// Handler request for establishing WebSocket connection. +/// +/// [`WebSocketUpgrade`] can be passed as an argument to a handler, which will be called if the +/// http connection making the request can be upgraded to a websocket connection. +/// +/// [`WebSocketUpgrade`] must be used with [`WebSocketUpgrade::on_upgrade`] and a websocket +/// handler, [`WebSocketUpgrade::on_upgrade`] will return a [`ServerResponse`] for the client and +/// the connection will then be upgraded later. +/// +/// # Example +/// +/// ``` +/// use volo_http::{response::ServerResponse, server::utils::ws::WebSocketUpgrade}; +/// +/// fn ws_handler(ws: WebSocketUpgrade) -> ServerResponse { +/// ws.on_upgrade(|socket| async { todo!() }) +/// } +/// ``` +#[must_use] +pub struct WebSocketUpgrade { + config: WebSocketConfig, + protocol: Option, sec_websocket_key: HeaderValue, - /// The `Sec-WebSocket-Protocol` request header value - /// specify [`Callback`] method depend on the protocol sec_websocket_protocol: Option, + on_upgrade: hyper::upgrade::OnUpgrade, + on_failed_upgrade: F, } -impl std::fmt::Debug for Headers { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Headers") - .field("sec_websocket_key", &self.sec_websocket_protocol) - .field("sec_websocket_protocol", &self.sec_websocket_protocol) - .finish_non_exhaustive() - } -} - -/// WebSocket config -#[derive(Default)] -pub struct Config { - /// WebSocket config for transport (alias of - /// [`WebSocketConfig`](tungstenite::protocol::WebSocketConfig)) e.g. max write buffer size - transport: WebSocketConfig, - /// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response. - /// use [`WebSocketUpgrade::protocols`] to set server supported protocols - protocols: Vec, -} - -impl Config { - /// Create Default Config - pub fn new() -> Self { - Config { - transport: WebSocketConfig::default(), - protocols: Vec::new(), - } - } - - /// Set server supported protocols. +impl WebSocketUpgrade { + /// The target minimum size of the write buffer to reach before writing the data to the + /// underlying stream. /// - /// This will filter protocols in request header `Sec-WebSocket-Protocol` - /// and will set the first server supported protocol in [`http::header::Sec-WebSocket-Protocol`] - /// in response + /// The default value is 128 KiB. /// - /// # Example + /// If set to `0` each message will be eagerly written to the underlying stream. It is often + /// more optimal to allow them to buffer a little, hence the default value. /// - /// ```rust - /// use volo_http::server::utils::WebSocketConfig; + /// Note: [`flush`] will always fully write the buffer regardless. /// - /// let config = WebSocketConfig::new().set_protocols(["graphql-ws", "graphql-transport-ws"]); - /// ``` - pub fn set_protocols(mut self, protocols: I) -> Self - where - I: IntoIterator, - I::Item: Into>, - { - self.protocols = protocols - .into_iter() - .map(Into::into) - .map(|protocol| match protocol { - Cow::Owned(s) => s, - Cow::Borrowed(s) => s.to_string(), - }) - .collect(); + /// [`flush`]: futures_util::sink::SinkExt::flush + pub fn write_buffer_size(mut self, size: usize) -> Self { + self.config.write_buffer_size = size; self } - /// Set transport config + /// The max size of the write buffer in bytes. Setting this can provide backpressure + /// in the case the write buffer is filling up due to write errors. /// - /// e.g. write buffer size + /// The default value is unlimited. /// - /// ```rust - /// use tokio_tungstenite::tungstenite::protocol::WebSocketConfig as WebSocketTransConfig; - /// use volo_http::server::utils::WebSocketConfig; + /// Note: The write buffer only builds up past [`write_buffer_size`](Self::write_buffer_size) + /// when writes to the underlying stream are failing. So the **write buffer can not + /// fill up if you are not observing write errors even if not flushing**. /// - /// let config = WebSocketConfig::new().set_transport(WebSocketTransConfig { - /// write_buffer_size: 128 * 1024, - /// ..<_>::default() - /// }); - /// ``` - pub fn set_transport(mut self, config: WebSocketConfig) -> Self { - self.transport = config; + /// Note: Should always be at least [`write_buffer_size + 1 message`](Self::write_buffer_size) + /// and probably a little more depending on error handling strategy. + pub fn max_write_buffer_size(mut self, max: usize) -> Self { + self.config.max_write_buffer_size = max; self } -} -impl std::fmt::Debug for Config { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Config") - .field("transport", &self.transport) - .field("protocols", &self.protocols) - .finish_non_exhaustive() + /// The maximum size of an incoming message. + /// + /// `None` means no size limit. + /// + /// The default value is 64 MiB, which should be reasonably big for all normal use-cases but + /// small enough to prevent memory eating by a malicious user. + pub fn max_message_size(mut self, max: Option) -> Self { + self.config.max_message_size = max; + self } -} -/// Callback fn that processes [`WebSocket`] -pub trait Callback: Send + 'static { - /// Called when a connection upgrade succeeds - fn call(self, _: WebSocket) -> impl Future + Send; -} - -impl Callback for C -where - Fut: Future + Send + 'static, - C: FnOnce(WebSocket) -> Fut + Send + Copy + 'static, -{ - async fn call(self, websocket: WebSocket) { - self(websocket).await; + /// The maximum size of a single incoming message frame. + /// + /// `None` means no size limit. + /// + /// The limit is for frame payload NOT including the frame header. + /// + /// The default value is 16 MiB, which should be reasonably big for all normal use-cases but + /// small enough to prevent memory eating by a malicious user. + pub fn max_frame_size(mut self, max: Option) -> Self { + self.config.max_frame_size = max; + self } -} - -/// What to do when a connection upgrade fails. -/// -/// See [`WebSocketUpgrade::on_failed_upgrade`] for more details. -pub trait OnFailedUpgrade: Send + 'static { - /// Called when a connection upgrade fails. - fn call(self, error: Error); -} -impl OnFailedUpgrade for F -where - F: FnOnce(Error) + Send + 'static, -{ - fn call(self, error: Error) { - self(error) + /// If server to accept unmasked frames. + /// + /// When set to `true`, the server will accept and handle unmasked frames from the client. + /// + /// According to the RFC 6455, the server must close the connection to the client in such + /// cases, however it seems like there are some popular libraries that are sending unmasked + /// frames, ignoring the RFC. + /// + /// By default this option is set to `false`, i.e. according to RFC 6455. + pub fn accept_unmasked_frames(mut self, accept: bool) -> Self { + self.config.accept_unmasked_frames = accept; + self } -} - -/// The default `OnFailedUpgrade` used by `WebSocketUpgrade`. -/// -/// It simply ignores the error. -#[non_exhaustive] -#[derive(Debug)] -pub struct DefaultOnFailedUpgrade; - -impl OnFailedUpgrade for DefaultOnFailedUpgrade { - #[inline] - fn call(self, _error: Error) {} -} - -/// The default `Callback` used by `WebSocketUpgrade`. -/// -/// It simply ignores the socket. -#[derive(Clone)] -pub struct DefaultCallback; -impl Callback for DefaultCallback { - #[inline] - async fn call(self, _: WebSocket) {} -} -/// Handler request for establishing WebSocket connection -/// -/// # Constrains: -/// -/// The extractor only supports for the request that has the method [`GET`](http::Method::GET) -/// and contains certain header values. -/// -/// See more details in [`WebSocketUpgrade::from_context`] -/// -/// # Usage -/// -/// ```rust -/// use volo_http::{response::ServerResponse, server::utils::WebSocketUpgrade}; -/// -/// fn ws_handler(ws: WebSocketUpgrade) -> ServerResponse { -/// ws.on_upgrade(|socket| async { unimplemented!() }) -/// } -/// ``` -pub struct WebSocketUpgrade { - config: Config, - on_failed_upgrade: F, - on_upgrade: hyper::upgrade::OnUpgrade, - headers: Headers, -} + fn get_protocol(&mut self, protocols: I) -> Option + where + I: IntoIterator, + I::Item: Into>, + { + let req_protocols = self + .sec_websocket_protocol + .as_ref()? + .to_str() + .ok()? + .split(',') + .map(str::trim) + .collect::>(); + for protocol in protocols.into_iter().map(Into::into) { + if req_protocols.contains(protocol.as_ref()) { + let protocol = match protocol { + Cow::Owned(s) => HeaderValue::from_str(&s).ok()?, + Cow::Borrowed(s) => HeaderValue::from_static(s), + }; + return Some(protocol); + } + } -impl std::fmt::Debug for WebSocketUpgrade { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("WebSocketUpgrade") - .field("config", &self.config) - .field("headers", &self.headers) - .finish_non_exhaustive() + None } -} -impl WebSocketUpgrade -where - F: OnFailedUpgrade, -{ - /// Set WebSocket config + /// Set available protocols for [`Sec-WebSocket-Protocol`][mdn]. /// - /// # Example + /// If the protocol in [`Sec-WebSocket-Protocol`][mdn] matches any protocol, the upgrade + /// response will insert [`Sec-WebSocket-Protocol`][mdn] and [`WebSocket`] will contain the + /// protocol name. /// - /// ```rust - /// use tokio_tungstenite::tungstenite::protocol::WebSocketConfig as WebSocketTransConfig; - /// use volo_http::{ - /// response::ServerResponse, - /// server::utils::{WebSocketConfig, WebSocketUpgrade}, - /// }; + /// Note that if the client offers multiple protocols that the server supports, the server will + /// pick the first one in the list. /// - /// async fn ws_handler(ws: WebSocketUpgrade) -> ServerResponse { - /// ws.set_config( - /// WebSocketConfig::new() - /// .set_protocols(["graphql-ws", "graphql-transport-ws"]) - /// .set_transport(WebSocketTransConfig { - /// write_buffer_size: 128 * 1024, - /// ..<_>::default() - /// }), - /// ) - /// .on_upgrade(|socket| async {}) - /// } - /// ``` - pub fn set_config(mut self, config: Config) -> Self { - self.config = config; + /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Sec-WebSocket-Protocol + pub fn protocols(mut self, protocols: I) -> Self + where + I: IntoIterator, + I::Item: Into>, + { + self.protocol = self.get_protocol(protocols); self } /// Provide a callback to call if upgrading the connection fails. /// - /// The connection upgrade is performed in a background task. - /// If that fails this callback will be called. + /// The connection upgrade is performed in a background task. If that fails this callback will + /// be called. /// /// By default, any errors will be silently ignored. /// /// # Example /// - /// ```rust - /// use std::collections::HashMap; - /// + /// ``` /// use volo_http::{ /// response::ServerResponse, - /// server::utils::{WebSocket, WebSocketConfig, WebSocketUpgrade}, + /// server::{ + /// route::{get, Router}, + /// utils::ws::{WebSocket, WebSocketUpgrade}, + /// }, /// }; /// /// async fn ws_handler(ws: WebSocketUpgrade) -> ServerResponse { - /// ws.on_failed_upgrade(|error| unimplemented!()) - /// .on_upgrade(|socket| async {}) + /// ws.on_failed_upgrade(|err| eprintln!("Failed to upgrade connection, err: {err}")) + /// .on_upgrade(|socket| async { todo!() }) /// } + /// + /// let router: Router = Router::new().route("/ws", get(ws_handler)); /// ``` - pub fn on_failed_upgrade(self, callback: F1) -> WebSocketUpgrade + pub fn on_failed_upgrade(self, callback: F2) -> WebSocketUpgrade where - F1: OnFailedUpgrade, + F2: OnFailedUpgrade, { WebSocketUpgrade { config: self.config, - on_failed_upgrade: callback, + protocol: self.protocol, + sec_websocket_key: self.sec_websocket_key, + sec_websocket_protocol: self.sec_websocket_protocol, on_upgrade: self.on_upgrade, - headers: self.headers, + on_failed_upgrade: callback, } } /// Finalize upgrading the connection and call the provided callback /// - /// If request protocol is matched, it will use `callback` to handle the connection stream data - pub fn on_upgrade(self, callback: C) -> ServerResponse + /// If request protocol is matched, it will use `callback` to handle the connection stream + /// data. + /// + /// The callback function should be an async function with [`WebSocket`] as parameter. + /// + /// # Example + /// + /// ``` + /// use futures_util::{sink::SinkExt, stream::StreamExt}; + /// use volo_http::{ + /// response::ServerResponse, + /// server::{ + /// route::{get, Router}, + /// utils::ws::{WebSocket, WebSocketUpgrade}, + /// }, + /// }; + /// + /// async fn ws_handler(ws: WebSocketUpgrade) -> ServerResponse { + /// ws.on_upgrade(|mut socket| async move { + /// while let Some(Ok(msg)) = socket.next().await { + /// if msg.is_ping() || msg.is_pong() { + /// continue; + /// } + /// if socket.send(msg).await.is_err() { + /// break; + /// } + /// } + /// }) + /// } + /// + /// let router: Router = Router::new().route("/ws", get(ws_handler)); + /// ``` + pub fn on_upgrade(self, callback: C) -> ServerResponse where - Fut: Future + Send + 'static, - C: FnOnce(WebSocket) -> Fut + Send + Sync + 'static, + C: FnOnce(WebSocket) -> Fut + Send + 'static, + Fut: Future + Send, + F: OnFailedUpgrade + Send + 'static, { - let on_upgrade = self.on_upgrade; - let config = self.config.transport; - let on_failed_upgrade = self.on_failed_upgrade; - - let protocol = self - .headers - .sec_websocket_protocol - .clone() - .as_ref() - .and_then(|p| p.to_str().ok()) - .and_then(|req_protocols| { - self.config.protocols.iter().find(|protocol| { - req_protocols - .split(',') - .any(|req_protocol| req_protocol == *protocol) - }) - }); - - tokio::spawn(async move { - let upgraded = match on_upgrade.await { + let protocol = self.protocol.clone(); + let fut = async move { + let upgraded = match self.on_upgrade.await { Ok(upgraded) => upgraded, Err(err) => { - on_failed_upgrade.call(err); + self.on_failed_upgrade.call(WebSocketError::Upgrade(err)); return; } }; let upgraded = TokioIo::new(upgraded); - let socket = - WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config)) - .await; + let socket = WebSocketStream::from_raw_socket( + upgraded, + protocol::Role::Server, + Some(self.config), + ) + .await; + let socket = WebSocket { + inner: socket, + protocol, + }; callback(socket).await; - }); - - let mut builder = ServerResponse::builder() - .status(http::StatusCode::SWITCHING_PROTOCOLS) - .header(http::header::CONNECTION, "upgrade") - .header(http::header::UPGRADE, "websocket") - .header( - http::header::SEC_WEBSOCKET_ACCEPT, - derive_accept_key(self.headers.sec_websocket_key.as_bytes()), - ); - - if let Some(protocol) = protocol { - builder = builder.header(http::header::SEC_WEBSOCKET_PROTOCOL, protocol); + }; + + let mut resp = ServerResponse::new(Body::empty()); + *resp.status_mut() = StatusCode::SWITCHING_PROTOCOLS; + resp.headers_mut() + .insert(header::CONNECTION, HEADERVALUE_UPGRADE); + resp.headers_mut() + .insert(header::UPGRADE, HEADERVALUE_WEBSOCKET); + let Ok(accept_key) = + HeaderValue::from_str(&derive_accept_key(self.sec_websocket_key.as_bytes())) + else { + return StatusCode::BAD_REQUEST.into_response(); + }; + resp.headers_mut() + .insert(header::SEC_WEBSOCKET_ACCEPT, accept_key); + if let Some(protocol) = self.protocol { + if let Ok(protocol) = HeaderValue::from_bytes(protocol.as_bytes()) { + resp.headers_mut() + .insert(header::SEC_WEBSOCKET_PROTOCOL, protocol); + } } - builder.body(Body::empty()).unwrap() + tokio::spawn(fut); + + resp } } fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool { - let header = if let Some(header) = headers.get(&key) { - header - } else { + let Some(header) = headers.get(&key) else { return false; }; - - if let Ok(header) = std::str::from_utf8(header.as_bytes()) { - header.to_ascii_lowercase().contains(value) - } else { - false - } + let Ok(header) = simdutf8::basic::from_utf8(header.as_bytes()) else { + return false; + }; + header.to_ascii_lowercase().contains(value) } fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool { - if let Some(header) = headers.get(&key) { - header.as_bytes().eq_ignore_ascii_case(value.as_bytes()) - } else { - false - } + let Some(header) = headers.get(&key) else { + return false; + }; + header.as_bytes().eq_ignore_ascii_case(value.as_bytes()) } impl FromContext for WebSocketUpgrade { type Rejection = WebSocketUpgradeRejectionError; async fn from_context( - _cx: &mut ServerContext, + _: &mut ServerContext, parts: &mut Parts, ) -> Result { - if parts.method != http::Method::GET { + if parts.method != Method::GET { return Err(WebSocketUpgradeRejectionError::MethodNotGet); } - if parts.version < http::Version::HTTP_11 { + if parts.version < Version::HTTP_11 { return Err(WebSocketUpgradeRejectionError::InvalidHttpVersion); } - if !header_contains(&parts.headers, http::header::CONNECTION, "upgrade") { + // The `Connection` may be multiple values separated by comma, so we should use + // `header_contains` rather than `header_eq` here. + // + // ref: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Connection + if !header_contains(&parts.headers, header::CONNECTION, "upgrade") { return Err(WebSocketUpgradeRejectionError::InvalidConnectionHeader); } - if !header_eq(&parts.headers, http::header::UPGRADE, "websocket") { + if !header_eq(&parts.headers, header::UPGRADE, "websocket") { return Err(WebSocketUpgradeRejectionError::InvalidUpgradeHeader); } - if !header_eq(&parts.headers, http::header::SEC_WEBSOCKET_VERSION, "13") { + if !header_eq(&parts.headers, header::SEC_WEBSOCKET_VERSION, "13") { return Err(WebSocketUpgradeRejectionError::InvalidWebSocketVersionHeader); } let sec_websocket_key = parts .headers - .get(http::header::SEC_WEBSOCKET_KEY) + .get(header::SEC_WEBSOCKET_KEY) .ok_or(WebSocketUpgradeRejectionError::WebSocketKeyHeaderMissing)? .clone(); + let sec_websocket_protocol = parts.headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned(); + let on_upgrade = parts .extensions .remove::() - .ok_or(WebSocketUpgradeRejectionError::ConnectionNotUpgradable)?; - - let sec_websocket_protocol = parts - .headers - .get(http::header::SEC_WEBSOCKET_PROTOCOL) - .cloned(); + .expect("`OnUpgrade` is unavailable, maybe something wrong with `hyper`"); Ok(Self { config: Default::default(), - headers: Headers { - sec_websocket_key, - sec_websocket_protocol, - }, - on_failed_upgrade: DefaultOnFailedUpgrade, + protocol: None, + sec_websocket_key, + sec_websocket_protocol, on_upgrade, + on_failed_upgrade: DefaultOnFailedUpgrade, }) } } -#[cfg(test)] -mod websocket_tests { - use std::net; +/// WebSocketStream used In handler Request +pub struct WebSocket { + inner: WebSocketStream>, + protocol: Option, +} - use futures_util::{SinkExt, StreamExt}; - use http::Uri; - use motore::Service; - use tokio::net::TcpStream; - use tokio_tungstenite::{ - tungstenite::{client::IntoClientRequest, ClientRequestBuilder}, - MaybeTlsStream, - }; - use volo::net::Address; +impl WebSocket { + /// Get protocol of current websocket. + /// + /// The value of protocol is from [`Sec-WebSocket-Protocol`][mdn] and + /// [`WebSocketUpgrade::protocols`] will pick one if there is any protocol that the server + /// gived. + /// + /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Sec-WebSocket-Protocol + pub fn protocol(&self) -> Option<&str> { + simdutf8::basic::from_utf8(self.protocol.as_ref()?.as_bytes()).ok() + } +} - use super::*; - use crate::{ - server::{ - route::{get, Route}, - test_helpers::empty_cx, - }, - Router, Server, - }; +impl Deref for WebSocket { + type Target = WebSocketStream>; - async fn run_ws_handler( - addr: Address, - handler: C, - req: R, - ) -> (WebSocketStream>, ServerResponse) - where - R: IntoClientRequest + Unpin, - Fut: Future + Send + 'static, - C: FnOnce(WebSocketUpgrade) -> Fut + Send + Sync + Clone + 'static, - { - let app = Router::new().route("/echo", get(handler)); + fn deref(&self) -> &Self::Target { + &self.inner + } +} - tokio::spawn(async move { - Server::new(app).run(addr).await.unwrap(); - }); +impl DerefMut for WebSocket { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} - tokio::time::sleep(std::time::Duration::from_secs(1)).await; +/// Error type when using [`WebSocket`]. +#[derive(Debug)] +pub enum WebSocketError { + /// Error from [`hyper`] when calling [`OnUpgrade.await`][OnUpgrade] for upgrade a HTTP + /// connection to a WebSocket connection. + /// + /// [OnUpgrade]: hyper::upgrade::OnUpgrade + Upgrade(hyper::Error), +} - let (socket, resp) = tokio_tungstenite::connect_async(req).await.unwrap(); +impl fmt::Display for WebSocketError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Upgrade(err) => write!(f, "failed to upgrade: {err}"), + } + } +} + +impl std::error::Error for WebSocketError {} - (socket, resp.map(|resp| resp.unwrap_or_default().into())) +/// What to do when a connection upgrade fails. +/// +/// See [`WebSocketUpgrade::on_failed_upgrade`] for more details. +pub trait OnFailedUpgrade { + /// Called when a connection upgrade fails. + fn call(self, error: WebSocketError); +} + +impl OnFailedUpgrade for F +where + F: FnOnce(WebSocketError), +{ + fn call(self, error: WebSocketError) { + self(error) } +} - #[tokio::test] - async fn reject_unupgradable_requests() { - let route: Route = Route::new(get( - |ws: Result| { - let rejection = ws.unwrap_err(); - assert!(matches!( - rejection, - WebSocketUpgradeRejectionError::ConnectionNotUpgradable, - )); - std::future::ready(()) - }, - )); +/// The default `OnFailedUpgrade` used by `WebSocketUpgrade`. +/// +/// It simply ignores the error. +#[derive(Debug)] +pub struct DefaultOnFailedUpgrade; - let req = http::Request::builder() - .version(http::Version::HTTP_11) - .method(http::Method::GET) - .header("upgrade", "websocket") - .header("connection", "Upgrade") - .header("sec-websocket-key", "6D69KGBOr4Re+Nj6zx9aQA==") - .header("sec-websocket-version", "13") - .body(Body::empty()) - .unwrap(); +impl OnFailedUpgrade for DefaultOnFailedUpgrade { + fn call(self, _: WebSocketError) {} +} - let mut cx = empty_cx(); +/// [`Error`]s while extracting [`WebSocketUpgrade`]. +/// +/// [`Error`]: std::error::Error +/// [`WebSocketUpgrade`]: crate::server::utils::ws::WebSocketUpgrade +#[derive(Debug)] +pub enum WebSocketUpgradeRejectionError { + /// The request method must be `GET` + MethodNotGet, + /// The HTTP version is not supported + InvalidHttpVersion, + /// The `Connection` header is invalid + InvalidConnectionHeader, + /// The `Upgrade` header is invalid + InvalidUpgradeHeader, + /// The `Sec-WebSocket-Version` header is invalid + InvalidWebSocketVersionHeader, + /// The `Sec-WebSocket-Key` header is missing + WebSocketKeyHeaderMissing, +} - let resp = route.call(&mut cx, req).await.unwrap(); +impl WebSocketUpgradeRejectionError { + /// Convert the [`WebSocketUpgradeRejectionError`] to the corresponding [`StatusCode`] + fn to_status_code(&self) -> StatusCode { + match self { + Self::MethodNotGet => StatusCode::METHOD_NOT_ALLOWED, + Self::InvalidHttpVersion => StatusCode::HTTP_VERSION_NOT_SUPPORTED, + Self::InvalidConnectionHeader => StatusCode::UPGRADE_REQUIRED, + Self::InvalidUpgradeHeader => StatusCode::BAD_REQUEST, + Self::InvalidWebSocketVersionHeader => StatusCode::BAD_REQUEST, + Self::WebSocketKeyHeaderMissing => StatusCode::BAD_REQUEST, + } + } +} - assert_eq!(resp.status(), http::StatusCode::OK); +impl std::error::Error for WebSocketUpgradeRejectionError {} + +impl fmt::Display for WebSocketUpgradeRejectionError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::MethodNotGet => f.write_str("Request method must be `GET`"), + Self::InvalidHttpVersion => f.write_str("HTTP version not support"), + Self::InvalidConnectionHeader => { + f.write_str("Header `Connection` does not include `upgrade`") + } + Self::InvalidUpgradeHeader => f.write_str("Header `Upgrade` is not `websocket`"), + Self::InvalidWebSocketVersionHeader => { + f.write_str("Header `Sec-WebSocket-Version` is not `13`") + } + Self::WebSocketKeyHeaderMissing => f.write_str("Header `Sec-WebSocket-Key` is missing"), + } } +} - #[tokio::test] - async fn reject_non_get_requests() { - let route: Route = Route::new(get( - |ws: Result| { - let rejection = ws.unwrap_err(); - assert!(matches!( - rejection, - WebSocketUpgradeRejectionError::MethodNotGet, - )); - std::future::ready(()) - }, - )); +impl IntoResponse for WebSocketUpgradeRejectionError { + fn into_response(self) -> ServerResponse { + self.to_status_code().into_response() + } +} - let req = http::Request::builder() - .method(http::Method::POST) - .body(Body::empty()) +#[cfg(test)] +mod websocket_tests { + use std::{ + convert::Infallible, + net::{IpAddr, Ipv4Addr, SocketAddr}, + str::FromStr, + }; + + use futures_util::{sink::SinkExt, stream::StreamExt}; + use http::uri::Uri; + use motore::service::Service; + use tokio::net::TcpStream; + use tokio_tungstenite::MaybeTlsStream; + use tungstenite::ClientRequestBuilder; + use volo::net::Address; + + use super::*; + use crate::{request::ServerRequest, server::test_helpers, Server}; + + fn simple_parts() -> Parts { + let req = ServerRequest::builder() + .method(Method::GET) + .version(Version::HTTP_11) + .header(header::HOST, "localhost") + .header(header::CONNECTION, super::HEADERVALUE_UPGRADE) + .header(header::UPGRADE, super::HEADERVALUE_WEBSOCKET) + .header(header::SEC_WEBSOCKET_KEY, "6D69KGBOr4Re+Nj6zx9aQA==") + .header(header::SEC_WEBSOCKET_VERSION, "13") + .body(()) .unwrap(); + req.into_parts().0 + } - let mut cx = empty_cx(); + async fn run_ws_handler( + service: S, + sub_protocol: Option<&'static str>, + port: u16, + ) -> ( + WebSocketStream>, + ServerResponse>>, + ) + where + S: Service + + Send + + Sync + + 'static, + { + let addr = Address::Ip(SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), + port, + )); + tokio::spawn(Server::new(service).run(addr.clone())); - let resp = route.call(&mut cx, req).await.unwrap(); + tokio::time::sleep(std::time::Duration::from_secs(1)).await; - assert_eq!(resp.status(), http::StatusCode::METHOD_NOT_ALLOWED); + let mut req = ClientRequestBuilder::new(Uri::from_str(&format!("ws://{addr}/")).unwrap()); + if let Some(sub_protocol) = sub_protocol { + req = req.with_sub_protocol(sub_protocol); + } + tokio_tungstenite::connect_async(req).await.unwrap() } #[tokio::test] - async fn set_protocols() { - let addr = Address::Ip(net::SocketAddr::new( - net::IpAddr::V4(net::Ipv4Addr::new(127, 0, 0, 1)), - 25230, - )); + async fn rejection() { + { + let mut parts = simple_parts(); + parts.method = Method::POST; + let res = + WebSocketUpgrade::from_context(&mut test_helpers::empty_cx(), &mut parts).await; + assert!(matches!( + res, + Err(WebSocketUpgradeRejectionError::MethodNotGet) + )); + } + { + let mut parts = simple_parts(); + parts.version = Version::HTTP_10; + let res = + WebSocketUpgrade::from_context(&mut test_helpers::empty_cx(), &mut parts).await; + assert!(matches!( + res, + Err(WebSocketUpgradeRejectionError::InvalidHttpVersion) + )); + } + { + let mut parts = simple_parts(); + parts.headers.remove(header::CONNECTION); + let res = + WebSocketUpgrade::from_context(&mut test_helpers::empty_cx(), &mut parts).await; + assert!(matches!( + res, + Err(WebSocketUpgradeRejectionError::InvalidConnectionHeader) + )); + } + { + let mut parts = simple_parts(); + parts.headers.remove(header::CONNECTION); + parts + .headers + .insert(header::CONNECTION, HeaderValue::from_static("downgrade")); + let res = + WebSocketUpgrade::from_context(&mut test_helpers::empty_cx(), &mut parts).await; + assert!(matches!( + res, + Err(WebSocketUpgradeRejectionError::InvalidConnectionHeader) + )); + } + { + let mut parts = simple_parts(); + parts.headers.remove(header::UPGRADE); + let res = + WebSocketUpgrade::from_context(&mut test_helpers::empty_cx(), &mut parts).await; + assert!(matches!( + res, + Err(WebSocketUpgradeRejectionError::InvalidUpgradeHeader) + )); + } + { + let mut parts = simple_parts(); + parts.headers.remove(header::UPGRADE); + parts + .headers + .insert(header::UPGRADE, HeaderValue::from_static("supersocket")); + let res = + WebSocketUpgrade::from_context(&mut test_helpers::empty_cx(), &mut parts).await; + assert!(matches!( + res, + Err(WebSocketUpgradeRejectionError::InvalidUpgradeHeader) + )); + } + { + let mut parts = simple_parts(); + parts.headers.remove(header::SEC_WEBSOCKET_VERSION); + let res = + WebSocketUpgrade::from_context(&mut test_helpers::empty_cx(), &mut parts).await; + assert!(matches!( + res, + Err(WebSocketUpgradeRejectionError::InvalidWebSocketVersionHeader) + )); + } + { + let mut parts = simple_parts(); + parts.headers.remove(header::SEC_WEBSOCKET_VERSION); + parts.headers.insert( + header::SEC_WEBSOCKET_VERSION, + HeaderValue::from_static("114514"), + ); + let res = + WebSocketUpgrade::from_context(&mut test_helpers::empty_cx(), &mut parts).await; + assert!(matches!( + res, + Err(WebSocketUpgradeRejectionError::InvalidWebSocketVersionHeader) + )); + } + { + let mut parts = simple_parts(); + parts.headers.remove(header::SEC_WEBSOCKET_KEY); + let res = + WebSocketUpgrade::from_context(&mut test_helpers::empty_cx(), &mut parts).await; + assert!(matches!( + res, + Err(WebSocketUpgradeRejectionError::WebSocketKeyHeaderMissing) + )); + } + } - let builder = ClientRequestBuilder::new( - format!("ws://{}/echo", addr.clone()) - .parse::() - .unwrap(), - ) - .with_sub_protocol("graphql-ws"); - - let (_, resp) = run_ws_handler( - addr.clone(), - |ws: WebSocketUpgrade| async { - ws.set_config(Config::new().set_protocols(["graphql-ws", "test-protocol"])) - .on_upgrade(|_| async {}) - }, - builder, - ) - .await; + #[tokio::test] + async fn protocol_test() { + async fn handler(ws: WebSocketUpgrade) -> ServerResponse { + ws.protocols(["soap", "wmap", "graphql-ws", "chat"]) + .on_upgrade(|_| async {}) + } + + let (_, resp) = + run_ws_handler(test_helpers::to_service(handler), Some("graphql-ws"), 25230).await; assert_eq!( resp.headers() @@ -584,39 +741,23 @@ mod websocket_tests { #[tokio::test] async fn success_on_upgrade() { - async fn handle_socket(mut socket: WebSocket) { + async fn echo(mut socket: WebSocket) { while let Some(Ok(msg)) = socket.next().await { - match msg { - Message::Text(_) - | Message::Binary(_) - | Message::Close(_) - | Message::Frame(_) => { - if socket.send(msg).await.is_err() { - break; - } - } - Message::Ping(_) | Message::Pong(_) => {} + if msg.is_ping() || msg.is_pong() { + continue; + } + if socket.send(msg).await.is_err() { + break; } } } - let addr = Address::Ip(net::SocketAddr::new( - net::IpAddr::V4(net::Ipv4Addr::new(127, 0, 0, 1)), - 25231, - )); - - let builder = ClientRequestBuilder::new( - format!("ws://{}/echo", addr.clone()) - .parse::() - .unwrap(), - ); + async fn handler(ws: WebSocketUpgrade) -> ServerResponse { + ws.on_upgrade(echo) + } - let (mut ws_stream, _resp) = run_ws_handler( - addr.clone(), - |ws: WebSocketUpgrade| std::future::ready(ws.on_upgrade(handle_socket)), - builder, - ) - .await; + let (mut ws_stream, _) = + run_ws_handler(test_helpers::to_service(handler), None, 25231).await; let input = Message::Text("foobar".to_owned()); ws_stream.send(input.clone()).await.unwrap();