Skip to content

Commit e86bbaf

Browse files
committed
h2-support,h2-tests: add tools to ensure wake
This commit adds wrappers around futures::future helpers and augments TestFuture to ensure that the underlying futures are notified before they are polled. This helps to catch bugs where there are missing notify calls or bad handling of the waker. The commit then extends the tests to use these helpers instead of the library functions from futures. It also ammends the client_requests::recv_too_big_headers test to no longer use the tokio spawned tasks that were added in hyperium#791.
1 parent 4782856 commit e86bbaf

11 files changed

+151
-38
lines changed

tests/h2-support/src/future_ext.rs

+137-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
use futures::FutureExt;
1+
use futures::{FutureExt, TryFuture};
22
use std::future::Future;
33
use std::pin::Pin;
4-
use std::task::{Context, Poll};
4+
use std::sync::atomic::AtomicBool;
5+
use std::sync::Arc;
6+
use std::task::{Context, Poll, Wake, Waker};
57

68
/// Future extension helpers that are useful for tests
79
pub trait TestFuture: Future {
@@ -15,9 +17,140 @@ pub trait TestFuture: Future {
1517
{
1618
Drive {
1719
driver: self,
18-
future: Box::pin(other),
20+
future: other.wakened(),
1921
}
2022
}
23+
24+
fn wakened(self) -> Wakened<Self>
25+
where
26+
Self: Sized,
27+
{
28+
Wakened {
29+
future: Box::pin(self),
30+
woken: Arc::new(AtomicBool::new(true)),
31+
}
32+
}
33+
}
34+
35+
/// Wraps futures::future::join to ensure that the futures are only polled if they are woken.
36+
pub fn join<Fut1, Fut2>(
37+
future1: Fut1,
38+
future2: Fut2,
39+
) -> futures::future::Join<Wakened<Fut1>, Wakened<Fut2>>
40+
where
41+
Fut1: Future,
42+
Fut2: Future,
43+
{
44+
futures::future::join(future1.wakened(), future2.wakened())
45+
}
46+
47+
/// Wraps futures::future::join3 to ensure that the futures are only polled if they are woken.
48+
pub fn join3<Fut1, Fut2, Fut3>(
49+
future1: Fut1,
50+
future2: Fut2,
51+
future3: Fut3,
52+
) -> futures::future::Join3<Wakened<Fut1>, Wakened<Fut2>, Wakened<Fut3>>
53+
where
54+
Fut1: Future,
55+
Fut2: Future,
56+
Fut3: Future,
57+
{
58+
futures::future::join3(future1.wakened(), future2.wakened(), future3.wakened())
59+
}
60+
61+
/// Wraps futures::future::join4 to ensure that the futures are only polled if they are woken.
62+
pub fn join4<Fut1, Fut2, Fut3, Fut4>(
63+
future1: Fut1,
64+
future2: Fut2,
65+
future3: Fut3,
66+
future4: Fut4,
67+
) -> futures::future::Join4<Wakened<Fut1>, Wakened<Fut2>, Wakened<Fut3>, Wakened<Fut4>>
68+
where
69+
Fut1: Future,
70+
Fut2: Future,
71+
Fut3: Future,
72+
Fut4: Future,
73+
{
74+
futures::future::join4(
75+
future1.wakened(),
76+
future2.wakened(),
77+
future3.wakened(),
78+
future4.wakened(),
79+
)
80+
}
81+
82+
/// Wraps futures::future::try_join to ensure that the futures are only polled if they are woken.
83+
pub fn try_join<Fut1, Fut2>(
84+
future1: Fut1,
85+
future2: Fut2,
86+
) -> futures::future::TryJoin<Wakened<Fut1>, Wakened<Fut2>>
87+
where
88+
Fut1: futures::future::TryFuture + Future,
89+
Fut2: Future,
90+
Wakened<Fut1>: futures::future::TryFuture,
91+
Wakened<Fut2>: futures::future::TryFuture<Error = <Wakened<Fut1> as TryFuture>::Error>,
92+
{
93+
futures::future::try_join(future1.wakened(), future2.wakened())
94+
}
95+
96+
/// Wraps futures::future::select to ensure that the futures are only polled if they are woken.
97+
pub fn select<A, B>(future1: A, future2: B) -> futures::future::Select<Wakened<A>, Wakened<B>>
98+
where
99+
A: Future + Unpin,
100+
B: Future + Unpin,
101+
{
102+
futures::future::select(future1.wakened(), future2.wakened())
103+
}
104+
105+
/// Wraps futures::future::join_all to ensure that the futures are only polled if they are woken.
106+
pub fn join_all<I>(iter: I) -> futures::future::JoinAll<Wakened<I::Item>>
107+
where
108+
I: IntoIterator,
109+
I::Item: Future,
110+
{
111+
futures::future::join_all(iter.into_iter().map(|f| f.wakened()))
112+
}
113+
114+
/// A future that only polls the inner future if it has been woken (after the initial poll).
115+
pub struct Wakened<T> {
116+
future: Pin<Box<T>>,
117+
woken: Arc<AtomicBool>,
118+
}
119+
120+
/// A future that only polls the inner future if it has been woken (after the initial poll).
121+
impl<T> Future for Wakened<T>
122+
where
123+
T: Future,
124+
{
125+
type Output = T::Output;
126+
127+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
128+
let this = self.get_mut();
129+
if !this.woken.load(std::sync::atomic::Ordering::SeqCst) {
130+
return Poll::Pending;
131+
}
132+
this.woken.store(false, std::sync::atomic::Ordering::SeqCst);
133+
let my_waker = IfWokenWaker {
134+
inner: cx.waker().clone(),
135+
wakened: this.woken.clone(),
136+
};
137+
let my_waker = Arc::new(my_waker).into();
138+
let mut cx = Context::from_waker(&my_waker);
139+
this.future.as_mut().poll(&mut cx)
140+
}
141+
}
142+
143+
impl Wake for IfWokenWaker {
144+
fn wake(self: Arc<Self>) {
145+
self.wakened
146+
.store(true, std::sync::atomic::Ordering::SeqCst);
147+
self.inner.wake_by_ref();
148+
}
149+
}
150+
151+
struct IfWokenWaker {
152+
inner: Waker,
153+
wakened: Arc<AtomicBool>,
21154
}
22155

23156
impl<T: Future> TestFuture for T {}
@@ -29,7 +162,7 @@ impl<T: Future> TestFuture for T {}
29162
/// This is useful for H2 futures that also require the connection to be polled.
30163
pub struct Drive<'a, T, U> {
31164
driver: &'a mut T,
32-
future: Pin<Box<U>>,
165+
future: Wakened<U>,
33166
}
34167

35168
impl<'a, T, U> Future for Drive<'a, T, U>

tests/h2-support/src/prelude.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ pub use {bytes, futures, http, tokio::io as tokio_io, tracing, tracing_subscribe
3535
pub use futures::{Future, Sink, Stream};
3636

3737
// And our Future extensions
38-
pub use super::future_ext::TestFuture;
38+
pub use super::future_ext::{TestFuture, join, join_all, select, join3, join4, try_join};
3939

4040
// Our client_ext helpers
4141
pub use super::client_ext::SendRequestExt;

tests/h2-tests/tests/client_request.rs

+8-14
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use futures::future::{join, join_all, ready, select, Either};
1+
use futures::future::{ready, Either};
22
use futures::stream::FuturesUnordered;
33
use futures::StreamExt;
44
use h2_support::prelude::*;
@@ -849,7 +849,7 @@ async fn recv_too_big_headers() {
849849
};
850850

851851
let client = async move {
852-
let (mut client, conn) = client::Builder::new()
852+
let (mut client, mut conn) = client::Builder::new()
853853
.max_header_list_size(10)
854854
.handshake::<_, Bytes>(io)
855855
.await
@@ -863,30 +863,24 @@ async fn recv_too_big_headers() {
863863
let req1 = client.send_request(request, true);
864864
// Spawn tasks to ensure that the error wakes up tasks that are blocked
865865
// waiting for a response.
866-
let req1 = tokio::spawn(async move {
866+
let req1 = async move {
867867
let err = req1.expect("send_request").0.await.expect_err("response1");
868868
assert_eq!(err.reason(), Some(Reason::REFUSED_STREAM));
869-
});
869+
};
870870

871871
let request = Request::builder()
872872
.uri("https://http2.akamai.com/")
873873
.body(())
874874
.unwrap();
875875

876876
let req2 = client.send_request(request, true);
877-
let req2 = tokio::spawn(async move {
877+
let req2 = async move {
878878
let err = req2.expect("send_request").0.await.expect_err("response2");
879879
assert_eq!(err.reason(), Some(Reason::REFUSED_STREAM));
880-
});
880+
};
881881

882-
let conn = tokio::spawn(async move {
883-
conn.await.expect("client");
884-
});
885-
for err in join_all([req1, req2, conn]).await {
886-
if let Some(err) = err.err().and_then(|err| err.try_into_panic().ok()) {
887-
std::panic::resume_unwind(err);
888-
}
889-
}
882+
883+
conn.drive(join(req1, req2)).await;
890884
};
891885

892886
join(srv, client).await;

tests/h2-tests/tests/codec_read.rs

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use futures::future::join;
21
use h2_support::prelude::*;
32

43
#[tokio::test]

tests/h2-tests/tests/codec_write.rs

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use futures::future::join;
21
use h2_support::prelude::*;
32

43
#[tokio::test]

tests/h2-tests/tests/flow_control.rs

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use futures::future::{join, join4};
21
use futures::{StreamExt, TryStreamExt};
32
use h2_support::prelude::*;
43
use h2_support::util::yield_once;

tests/h2-tests/tests/ping_pong.rs

-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
use futures::channel::oneshot;
2-
use futures::future::join;
32
use futures::StreamExt;
43
use h2_support::assert_ping;
54
use h2_support::prelude::*;

tests/h2-tests/tests/prioritization.rs

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use futures::future::{join, select};
21
use futures::{pin_mut, FutureExt, StreamExt};
32

43
use h2_support::prelude::*;

tests/h2-tests/tests/push_promise.rs

+4-12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
use std::iter::FromIterator;
2-
3-
use futures::{future::join, FutureExt as _, StreamExt, TryStreamExt};
1+
use futures::{StreamExt, TryStreamExt};
42
use h2_support::prelude::*;
53

64
#[tokio::test]
@@ -52,15 +50,9 @@ async fn recv_push_works() {
5250
let ps: Vec<_> = p.collect().await;
5351
assert_eq!(1, ps.len())
5452
};
55-
// Use a FuturesUnordered to poll both tasks but only poll them
56-
// if they have been notified.
57-
let tasks = futures::stream::FuturesUnordered::from_iter([
58-
check_resp_status.boxed(),
59-
check_pushed_response.boxed(),
60-
])
61-
.collect::<()>();
62-
63-
h2.drive(tasks).await;
53+
54+
h2.drive(join(check_resp_status, check_pushed_response))
55+
.await;
6456
};
6557

6658
join(mock, h2).await;

tests/h2-tests/tests/server.rs

-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#![deny(warnings)]
22

3-
use futures::future::join;
43
use futures::StreamExt;
54
use h2_support::prelude::*;
65
use tokio::io::AsyncWriteExt;

tests/h2-tests/tests/stream_states.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#![deny(warnings)]
22

3-
use futures::future::{join, join3, lazy, try_join};
3+
use futures::future::lazy;
44
use futures::{FutureExt, StreamExt, TryStreamExt};
55
use h2_support::prelude::*;
66
use h2_support::util::yield_once;

0 commit comments

Comments
 (0)