Skip to content

Commit d8e519c

Browse files
authored
handle_tunnel_request: small code cleanup (#391)
more idiomatic, less code, better readability
1 parent 4d33b62 commit d8e519c

File tree

6 files changed

+97
-81
lines changed

6 files changed

+97
-81
lines changed

Cargo.lock

+7
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ tracing = { version = "0.1.41", features = ["log"] }
5353
url = "2.5.4"
5454
urlencoding = "2.1.3"
5555
uuid = { version = "1.11.0", features = ["v7", "serde"] }
56+
derive_more = { version = "1.0.0", features = ["display", "error"] }
5657

5758
[target.'cfg(not(target_family = "unix"))'.dependencies]
5859
crossterm = { version = "0.28.1" }

src/tunnel/server/handler_http2.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::restrictions::types::RestrictionsRules;
2-
use crate::tunnel::server::utils::{bad_request, inject_cookie};
2+
use crate::tunnel::server::utils::{bad_request, inject_cookie, HttpResponse};
33
use crate::tunnel::server::WsServer;
44
use crate::tunnel::transport;
55
use crate::tunnel::transport::http2::{Http2TunnelRead, Http2TunnelWrite};
@@ -22,7 +22,7 @@ pub(super) async fn http_server_upgrade(
2222
restrict_path_prefix: Option<String>,
2323
client_addr: SocketAddr,
2424
mut req: Request<Incoming>,
25-
) -> Response<Either<String, BoxBody<Bytes, anyhow::Error>>> {
25+
) -> HttpResponse {
2626
let (remote_addr, local_rx, local_tx, need_cookie) = match server
2727
.handle_tunnel_request(restrictions, restrict_path_prefix, client_addr, &req)
2828
.await

src/tunnel/server/handler_websocket.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::restrictions::types::RestrictionsRules;
2-
use crate::tunnel::server::utils::{bad_request, inject_cookie};
2+
use crate::tunnel::server::utils::{bad_request, inject_cookie, HttpResponse};
33
use crate::tunnel::server::WsServer;
44
use crate::tunnel::transport;
55
use crate::tunnel::transport::websocket::mk_websocket_tunnel;
@@ -21,7 +21,7 @@ pub(super) async fn ws_server_upgrade(
2121
restrict_path_prefix: Option<String>,
2222
client_addr: SocketAddr,
2323
mut req: Request<Incoming>,
24-
) -> Response<Either<String, BoxBody<Bytes, anyhow::Error>>> {
24+
) -> HttpResponse {
2525
if !fastwebsockets::upgrade::is_upgrade_request(&req) {
2626
warn!("Rejecting connection with bad upgrade request: {}", req.uri());
2727
return bad_request();

src/tunnel/server/server.rs

+28-43
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ use crate::tunnel::server::handler_websocket::ws_server_upgrade;
3333
use crate::tunnel::server::reverse_tunnel::ReverseTunnelServer;
3434
use crate::tunnel::server::utils::{
3535
bad_request, extract_path_prefix, extract_tunnel_info, extract_x_forwarded_for, find_mapped_port, validate_tunnel,
36+
HttpResponse,
3637
};
3738
use crate::tunnel::tls_reloader::TlsReloader;
3839
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
@@ -89,68 +90,52 @@ impl WsServer {
8990
Pin<Box<dyn AsyncWrite + Send>>,
9091
bool,
9192
),
92-
Response<Either<String, BoxBody<Bytes, anyhow::Error>>>,
93+
HttpResponse,
9394
> {
94-
match extract_x_forwarded_for(req) {
95-
Ok(Some((x_forward_for, x_forward_for_str))) => {
96-
info!("Request X-Forwarded-For: {:?}", x_forward_for);
97-
Span::current().record("forwarded_for", x_forward_for_str);
98-
client_addr.set_ip(x_forward_for);
99-
}
100-
Ok(_) => {}
101-
Err(_err) => return Err(bad_request()),
95+
if let Some((x_forward_for, x_forward_for_str)) = extract_x_forwarded_for(req) {
96+
info!("Request X-Forwarded-For: {x_forward_for:?}");
97+
Span::current().record("forwarded_for", x_forward_for_str);
98+
client_addr.set_ip(x_forward_for);
10299
};
103100

104-
let path_prefix = match extract_path_prefix(req) {
105-
Ok(p) => p,
106-
Err(_err) => return Err(bad_request()),
107-
};
101+
let path_prefix = extract_path_prefix(req.uri().path()).map_err(|err| {
102+
warn!("Rejecting connection with {err}: {}", req.uri());
103+
bad_request()
104+
})?;
108105

109106
if let Some(restrict_path) = restrict_path_prefix {
110107
if path_prefix != restrict_path {
111108
warn!(
112-
"Client requested upgrade path '{}' does not match upgrade path restriction '{}' (mTLS, etc.)",
113-
path_prefix, restrict_path
109+
"Client requested upgrade path '{path_prefix}' does not match upgrade path restriction '{restrict_path}' (mTLS, etc.)"
114110
);
115111
return Err(bad_request());
116112
}
117113
}
118114

119-
let jwt = match extract_tunnel_info(req) {
120-
Ok(jwt) => jwt,
121-
Err(_err) => return Err(bad_request()),
122-
};
115+
let jwt = extract_tunnel_info(req)?;
123116

124117
Span::current().record("id", &jwt.claims.id);
125118
Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp));
126-
let remote = match RemoteAddr::try_from(jwt.claims) {
127-
Ok(remote) => remote,
128-
Err(err) => {
129-
warn!("Rejecting connection with bad tunnel info: {} {}", err, req.uri());
130-
return Err(bad_request());
131-
}
132-
};
119+
let remote = RemoteAddr::try_from(jwt.claims).map_err(|err| {
120+
warn!("Rejecting connection with bad tunnel info: {err} {}", req.uri());
121+
bad_request()
122+
})?;
133123

134-
let restriction = match validate_tunnel(&remote, path_prefix, &restrictions) {
135-
Some(matched_restriction) => {
136-
info!("Tunnel accepted due to matched restriction: {}", matched_restriction.name);
137-
matched_restriction
138-
}
139-
None => {
140-
warn!("Rejecting connection with not allowed destination: {:?}", remote);
141-
return Err(bad_request());
142-
}
143-
};
124+
let restriction = validate_tunnel(&remote, path_prefix, &restrictions).ok_or_else(|| {
125+
warn!("Rejecting connection with not allowed destination: {remote:?}");
126+
bad_request()
127+
})?;
128+
info!("Tunnel accepted due to matched restriction: {}", restriction.name);
144129

145130
let req_protocol = remote.protocol.clone();
146131
let inject_cookie = req_protocol.is_dynamic_reverse_tunnel();
147-
let tunnel = match self.exec_tunnel(restriction, remote, client_addr).await {
148-
Ok(ret) => ret,
149-
Err(err) => {
150-
warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri());
151-
return Err(bad_request());
152-
}
153-
};
132+
let tunnel = self
133+
.exec_tunnel(restriction, remote, client_addr)
134+
.await
135+
.map_err(|err| {
136+
warn!("Rejecting connection with bad upgrade request: {err} {}", req.uri());
137+
bad_request()
138+
})?;
154139

155140
let (remote_addr, local_rx, local_tx) = tunnel;
156141
info!("connected to {:?} {}:{}", req_protocol, remote_addr.host, remote_addr.port);

src/tunnel/server/utils.rs

+57-34
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use crate::restrictions::types::{
55
use crate::tunnel::transport::{jwt_token_to_tunnel, tunnel_to_jwt_token, JwtTunnelConfig, JWT_HEADER_PREFIX};
66
use crate::tunnel::RemoteAddr;
77
use bytes::Bytes;
8+
use derive_more::{Display, Error};
89
use http_body_util::combinators::BoxBody;
910
use http_body_util::Either;
1011
use hyper::body::{Body, Incoming};
@@ -17,7 +18,9 @@ use tracing::{error, info, warn};
1718
use url::Host;
1819
use uuid::Uuid;
1920

20-
pub(super) fn bad_request() -> Response<Either<String, BoxBody<Bytes, anyhow::Error>>> {
21+
pub type HttpResponse = Response<Either<String, BoxBody<Bytes, anyhow::Error>>>;
22+
23+
pub(super) fn bad_request() -> HttpResponse {
2124
http::Response::builder()
2225
.status(StatusCode::BAD_REQUEST)
2326
.body(Either::Left("Invalid request".to_string()))
@@ -48,42 +51,41 @@ pub(super) fn find_mapped_port(req_port: u16, restriction: &RestrictionConfig) -
4851
}
4952

5053
#[inline]
51-
pub(super) fn extract_x_forwarded_for(req: &Request<Incoming>) -> Result<Option<(IpAddr, &str)>, ()> {
52-
let Some(x_forward_for) = req.headers().get("X-Forwarded-For") else {
53-
return Ok(None);
54-
};
54+
pub(super) fn extract_x_forwarded_for(req: &Request<Incoming>) -> Option<(IpAddr, &str)> {
55+
let x_forward_for = req.headers().get("X-Forwarded-For")?;
5556

5657
// X-Forwarded-For: <client>, <proxy1>, <proxy2>
5758
let x_forward_for = x_forward_for.to_str().unwrap_or_default();
5859
let x_forward_for = x_forward_for.split_once(',').map(|x| x.0).unwrap_or(x_forward_for);
5960
let ip: Option<IpAddr> = x_forward_for.parse().ok();
60-
Ok(ip.map(|ip| (ip, x_forward_for)))
61+
ip.map(|ip| (ip, x_forward_for))
6162
}
6263

6364
#[inline]
64-
pub(super) fn extract_path_prefix(req: &Request<Incoming>) -> Result<&str, ()> {
65-
let path = req.uri().path();
66-
let min_len = min(path.len(), 1);
67-
if &path[0..min_len] != "/" {
68-
warn!("Rejecting connection with bad path prefix in upgrade request: {}", req.uri());
69-
return Err(());
65+
pub(super) fn extract_path_prefix(path: &str) -> Result<&str, PathPrefixErr> {
66+
if !path.starts_with('/') {
67+
return Err(PathPrefixErr::BadPathPrefix);
7068
}
7169

72-
let Some((l, r)) = path[min_len..].split_once('/') else {
73-
warn!("Rejecting connection with bad upgrade request: {}", req.uri());
74-
return Err(());
75-
};
70+
let (l, r) = path[1..].split_once('/').ok_or(PathPrefixErr::BadUpgradeRequest)?;
7671

77-
if !r.ends_with("events") {
78-
warn!("Rejecting connection with bad upgrade request: {}", req.uri());
79-
return Err(());
72+
match r.ends_with("events") {
73+
true => Ok(l),
74+
false => Err(PathPrefixErr::BadUpgradeRequest),
8075
}
76+
}
8177

82-
Ok(l)
78+
#[derive(Debug, Display, Error)]
79+
#[cfg_attr(test, derive(PartialEq, Eq))]
80+
pub(super) enum PathPrefixErr {
81+
#[display("bad path prefix in upgrade request")]
82+
BadPathPrefix,
83+
#[display("bad upgrade request")]
84+
BadUpgradeRequest,
8385
}
8486

8587
#[inline]
86-
pub(super) fn extract_tunnel_info(req: &Request<Incoming>) -> Result<TokenData<JwtTunnelConfig>, ()> {
88+
pub(super) fn extract_tunnel_info(req: &Request<Incoming>) -> anyhow::Result<TokenData<JwtTunnelConfig>, HttpResponse> {
8789
let jwt = req
8890
.headers()
8991
.get(SEC_WEBSOCKET_PROTOCOL)
@@ -93,19 +95,13 @@ pub(super) fn extract_tunnel_info(req: &Request<Incoming>) -> Result<TokenData<J
9395
.or_else(|| req.headers().get(COOKIE).and_then(|header| header.to_str().ok()))
9496
.unwrap_or_default();
9597

96-
let jwt = match jwt_token_to_tunnel(jwt) {
97-
Ok(jwt) => jwt,
98-
err => {
99-
warn!(
100-
"error while decoding jwt for tunnel info {:?} header {:?}",
101-
err,
102-
req.headers().get(SEC_WEBSOCKET_PROTOCOL)
103-
);
104-
return Err(());
105-
}
106-
};
107-
108-
Ok(jwt)
98+
jwt_token_to_tunnel(jwt).map_err(|err| {
99+
warn!(
100+
"error while decoding jwt for tunnel info {err:?} header {:?}",
101+
req.headers().get(SEC_WEBSOCKET_PROTOCOL)
102+
);
103+
bad_request()
104+
})
109105
}
110106

111107
impl RestrictionConfig {
@@ -497,4 +493,31 @@ mod tests {
497493
assert!(!config.is_allowed(&remote));
498494
assert!(!AllowConfig::from(config.clone()).is_allowed(&remote));
499495
}
496+
497+
#[test]
498+
fn test_extract_path_prefix_happy_path() {
499+
assert_eq!(extract_path_prefix("/prefix/events"), Ok("prefix"));
500+
assert_eq!(extract_path_prefix("/prefix/a/events"), Ok("prefix"));
501+
assert_eq!(extract_path_prefix("/prefix/a/b/events"), Ok("prefix"));
502+
}
503+
504+
#[test]
505+
fn test_extract_path_prefix_no_events_suffix() {
506+
assert_eq!(extract_path_prefix("/prefix/events/"), Err(PathPrefixErr::BadUpgradeRequest));
507+
assert_eq!(extract_path_prefix("/prefix"), Err(PathPrefixErr::BadUpgradeRequest));
508+
assert_eq!(extract_path_prefix("/prefixevents"), Err(PathPrefixErr::BadUpgradeRequest));
509+
assert_eq!(extract_path_prefix("/prefix/event"), Err(PathPrefixErr::BadUpgradeRequest));
510+
assert_eq!(extract_path_prefix("/prefix/a"), Err(PathPrefixErr::BadUpgradeRequest));
511+
assert_eq!(extract_path_prefix("/prefix/a/b"), Err(PathPrefixErr::BadUpgradeRequest));
512+
}
513+
514+
#[test]
515+
fn test_extract_path_prefix_no_slash_prefix() {
516+
assert_eq!(extract_path_prefix(""), Err(PathPrefixErr::BadPathPrefix));
517+
assert_eq!(extract_path_prefix("p"), Err(PathPrefixErr::BadPathPrefix));
518+
assert_eq!(extract_path_prefix("\\"), Err(PathPrefixErr::BadPathPrefix));
519+
assert_eq!(extract_path_prefix("prefix/events"), Err(PathPrefixErr::BadPathPrefix));
520+
assert_eq!(extract_path_prefix("prefix/a/events"), Err(PathPrefixErr::BadPathPrefix));
521+
assert_eq!(extract_path_prefix("prefix/a/b/events"), Err(PathPrefixErr::BadPathPrefix));
522+
}
500523
}

0 commit comments

Comments
 (0)