1
- use futures:: FutureExt ;
1
+ use futures:: { FutureExt , TryFuture } ;
2
2
use std:: future:: Future ;
3
3
use std:: pin:: Pin ;
4
- use std:: task:: { Context , Poll } ;
4
+ use std:: sync:: atomic:: AtomicBool ;
5
+ use std:: sync:: Arc ;
6
+ use std:: task:: { Context , Poll , Wake , Waker } ;
5
7
6
8
/// Future extension helpers that are useful for tests
7
9
pub trait TestFuture : Future {
@@ -15,9 +17,140 @@ pub trait TestFuture: Future {
15
17
{
16
18
Drive {
17
19
driver : self ,
18
- future : Box :: pin ( other) ,
20
+ future : other. wakened ( ) ,
19
21
}
20
22
}
23
+
24
+ fn wakened ( self ) -> Wakened < Self >
25
+ where
26
+ Self : Sized ,
27
+ {
28
+ Wakened {
29
+ future : Box :: pin ( self ) ,
30
+ woken : Arc :: new ( AtomicBool :: new ( true ) ) ,
31
+ }
32
+ }
33
+ }
34
+
35
+ /// Wraps futures::future::join to ensure that the futures are only polled if they are woken.
36
+ pub fn join < Fut1 , Fut2 > (
37
+ future1 : Fut1 ,
38
+ future2 : Fut2 ,
39
+ ) -> futures:: future:: Join < Wakened < Fut1 > , Wakened < Fut2 > >
40
+ where
41
+ Fut1 : Future ,
42
+ Fut2 : Future ,
43
+ {
44
+ futures:: future:: join ( future1. wakened ( ) , future2. wakened ( ) )
45
+ }
46
+
47
+ /// Wraps futures::future::join3 to ensure that the futures are only polled if they are woken.
48
+ pub fn join3 < Fut1 , Fut2 , Fut3 > (
49
+ future1 : Fut1 ,
50
+ future2 : Fut2 ,
51
+ future3 : Fut3 ,
52
+ ) -> futures:: future:: Join3 < Wakened < Fut1 > , Wakened < Fut2 > , Wakened < Fut3 > >
53
+ where
54
+ Fut1 : Future ,
55
+ Fut2 : Future ,
56
+ Fut3 : Future ,
57
+ {
58
+ futures:: future:: join3 ( future1. wakened ( ) , future2. wakened ( ) , future3. wakened ( ) )
59
+ }
60
+
61
+ /// Wraps futures::future::join4 to ensure that the futures are only polled if they are woken.
62
+ pub fn join4 < Fut1 , Fut2 , Fut3 , Fut4 > (
63
+ future1 : Fut1 ,
64
+ future2 : Fut2 ,
65
+ future3 : Fut3 ,
66
+ future4 : Fut4 ,
67
+ ) -> futures:: future:: Join4 < Wakened < Fut1 > , Wakened < Fut2 > , Wakened < Fut3 > , Wakened < Fut4 > >
68
+ where
69
+ Fut1 : Future ,
70
+ Fut2 : Future ,
71
+ Fut3 : Future ,
72
+ Fut4 : Future ,
73
+ {
74
+ futures:: future:: join4 (
75
+ future1. wakened ( ) ,
76
+ future2. wakened ( ) ,
77
+ future3. wakened ( ) ,
78
+ future4. wakened ( ) ,
79
+ )
80
+ }
81
+
82
+ /// Wraps futures::future::try_join to ensure that the futures are only polled if they are woken.
83
+ pub fn try_join < Fut1 , Fut2 > (
84
+ future1 : Fut1 ,
85
+ future2 : Fut2 ,
86
+ ) -> futures:: future:: TryJoin < Wakened < Fut1 > , Wakened < Fut2 > >
87
+ where
88
+ Fut1 : futures:: future:: TryFuture + Future ,
89
+ Fut2 : Future ,
90
+ Wakened < Fut1 > : futures:: future:: TryFuture ,
91
+ Wakened < Fut2 > : futures:: future:: TryFuture < Error = <Wakened < Fut1 > as TryFuture >:: Error > ,
92
+ {
93
+ futures:: future:: try_join ( future1. wakened ( ) , future2. wakened ( ) )
94
+ }
95
+
96
+ /// Wraps futures::future::select to ensure that the futures are only polled if they are woken.
97
+ pub fn select < A , B > ( future1 : A , future2 : B ) -> futures:: future:: Select < Wakened < A > , Wakened < B > >
98
+ where
99
+ A : Future + Unpin ,
100
+ B : Future + Unpin ,
101
+ {
102
+ futures:: future:: select ( future1. wakened ( ) , future2. wakened ( ) )
103
+ }
104
+
105
+ /// Wraps futures::future::join_all to ensure that the futures are only polled if they are woken.
106
+ pub fn join_all < I > ( iter : I ) -> futures:: future:: JoinAll < Wakened < I :: Item > >
107
+ where
108
+ I : IntoIterator ,
109
+ I :: Item : Future ,
110
+ {
111
+ futures:: future:: join_all ( iter. into_iter ( ) . map ( |f| f. wakened ( ) ) )
112
+ }
113
+
114
+ /// A future that only polls the inner future if it has been woken (after the initial poll).
115
+ pub struct Wakened < T > {
116
+ future : Pin < Box < T > > ,
117
+ woken : Arc < AtomicBool > ,
118
+ }
119
+
120
+ /// A future that only polls the inner future if it has been woken (after the initial poll).
121
+ impl < T > Future for Wakened < T >
122
+ where
123
+ T : Future ,
124
+ {
125
+ type Output = T :: Output ;
126
+
127
+ fn poll ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
128
+ let this = self . get_mut ( ) ;
129
+ if !this. woken . load ( std:: sync:: atomic:: Ordering :: SeqCst ) {
130
+ return Poll :: Pending ;
131
+ }
132
+ this. woken . store ( false , std:: sync:: atomic:: Ordering :: SeqCst ) ;
133
+ let my_waker = IfWokenWaker {
134
+ inner : cx. waker ( ) . clone ( ) ,
135
+ wakened : this. woken . clone ( ) ,
136
+ } ;
137
+ let my_waker = Arc :: new ( my_waker) . into ( ) ;
138
+ let mut cx = Context :: from_waker ( & my_waker) ;
139
+ this. future . as_mut ( ) . poll ( & mut cx)
140
+ }
141
+ }
142
+
143
+ impl Wake for IfWokenWaker {
144
+ fn wake ( self : Arc < Self > ) {
145
+ self . wakened
146
+ . store ( true , std:: sync:: atomic:: Ordering :: SeqCst ) ;
147
+ self . inner . wake_by_ref ( ) ;
148
+ }
149
+ }
150
+
151
+ struct IfWokenWaker {
152
+ inner : Waker ,
153
+ wakened : Arc < AtomicBool > ,
21
154
}
22
155
23
156
impl < T : Future > TestFuture for T { }
@@ -29,7 +162,7 @@ impl<T: Future> TestFuture for T {}
29
162
/// This is useful for H2 futures that also require the connection to be polled.
30
163
pub struct Drive < ' a , T , U > {
31
164
driver : & ' a mut T ,
32
- future : Pin < Box < U > > ,
165
+ future : Wakened < U > ,
33
166
}
34
167
35
168
impl < ' a , T , U > Future for Drive < ' a , T , U >
0 commit comments