1
- //! Handle WebSocket connections.
1
+ //! Module for handling WebSocket connection
2
+ //!
3
+ //!
4
+ //! This module provides utilities for setting up and handling WebSocket connections, including
5
+ //! configuring WebSocket options, setting protocols, and upgrading connections.
6
+ //!
7
+ //! It uses [`hyper::upgrade::OnUpgrade`] to upgrade the connection.
8
+ //!
2
9
//!
3
10
//! # Example
11
+ //!
12
+ //! ```rust
13
+ //! use futures_util::{SinkExt, StreamExt};
14
+ //! use volo_http::{
15
+ //! response::ServerResponse,
16
+ //! server::{
17
+ //! extract::{Message, WebSocket},
18
+ //! route::get,
19
+ //! utils::WebSocketUpgrade,
20
+ //! },
21
+ //! Router,
22
+ //! };
23
+ //!
24
+ //! async fn handle_socket(mut socket: WebSocket) {
25
+ //! while let Some(Ok(msg)) = socket.next().await {
26
+ //! match msg {
27
+ //! Message::Text(text) => {
28
+ //! socket.send(msg).await.unwrap();
29
+ //! }
30
+ //! _ => {}
31
+ //! }
32
+ //! }
33
+ //! }
34
+ //!
35
+ //! async fn ws_handler(ws: WebSocketUpgrade) -> ServerResponse {
36
+ //! ws.on_upgrade(handle_socket)
37
+ //! }
38
+ //!
39
+ //! let app = Router::new().route("/ws", get(ws_handler));
40
+ //! ```
4
41
5
42
use std:: { borrow:: Cow , collections:: HashMap , fmt:: Formatter , future:: Future } ;
6
43
7
- use http:: { header , request:: Parts , HeaderMap , HeaderName , HeaderValue , StatusCode } ;
44
+ use http:: { request:: Parts , HeaderMap , HeaderName , HeaderValue } ;
8
45
use hyper:: Error ;
9
46
use hyper_util:: rt:: TokioIo ;
10
47
use tokio_tungstenite:: {
@@ -69,14 +106,13 @@ impl Config {
69
106
/// This will filter protocols in request header `Sec-WebSocket-Protocol`
70
107
/// will set the first server supported protocol in [`http::header::Sec-WebSocket-Protocol`] in
71
108
/// response
109
+ ///
110
+ ///
72
111
/// ```rust
73
112
/// use volo_http::server::utils::WebSocketConfig;
74
113
///
75
- /// let config = WebSocketConfig::new()
76
- /// .set_protocols([
77
- /// "graphql-ws",
78
- /// "graphql-transport-ws",
79
- /// ]);
114
+ /// let config = WebSocketConfig::new().set_protocols(["graphql-ws", "graphql-transport-ws"]);
115
+ /// ```
80
116
pub fn set_protocols < I > ( mut self , protocols : I ) -> Self
81
117
where
82
118
I : IntoIterator ,
@@ -97,17 +133,17 @@ impl Config {
97
133
98
134
/// Set transport config
99
135
/// e.g. write buffer size
136
+ ///
137
+ ///
100
138
/// ```rust
101
- /// use volo_http::server::utils ::WebSocketConfig;
102
- /// use tokio_tungstenite::tungstenite::protocol::{ WebSocketConfig as WebSocketTransConfig} ;
139
+ /// use tokio_tungstenite::tungstenite::protocol ::WebSocketConfig as WebSocketTransConfig ;
140
+ /// use volo_http::server::utils:: WebSocketConfig;
103
141
///
104
- /// let config = WebSocketConfig::new()
105
- /// .set_transport(
106
- /// WebSocketTransConfig{
107
- /// write_buffer_size: 128 * 1024,
108
- /// ..<_>::default()
109
- /// }
110
- /// );
142
+ /// let config = WebSocketConfig::new().set_transport(WebSocketTransConfig {
143
+ /// write_buffer_size: 128 * 1024,
144
+ /// ..<_>::default()
145
+ /// });
146
+ /// ```
111
147
pub fn set_transport ( mut self , config : WebSocketConfig ) -> Self {
112
148
self . transport = config;
113
149
self
@@ -126,12 +162,12 @@ impl std::fmt::Debug for Config {
126
162
/// Callback fn that processes [`WebSocket`]
127
163
pub trait Callback : Send + ' static {
128
164
/// Called when a connection upgrade succeeds
129
- fn call ( self , _: WebSocket ) -> impl Future < Output = ( ) > + Send ;
165
+ fn call ( self , _: WebSocket ) -> impl Future < Output = ( ) > + Send ;
130
166
}
131
167
132
168
impl < Fut , C > Callback for C
133
169
where
134
- Fut : Future < Output = ( ) > + Send + ' static ,
170
+ Fut : Future < Output = ( ) > + Send + ' static ,
135
171
C : FnOnce ( WebSocket ) -> Fut + Send + ' static ,
136
172
C : Copy ,
137
173
{
@@ -181,10 +217,20 @@ impl Callback for DefaultCallback {
181
217
182
218
/// Extractor of [`FromContext`] for establishing WebSocket connection
183
219
///
184
- /// Constrains:
220
+ /// **Constrains**:
221
+ ///
185
222
/// The extractor only supports for the request that has the method [`GET`](http::Method::GET)
186
223
/// and contains certain header values.
187
- /// See [`WebSocketUpgrade::from_context`] for more details.
224
+ ///
225
+ /// # Usage
226
+ ///
227
+ /// ```rust
228
+ /// use volo_http::{response::ServerResponse, server::extract::WebSocketUpgrade};
229
+ ///
230
+ /// fn ws_handler(ws: WebSocketUpgrade) -> ServerResponse {
231
+ /// ws.on_upgrade(|socket| unimplemented!())
232
+ /// }
233
+ /// ```
188
234
pub struct WebSocketUpgrade < C = DefaultCallback , F = DefaultOnFailedUpgrade > {
189
235
config : Config ,
190
236
on_protocol : HashMap < HeaderValue , C > ,
@@ -228,7 +274,6 @@ where
228
274
/// )
229
275
/// )
230
276
/// .on_upgrade(|socket| async{} )
231
- /// .unwrap()
232
277
/// }
233
278
pub fn set_config ( mut self , config : Config ) -> Self {
234
279
self . config = config;
@@ -254,7 +299,6 @@ where
254
299
/// )
255
300
/// .on_protocol(HashMap::from([("graphql-ws",|mut socket: WebSocket| async move{})]))
256
301
/// .on_upgrade(|socket| async{} )
257
- /// .unwrap()
258
302
/// }
259
303
pub fn on_protocol < H , I , C1 > ( self , on_protocol : I ) -> WebSocketUpgrade < C1 , F >
260
304
where
@@ -302,7 +346,6 @@ where
302
346
/// unimplemented!()
303
347
/// })
304
348
/// .on_upgrade(|socket| async{} )
305
- /// .unwrap()
306
349
/// }
307
350
pub fn on_failed_upgrade < F1 > ( self , callback : F1 ) -> WebSocketUpgrade < C , F1 > {
308
351
WebSocketUpgrade {
@@ -317,11 +360,10 @@ where
317
360
/// Finalize upgrading the connection and call the provided callback
318
361
/// if request protocol is matched, it will use callback set by
319
362
/// [`WebSocketUpgrade::on_protocol`], otherwise use `default_callback`.
320
- pub fn on_upgrade < Fut , C1 > ( self , default_callback : C1 ) -> Result < ServerResponse , Error >
363
+ pub fn on_upgrade < Fut , C1 > ( self , default_callback : C1 ) -> ServerResponse
321
364
where
322
- Fut : Future < Output = ( ) > + Send + ' static ,
323
- C1 : FnOnce ( WebSocket ) -> Fut + Send + ' static ,
324
- C1 : Send + Sync ,
365
+ Fut : Future < Output =( ) > + Send + ' static ,
366
+ C1 : FnOnce ( WebSocket ) -> Fut + Send + Sync + ' static ,
325
367
{
326
368
let on_upgrade = self . on_upgrade ;
327
369
let config = self . config . transport ;
@@ -376,19 +418,41 @@ where
376
418
const WEBSOCKET : HeaderValue = HeaderValue :: from_static ( "websocket" ) ;
377
419
378
420
let mut builder = ServerResponse :: builder ( )
379
- . status ( StatusCode :: SWITCHING_PROTOCOLS )
380
- . header ( header:: CONNECTION , UPGRADE )
381
- . header ( header:: UPGRADE , WEBSOCKET )
421
+ . status ( http :: StatusCode :: SWITCHING_PROTOCOLS )
422
+ . header ( http :: header:: CONNECTION , UPGRADE )
423
+ . header ( http :: header:: UPGRADE , WEBSOCKET )
382
424
. header (
383
- header:: SEC_WEBSOCKET_ACCEPT ,
425
+ http :: header:: SEC_WEBSOCKET_ACCEPT ,
384
426
derive_accept_key ( self . headers . sec_websocket_key . as_bytes ( ) ) ,
385
427
) ;
386
428
387
429
if let Some ( protocol) = protocol {
388
- builder = builder. header ( header:: SEC_WEBSOCKET_PROTOCOL , protocol) ;
430
+ builder = builder. header ( http :: header:: SEC_WEBSOCKET_PROTOCOL , protocol) ;
389
431
}
390
432
391
- Ok ( builder. body ( Body :: empty ( ) ) . unwrap ( ) )
433
+ builder. body ( Body :: empty ( ) ) . unwrap ( )
434
+ }
435
+ }
436
+
437
+ fn header_contains ( headers : & HeaderMap , key : HeaderName , value : & ' static str ) -> bool {
438
+ let header = if let Some ( header) = headers. get ( & key) {
439
+ header
440
+ } else {
441
+ return false ;
442
+ } ;
443
+
444
+ if let Ok ( header) = std:: str:: from_utf8 ( header. as_bytes ( ) ) {
445
+ header. to_ascii_lowercase ( ) . contains ( value)
446
+ } else {
447
+ false
448
+ }
449
+ }
450
+
451
+ fn header_eq ( headers : & HeaderMap , key : HeaderName , value : & ' static str ) -> bool {
452
+ if let Some ( header) = headers. get ( & key) {
453
+ header. as_bytes ( ) . eq_ignore_ascii_case ( value. as_bytes ( ) )
454
+ } else {
455
+ false
392
456
}
393
457
}
394
458
@@ -406,32 +470,10 @@ impl FromContext for WebSocketUpgrade<DefaultCallback> {
406
470
return Err ( WebSocketUpgradeRejectionError :: InvalidHttpVersion ) ;
407
471
}
408
472
409
- fn header_contains ( headers : & HeaderMap , key : HeaderName , value : & ' static str ) -> bool {
410
- let header = if let Some ( header) = headers. get ( & key) {
411
- header
412
- } else {
413
- return false ;
414
- } ;
415
-
416
- if let Ok ( header) = std:: str:: from_utf8 ( header. as_bytes ( ) ) {
417
- header. to_ascii_lowercase ( ) . contains ( value)
418
- } else {
419
- false
420
- }
421
- }
422
-
423
473
if !header_contains ( & parts. headers , http:: header:: CONNECTION , "upgrade" ) {
424
474
return Err ( WebSocketUpgradeRejectionError :: InvalidConnectionHeader ) ;
425
475
}
426
476
427
- fn header_eq ( headers : & HeaderMap , key : HeaderName , value : & ' static str ) -> bool {
428
- if let Some ( header) = headers. get ( & key) {
429
- header. as_bytes ( ) . eq_ignore_ascii_case ( value. as_bytes ( ) )
430
- } else {
431
- false
432
- }
433
- }
434
-
435
477
if !header_eq ( & parts. headers , http:: header:: UPGRADE , "websocket" ) {
436
478
return Err ( WebSocketUpgradeRejectionError :: InvalidUpgradeHeader ) ;
437
479
}
@@ -442,7 +484,7 @@ impl FromContext for WebSocketUpgrade<DefaultCallback> {
442
484
443
485
let sec_websocket_key = parts
444
486
. headers
445
- . get ( header:: SEC_WEBSOCKET_KEY )
487
+ . get ( http :: header:: SEC_WEBSOCKET_KEY )
446
488
. ok_or ( WebSocketUpgradeRejectionError :: WebSocketKeyHeaderMissing ) ?
447
489
. clone ( ) ;
448
490
@@ -451,7 +493,7 @@ impl FromContext for WebSocketUpgrade<DefaultCallback> {
451
493
. remove :: < hyper:: upgrade:: OnUpgrade > ( )
452
494
. ok_or ( WebSocketUpgradeRejectionError :: ConnectionNotUpgradable ) ?;
453
495
454
- let sec_websocket_protocol = parts. headers . get ( header:: SEC_WEBSOCKET_PROTOCOL ) . cloned ( ) ;
496
+ let sec_websocket_protocol = parts. headers . get ( http :: header:: SEC_WEBSOCKET_PROTOCOL ) . cloned ( ) ;
455
497
456
498
Ok ( Self {
457
499
config : Default :: default ( ) ,
@@ -471,7 +513,7 @@ mod websocket_tests {
471
513
use std:: { net, ops:: Add } ;
472
514
473
515
use futures_util:: { SinkExt , StreamExt } ;
474
- use http:: { self , Uri } ;
516
+ use http:: { Uri } ;
475
517
use motore:: Service ;
476
518
use tokio:: net:: TcpStream ;
477
519
use tokio_tungstenite:: {
@@ -496,7 +538,7 @@ mod websocket_tests {
496
538
) -> ( WebSocketStream < MaybeTlsStream < TcpStream > > , ServerResponse )
497
539
where
498
540
R : IntoClientRequest + Unpin ,
499
- Fut : Future < Output = ServerResponse > + Send + ' static ,
541
+ Fut : Future < Output = ServerResponse > + Send + ' static ,
500
542
C : FnOnce ( WebSocketUpgrade ) -> Fut + Send + ' static ,
501
543
C : Send + Sync + Clone ,
502
544
{
@@ -543,7 +585,7 @@ mod websocket_tests {
543
585
544
586
let resp = route. call ( & mut cx, req) . await . unwrap ( ) ;
545
587
546
- assert_eq ! ( resp. status( ) , StatusCode :: OK ) ;
588
+ assert_eq ! ( resp. status( ) , http :: StatusCode :: OK ) ;
547
589
}
548
590
549
591
#[ tokio:: test]
@@ -568,7 +610,7 @@ mod websocket_tests {
568
610
569
611
let resp = route. call ( & mut cx, req) . await . unwrap ( ) ;
570
612
571
- assert_eq ! ( resp. status( ) , StatusCode :: METHOD_NOT_ALLOWED ) ;
613
+ assert_eq ! ( resp. status( ) , http :: StatusCode :: METHOD_NOT_ALLOWED ) ;
572
614
}
573
615
574
616
#[ tokio:: test]
@@ -582,24 +624,23 @@ mod websocket_tests {
582
624
ws. set_config (
583
625
WebSocketConfig :: new ( ) . set_protocols ( [ "graphql-ws" , "graphql-transport-ws" ] ) ,
584
626
)
585
- . on_protocol ( HashMap :: from ( [ (
586
- "graphql-ws" ,
587
- |mut socket : WebSocket | async move {
588
- while let Some ( Ok ( msg) ) = socket. next ( ) . await {
589
- match msg {
590
- Message :: Text ( text) => {
591
- socket
592
- . send ( Message :: Text ( text. add ( "-graphql-ws" ) ) )
593
- . await
594
- . unwrap ( ) ;
627
+ . on_protocol ( HashMap :: from ( [ (
628
+ "graphql-ws" ,
629
+ |mut socket : WebSocket | async move {
630
+ while let Some ( Ok ( msg) ) = socket. next ( ) . await {
631
+ match msg {
632
+ Message :: Text ( text) => {
633
+ socket
634
+ . send ( Message :: Text ( text. add ( "-graphql-ws" ) ) )
635
+ . await
636
+ . unwrap ( ) ;
637
+ }
638
+ _ => { }
595
639
}
596
- _ => { }
597
640
}
598
- }
599
- } ,
600
- ) ] ) )
601
- . on_upgrade ( |_| async { } )
602
- . unwrap ( )
641
+ } ,
642
+ ) ] ) )
643
+ . on_upgrade ( |_| async { } )
603
644
}
604
645
605
646
let addr = Address :: Ip ( net:: SocketAddr :: new (
@@ -612,7 +653,7 @@ mod websocket_tests {
612
653
. parse :: < Uri > ( )
613
654
. unwrap ( ) ,
614
655
)
615
- . with_sub_protocol ( "graphql-ws" ) ;
656
+ . with_sub_protocol ( "graphql-ws" ) ;
616
657
let ( mut ws_stream, _response) = run_ws_handler ( addr, ws_handler, builder) . await ;
617
658
618
659
let input = Message :: Text ( "foobar" . to_owned ( ) ) ;
@@ -621,7 +662,7 @@ mod websocket_tests {
621
662
assert_eq ! ( output, Message :: Text ( "foobar-graphql-ws" . to_owned( ) ) ) ;
622
663
}
623
664
624
- #[ cfg ( test) ]
665
+ #[ tokio :: test]
625
666
async fn integration_test ( ) {
626
667
async fn handle_socket ( mut socket : WebSocket ) {
627
668
while let Some ( Ok ( msg) ) = socket. next ( ) . await {
@@ -651,11 +692,11 @@ mod websocket_tests {
651
692
) ;
652
693
653
694
let ( mut ws_stream, _response) = run_ws_handler (
654
- addr,
655
- |ws : WebSocketUpgrade | std:: future:: ready ( ws. on_upgrade ( handle_socket) . unwrap ( ) ) ,
695
+ addr. clone ( ) ,
696
+ |ws : WebSocketUpgrade | std:: future:: ready ( ws. on_upgrade ( handle_socket) ) ,
656
697
builder,
657
698
)
658
- . await ;
699
+ . await ;
659
700
660
701
let input = Message :: Text ( "foobar" . to_owned ( ) ) ;
661
702
ws_stream. send ( input. clone ( ) ) . await . unwrap ( ) ;
0 commit comments