@@ -5,6 +5,7 @@ use crate::restrictions::types::{
5
5
use crate :: tunnel:: transport:: { jwt_token_to_tunnel, tunnel_to_jwt_token, JwtTunnelConfig , JWT_HEADER_PREFIX } ;
6
6
use crate :: tunnel:: RemoteAddr ;
7
7
use bytes:: Bytes ;
8
+ use derive_more:: { Display , Error } ;
8
9
use http_body_util:: combinators:: BoxBody ;
9
10
use http_body_util:: Either ;
10
11
use hyper:: body:: { Body , Incoming } ;
@@ -17,7 +18,9 @@ use tracing::{error, info, warn};
17
18
use url:: Host ;
18
19
use uuid:: Uuid ;
19
20
20
- pub ( super ) fn bad_request ( ) -> Response < Either < String , BoxBody < Bytes , anyhow:: Error > > > {
21
+ pub type HttpResponse = Response < Either < String , BoxBody < Bytes , anyhow:: Error > > > ;
22
+
23
+ pub ( super ) fn bad_request ( ) -> HttpResponse {
21
24
http:: Response :: builder ( )
22
25
. status ( StatusCode :: BAD_REQUEST )
23
26
. body ( Either :: Left ( "Invalid request" . to_string ( ) ) )
@@ -48,42 +51,41 @@ pub(super) fn find_mapped_port(req_port: u16, restriction: &RestrictionConfig) -
48
51
}
49
52
50
53
#[ inline]
51
- pub ( super ) fn extract_x_forwarded_for ( req : & Request < Incoming > ) -> Result < Option < ( IpAddr , & str ) > , ( ) > {
52
- let Some ( x_forward_for) = req. headers ( ) . get ( "X-Forwarded-For" ) else {
53
- return Ok ( None ) ;
54
- } ;
54
+ pub ( super ) fn extract_x_forwarded_for ( req : & Request < Incoming > ) -> Option < ( IpAddr , & str ) > {
55
+ let x_forward_for = req. headers ( ) . get ( "X-Forwarded-For" ) ?;
55
56
56
57
// X-Forwarded-For: <client>, <proxy1>, <proxy2>
57
58
let x_forward_for = x_forward_for. to_str ( ) . unwrap_or_default ( ) ;
58
59
let x_forward_for = x_forward_for. split_once ( ',' ) . map ( |x| x. 0 ) . unwrap_or ( x_forward_for) ;
59
60
let ip: Option < IpAddr > = x_forward_for. parse ( ) . ok ( ) ;
60
- Ok ( ip. map ( |ip| ( ip, x_forward_for) ) )
61
+ ip. map ( |ip| ( ip, x_forward_for) )
61
62
}
62
63
63
64
#[ inline]
64
- pub ( super ) fn extract_path_prefix ( req : & Request < Incoming > ) -> Result < & str , ( ) > {
65
- let path = req. uri ( ) . path ( ) ;
66
- let min_len = min ( path. len ( ) , 1 ) ;
67
- if & path[ 0 ..min_len] != "/" {
68
- warn ! ( "Rejecting connection with bad path prefix in upgrade request: {}" , req. uri( ) ) ;
69
- return Err ( ( ) ) ;
65
+ pub ( super ) fn extract_path_prefix ( path : & str ) -> Result < & str , PathPrefixErr > {
66
+ if !path. starts_with ( '/' ) {
67
+ return Err ( PathPrefixErr :: BadPathPrefix ) ;
70
68
}
71
69
72
- let Some ( ( l, r) ) = path[ min_len..] . split_once ( '/' ) else {
73
- warn ! ( "Rejecting connection with bad upgrade request: {}" , req. uri( ) ) ;
74
- return Err ( ( ) ) ;
75
- } ;
70
+ let ( l, r) = path[ 1 ..] . split_once ( '/' ) . ok_or ( PathPrefixErr :: BadUpgradeRequest ) ?;
76
71
77
- if ! r. ends_with ( "events" ) {
78
- warn ! ( "Rejecting connection with bad upgrade request: {}" , req . uri ( ) ) ;
79
- return Err ( ( ) ) ;
72
+ match r. ends_with ( "events" ) {
73
+ true => Ok ( l ) ,
74
+ false => Err ( PathPrefixErr :: BadUpgradeRequest ) ,
80
75
}
76
+ }
81
77
82
- Ok ( l)
78
+ #[ derive( Debug , Display , Error ) ]
79
+ #[ cfg_attr( test, derive( PartialEq , Eq ) ) ]
80
+ pub ( super ) enum PathPrefixErr {
81
+ #[ display( "bad path prefix in upgrade request" ) ]
82
+ BadPathPrefix ,
83
+ #[ display( "bad upgrade request" ) ]
84
+ BadUpgradeRequest ,
83
85
}
84
86
85
87
#[ inline]
86
- pub ( super ) fn extract_tunnel_info ( req : & Request < Incoming > ) -> Result < TokenData < JwtTunnelConfig > , ( ) > {
88
+ pub ( super ) fn extract_tunnel_info ( req : & Request < Incoming > ) -> anyhow :: Result < TokenData < JwtTunnelConfig > , HttpResponse > {
87
89
let jwt = req
88
90
. headers ( )
89
91
. get ( SEC_WEBSOCKET_PROTOCOL )
@@ -93,19 +95,13 @@ pub(super) fn extract_tunnel_info(req: &Request<Incoming>) -> Result<TokenData<J
93
95
. or_else ( || req. headers ( ) . get ( COOKIE ) . and_then ( |header| header. to_str ( ) . ok ( ) ) )
94
96
. unwrap_or_default ( ) ;
95
97
96
- let jwt = match jwt_token_to_tunnel ( jwt) {
97
- Ok ( jwt) => jwt,
98
- err => {
99
- warn ! (
100
- "error while decoding jwt for tunnel info {:?} header {:?}" ,
101
- err,
102
- req. headers( ) . get( SEC_WEBSOCKET_PROTOCOL )
103
- ) ;
104
- return Err ( ( ) ) ;
105
- }
106
- } ;
107
-
108
- Ok ( jwt)
98
+ jwt_token_to_tunnel ( jwt) . map_err ( |err| {
99
+ warn ! (
100
+ "error while decoding jwt for tunnel info {err:?} header {:?}" ,
101
+ req. headers( ) . get( SEC_WEBSOCKET_PROTOCOL )
102
+ ) ;
103
+ bad_request ( )
104
+ } )
109
105
}
110
106
111
107
impl RestrictionConfig {
@@ -497,4 +493,31 @@ mod tests {
497
493
assert ! ( !config. is_allowed( & remote) ) ;
498
494
assert ! ( !AllowConfig :: from( config. clone( ) ) . is_allowed( & remote) ) ;
499
495
}
496
+
497
+ #[ test]
498
+ fn test_extract_path_prefix_happy_path ( ) {
499
+ assert_eq ! ( extract_path_prefix( "/prefix/events" ) , Ok ( "prefix" ) ) ;
500
+ assert_eq ! ( extract_path_prefix( "/prefix/a/events" ) , Ok ( "prefix" ) ) ;
501
+ assert_eq ! ( extract_path_prefix( "/prefix/a/b/events" ) , Ok ( "prefix" ) ) ;
502
+ }
503
+
504
+ #[ test]
505
+ fn test_extract_path_prefix_no_events_suffix ( ) {
506
+ assert_eq ! ( extract_path_prefix( "/prefix/events/" ) , Err ( PathPrefixErr :: BadUpgradeRequest ) ) ;
507
+ assert_eq ! ( extract_path_prefix( "/prefix" ) , Err ( PathPrefixErr :: BadUpgradeRequest ) ) ;
508
+ assert_eq ! ( extract_path_prefix( "/prefixevents" ) , Err ( PathPrefixErr :: BadUpgradeRequest ) ) ;
509
+ assert_eq ! ( extract_path_prefix( "/prefix/event" ) , Err ( PathPrefixErr :: BadUpgradeRequest ) ) ;
510
+ assert_eq ! ( extract_path_prefix( "/prefix/a" ) , Err ( PathPrefixErr :: BadUpgradeRequest ) ) ;
511
+ assert_eq ! ( extract_path_prefix( "/prefix/a/b" ) , Err ( PathPrefixErr :: BadUpgradeRequest ) ) ;
512
+ }
513
+
514
+ #[ test]
515
+ fn test_extract_path_prefix_no_slash_prefix ( ) {
516
+ assert_eq ! ( extract_path_prefix( "" ) , Err ( PathPrefixErr :: BadPathPrefix ) ) ;
517
+ assert_eq ! ( extract_path_prefix( "p" ) , Err ( PathPrefixErr :: BadPathPrefix ) ) ;
518
+ assert_eq ! ( extract_path_prefix( "\\ " ) , Err ( PathPrefixErr :: BadPathPrefix ) ) ;
519
+ assert_eq ! ( extract_path_prefix( "prefix/events" ) , Err ( PathPrefixErr :: BadPathPrefix ) ) ;
520
+ assert_eq ! ( extract_path_prefix( "prefix/a/events" ) , Err ( PathPrefixErr :: BadPathPrefix ) ) ;
521
+ assert_eq ! ( extract_path_prefix( "prefix/a/b/events" ) , Err ( PathPrefixErr :: BadPathPrefix ) ) ;
522
+ }
500
523
}
0 commit comments