Skip to content

Commit c0d3485

Browse files
committed
handle_tunnel_request: small code cleanup
more idiomatic, less code, better readability
1 parent ace8389 commit c0d3485

File tree

5 files changed

+44
-72
lines changed

5 files changed

+44
-72
lines changed

src/protocols/tcp/server.rs

-1
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,6 @@ mod tests {
233233
use super::*;
234234
use futures_util::pin_mut;
235235
use std::borrow::Cow;
236-
use std::net::SocketAddr;
237236
use testcontainers::core::WaitFor;
238237
use testcontainers::runners::AsyncRunner;
239238
use testcontainers::{ContainerAsync, Image, ImageExt};

src/tunnel/server/handler_http2.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::restrictions::types::RestrictionsRules;
2-
use crate::tunnel::server::utils::{bad_request, inject_cookie};
2+
use crate::tunnel::server::utils::{bad_request, inject_cookie, HttpResponse};
33
use crate::tunnel::server::WsServer;
44
use crate::tunnel::transport;
55
use crate::tunnel::transport::http2::{Http2TunnelRead, Http2TunnelWrite};
@@ -22,7 +22,7 @@ pub(super) async fn http_server_upgrade(
2222
restrict_path_prefix: Option<String>,
2323
client_addr: SocketAddr,
2424
mut req: Request<Incoming>,
25-
) -> Response<Either<String, BoxBody<Bytes, anyhow::Error>>> {
25+
) -> HttpResponse {
2626
let (remote_addr, local_rx, local_tx, need_cookie) = match server
2727
.handle_tunnel_request(restrictions, restrict_path_prefix, client_addr, &req)
2828
.await

src/tunnel/server/handler_websocket.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::restrictions::types::RestrictionsRules;
2-
use crate::tunnel::server::utils::{bad_request, inject_cookie};
2+
use crate::tunnel::server::utils::{bad_request, inject_cookie, HttpResponse};
33
use crate::tunnel::server::WsServer;
44
use crate::tunnel::transport;
55
use crate::tunnel::transport::websocket::mk_websocket_tunnel;
@@ -21,7 +21,7 @@ pub(super) async fn ws_server_upgrade(
2121
restrict_path_prefix: Option<String>,
2222
client_addr: SocketAddr,
2323
mut req: Request<Incoming>,
24-
) -> Response<Either<String, BoxBody<Bytes, anyhow::Error>>> {
24+
) -> HttpResponse {
2525
if !fastwebsockets::upgrade::is_upgrade_request(&req) {
2626
warn!("Rejecting connection with bad upgrade request: {}", req.uri());
2727
return bad_request();

src/tunnel/server/server.rs

+22-43
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ use crate::tunnel::server::handler_websocket::ws_server_upgrade;
3333
use crate::tunnel::server::reverse_tunnel::ReverseTunnelServer;
3434
use crate::tunnel::server::utils::{
3535
bad_request, extract_path_prefix, extract_tunnel_info, extract_x_forwarded_for, find_mapped_port, validate_tunnel,
36+
HttpResponse,
3637
};
3738
use crate::tunnel::tls_reloader::TlsReloader;
3839
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
@@ -89,68 +90,46 @@ impl WsServer {
8990
Pin<Box<dyn AsyncWrite + Send>>,
9091
bool,
9192
),
92-
Response<Either<String, BoxBody<Bytes, anyhow::Error>>>,
93+
HttpResponse,
9394
> {
94-
match extract_x_forwarded_for(req) {
95-
Ok(Some((x_forward_for, x_forward_for_str))) => {
96-
info!("Request X-Forwarded-For: {:?}", x_forward_for);
97-
Span::current().record("forwarded_for", x_forward_for_str);
98-
client_addr.set_ip(x_forward_for);
99-
}
100-
Ok(_) => {}
101-
Err(_err) => return Err(bad_request()),
95+
if let Some((x_forward_for, x_forward_for_str)) = extract_x_forwarded_for(req) {
96+
info!("Request X-Forwarded-For: {x_forward_for:?}");
97+
Span::current().record("forwarded_for", x_forward_for_str);
98+
client_addr.set_ip(x_forward_for);
10299
};
103100

104-
let path_prefix = match extract_path_prefix(req) {
105-
Ok(p) => p,
106-
Err(_err) => return Err(bad_request()),
107-
};
101+
let path_prefix = extract_path_prefix(req)?;
108102

109103
if let Some(restrict_path) = restrict_path_prefix {
110104
if path_prefix != restrict_path {
111105
warn!(
112-
"Client requested upgrade path '{}' does not match upgrade path restriction '{}' (mTLS, etc.)",
113-
path_prefix, restrict_path
106+
"Client requested upgrade path '{path_prefix}' does not match upgrade path restriction '{restrict_path}' (mTLS, etc.)"
114107
);
115108
return Err(bad_request());
116109
}
117110
}
118111

119-
let jwt = match extract_tunnel_info(req) {
120-
Ok(jwt) => jwt,
121-
Err(_err) => return Err(bad_request()),
122-
};
112+
let jwt = extract_tunnel_info(req)?;
123113

124114
Span::current().record("id", &jwt.claims.id);
125115
Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp));
126-
let remote = match RemoteAddr::try_from(jwt.claims) {
127-
Ok(remote) => remote,
128-
Err(err) => {
129-
warn!("Rejecting connection with bad tunnel info: {} {}", err, req.uri());
130-
return Err(bad_request());
131-
}
132-
};
116+
let remote = RemoteAddr::try_from(jwt.claims)
117+
.inspect_err(|err| warn!("Rejecting connection with bad tunnel info: {err} {}", req.uri()))
118+
.map_err(|_| bad_request())?;
133119

134-
let restriction = match validate_tunnel(&remote, path_prefix, &restrictions) {
135-
Some(matched_restriction) => {
136-
info!("Tunnel accepted due to matched restriction: {}", matched_restriction.name);
137-
matched_restriction
138-
}
139-
None => {
140-
warn!("Rejecting connection with not allowed destination: {:?}", remote);
141-
return Err(bad_request());
142-
}
143-
};
120+
let restriction = validate_tunnel(&remote, path_prefix, &restrictions).ok_or_else(|| {
121+
warn!("Rejecting connection with not allowed destination: {remote:?}");
122+
bad_request()
123+
})?;
124+
info!("Tunnel accepted due to matched restriction: {}", restriction.name);
144125

145126
let req_protocol = remote.protocol.clone();
146127
let inject_cookie = req_protocol.is_dynamic_reverse_tunnel();
147-
let tunnel = match self.exec_tunnel(restriction, remote, client_addr).await {
148-
Ok(ret) => ret,
149-
Err(err) => {
150-
warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri());
151-
return Err(bad_request());
152-
}
153-
};
128+
let tunnel = self
129+
.exec_tunnel(restriction, remote, client_addr)
130+
.await
131+
.inspect_err(|err| warn!("Rejecting connection with bad upgrade request: {err} {}", req.uri()))
132+
.map_err(|_| bad_request())?;
154133

155134
let (remote_addr, local_rx, local_tx) = tunnel;
156135
info!("connected to {:?} {}:{}", req_protocol, remote_addr.host, remote_addr.port);

src/tunnel/server/utils.rs

+18-24
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ use tracing::{error, info, warn};
1717
use url::Host;
1818
use uuid::Uuid;
1919

20-
pub(super) fn bad_request() -> Response<Either<String, BoxBody<Bytes, anyhow::Error>>> {
20+
pub type HttpResponse = Response<Either<String, BoxBody<Bytes, anyhow::Error>>>;
21+
22+
pub(super) fn bad_request() -> HttpResponse {
2123
http::Response::builder()
2224
.status(StatusCode::BAD_REQUEST)
2325
.body(Either::Left("Invalid request".to_string()))
@@ -48,42 +50,40 @@ pub(super) fn find_mapped_port(req_port: u16, restriction: &RestrictionConfig) -
4850
}
4951

5052
#[inline]
51-
pub(super) fn extract_x_forwarded_for(req: &Request<Incoming>) -> Result<Option<(IpAddr, &str)>, ()> {
52-
let Some(x_forward_for) = req.headers().get("X-Forwarded-For") else {
53-
return Ok(None);
54-
};
53+
pub(super) fn extract_x_forwarded_for(req: &Request<Incoming>) -> Option<(IpAddr, &str)> {
54+
let x_forward_for = req.headers().get("X-Forwarded-For")?;
5555

5656
// X-Forwarded-For: <client>, <proxy1>, <proxy2>
5757
let x_forward_for = x_forward_for.to_str().unwrap_or_default();
5858
let x_forward_for = x_forward_for.split_once(',').map(|x| x.0).unwrap_or(x_forward_for);
5959
let ip: Option<IpAddr> = x_forward_for.parse().ok();
60-
Ok(ip.map(|ip| (ip, x_forward_for)))
60+
ip.map(|ip| (ip, x_forward_for))
6161
}
6262

6363
#[inline]
64-
pub(super) fn extract_path_prefix(req: &Request<Incoming>) -> Result<&str, ()> {
64+
pub(super) fn extract_path_prefix(req: &Request<Incoming>) -> Result<&str, HttpResponse> {
6565
let path = req.uri().path();
6666
let min_len = min(path.len(), 1);
6767
if &path[0..min_len] != "/" {
6868
warn!("Rejecting connection with bad path prefix in upgrade request: {}", req.uri());
69-
return Err(());
69+
return Err(bad_request());
7070
}
7171

7272
let Some((l, r)) = path[min_len..].split_once('/') else {
7373
warn!("Rejecting connection with bad upgrade request: {}", req.uri());
74-
return Err(());
74+
return Err(bad_request());
7575
};
7676

7777
if !r.ends_with("events") {
7878
warn!("Rejecting connection with bad upgrade request: {}", req.uri());
79-
return Err(());
79+
return Err(bad_request());
8080
}
8181

8282
Ok(l)
8383
}
8484

8585
#[inline]
86-
pub(super) fn extract_tunnel_info(req: &Request<Incoming>) -> Result<TokenData<JwtTunnelConfig>, ()> {
86+
pub(super) fn extract_tunnel_info(req: &Request<Incoming>) -> anyhow::Result<TokenData<JwtTunnelConfig>, HttpResponse> {
8787
let jwt = req
8888
.headers()
8989
.get(SEC_WEBSOCKET_PROTOCOL)
@@ -93,19 +93,13 @@ pub(super) fn extract_tunnel_info(req: &Request<Incoming>) -> Result<TokenData<J
9393
.or_else(|| req.headers().get(COOKIE).and_then(|header| header.to_str().ok()))
9494
.unwrap_or_default();
9595

96-
let jwt = match jwt_token_to_tunnel(jwt) {
97-
Ok(jwt) => jwt,
98-
err => {
99-
warn!(
100-
"error while decoding jwt for tunnel info {:?} header {:?}",
101-
err,
102-
req.headers().get(SEC_WEBSOCKET_PROTOCOL)
103-
);
104-
return Err(());
105-
}
106-
};
107-
108-
Ok(jwt)
96+
jwt_token_to_tunnel(jwt).map_err(|err| {
97+
warn!(
98+
"error while decoding jwt for tunnel info {err:?} header {:?}",
99+
req.headers().get(SEC_WEBSOCKET_PROTOCOL)
100+
);
101+
bad_request()
102+
})
109103
}
110104

111105
impl RestrictionConfig {

0 commit comments

Comments
 (0)