@@ -58,6 +58,8 @@ pub struct Builder<E> {
58
58
http1 : http1:: Builder ,
59
59
#[ cfg( feature = "http2" ) ]
60
60
http2 : http2:: Builder < E > ,
61
+ #[ cfg( any( feature = "http1" , feature = "http2" ) ) ]
62
+ version : Option < Version > ,
61
63
#[ cfg( not( feature = "http2" ) ) ]
62
64
_executor : E ,
63
65
}
@@ -84,6 +86,8 @@ impl<E> Builder<E> {
84
86
http1 : http1:: Builder :: new ( ) ,
85
87
#[ cfg( feature = "http2" ) ]
86
88
http2 : http2:: Builder :: new ( executor) ,
89
+ #[ cfg( any( feature = "http1" , feature = "http2" ) ) ]
90
+ version : None ,
87
91
#[ cfg( not( feature = "http2" ) ) ]
88
92
_executor : executor,
89
93
}
@@ -101,6 +105,26 @@ impl<E> Builder<E> {
101
105
Http2Builder { inner : self }
102
106
}
103
107
108
+ /// Only accepts HTTP/2
109
+ ///
110
+ /// Does not do anything if used with [`serve_connection_with_upgrades`]
111
+ #[ cfg( feature = "http2" ) ]
112
+ pub fn http2_only ( mut self ) -> Self {
113
+ assert ! ( self . version. is_none( ) ) ;
114
+ self . version = Some ( Version :: H2 ) ;
115
+ self
116
+ }
117
+
118
+ /// Only accepts HTTP/1
119
+ ///
120
+ /// Does not do anything if used with [`serve_connection_with_upgrades`]
121
+ #[ cfg( feature = "http1" ) ]
122
+ pub fn http1_only ( mut self ) -> Self {
123
+ assert ! ( self . version. is_none( ) ) ;
124
+ self . version = Some ( Version :: H1 ) ;
125
+ self
126
+ }
127
+
104
128
/// Bind a connection together with a [`Service`].
105
129
pub fn serve_connection < I , S , B > ( & self , io : I , service : S ) -> Connection < ' _ , I , S , E >
106
130
where
@@ -112,13 +136,28 @@ impl<E> Builder<E> {
112
136
I : Read + Write + Unpin + ' static ,
113
137
E : HttpServerConnExec < S :: Future , B > ,
114
138
{
115
- Connection {
116
- state : ConnState :: ReadVersion {
139
+ let state = match self . version {
140
+ #[ cfg( feature = "http1" ) ]
141
+ Some ( Version :: H1 ) => {
142
+ let io = Rewind :: new_buffered ( io, Bytes :: new ( ) ) ;
143
+ let conn = self . http1 . serve_connection ( io, service) ;
144
+ ConnState :: H1 { conn }
145
+ }
146
+ #[ cfg( feature = "http2" ) ]
147
+ Some ( Version :: H2 ) => {
148
+ let io = Rewind :: new_buffered ( io, Bytes :: new ( ) ) ;
149
+ let conn = self . http2 . serve_connection ( io, service) ;
150
+ ConnState :: H2 { conn }
151
+ }
152
+ #[ cfg( any( feature = "http1" , feature = "http2" ) ) ]
153
+ _ => ConnState :: ReadVersion {
117
154
read_version : read_version ( io) ,
118
155
builder : self ,
119
156
service : Some ( service) ,
120
157
} ,
121
- }
158
+ } ;
159
+
160
+ Connection { state }
122
161
}
123
162
124
163
/// Bind a connection together with a [`Service`], with the ability to
@@ -148,7 +187,7 @@ impl<E> Builder<E> {
148
187
}
149
188
}
150
189
151
- #[ derive( Copy , Clone ) ]
190
+ #[ derive( Copy , Clone , Debug ) ]
152
191
enum Version {
153
192
H1 ,
154
193
H2 ,
@@ -894,6 +933,62 @@ mod tests {
894
933
assert_eq ! ( body, BODY ) ;
895
934
}
896
935
936
+ #[ cfg( not( miri) ) ]
937
+ #[ tokio:: test]
938
+ async fn http2_only ( ) {
939
+ let addr = start_server_h2_only ( ) . await ;
940
+ let mut sender = connect_h2 ( addr) . await ;
941
+
942
+ let response = sender
943
+ . send_request ( Request :: new ( Empty :: < Bytes > :: new ( ) ) )
944
+ . await
945
+ . unwrap ( ) ;
946
+
947
+ let body = response. into_body ( ) . collect ( ) . await . unwrap ( ) . to_bytes ( ) ;
948
+
949
+ assert_eq ! ( body, BODY ) ;
950
+ }
951
+
952
+ #[ cfg( not( miri) ) ]
953
+ #[ tokio:: test]
954
+ async fn http2_only_fail_if_client_is_http1 ( ) {
955
+ let addr = start_server_h2_only ( ) . await ;
956
+ let mut sender = connect_h1 ( addr) . await ;
957
+
958
+ let _ = sender
959
+ . send_request ( Request :: new ( Empty :: < Bytes > :: new ( ) ) )
960
+ . await
961
+ . expect_err ( "should fail" ) ;
962
+ }
963
+
964
+ #[ cfg( not( miri) ) ]
965
+ #[ tokio:: test]
966
+ async fn http1_only ( ) {
967
+ let addr = start_server_h1_only ( ) . await ;
968
+ let mut sender = connect_h1 ( addr) . await ;
969
+
970
+ let response = sender
971
+ . send_request ( Request :: new ( Empty :: < Bytes > :: new ( ) ) )
972
+ . await
973
+ . unwrap ( ) ;
974
+
975
+ let body = response. into_body ( ) . collect ( ) . await . unwrap ( ) . to_bytes ( ) ;
976
+
977
+ assert_eq ! ( body, BODY ) ;
978
+ }
979
+
980
+ #[ cfg( not( miri) ) ]
981
+ #[ tokio:: test]
982
+ async fn http1_only_fail_if_client_is_http2 ( ) {
983
+ let addr = start_server_h1_only ( ) . await ;
984
+ let mut sender = connect_h2 ( addr) . await ;
985
+
986
+ let _ = sender
987
+ . send_request ( Request :: new ( Empty :: < Bytes > :: new ( ) ) )
988
+ . await
989
+ . expect_err ( "should fail" ) ;
990
+ }
991
+
897
992
#[ cfg( not( miri) ) ]
898
993
#[ tokio:: test]
899
994
async fn graceful_shutdown ( ) {
@@ -980,6 +1075,50 @@ mod tests {
980
1075
local_addr
981
1076
}
982
1077
1078
+ async fn start_server_h2_only ( ) -> SocketAddr {
1079
+ let addr: SocketAddr = ( [ 127 , 0 , 0 , 1 ] , 0 ) . into ( ) ;
1080
+ let listener = TcpListener :: bind ( addr) . await . unwrap ( ) ;
1081
+
1082
+ let local_addr = listener. local_addr ( ) . unwrap ( ) ;
1083
+
1084
+ tokio:: spawn ( async move {
1085
+ loop {
1086
+ let ( stream, _) = listener. accept ( ) . await . unwrap ( ) ;
1087
+ let stream = TokioIo :: new ( stream) ;
1088
+ tokio:: task:: spawn ( async move {
1089
+ let _ = auto:: Builder :: new ( TokioExecutor :: new ( ) )
1090
+ . http2_only ( )
1091
+ . serve_connection ( stream, service_fn ( hello) )
1092
+ . await ;
1093
+ } ) ;
1094
+ }
1095
+ } ) ;
1096
+
1097
+ local_addr
1098
+ }
1099
+
1100
+ async fn start_server_h1_only ( ) -> SocketAddr {
1101
+ let addr: SocketAddr = ( [ 127 , 0 , 0 , 1 ] , 0 ) . into ( ) ;
1102
+ let listener = TcpListener :: bind ( addr) . await . unwrap ( ) ;
1103
+
1104
+ let local_addr = listener. local_addr ( ) . unwrap ( ) ;
1105
+
1106
+ tokio:: spawn ( async move {
1107
+ loop {
1108
+ let ( stream, _) = listener. accept ( ) . await . unwrap ( ) ;
1109
+ let stream = TokioIo :: new ( stream) ;
1110
+ tokio:: task:: spawn ( async move {
1111
+ let _ = auto:: Builder :: new ( TokioExecutor :: new ( ) )
1112
+ . http1_only ( )
1113
+ . serve_connection ( stream, service_fn ( hello) )
1114
+ . await ;
1115
+ } ) ;
1116
+ }
1117
+ } ) ;
1118
+
1119
+ local_addr
1120
+ }
1121
+
983
1122
async fn hello ( _req : Request < body:: Incoming > ) -> Result < Response < Full < Bytes > > , Infallible > {
984
1123
Ok ( Response :: new ( Full :: new ( Bytes :: from ( BODY ) ) ) )
985
1124
}
0 commit comments