Skip to content

Commit 112d15d

Browse files
committed
feat(pool): add a Singleton pool type
1 parent 3021828 commit 112d15d

File tree

4 files changed

+306
-0
lines changed

4 files changed

+306
-0
lines changed

Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ futures-util = { version = "0.3.16", default-features = false, features = ["allo
4242
http-body-util = "0.1.0"
4343
tokio = { version = "1", features = ["macros", "test-util", "signal"] }
4444
tokio-test = "0.4"
45+
tower-test = "0.4"
4546
pretty_env_logger = "0.5"
4647

4748
[target.'cfg(any(target_os = "linux", target_os = "macos"))'.dev-dependencies]
@@ -60,6 +61,7 @@ default = []
6061
full = [
6162
"client",
6263
"client-legacy",
64+
"client-pool",
6365
"client-proxy",
6466
"client-proxy-system",
6567
"server",
@@ -74,6 +76,7 @@ full = [
7476

7577
client = ["hyper/client", "tokio/net", "dep:tracing", "dep:futures-channel", "dep:tower-service"]
7678
client-legacy = ["client", "dep:socket2", "tokio/sync", "dep:libc", "dep:futures-util"]
79+
client-pool = []
7780
client-proxy = ["client", "dep:base64", "dep:ipnet", "dep:percent-encoding"]
7881
client-proxy-system = ["dep:system-configuration", "dep:windows-registry"]
7982

src/client/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,8 @@
44
#[cfg(feature = "client-legacy")]
55
pub mod legacy;
66

7+
#[cfg(feature = "client-pool")]
8+
pub mod pool;
9+
710
#[cfg(feature = "client-proxy")]
811
pub mod proxy;

src/client/pool/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
//! Composable pool services
2+
3+
mod singleton;
4+
5+
pub use self::singleton::Singleton;

src/client/pool/singleton.rs

Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
use std::future::Future;
2+
use std::pin::Pin;
3+
use std::sync::{Arc, Mutex, Weak};
4+
use std::task::{self, Poll};
5+
6+
use futures_core::ready;
7+
use pin_project_lite::pin_project;
8+
use tokio::sync::oneshot;
9+
use tower_service::Service;
10+
11+
type BoxError = Box<dyn std::error::Error + Send + Sync>;
12+
13+
/// A singleton pool over an inner service.
14+
#[derive(Clone, Debug)]
15+
pub struct Singleton<M, Dst>
16+
where
17+
M: Service<Dst>,
18+
{
19+
mk_svc: M,
20+
state: Arc<Mutex<State<M::Response>>>,
21+
}
22+
23+
pin_project! {
24+
#[project = SingletonFutureProj]
25+
pub enum SingletonFuture<F, S> {
26+
Driving {
27+
#[pin]
28+
future: F,
29+
singleton: DitchGuard<S>,
30+
},
31+
Waiting {
32+
rx: oneshot::Receiver<S>,
33+
},
34+
Made {
35+
svc: Option<S>,
36+
},
37+
}
38+
}
39+
40+
struct DitchGuard<S>(Weak<Mutex<State<S>>>);
41+
42+
#[derive(Debug)]
43+
enum State<S> {
44+
Empty,
45+
Making(Vec<oneshot::Sender<S>>),
46+
Made(S),
47+
}
48+
49+
impl<M, Target> Singleton<M, Target>
50+
where
51+
M: Service<Target>,
52+
M::Response: Clone,
53+
{
54+
/// Create a new singleton pool over an inner make service.
55+
pub fn new(mk_svc: M) -> Self {
56+
Singleton {
57+
mk_svc,
58+
state: Arc::new(Mutex::new(State::Empty)),
59+
}
60+
}
61+
62+
// pub fn reset?
63+
// pub fn retain?
64+
}
65+
66+
impl<M, Target> Service<Target> for Singleton<M, Target>
67+
where
68+
M: Service<Target>,
69+
M::Response: Clone,
70+
M::Error: Into<BoxError>,
71+
{
72+
type Response = M::Response;
73+
type Error = SingletonError;
74+
type Future = SingletonFuture<M::Future, M::Response>;
75+
76+
fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
77+
if let State::Empty = *self.state.lock().unwrap() {
78+
return self
79+
.mk_svc
80+
.poll_ready(cx)
81+
.map_err(|e| SingletonError(e.into()));
82+
}
83+
Poll::Ready(Ok(()))
84+
}
85+
86+
fn call(&mut self, dst: Target) -> Self::Future {
87+
let mut locked = self.state.lock().unwrap();
88+
match *locked {
89+
State::Empty => {
90+
let fut = self.mk_svc.call(dst);
91+
*locked = State::Making(Vec::new());
92+
SingletonFuture::Driving {
93+
future: fut,
94+
singleton: DitchGuard(Arc::downgrade(&self.state)),
95+
}
96+
}
97+
State::Making(ref mut waiters) => {
98+
let (tx, rx) = oneshot::channel();
99+
waiters.push(tx);
100+
SingletonFuture::Waiting { rx }
101+
}
102+
State::Made(ref svc) => SingletonFuture::Made {
103+
svc: Some(svc.clone()),
104+
},
105+
}
106+
}
107+
}
108+
109+
impl<F, S, E> Future for SingletonFuture<F, S>
110+
where
111+
F: Future<Output = Result<S, E>>,
112+
E: Into<BoxError>,
113+
S: Clone,
114+
{
115+
type Output = Result<S, SingletonError>;
116+
117+
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
118+
match self.project() {
119+
SingletonFutureProj::Driving { future, singleton } => {
120+
match ready!(future.poll(cx)) {
121+
Ok(svc) => {
122+
if let Some(state) = singleton.0.upgrade() {
123+
let mut locked = state.lock().unwrap();
124+
singleton.0 = Weak::new();
125+
match std::mem::replace(&mut *locked, State::Made(svc.clone())) {
126+
State::Making(waiters) => {
127+
for tx in waiters {
128+
let _ = tx.send(svc.clone());
129+
}
130+
}
131+
State::Empty | State::Made(_) => {
132+
// shouldn't happen!
133+
}
134+
}
135+
}
136+
Poll::Ready(Ok(svc))
137+
}
138+
Err(e) => {
139+
if let Some(state) = singleton.0.upgrade() {
140+
let mut locked = state.lock().unwrap();
141+
singleton.0 = Weak::new();
142+
*locked = State::Empty;
143+
}
144+
Poll::Ready(Err(SingletonError(e.into())))
145+
}
146+
}
147+
}
148+
SingletonFutureProj::Waiting { rx } => match ready!(Pin::new(rx).poll(cx)) {
149+
Ok(svc) => Poll::Ready(Ok(svc)),
150+
Err(_canceled) => Poll::Ready(Err(SingletonError(Canceled.into()))),
151+
},
152+
SingletonFutureProj::Made { svc } => Poll::Ready(Ok(svc.take().unwrap())),
153+
}
154+
}
155+
}
156+
157+
impl<S> Drop for DitchGuard<S> {
158+
fn drop(&mut self) {
159+
if let Some(state) = self.0.upgrade() {
160+
if let Ok(mut locked) = state.lock() {
161+
*locked = State::Empty;
162+
}
163+
}
164+
}
165+
}
166+
167+
// An opaque error type. By not exposing the type, nor being specifically
168+
// Box<dyn Error>, we can _change_ the type once we no longer need the Canceled
169+
// error type. This will be possible with the refactor to baton passing.
170+
#[derive(Debug)]
171+
pub struct SingletonError(BoxError);
172+
173+
impl std::fmt::Display for SingletonError {
174+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175+
f.write_str("singleton connection error")
176+
}
177+
}
178+
179+
impl std::error::Error for SingletonError {
180+
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
181+
Some(&*self.0)
182+
}
183+
}
184+
185+
#[derive(Debug)]
186+
struct Canceled;
187+
188+
impl std::fmt::Display for Canceled {
189+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190+
f.write_str("singleton connection canceled")
191+
}
192+
}
193+
194+
impl std::error::Error for Canceled {}
195+
196+
#[cfg(test)]
197+
mod tests {
198+
use std::future::Future;
199+
use std::pin::Pin;
200+
use std::task::Poll;
201+
202+
use tower_service::Service;
203+
204+
use super::Singleton;
205+
206+
#[tokio::test]
207+
async fn first_call_drives_subsequent_wait() {
208+
let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>();
209+
210+
let mut singleton = Singleton::new(mock_svc);
211+
212+
handle.allow(1);
213+
crate::common::future::poll_fn(|cx| singleton.poll_ready(cx))
214+
.await
215+
.unwrap();
216+
// First call: should go into Driving
217+
let fut1 = singleton.call(());
218+
// Second call: should go into Waiting
219+
let fut2 = singleton.call(());
220+
221+
// Expect exactly one request to the inner service
222+
let ((), send_response) = handle.next_request().await.unwrap();
223+
send_response.send_response("svc");
224+
225+
// Both futures should resolve to the same value
226+
assert_eq!(fut1.await.unwrap(), "svc");
227+
assert_eq!(fut2.await.unwrap(), "svc");
228+
}
229+
230+
#[tokio::test]
231+
async fn made_state_returns_immediately() {
232+
let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>();
233+
let mut singleton = Singleton::new(mock_svc);
234+
235+
handle.allow(1);
236+
crate::common::future::poll_fn(|cx| singleton.poll_ready(cx))
237+
.await
238+
.unwrap();
239+
// Drive first call to completion
240+
let fut1 = singleton.call(());
241+
let ((), send_response) = handle.next_request().await.unwrap();
242+
send_response.send_response("svc");
243+
assert_eq!(fut1.await.unwrap(), "svc");
244+
245+
// Second call should not hit inner service
246+
let res = singleton.call(()).await.unwrap();
247+
assert_eq!(res, "svc");
248+
}
249+
250+
#[tokio::test]
251+
async fn cancel_waiter_does_not_affect_others() {
252+
let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>();
253+
let mut singleton = Singleton::new(mock_svc);
254+
255+
crate::common::future::poll_fn(|cx| singleton.poll_ready(cx))
256+
.await
257+
.unwrap();
258+
let fut1 = singleton.call(());
259+
let fut2 = singleton.call(());
260+
drop(fut2); // cancel one waiter
261+
262+
let ((), send_response) = handle.next_request().await.unwrap();
263+
send_response.send_response("svc");
264+
265+
assert_eq!(fut1.await.unwrap(), "svc");
266+
}
267+
268+
// TODO: this should be able to be improved with a cooperative baton refactor
269+
#[tokio::test]
270+
async fn cancel_driver_cancels_all() {
271+
let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>();
272+
let mut singleton = Singleton::new(mock_svc);
273+
274+
crate::common::future::poll_fn(|cx| singleton.poll_ready(cx))
275+
.await
276+
.unwrap();
277+
let mut fut1 = singleton.call(());
278+
let fut2 = singleton.call(());
279+
280+
// poll driver just once, and then drop
281+
crate::common::future::poll_fn(move |cx| {
282+
let _ = Pin::new(&mut fut1).poll(cx);
283+
Poll::Ready(())
284+
})
285+
.await;
286+
287+
let ((), send_response) = handle.next_request().await.unwrap();
288+
send_response.send_response("svc");
289+
290+
assert_eq!(
291+
fut2.await.unwrap_err().0.to_string(),
292+
"singleton connection canceled"
293+
);
294+
}
295+
}

0 commit comments

Comments
 (0)