Skip to content

Commit daafd60

Browse files
authored
feat(http): support websocket server (#481)
1 parent 0c15227 commit daafd60

File tree

7 files changed

+817
-4
lines changed

7 files changed

+817
-4
lines changed

Cargo.lock

+102
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ webpki-roots = "0.26"
133133
tokio-rustls = "0.25"
134134
native-tls = "0.2"
135135
tokio-native-tls = "0.3"
136+
tokio-tungstenite="0.23.1"
136137

137138
[profile.release]
138139
opt-level = 3

volo-http/Cargo.toml

+7-1
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,13 @@ tokio = { workspace = true, features = [
5959
tokio-util = { workspace = true, features = ["io"] }
6060
tracing.workspace = true
6161

62+
# =====optional=====
6263
# server optional
6364
matchit = { workspace = true, optional = true }
6465

66+
# protocol optional
67+
tokio-tungstenite = { workspace = true, optional = true }
68+
6569
# tls optional
6670
tokio-rustls = { workspace = true, optional = true }
6771
tokio-native-tls = { workspace = true, optional = true }
@@ -86,11 +90,13 @@ default = []
8690
default_client = ["client", "json"]
8791
default_server = ["server", "query", "form", "json"]
8892

89-
full = ["client", "server", "rustls", "cookie", "query", "form", "json", "tls"]
93+
full = ["client", "server", "rustls", "cookie", "query", "form", "json", "tls", "ws"]
9094

9195
client = ["hyper/client", "hyper/http1"] # client core
9296
server = ["hyper/server", "hyper/http1", "dep:matchit"] # server core
9397

98+
ws = ["dep:tokio-tungstenite"]
99+
94100
tls = ["rustls"]
95101
__tls = []
96102
rustls = ["__tls", "dep:tokio-rustls", "volo/rustls"]

volo-http/src/error/server.rs

+67-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ impl Error for GenericRejectionError {}
7676

7777
impl GenericRejectionError {
7878
/// Convert the [`GenericRejectionError`] to the corresponding [`StatusCode`]
79-
pub fn to_status_code(self) -> StatusCode {
79+
pub fn to_status_code(&self) -> StatusCode {
8080
match self {
8181
Self::BodyCollectionError => StatusCode::INTERNAL_SERVER_ERROR,
8282
Self::InvalidContentType => StatusCode::UNSUPPORTED_MEDIA_TYPE,
@@ -99,3 +99,69 @@ pub fn body_collection_error() -> ExtractBodyError {
9999
pub fn invalid_content_type() -> ExtractBodyError {
100100
ExtractBodyError::Generic(GenericRejectionError::InvalidContentType)
101101
}
102+
103+
/// Rejection used for [`WebSocketUpgrade`](crate::server::utils::WebSocketUpgrade).
104+
#[derive(Debug)]
105+
#[non_exhaustive]
106+
pub enum WebSocketUpgradeRejectionError {
107+
/// The request method must be `GET`
108+
MethodNotGet,
109+
/// The HTTP version is not supported
110+
InvalidHttpVersion,
111+
/// The `Connection` header is invalid
112+
InvalidConnectionHeader,
113+
/// The `Upgrade` header is invalid
114+
InvalidUpgradeHeader,
115+
/// The `Sec-WebSocket-Version` header is invalid
116+
InvalidWebSocketVersionHeader,
117+
/// The `Sec-WebSocket-Key` header is missing
118+
WebSocketKeyHeaderMissing,
119+
/// The connection is not upgradable
120+
ConnectionNotUpgradable,
121+
}
122+
123+
impl WebSocketUpgradeRejectionError {
124+
/// Convert the [`WebSocketUpgradeRejectionError`] to the corresponding [`StatusCode`]
125+
fn to_status_code(&self) -> StatusCode {
126+
match self {
127+
Self::MethodNotGet => StatusCode::METHOD_NOT_ALLOWED,
128+
Self::InvalidHttpVersion => StatusCode::HTTP_VERSION_NOT_SUPPORTED,
129+
Self::InvalidConnectionHeader => StatusCode::BAD_REQUEST,
130+
Self::InvalidUpgradeHeader => StatusCode::BAD_REQUEST,
131+
Self::InvalidWebSocketVersionHeader => StatusCode::BAD_REQUEST,
132+
Self::WebSocketKeyHeaderMissing => StatusCode::BAD_REQUEST,
133+
Self::ConnectionNotUpgradable => StatusCode::UPGRADE_REQUIRED,
134+
}
135+
}
136+
}
137+
138+
impl Error for WebSocketUpgradeRejectionError {}
139+
140+
impl fmt::Display for WebSocketUpgradeRejectionError {
141+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
142+
match self {
143+
Self::MethodNotGet => write!(f, "Request method must be 'GET'"),
144+
Self::InvalidHttpVersion => {
145+
write!(f, "Http version not support, only support HTTP 1.1 for now")
146+
}
147+
Self::InvalidConnectionHeader => {
148+
write!(f, "Connection header did not include 'upgrade'")
149+
}
150+
Self::InvalidUpgradeHeader => write!(f, "`Upgrade` header did not include 'websocket'"),
151+
Self::InvalidWebSocketVersionHeader => {
152+
write!(f, "`Sec-WebSocket-Version` header did not include '13'")
153+
}
154+
Self::WebSocketKeyHeaderMissing => write!(f, "`Sec-WebSocket-Key` header missing"),
155+
Self::ConnectionNotUpgradable => write!(
156+
f,
157+
"WebSocket request couldn't be upgraded since no upgrade state was present"
158+
),
159+
}
160+
}
161+
}
162+
163+
impl IntoResponse for WebSocketUpgradeRejectionError {
164+
fn into_response(self) -> ServerResponse {
165+
self.to_status_code().into_response()
166+
}
167+
}

volo-http/src/server/mod.rs

+4-2
Original file line numberDiff line numberDiff line change
@@ -443,13 +443,15 @@ async fn serve_conn<S>(
443443
let notified = exit_notify.notified();
444444
tokio::pin!(notified);
445445

446-
let mut http_conn = server.serve_connection(TokioIo::new(conn), service);
446+
let mut http_conn = server
447+
.serve_connection(TokioIo::new(conn), service)
448+
.with_upgrades();
447449

448450
tokio::select! {
449451
_ = &mut notified => {
450452
tracing::trace!("[VOLO] closing a pending connection");
451453
// Graceful shutdown.
452-
hyper::server::conn::http1::Connection::graceful_shutdown(
454+
hyper::server::conn::http1::UpgradeableConnection::graceful_shutdown(
453455
Pin::new(&mut http_conn)
454456
);
455457
// Continue to poll this connection until shutdown can finish.

volo-http/src/server/utils/mod.rs

+5
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,8 @@ mod serve_dir;
55

66
pub use file_response::FileResponse;
77
pub use serve_dir::ServeDir;
8+
9+
#[cfg(feature = "ws")]
10+
pub mod ws;
11+
#[cfg(feature = "ws")]
12+
pub use self::ws::{Config as WebSocketConfig, Message, WebSocket, WebSocketUpgrade};

0 commit comments

Comments
 (0)