From 4a0bf065d98c78670b3f0d1cf3cc925ebd30ecc0 Mon Sep 17 00:00:00 2001 From: Mrinal Paliwal Date: Fri, 20 May 2022 16:07:08 +0530 Subject: [PATCH] feat(client): add option to HttpConnector to enable or disable HttpInfo Add `http_info` configuration option in `Config`. Change return type of `HttpConnector` from `TcpStream` to `HttpConnection`. This gives `TcpStream` access to `Config` for setting `HttpInfo` value. BREAKING CHANGE: HttpConnector returns a `HttpConnection` instead of `TcpStream`. --- src/client/connect/http.rs | 99 +++++++++++++++++++++++++++++++++----- src/client/connect/mod.rs | 2 +- src/client/tests.rs | 4 +- tests/client.rs | 5 +- 4 files changed, 92 insertions(+), 18 deletions(-) diff --git a/src/client/connect/http.rs b/src/client/connect/http.rs index afe7b155eb..39022aad6c 100644 --- a/src/client/connect/http.rs +++ b/src/client/connect/http.rs @@ -6,14 +6,16 @@ use std::marker::PhantomData; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::pin::Pin; use std::sync::Arc; -use std::task::{self, Poll}; +use std::task::{self, Poll, Context}; use std::time::Duration; +use std::ops::{Deref, DerefMut}; use futures_util::future::Either; use http::uri::{Scheme, Uri}; use pin_project_lite::pin_project; use tokio::net::{TcpSocket, TcpStream}; use tokio::time::Sleep; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tracing::{debug, trace, warn}; use super::dns::{self, resolve, GaiResolver, Resolve}; @@ -35,6 +37,12 @@ pub struct HttpConnector { resolver: R, } +/// Connection returned by `HttpConnector`. +pub struct HttpConnection { + inner: TcpStream, + config: Arc, +} + /// Extra information about the transport when an HttpConnector is used. /// /// # Example @@ -81,6 +89,7 @@ struct Config { reuse_address: bool, send_buffer_size: Option, recv_buffer_size: Option, + http_info: bool, } // ===== impl HttpConnector ===== @@ -121,6 +130,7 @@ impl HttpConnector { reuse_address: false, send_buffer_size: None, recv_buffer_size: None, + http_info: true, }), resolver, } @@ -164,6 +174,14 @@ impl HttpConnector { self.config_mut().recv_buffer_size = size; } + /// Set if `HttpInfo` is enabled or disabled in connection metadata. + /// + /// Default is `true`. + #[inline] + pub fn set_httpinfo(&mut self, httpinfo: bool) { + self.config_mut().http_info = httpinfo; + } + /// Set that all sockets are bound to the configured address before connection. /// /// If `None`, the sockets will not be bound. @@ -256,7 +274,7 @@ where R: Resolve + Clone + Send + Sync + 'static, R::Future: Send, { - type Response = TcpStream; + type Response = HttpConnection; type Error = ConnectError; type Future = HttpConnecting; @@ -323,7 +341,7 @@ impl HttpConnector where R: Resolve, { - async fn call_async(&mut self, dst: Uri) -> Result { + async fn call_async(&mut self, dst: Uri) -> Result { let config = &self.config; let (host, port) = get_host_port(config, &dst)?; @@ -340,7 +358,7 @@ where let addrs = addrs .map(|mut addr| { addr.set_port(port); - addr + addr }) .collect(); dns::SocketAddrs::new(addrs) @@ -354,18 +372,74 @@ where warn!("tcp set_nodelay error: {}", e); } - Ok(sock) + Ok(HttpConnection{inner:sock, config: config.clone()}) + } +} + +impl AsyncWrite for HttpConnection { + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.inner).poll_shutdown(cx) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.inner).poll_write(cx, buf) + } +} + +impl AsyncRead for HttpConnection { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.inner).poll_read(cx, buf) + } +} + +impl fmt::Debug for HttpConnection { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("HttpConnection").finish() + } +} + +impl Deref for HttpConnection { + type Target = TcpStream; + fn deref(&self) -> &TcpStream { + &self.inner } } -impl Connection for TcpStream { +impl DerefMut for HttpConnection { + fn deref_mut(&mut self) -> &mut TcpStream { + &mut self.inner + } +} + +impl Connection for HttpConnection { fn connected(&self) -> Connected { - let connected = Connected::new(); - if let (Ok(remote_addr), Ok(local_addr)) = (self.peer_addr(), self.local_addr()) { - connected.extra(HttpInfo { remote_addr, local_addr }) - } else { - connected + let mut connected = Connected::new(); + + if self.config.http_info { + if let (Ok(remote_addr), Ok(local_addr)) = (self.inner.peer_addr(), self.inner.local_addr()) { + connected = connected.extra(HttpInfo { remote_addr, local_addr }); + } } + + connected } } @@ -396,7 +470,7 @@ pin_project! { } } -type ConnectResult = Result; +type ConnectResult = Result; type BoxConnecting = Pin + Send>>; impl Future for HttpConnecting { @@ -942,6 +1016,7 @@ mod tests { enforce_http: false, send_buffer_size: None, recv_buffer_size: None, + http_info: true, }; let connecting_tcp = ConnectingTcp::new(dns::SocketAddrs::new(addrs), &cfg); let start = Instant::now(); diff --git a/src/client/connect/mod.rs b/src/client/connect/mod.rs index 862a0e65c1..5c04b9d9e2 100644 --- a/src/client/connect/mod.rs +++ b/src/client/connect/mod.rs @@ -86,7 +86,7 @@ use ::http::Extensions; cfg_feature! { #![feature = "tcp"] - pub use self::http::{HttpConnector, HttpInfo}; + pub use self::http::{HttpConnector, HttpInfo, HttpConnection}; pub mod dns; mod http; diff --git a/src/client/tests.rs b/src/client/tests.rs index 0a281a637d..f21894c30d 100644 --- a/src/client/tests.rs +++ b/src/client/tests.rs @@ -1,9 +1,9 @@ use std::io; use futures_util::future; -use tokio::net::TcpStream; use super::Client; +use super::connect::HttpConnection; #[tokio::test] async fn client_connect_uri_argument() { @@ -13,7 +13,7 @@ async fn client_connect_uri_argument() { assert_eq!(dst.port(), None); assert_eq!(dst.path(), "/", "path should be removed"); - future::err::(io::Error::new(io::ErrorKind::Other, "expect me")) + future::err::(io::Error::new(io::ErrorKind::Other, "expect me")) }); let client = Client::builder().build::<_, crate::Body>(connector); diff --git a/tests/client.rs b/tests/client.rs index 417e9bf2d9..0015c02ed9 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -1120,11 +1120,10 @@ mod dispatch_impl { use futures_util::stream::StreamExt; use http::Uri; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; - use tokio::net::TcpStream; use super::support; use hyper::body::HttpBody; - use hyper::client::connect::{Connected, Connection, HttpConnector}; + use hyper::client::connect::{Connected, Connection, HttpConnection, HttpConnector}; use hyper::Client; #[test] @@ -2090,7 +2089,7 @@ mod dispatch_impl { } struct DebugStream { - tcp: TcpStream, + tcp: HttpConnection, on_drop: mpsc::Sender<()>, is_alpn_h2: bool, is_proxy: bool,