Skip to content

Commit 1641e22

Browse files
authored
Merge pull request #139 from jbr/validate-before-continue
don't send 100-continue until the body has been read from
2 parents 9738d53 + de54e57 commit 1641e22

File tree

5 files changed

+174
-52
lines changed

5 files changed

+174
-52
lines changed

Cargo.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ log = "0.4"
2323

2424
[dev-dependencies]
2525
pretty_assertions = "0.6.1"
26-
async-std = { version = "1.4.0", features = ["unstable", "attributes"] }
26+
async-std = { version = "1.6.2", features = ["unstable", "attributes"] }
2727
tempfile = "3.1.0"
2828
async-test = "1.0.0"
29+
duplexify = "1.2.1"
30+
async-dup = "1.2.1"

src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ const MAX_HEAD_LENGTH: usize = 8 * 1024;
106106

107107
mod chunked;
108108
mod date;
109+
mod read_notifier;
109110

110111
pub mod client;
111112
pub mod server;

src/read_notifier.rs

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
use std::fmt;
2+
use std::pin::Pin;
3+
use std::task::{Context, Poll};
4+
5+
use async_std::io::{self, BufRead, Read};
6+
use async_std::sync::Sender;
7+
8+
pin_project_lite::pin_project! {
9+
/// ReadNotifier forwards [`async_std::io::Read`] and
10+
/// [`async_std::io::BufRead`] to an inner reader. When the
11+
/// ReadNotifier is read from (using `Read`, `ReadExt`, or
12+
/// `BufRead` methods), it sends a single message containing `()`
13+
/// on the channel.
14+
pub(crate) struct ReadNotifier<B> {
15+
#[pin]
16+
reader: B,
17+
sender: Sender<()>,
18+
has_been_read: bool
19+
}
20+
}
21+
22+
impl<B> fmt::Debug for ReadNotifier<B> {
23+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24+
f.debug_struct("ReadNotifier")
25+
.field("read", &self.has_been_read)
26+
.finish()
27+
}
28+
}
29+
30+
impl<B: BufRead> ReadNotifier<B> {
31+
pub(crate) fn new(reader: B, sender: Sender<()>) -> Self {
32+
Self {
33+
reader,
34+
sender,
35+
has_been_read: false,
36+
}
37+
}
38+
}
39+
40+
impl<B: BufRead> BufRead for ReadNotifier<B> {
41+
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
42+
self.project().reader.poll_fill_buf(cx)
43+
}
44+
45+
fn consume(self: Pin<&mut Self>, amt: usize) {
46+
self.project().reader.consume(amt)
47+
}
48+
}
49+
50+
impl<B: Read> Read for ReadNotifier<B> {
51+
fn poll_read(
52+
self: Pin<&mut Self>,
53+
cx: &mut Context<'_>,
54+
buf: &mut [u8],
55+
) -> Poll<io::Result<usize>> {
56+
let this = self.project();
57+
58+
if !*this.has_been_read {
59+
if let Ok(()) = this.sender.try_send(()) {
60+
*this.has_been_read = true;
61+
};
62+
}
63+
64+
this.reader.poll_read(cx, buf)
65+
}
66+
}

src/server/decode.rs

+29-51
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,23 @@
33
use std::str::FromStr;
44

55
use async_std::io::{BufReader, Read, Write};
6-
use async_std::prelude::*;
6+
use async_std::{prelude::*, sync, task};
77
use http_types::headers::{CONTENT_LENGTH, EXPECT, TRANSFER_ENCODING};
88
use http_types::{ensure, ensure_eq, format_err};
99
use http_types::{Body, Method, Request, Url};
1010

1111
use crate::chunked::ChunkedDecoder;
12+
use crate::read_notifier::ReadNotifier;
1213
use crate::{MAX_HEADERS, MAX_HEAD_LENGTH};
1314

1415
const LF: u8 = b'\n';
1516

1617
/// The number returned from httparse when the request is HTTP 1.1
1718
const HTTP_1_1_VERSION: u8 = 1;
1819

20+
const CONTINUE_HEADER_VALUE: &str = "100-continue";
21+
const CONTINUE_RESPONSE: &[u8] = b"HTTP/1.1 100 Continue\r\n\r\n";
22+
1923
/// Decode an HTTP request on the server.
2024
pub async fn decode<IO>(mut io: IO) -> http_types::Result<Option<Request>>
2125
where
@@ -76,8 +80,6 @@ where
7680
req.insert_header(header.name, std::str::from_utf8(header.value)?);
7781
}
7882

79-
handle_100_continue(&req, &mut io).await?;
80-
8183
let content_length = req.header(CONTENT_LENGTH);
8284
let transfer_encoding = req.header(TRANSFER_ENCODING);
8385

@@ -86,11 +88,32 @@ where
8688
"Unexpected Content-Length header"
8789
);
8890

91+
// Establish a channel to wait for the body to be read. This
92+
// allows us to avoid sending 100-continue in situations that
93+
// respond without reading the body, saving clients from uploading
94+
// their body.
95+
let (body_read_sender, body_read_receiver) = sync::channel(1);
96+
97+
if Some(CONTINUE_HEADER_VALUE) == req.header(EXPECT).map(|h| h.as_str()) {
98+
task::spawn(async move {
99+
// If the client expects a 100-continue header, spawn a
100+
// task to wait for the first read attempt on the body.
101+
if let Ok(()) = body_read_receiver.recv().await {
102+
io.write_all(CONTINUE_RESPONSE).await.ok();
103+
};
104+
// Since the sender is moved into the Body, this task will
105+
// finish when the client disconnects, whether or not
106+
// 100-continue was sent.
107+
});
108+
}
109+
89110
// Check for Transfer-Encoding
90111
if let Some(encoding) = transfer_encoding {
91112
if encoding.last().as_str() == "chunked" {
92113
let trailer_sender = req.send_trailers();
93-
let reader = BufReader::new(ChunkedDecoder::new(reader, trailer_sender));
114+
let reader = ChunkedDecoder::new(reader, trailer_sender);
115+
let reader = BufReader::new(reader);
116+
let reader = ReadNotifier::new(reader, body_read_sender);
94117
req.set_body(Body::from_reader(reader, None));
95118
return Ok(Some(req));
96119
}
@@ -100,7 +123,8 @@ where
100123
// Check for Content-Length.
101124
if let Some(len) = content_length {
102125
let len = len.last().as_str().parse::<usize>()?;
103-
req.set_body(Body::from_reader(reader.take(len as u64), Some(len)));
126+
let reader = ReadNotifier::new(reader.take(len as u64), body_read_sender);
127+
req.set_body(Body::from_reader(reader, Some(len)));
104128
}
105129

106130
Ok(Some(req))
@@ -129,20 +153,6 @@ fn url_from_httparse_req(req: &httparse::Request<'_, '_>) -> http_types::Result<
129153
}
130154
}
131155

132-
const EXPECT_HEADER_VALUE: &str = "100-continue";
133-
const EXPECT_RESPONSE: &[u8] = b"HTTP/1.1 100 Continue\r\n\r\n";
134-
135-
async fn handle_100_continue<IO>(req: &Request, io: &mut IO) -> http_types::Result<()>
136-
where
137-
IO: Write + Unpin,
138-
{
139-
if let Some(EXPECT_HEADER_VALUE) = req.header(EXPECT).map(|h| h.as_str()) {
140-
io.write_all(EXPECT_RESPONSE).await?;
141-
}
142-
143-
Ok(())
144-
}
145-
146156
#[cfg(test)]
147157
mod tests {
148158
use super::*;
@@ -207,36 +217,4 @@ mod tests {
207217
},
208218
)
209219
}
210-
211-
#[test]
212-
fn handle_100_continue_does_nothing_with_no_expect_header() {
213-
let request = Request::new(Method::Get, Url::parse("x:").unwrap());
214-
let mut io = async_std::io::Cursor::new(vec![]);
215-
let result = async_std::task::block_on(handle_100_continue(&request, &mut io));
216-
assert_eq!(std::str::from_utf8(&io.into_inner()).unwrap(), "");
217-
assert!(result.is_ok());
218-
}
219-
220-
#[test]
221-
fn handle_100_continue_sends_header_if_expects_is_exactly_right() {
222-
let mut request = Request::new(Method::Get, Url::parse("x:").unwrap());
223-
request.append_header("expect", "100-continue");
224-
let mut io = async_std::io::Cursor::new(vec![]);
225-
let result = async_std::task::block_on(handle_100_continue(&request, &mut io));
226-
assert_eq!(
227-
std::str::from_utf8(&io.into_inner()).unwrap(),
228-
"HTTP/1.1 100 Continue\r\n\r\n"
229-
);
230-
assert!(result.is_ok());
231-
}
232-
233-
#[test]
234-
fn handle_100_continue_does_nothing_if_expects_header_is_wrong() {
235-
let mut request = Request::new(Method::Get, Url::parse("x:").unwrap());
236-
request.append_header("expect", "110-extensions-not-allowed");
237-
let mut io = async_std::io::Cursor::new(vec![]);
238-
let result = async_std::task::block_on(handle_100_continue(&request, &mut io));
239-
assert_eq!(std::str::from_utf8(&io.into_inner()).unwrap(), "");
240-
assert!(result.is_ok());
241-
}
242220
}

tests/continue.rs

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
use async_dup::{Arc, Mutex};
2+
use async_std::io::{Cursor, SeekFrom};
3+
use async_std::{prelude::*, task};
4+
use duplexify::Duplex;
5+
use http_types::Result;
6+
use std::time::Duration;
7+
8+
const REQUEST_WITH_EXPECT: &[u8] = b"POST / HTTP/1.1\r\n\
9+
Host: example.com\r\n\
10+
Content-Length: 10\r\n\
11+
Expect: 100-continue\r\n\r\n";
12+
13+
const SLEEP_DURATION: Duration = std::time::Duration::from_millis(100);
14+
#[async_std::test]
15+
async fn test_with_expect_when_reading_body() -> Result<()> {
16+
let client_str: Vec<u8> = REQUEST_WITH_EXPECT.to_vec();
17+
let server_str: Vec<u8> = vec![];
18+
19+
let mut client = Arc::new(Mutex::new(Cursor::new(client_str)));
20+
let server = Arc::new(Mutex::new(Cursor::new(server_str)));
21+
22+
let mut request = async_h1::server::decode(Duplex::new(client.clone(), server.clone()))
23+
.await?
24+
.unwrap();
25+
26+
task::sleep(SLEEP_DURATION).await; //prove we're not just testing before we've written
27+
28+
{
29+
let lock = server.lock();
30+
assert_eq!("", std::str::from_utf8(lock.get_ref())?); //we haven't written yet
31+
};
32+
33+
let mut buf = vec![0u8; 1];
34+
let bytes = request.read(&mut buf).await?; //this triggers the 100-continue even though there's nothing to read yet
35+
assert_eq!(bytes, 0); // normally we'd actually be waiting for the end of the buffer, but this lets us test this sequentially
36+
37+
task::sleep(SLEEP_DURATION).await; // just long enough to wait for the channel and io
38+
39+
{
40+
let lock = server.lock();
41+
assert_eq!(
42+
"HTTP/1.1 100 Continue\r\n\r\n",
43+
std::str::from_utf8(lock.get_ref())?
44+
);
45+
};
46+
47+
client.write_all(b"0123456789").await?;
48+
client
49+
.seek(SeekFrom::Start(REQUEST_WITH_EXPECT.len() as u64))
50+
.await?;
51+
52+
assert_eq!("0123456789", request.body_string().await?);
53+
54+
Ok(())
55+
}
56+
57+
#[async_std::test]
58+
async fn test_without_expect_when_not_reading_body() -> Result<()> {
59+
let client_str: Vec<u8> = REQUEST_WITH_EXPECT.to_vec();
60+
let server_str: Vec<u8> = vec![];
61+
62+
let client = Arc::new(Mutex::new(Cursor::new(client_str)));
63+
let server = Arc::new(Mutex::new(Cursor::new(server_str)));
64+
65+
async_h1::server::decode(Duplex::new(client.clone(), server.clone()))
66+
.await?
67+
.unwrap();
68+
69+
task::sleep(SLEEP_DURATION).await; // just long enough to wait for the channel
70+
71+
let server_lock = server.lock();
72+
assert_eq!("", std::str::from_utf8(server_lock.get_ref())?); // we haven't written 100-continue
73+
74+
Ok(())
75+
}

0 commit comments

Comments
 (0)