Skip to content

Commit 44f12ac

Browse files
committed
feat(http): support websocket server
1 parent 712586d commit 44f12ac

File tree

2 files changed

+125
-85
lines changed

2 files changed

+125
-85
lines changed

volo-http/Cargo.toml

+2-3
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,9 @@ default_server = ["server", "query", "form", "json"]
9292

9393
full = ["client", "server", "rustls", "cookie", "query", "form", "json", "tls"]
9494

95-
client = ["hyper/client", "hyper/http1", "ws"] # client core
96-
server = ["hyper/server", "hyper/http1", "dep:matchit", "ws"] # server core
95+
client = ["hyper/client", "hyper/http1"] # client core
96+
server = ["hyper/server", "hyper/http1", "dep:matchit"] # server core
9797

98-
protocol = ["ws"]
9998
ws = ["dep:tokio-tungstenite"]
10099

101100
tls = ["rustls"]

volo-http/src/server/utils/ws.rs

+123-82
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,47 @@
1-
//! Handle WebSocket connections.
1+
//! Module for handling WebSocket connection
2+
//!
3+
//!
4+
//! This module provides utilities for setting up and handling WebSocket connections, including
5+
//! configuring WebSocket options, setting protocols, and upgrading connections.
6+
//!
7+
//! It uses [`hyper::upgrade::OnUpgrade`] to upgrade the connection.
8+
//!
29
//!
310
//! # Example
11+
//!
12+
//! ```rust
13+
//! use futures_util::{SinkExt, StreamExt};
14+
//! use volo_http::{
15+
//! response::ServerResponse,
16+
//! server::{
17+
//! extract::{Message, WebSocket},
18+
//! route::get,
19+
//! utils::WebSocketUpgrade,
20+
//! },
21+
//! Router,
22+
//! };
23+
//!
24+
//! async fn handle_socket(mut socket: WebSocket) {
25+
//! while let Some(Ok(msg)) = socket.next().await {
26+
//! match msg {
27+
//! Message::Text(text) => {
28+
//! socket.send(msg).await.unwrap();
29+
//! }
30+
//! _ => {}
31+
//! }
32+
//! }
33+
//! }
34+
//!
35+
//! async fn ws_handler(ws: WebSocketUpgrade) -> ServerResponse {
36+
//! ws.on_upgrade(handle_socket)
37+
//! }
38+
//!
39+
//! let app = Router::new().route("/ws", get(ws_handler));
40+
//! ```
441
542
use std::{borrow::Cow, collections::HashMap, fmt::Formatter, future::Future};
643

7-
use http::{header, request::Parts, HeaderMap, HeaderName, HeaderValue, StatusCode};
44+
use http::{request::Parts, HeaderMap, HeaderName, HeaderValue};
845
use hyper::Error;
946
use hyper_util::rt::TokioIo;
1047
use tokio_tungstenite::{
@@ -69,14 +106,13 @@ impl Config {
69106
/// This will filter protocols in request header `Sec-WebSocket-Protocol`
70107
/// will set the first server supported protocol in [`http::header::Sec-WebSocket-Protocol`] in
71108
/// response
109+
///
110+
///
72111
/// ```rust
73112
/// use volo_http::server::utils::WebSocketConfig;
74113
///
75-
/// let config = WebSocketConfig::new()
76-
/// .set_protocols([
77-
/// "graphql-ws",
78-
/// "graphql-transport-ws",
79-
/// ]);
114+
/// let config = WebSocketConfig::new().set_protocols(["graphql-ws", "graphql-transport-ws"]);
115+
/// ```
80116
pub fn set_protocols<I>(mut self, protocols: I) -> Self
81117
where
82118
I: IntoIterator,
@@ -97,17 +133,17 @@ impl Config {
97133

98134
/// Set transport config
99135
/// e.g. write buffer size
136+
///
137+
///
100138
/// ```rust
101-
/// use volo_http::server::utils::WebSocketConfig;
102-
/// use tokio_tungstenite::tungstenite::protocol::{WebSocketConfig as WebSocketTransConfig};
139+
/// use tokio_tungstenite::tungstenite::protocol::WebSocketConfig as WebSocketTransConfig;
140+
/// use volo_http::server::utils::WebSocketConfig;
103141
///
104-
/// let config = WebSocketConfig::new()
105-
/// .set_transport(
106-
/// WebSocketTransConfig{
107-
/// write_buffer_size: 128 * 1024,
108-
/// ..<_>::default()
109-
/// }
110-
/// );
142+
/// let config = WebSocketConfig::new().set_transport(WebSocketTransConfig {
143+
/// write_buffer_size: 128 * 1024,
144+
/// ..<_>::default()
145+
/// });
146+
/// ```
111147
pub fn set_transport(mut self, config: WebSocketConfig) -> Self {
112148
self.transport = config;
113149
self
@@ -126,12 +162,12 @@ impl std::fmt::Debug for Config {
126162
/// Callback fn that processes [`WebSocket`]
127163
pub trait Callback: Send + 'static {
128164
/// Called when a connection upgrade succeeds
129-
fn call(self, _: WebSocket) -> impl Future<Output = ()> + Send;
165+
fn call(self, _: WebSocket) -> impl Future<Output=()> + Send;
130166
}
131167

132168
impl<Fut, C> Callback for C
133169
where
134-
Fut: Future<Output = ()> + Send + 'static,
170+
Fut: Future<Output=()> + Send + 'static,
135171
C: FnOnce(WebSocket) -> Fut + Send + 'static,
136172
C: Copy,
137173
{
@@ -181,10 +217,20 @@ impl Callback for DefaultCallback {
181217

182218
/// Extractor of [`FromContext`] for establishing WebSocket connection
183219
///
184-
/// Constrains:
220+
/// **Constrains**:
221+
///
185222
/// The extractor only supports for the request that has the method [`GET`](http::Method::GET)
186223
/// and contains certain header values.
187-
/// See [`WebSocketUpgrade::from_context`] for more details.
224+
///
225+
/// # Usage
226+
///
227+
/// ```rust
228+
/// use volo_http::{response::ServerResponse, server::extract::WebSocketUpgrade};
229+
///
230+
/// fn ws_handler(ws: WebSocketUpgrade) -> ServerResponse {
231+
/// ws.on_upgrade(|socket| unimplemented!())
232+
/// }
233+
/// ```
188234
pub struct WebSocketUpgrade<C = DefaultCallback, F = DefaultOnFailedUpgrade> {
189235
config: Config,
190236
on_protocol: HashMap<HeaderValue, C>,
@@ -228,7 +274,6 @@ where
228274
/// )
229275
/// )
230276
/// .on_upgrade(|socket| async{} )
231-
/// .unwrap()
232277
/// }
233278
pub fn set_config(mut self, config: Config) -> Self {
234279
self.config = config;
@@ -254,7 +299,6 @@ where
254299
/// )
255300
/// .on_protocol(HashMap::from([("graphql-ws",|mut socket: WebSocket| async move{})]))
256301
/// .on_upgrade(|socket| async{} )
257-
/// .unwrap()
258302
/// }
259303
pub fn on_protocol<H, I, C1>(self, on_protocol: I) -> WebSocketUpgrade<C1, F>
260304
where
@@ -302,7 +346,6 @@ where
302346
/// unimplemented!()
303347
/// })
304348
/// .on_upgrade(|socket| async{} )
305-
/// .unwrap()
306349
/// }
307350
pub fn on_failed_upgrade<F1>(self, callback: F1) -> WebSocketUpgrade<C, F1> {
308351
WebSocketUpgrade {
@@ -317,11 +360,10 @@ where
317360
/// Finalize upgrading the connection and call the provided callback
318361
/// if request protocol is matched, it will use callback set by
319362
/// [`WebSocketUpgrade::on_protocol`], otherwise use `default_callback`.
320-
pub fn on_upgrade<Fut, C1>(self, default_callback: C1) -> Result<ServerResponse, Error>
363+
pub fn on_upgrade<Fut, C1>(self, default_callback: C1) -> ServerResponse
321364
where
322-
Fut: Future<Output = ()> + Send + 'static,
323-
C1: FnOnce(WebSocket) -> Fut + Send + 'static,
324-
C1: Send + Sync,
365+
Fut: Future<Output=()> + Send + 'static,
366+
C1: FnOnce(WebSocket) -> Fut + Send + Sync + 'static,
325367
{
326368
let on_upgrade = self.on_upgrade;
327369
let config = self.config.transport;
@@ -376,19 +418,41 @@ where
376418
const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
377419

378420
let mut builder = ServerResponse::builder()
379-
.status(StatusCode::SWITCHING_PROTOCOLS)
380-
.header(header::CONNECTION, UPGRADE)
381-
.header(header::UPGRADE, WEBSOCKET)
421+
.status(http::StatusCode::SWITCHING_PROTOCOLS)
422+
.header(http::header::CONNECTION, UPGRADE)
423+
.header(http::header::UPGRADE, WEBSOCKET)
382424
.header(
383-
header::SEC_WEBSOCKET_ACCEPT,
425+
http::header::SEC_WEBSOCKET_ACCEPT,
384426
derive_accept_key(self.headers.sec_websocket_key.as_bytes()),
385427
);
386428

387429
if let Some(protocol) = protocol {
388-
builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol);
430+
builder = builder.header(http::header::SEC_WEBSOCKET_PROTOCOL, protocol);
389431
}
390432

391-
Ok(builder.body(Body::empty()).unwrap())
433+
builder.body(Body::empty()).unwrap()
434+
}
435+
}
436+
437+
fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
438+
let header = if let Some(header) = headers.get(&key) {
439+
header
440+
} else {
441+
return false;
442+
};
443+
444+
if let Ok(header) = std::str::from_utf8(header.as_bytes()) {
445+
header.to_ascii_lowercase().contains(value)
446+
} else {
447+
false
448+
}
449+
}
450+
451+
fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
452+
if let Some(header) = headers.get(&key) {
453+
header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
454+
} else {
455+
false
392456
}
393457
}
394458

@@ -406,32 +470,10 @@ impl FromContext for WebSocketUpgrade<DefaultCallback> {
406470
return Err(WebSocketUpgradeRejectionError::InvalidHttpVersion);
407471
}
408472

409-
fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
410-
let header = if let Some(header) = headers.get(&key) {
411-
header
412-
} else {
413-
return false;
414-
};
415-
416-
if let Ok(header) = std::str::from_utf8(header.as_bytes()) {
417-
header.to_ascii_lowercase().contains(value)
418-
} else {
419-
false
420-
}
421-
}
422-
423473
if !header_contains(&parts.headers, http::header::CONNECTION, "upgrade") {
424474
return Err(WebSocketUpgradeRejectionError::InvalidConnectionHeader);
425475
}
426476

427-
fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
428-
if let Some(header) = headers.get(&key) {
429-
header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
430-
} else {
431-
false
432-
}
433-
}
434-
435477
if !header_eq(&parts.headers, http::header::UPGRADE, "websocket") {
436478
return Err(WebSocketUpgradeRejectionError::InvalidUpgradeHeader);
437479
}
@@ -442,7 +484,7 @@ impl FromContext for WebSocketUpgrade<DefaultCallback> {
442484

443485
let sec_websocket_key = parts
444486
.headers
445-
.get(header::SEC_WEBSOCKET_KEY)
487+
.get(http::header::SEC_WEBSOCKET_KEY)
446488
.ok_or(WebSocketUpgradeRejectionError::WebSocketKeyHeaderMissing)?
447489
.clone();
448490

@@ -451,7 +493,7 @@ impl FromContext for WebSocketUpgrade<DefaultCallback> {
451493
.remove::<hyper::upgrade::OnUpgrade>()
452494
.ok_or(WebSocketUpgradeRejectionError::ConnectionNotUpgradable)?;
453495

454-
let sec_websocket_protocol = parts.headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned();
496+
let sec_websocket_protocol = parts.headers.get(http::header::SEC_WEBSOCKET_PROTOCOL).cloned();
455497

456498
Ok(Self {
457499
config: Default::default(),
@@ -471,7 +513,7 @@ mod websocket_tests {
471513
use std::{net, ops::Add};
472514

473515
use futures_util::{SinkExt, StreamExt};
474-
use http::{self, Uri};
516+
use http::{Uri};
475517
use motore::Service;
476518
use tokio::net::TcpStream;
477519
use tokio_tungstenite::{
@@ -496,7 +538,7 @@ mod websocket_tests {
496538
) -> (WebSocketStream<MaybeTlsStream<TcpStream>>, ServerResponse)
497539
where
498540
R: IntoClientRequest + Unpin,
499-
Fut: Future<Output = ServerResponse> + Send + 'static,
541+
Fut: Future<Output=ServerResponse> + Send + 'static,
500542
C: FnOnce(WebSocketUpgrade) -> Fut + Send + 'static,
501543
C: Send + Sync + Clone,
502544
{
@@ -543,7 +585,7 @@ mod websocket_tests {
543585

544586
let resp = route.call(&mut cx, req).await.unwrap();
545587

546-
assert_eq!(resp.status(), StatusCode::OK);
588+
assert_eq!(resp.status(), http::StatusCode::OK);
547589
}
548590

549591
#[tokio::test]
@@ -568,7 +610,7 @@ mod websocket_tests {
568610

569611
let resp = route.call(&mut cx, req).await.unwrap();
570612

571-
assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
613+
assert_eq!(resp.status(), http::StatusCode::METHOD_NOT_ALLOWED);
572614
}
573615

574616
#[tokio::test]
@@ -582,24 +624,23 @@ mod websocket_tests {
582624
ws.set_config(
583625
WebSocketConfig::new().set_protocols(["graphql-ws", "graphql-transport-ws"]),
584626
)
585-
.on_protocol(HashMap::from([(
586-
"graphql-ws",
587-
|mut socket: WebSocket| async move {
588-
while let Some(Ok(msg)) = socket.next().await {
589-
match msg {
590-
Message::Text(text) => {
591-
socket
592-
.send(Message::Text(text.add("-graphql-ws")))
593-
.await
594-
.unwrap();
627+
.on_protocol(HashMap::from([(
628+
"graphql-ws",
629+
|mut socket: WebSocket| async move {
630+
while let Some(Ok(msg)) = socket.next().await {
631+
match msg {
632+
Message::Text(text) => {
633+
socket
634+
.send(Message::Text(text.add("-graphql-ws")))
635+
.await
636+
.unwrap();
637+
}
638+
_ => {}
595639
}
596-
_ => {}
597640
}
598-
}
599-
},
600-
)]))
601-
.on_upgrade(|_| async {})
602-
.unwrap()
641+
},
642+
)]))
643+
.on_upgrade(|_| async {})
603644
}
604645

605646
let addr = Address::Ip(net::SocketAddr::new(
@@ -612,7 +653,7 @@ mod websocket_tests {
612653
.parse::<Uri>()
613654
.unwrap(),
614655
)
615-
.with_sub_protocol("graphql-ws");
656+
.with_sub_protocol("graphql-ws");
616657
let (mut ws_stream, _response) = run_ws_handler(addr, ws_handler, builder).await;
617658

618659
let input = Message::Text("foobar".to_owned());
@@ -621,7 +662,7 @@ mod websocket_tests {
621662
assert_eq!(output, Message::Text("foobar-graphql-ws".to_owned()));
622663
}
623664

624-
#[cfg(test)]
665+
#[tokio::test]
625666
async fn integration_test() {
626667
async fn handle_socket(mut socket: WebSocket) {
627668
while let Some(Ok(msg)) = socket.next().await {
@@ -651,11 +692,11 @@ mod websocket_tests {
651692
);
652693

653694
let (mut ws_stream, _response) = run_ws_handler(
654-
addr,
655-
|ws: WebSocketUpgrade| std::future::ready(ws.on_upgrade(handle_socket).unwrap()),
695+
addr.clone(),
696+
|ws: WebSocketUpgrade| std::future::ready(ws.on_upgrade(handle_socket)),
656697
builder,
657698
)
658-
.await;
699+
.await;
659700

660701
let input = Message::Text("foobar".to_owned());
661702
ws_stream.send(input.clone()).await.unwrap();

0 commit comments

Comments
 (0)