Skip to content

Commit e7324e8

Browse files
committed
set nodelay for tcp stream
1 parent 9791bc3 commit e7324e8

File tree

3 files changed

+112
-122
lines changed

3 files changed

+112
-122
lines changed

src/connect.rs

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
#[cfg(unix)]
2+
use async_std::os::unix::net::UnixStream;
3+
4+
use crate::Socket;
5+
use async_std::io;
6+
use async_std::net::TcpStream;
7+
use std::future::Future;
8+
use std::time::Duration;
9+
use tokio_postgres::config::{Config, Host};
10+
use tokio_postgres::tls::{MakeTlsConnect, TlsConnect};
11+
use tokio_postgres::{Client, Connection};
12+
13+
/// Default socket port of postgres.
14+
const DEFAULT_PORT: u16 = 5432;
15+
16+
/// Connect to postgres server with a tls connector.
17+
///
18+
/// ```rust
19+
/// use async_postgres::connect;
20+
///
21+
/// use std::error::Error;
22+
/// use async_std::task::spawn;
23+
///
24+
/// async fn play() -> Result<(), Box<dyn Error>> {
25+
/// let url = "host=localhost user=postgres";
26+
/// let (client, conn) = connect(url.parse()?).await?;
27+
/// spawn(conn);
28+
/// let row = client.query_one("SELECT * FROM user WHERE id=$1", &[&0]).await?;
29+
/// let value: &str = row.get(0);
30+
/// println!("value: {}", value);
31+
/// Ok(())
32+
/// }
33+
/// ```
34+
pub async fn connect_tls<T>(
35+
config: Config,
36+
mut tls: T,
37+
) -> io::Result<(Client, Connection<Socket, T::Stream>)>
38+
where
39+
T: MakeTlsConnect<Socket>,
40+
{
41+
let mut error = io::Error::new(io::ErrorKind::Other, "host missing");
42+
let mut ports = config.get_ports().iter().cloned();
43+
for host in config.get_hosts() {
44+
let port = ports.next().unwrap_or(DEFAULT_PORT);
45+
let hostname = match host {
46+
#[cfg(unix)]
47+
Host::Unix(path) => path.as_os_str().to_str().unwrap_or(""),
48+
Host::Tcp(tcp) => tcp.as_str(),
49+
};
50+
let connector = tls
51+
.make_tls_connect(hostname)
52+
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
53+
match connect_once(&config, host, port, connector).await {
54+
Err(err) => error = err,
55+
ok => return ok,
56+
}
57+
}
58+
Err(error)
59+
}
60+
61+
async fn connect_once<T>(
62+
config: &Config,
63+
host: &Host,
64+
port: u16,
65+
tls: T,
66+
) -> io::Result<(Client, Connection<Socket, T::Stream>)>
67+
where
68+
T: TlsConnect<Socket>,
69+
{
70+
let dur = config.get_connect_timeout();
71+
let socket = connect_socket(host, port, dur).await?;
72+
config
73+
.connect_raw(socket, tls)
74+
.await
75+
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))
76+
}
77+
78+
async fn connect_socket(
79+
host: &Host,
80+
port: u16,
81+
dur: Option<&Duration>,
82+
) -> io::Result<Socket> {
83+
match host {
84+
#[cfg(unix)]
85+
Host::Unix(path) => {
86+
let sock = path.join(format!(".s.PGSQL.{}", port));
87+
let fut = UnixStream::connect(sock);
88+
let socket = timeout(dur, fut).await?;
89+
Ok(socket.into())
90+
}
91+
Host::Tcp(tcp) => {
92+
let fut = TcpStream::connect((tcp.as_str(), port));
93+
let socket = timeout(dur, fut).await?;
94+
socket.set_nodelay(true)?;
95+
Ok(socket.into())
96+
}
97+
}
98+
}
99+
100+
async fn timeout<F, T>(dur: Option<&Duration>, fut: F) -> io::Result<T>
101+
where
102+
F: Future<Output = io::Result<T>>,
103+
{
104+
if let Some(timeout) = dur {
105+
io::timeout(timeout.clone(), fut).await
106+
} else {
107+
fut.await
108+
}
109+
}

src/lib.rs

Lines changed: 3 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
33
#![warn(missing_docs)]
44

5+
pub use connect::connect_tls;
56
pub use socket::Socket;
67
pub use tokio_postgres::*;
78

8-
use socket::connect_socket;
99
use std::io;
10-
use tokio_postgres::tls::{NoTls, NoTlsStream, TlsConnect};
10+
use tokio_postgres::tls::{NoTls, NoTlsStream};
1111
use tokio_postgres::{Client, Connection};
1212

1313
/// Connect to postgres server.
@@ -35,68 +35,5 @@ pub async fn connect(
3535
connect_tls(config, NoTls).await
3636
}
3737

38-
/// Connect to postgres server with a tls connector.
39-
///
40-
/// ```rust
41-
/// use async_postgres::connect;
42-
///
43-
/// use std::error::Error;
44-
/// use async_std::task::spawn;
45-
///
46-
/// async fn play() -> Result<(), Box<dyn Error>> {
47-
/// let url = "host=localhost user=postgres";
48-
/// let (client, conn) = connect(url.parse()?).await?;
49-
/// spawn(conn);
50-
/// let row = client.query_one("SELECT * FROM user WHERE id=$1", &[&0]).await?;
51-
/// let value: &str = row.get(0);
52-
/// println!("value: {}", value);
53-
/// Ok(())
54-
/// }
55-
/// ```
56-
#[inline]
57-
pub async fn connect_tls<T>(
58-
config: Config,
59-
tls: T,
60-
) -> io::Result<(Client, Connection<Socket, T::Stream>)>
61-
where
62-
T: TlsConnect<Socket>,
63-
{
64-
let stream = connect_socket(&config).await?;
65-
connect_raw(stream, config, tls).await
66-
}
67-
68-
/// Connect to postgres server with a tls connector.
69-
///
70-
/// ```rust
71-
/// use async_postgres::connect;
72-
///
73-
/// use std::error::Error;
74-
/// use async_std::task::spawn;
75-
///
76-
/// async fn play() -> Result<(), Box<dyn Error>> {
77-
/// let url = "host=localhost user=postgres";
78-
/// let (client, conn) = connect(url.parse()?).await?;
79-
/// spawn(conn);
80-
/// let row = client.query_one("SELECT * FROM user WHERE id=$1", &[&0]).await?;
81-
/// let value: &str = row.get(0);
82-
/// println!("value: {}", value);
83-
/// Ok(())
84-
/// }
85-
/// ```
86-
#[inline]
87-
pub async fn connect_raw<S, T>(
88-
stream: S,
89-
config: Config,
90-
tls: T,
91-
) -> io::Result<(Client, Connection<Socket, T::Stream>)>
92-
where
93-
S: Into<Socket>,
94-
T: TlsConnect<Socket>,
95-
{
96-
config
97-
.connect_raw(stream.into(), tls)
98-
.await
99-
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))
100-
}
101-
38+
mod connect;
10239
mod socket;

src/socket.rs

Lines changed: 0 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,8 @@
1-
#[cfg(unix)]
2-
use async_std::os::unix::net::UnixStream;
3-
41
use async_std::io::{self, Read, Write};
5-
use async_std::net::TcpStream;
6-
use std::future::Future;
72
use std::mem::MaybeUninit;
83
use std::pin::Pin;
94
use std::task::{Context, Poll};
10-
use std::time::Duration;
115
use tokio::io::{AsyncRead, AsyncWrite};
12-
use tokio_postgres::config::{Config, Host};
13-
14-
/// Default socket port of postgres.
15-
const DEFAULT_PORT: u16 = 5432;
166

177
/// A alias for 'static + Unpin + Send + Read + Write
188
pub trait AsyncReadWriter: 'static + Unpin + Send + Read + Write {}
@@ -73,49 +63,3 @@ impl AsyncWrite for Socket {
7363
Pin::new(&mut self.0).poll_close(cx)
7464
}
7565
}
76-
77-
/// Establish connection to postgres server by AsyncStream.
78-
///
79-
///
80-
#[inline]
81-
pub async fn connect_socket(config: &Config) -> io::Result<Socket> {
82-
let mut error = io::Error::new(io::ErrorKind::Other, "host missing");
83-
let mut ports = config.get_ports().iter().cloned();
84-
for host in config.get_hosts() {
85-
let port = ports.next().unwrap_or(DEFAULT_PORT);
86-
let dur = config.get_connect_timeout();
87-
let result = match host {
88-
#[cfg(unix)]
89-
Host::Unix(path) => {
90-
let sock = path.join(format!(".s.PGSQL.{}", port));
91-
let fut = UnixStream::connect(sock);
92-
timeout(dur, fut).await.map(Into::into)
93-
}
94-
Host::Tcp(tcp) => {
95-
let fut = TcpStream::connect((tcp.as_str(), port));
96-
timeout(dur, fut).await.map(Into::into)
97-
}
98-
#[cfg(not(unix))]
99-
Host::Unix(_) => {
100-
io::Error::new(io::ErrorKind::Other, "unix domain socket is unsupported")
101-
}
102-
};
103-
104-
match result {
105-
Err(err) => error = err,
106-
stream => return stream,
107-
}
108-
}
109-
Err(error)
110-
}
111-
112-
async fn timeout<F, T>(dur: Option<&Duration>, fut: F) -> io::Result<T>
113-
where
114-
F: Future<Output = io::Result<T>>,
115-
{
116-
if let Some(timeout) = dur {
117-
io::timeout(timeout.clone(), fut).await
118-
} else {
119-
fut.await
120-
}
121-
}

0 commit comments

Comments
 (0)