Skip to content

Commit b29fc46

Browse files
committed
feat: add {http1,http2}_only for auto conn
1 parent 16daef6 commit b29fc46

File tree

1 file changed

+143
-4
lines changed

1 file changed

+143
-4
lines changed

src/server/conn/auto.rs

+143-4
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ pub struct Builder<E> {
5858
http1: http1::Builder,
5959
#[cfg(feature = "http2")]
6060
http2: http2::Builder<E>,
61+
#[cfg(any(feature = "http1", feature = "http2"))]
62+
version: Option<Version>,
6163
#[cfg(not(feature = "http2"))]
6264
_executor: E,
6365
}
@@ -84,6 +86,8 @@ impl<E> Builder<E> {
8486
http1: http1::Builder::new(),
8587
#[cfg(feature = "http2")]
8688
http2: http2::Builder::new(executor),
89+
#[cfg(any(feature = "http1", feature = "http2"))]
90+
version: None,
8791
#[cfg(not(feature = "http2"))]
8892
_executor: executor,
8993
}
@@ -101,6 +105,26 @@ impl<E> Builder<E> {
101105
Http2Builder { inner: self }
102106
}
103107

108+
/// Only accepts HTTP/2
109+
///
110+
/// Does not do anything if used with [`serve_connection_with_upgrades`]
111+
#[cfg(feature = "http2")]
112+
pub fn http2_only(mut self) -> Self {
113+
assert!(self.version.is_none());
114+
self.version = Some(Version::H2);
115+
self
116+
}
117+
118+
/// Only accepts HTTP/1
119+
///
120+
/// Does not do anything if used with [`serve_connection_with_upgrades`]
121+
#[cfg(feature = "http1")]
122+
pub fn http1_only(mut self) -> Self {
123+
assert!(self.version.is_none());
124+
self.version = Some(Version::H1);
125+
self
126+
}
127+
104128
/// Bind a connection together with a [`Service`].
105129
pub fn serve_connection<I, S, B>(&self, io: I, service: S) -> Connection<'_, I, S, E>
106130
where
@@ -112,13 +136,28 @@ impl<E> Builder<E> {
112136
I: Read + Write + Unpin + 'static,
113137
E: HttpServerConnExec<S::Future, B>,
114138
{
115-
Connection {
116-
state: ConnState::ReadVersion {
139+
let state = match self.version {
140+
#[cfg(feature = "http1")]
141+
Some(Version::H1) => {
142+
let io = Rewind::new_buffered(io, Bytes::new());
143+
let conn = self.http1.serve_connection(io, service);
144+
ConnState::H1 { conn }
145+
}
146+
#[cfg(feature = "http2")]
147+
Some(Version::H2) => {
148+
let io = Rewind::new_buffered(io, Bytes::new());
149+
let conn = self.http2.serve_connection(io, service);
150+
ConnState::H2 { conn }
151+
}
152+
#[cfg(any(feature = "http1", feature = "http2"))]
153+
_ => ConnState::ReadVersion {
117154
read_version: read_version(io),
118155
builder: self,
119156
service: Some(service),
120157
},
121-
}
158+
};
159+
160+
Connection { state }
122161
}
123162

124163
/// Bind a connection together with a [`Service`], with the ability to
@@ -148,7 +187,7 @@ impl<E> Builder<E> {
148187
}
149188
}
150189

151-
#[derive(Copy, Clone)]
190+
#[derive(Copy, Clone, Debug)]
152191
enum Version {
153192
H1,
154193
H2,
@@ -894,6 +933,62 @@ mod tests {
894933
assert_eq!(body, BODY);
895934
}
896935

936+
#[cfg(not(miri))]
937+
#[tokio::test]
938+
async fn http2_only() {
939+
let addr = start_server_h2_only().await;
940+
let mut sender = connect_h2(addr).await;
941+
942+
let response = sender
943+
.send_request(Request::new(Empty::<Bytes>::new()))
944+
.await
945+
.unwrap();
946+
947+
let body = response.into_body().collect().await.unwrap().to_bytes();
948+
949+
assert_eq!(body, BODY);
950+
}
951+
952+
#[cfg(not(miri))]
953+
#[tokio::test]
954+
async fn http2_only_fail_if_client_is_http1() {
955+
let addr = start_server_h2_only().await;
956+
let mut sender = connect_h1(addr).await;
957+
958+
let _ = sender
959+
.send_request(Request::new(Empty::<Bytes>::new()))
960+
.await
961+
.expect_err("should fail");
962+
}
963+
964+
#[cfg(not(miri))]
965+
#[tokio::test]
966+
async fn http1_only() {
967+
let addr = start_server_h1_only().await;
968+
let mut sender = connect_h1(addr).await;
969+
970+
let response = sender
971+
.send_request(Request::new(Empty::<Bytes>::new()))
972+
.await
973+
.unwrap();
974+
975+
let body = response.into_body().collect().await.unwrap().to_bytes();
976+
977+
assert_eq!(body, BODY);
978+
}
979+
980+
#[cfg(not(miri))]
981+
#[tokio::test]
982+
async fn http1_only_fail_if_client_is_http2() {
983+
let addr = start_server_h1_only().await;
984+
let mut sender = connect_h2(addr).await;
985+
986+
let _ = sender
987+
.send_request(Request::new(Empty::<Bytes>::new()))
988+
.await
989+
.expect_err("should fail");
990+
}
991+
897992
#[cfg(not(miri))]
898993
#[tokio::test]
899994
async fn graceful_shutdown() {
@@ -980,6 +1075,50 @@ mod tests {
9801075
local_addr
9811076
}
9821077

1078+
async fn start_server_h2_only() -> SocketAddr {
1079+
let addr: SocketAddr = ([127, 0, 0, 1], 0).into();
1080+
let listener = TcpListener::bind(addr).await.unwrap();
1081+
1082+
let local_addr = listener.local_addr().unwrap();
1083+
1084+
tokio::spawn(async move {
1085+
loop {
1086+
let (stream, _) = listener.accept().await.unwrap();
1087+
let stream = TokioIo::new(stream);
1088+
tokio::task::spawn(async move {
1089+
let _ = auto::Builder::new(TokioExecutor::new())
1090+
.http2_only()
1091+
.serve_connection(stream, service_fn(hello))
1092+
.await;
1093+
});
1094+
}
1095+
});
1096+
1097+
local_addr
1098+
}
1099+
1100+
async fn start_server_h1_only() -> SocketAddr {
1101+
let addr: SocketAddr = ([127, 0, 0, 1], 0).into();
1102+
let listener = TcpListener::bind(addr).await.unwrap();
1103+
1104+
let local_addr = listener.local_addr().unwrap();
1105+
1106+
tokio::spawn(async move {
1107+
loop {
1108+
let (stream, _) = listener.accept().await.unwrap();
1109+
let stream = TokioIo::new(stream);
1110+
tokio::task::spawn(async move {
1111+
let _ = auto::Builder::new(TokioExecutor::new())
1112+
.http1_only()
1113+
.serve_connection(stream, service_fn(hello))
1114+
.await;
1115+
});
1116+
}
1117+
});
1118+
1119+
local_addr
1120+
}
1121+
9831122
async fn hello(_req: Request<body::Incoming>) -> Result<Response<Full<Bytes>>, Infallible> {
9841123
Ok(Response::new(Full::new(Bytes::from(BODY))))
9851124
}

0 commit comments

Comments
 (0)