diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index 02266c31e..b0c4b639e 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -2225,6 +2225,7 @@ impl Client { }; self.proxy_auth(&uri, &mut headers); + self.proxy_custom_headers(&uri, &mut headers); let builder = hyper::Request::builder() .method(method.clone()) @@ -2303,6 +2304,20 @@ impl Client { } } } + + fn proxy_custom_headers(&self, dst: &Uri, headers: &mut HeaderMap) { + for proxy in self.inner.proxies.iter() { + if proxy.is_match(dst) { + if let Some(iter) = proxy.http_custom_headers(dst) { + iter.iter().for_each(|(key, value)| { + headers.insert(key, value.clone()); + }); + } + + break; + } + } + } } impl fmt::Debug for Client { diff --git a/src/connect.rs b/src/connect.rs index 609982b50..0a5ad6346 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -1,6 +1,8 @@ #[cfg(feature = "__tls")] use http::header::HeaderValue; use http::uri::{Authority, Scheme}; +#[cfg(feature = "__tls")] +use http::HeaderMap; use http::Uri; use hyper::rt::{Read, ReadBufCursor, Write}; use hyper_util::client::legacy::connect::{Connected, Connection}; @@ -500,9 +502,9 @@ impl ConnectorService { ) -> Result { log::debug!("proxy({proxy_scheme:?}) intercepts '{dst:?}'"); - let (proxy_dst, _auth) = match proxy_scheme { - ProxyScheme::Http { host, auth } => (into_uri(Scheme::HTTP, host), auth), - ProxyScheme::Https { host, auth } => (into_uri(Scheme::HTTPS, host), auth), + let (proxy_dst, _auth, _misc) = match proxy_scheme { + ProxyScheme::Http { host, auth, misc } => (into_uri(Scheme::HTTP, host), auth, misc), + ProxyScheme::Https { host, auth, misc } => (into_uri(Scheme::HTTPS, host), auth, misc), #[cfg(feature = "socks")] ProxyScheme::Socks4 { .. } => return self.connect_socks(dst, proxy_scheme).await, #[cfg(feature = "socks")] @@ -512,6 +514,9 @@ impl ConnectorService { #[cfg(feature = "__tls")] let auth = _auth; + #[cfg(feature = "__tls")] + let misc = _misc; + match &self.inner { #[cfg(feature = "default-tls")] Inner::DefaultTls(http, tls) => { @@ -529,6 +534,7 @@ impl ConnectorService { port, self.user_agent.clone(), auth, + misc, ) .await?; let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone()); @@ -564,7 +570,8 @@ impl ConnectorService { log::trace!("tunneling HTTPS over proxy"); let maybe_server_name = ServerName::try_from(host.as_str().to_owned()) .map_err(|_| "Invalid Server Name"); - let tunneled = tunnel(conn, host, port, self.user_agent.clone(), auth).await?; + let tunneled = + tunnel(conn, host, port, self.user_agent.clone(), auth, misc).await?; let server_name = maybe_server_name?; let io = RustlsConnector::from(tls) .connect(server_name, TokioIo::new(tunneled)) @@ -851,6 +858,7 @@ async fn tunnel( port: u16, user_agent: Option, auth: Option, + misc: Option, ) -> Result where T: Read + Write + Unpin, @@ -881,6 +889,16 @@ where buf.extend_from_slice(b"\r\n"); } + if let Some(headers) = misc { + log::debug!("tunnel to {host}:{port} using headers"); + headers.iter().for_each(|(key, value)| { + buf.extend_from_slice(key.as_str().as_bytes()); + buf.extend_from_slice(b": "); + buf.extend_from_slice(value.as_bytes()); + buf.extend_from_slice(b"\r\n"); + }); + } + // headers end buf.extend_from_slice(b"\r\n"); @@ -1476,7 +1494,7 @@ mod tests { let tcp = TokioIo::new(TcpStream::connect(&addr).await?); let host = addr.ip().to_string(); let port = addr.port(); - tunnel(tcp, host, port, ua(), None).await + tunnel(tcp, host, port, ua(), None, None).await }; rt.block_on(f).unwrap(); @@ -1494,7 +1512,7 @@ mod tests { let tcp = TokioIo::new(TcpStream::connect(&addr).await?); let host = addr.ip().to_string(); let port = addr.port(); - tunnel(tcp, host, port, ua(), None).await + tunnel(tcp, host, port, ua(), None, None).await }; rt.block_on(f).unwrap_err(); @@ -1512,7 +1530,7 @@ mod tests { let tcp = TokioIo::new(TcpStream::connect(&addr).await?); let host = addr.ip().to_string(); let port = addr.port(); - tunnel(tcp, host, port, ua(), None).await + tunnel(tcp, host, port, ua(), None, None).await }; rt.block_on(f).unwrap_err(); @@ -1536,7 +1554,7 @@ mod tests { let tcp = TokioIo::new(TcpStream::connect(&addr).await?); let host = addr.ip().to_string(); let port = addr.port(); - tunnel(tcp, host, port, ua(), None).await + tunnel(tcp, host, port, ua(), None, None).await }; let error = rt.block_on(f).unwrap_err(); @@ -1564,6 +1582,7 @@ mod tests { port, ua(), Some(proxy::encode_basic_auth("Aladdin", "open sesame")), + None, ) .await }; diff --git a/src/proxy.rs b/src/proxy.rs index 5dc053d42..40e92e634 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use crate::into_url::{IntoUrl, IntoUrlSealed}; use crate::Url; -use http::{header::HeaderValue, Uri}; +use http::{header::HeaderValue, HeaderMap, Uri}; use ipnet::IpNet; use percent_encoding::percent_decode; use std::collections::HashMap; @@ -102,10 +102,12 @@ pub enum ProxyScheme { Http { auth: Option, host: http::uri::Authority, + misc: Option, }, Https { auth: Option, host: http::uri::Authority, + misc: Option, }, #[cfg(feature = "socks")] Socks4 { addr: SocketAddr, remote_dns: bool }, @@ -125,6 +127,14 @@ impl ProxyScheme { _ => None, } } + + fn maybe_http_custom_headers(&self) -> Option<&HeaderMap> { + match self { + ProxyScheme::Http { misc, .. } | ProxyScheme::Https { misc, .. } => misc.as_ref(), + #[cfg(feature = "socks")] + _ => None, + } + } } /// Trait used for converting into a proxy scheme. This trait supports @@ -360,6 +370,26 @@ impl Proxy { self } + /// Adds custom headers to this Proxy + /// + /// # Example + /// ``` + /// # extern crate reqwest; + /// # use reqwest::header::*; + /// # fn run() -> Result<(), Box> { + /// let mut headers = HeaderMap::new(); + /// headers.insert(USER_AGENT, "reqwest".parse().unwrap()); + /// let proxy = reqwest::Proxy::https("http://localhost:1234")? + /// .headers(headers); + /// # Ok(()) + /// # } + /// # fn main() {} + /// ``` + pub fn headers(mut self, headers: HeaderMap) -> Proxy { + self.intercept.set_custom_headers(headers); + self + } + pub(crate) fn maybe_has_http_auth(&self) -> bool { match &self.intercept { Intercept::All(p) | Intercept::Http(p) => p.maybe_http_auth().is_some(), @@ -386,6 +416,19 @@ impl Proxy { } } + pub(crate) fn http_custom_headers(&self, uri: &D) -> Option { + match &self.intercept { + Intercept::All(p) | Intercept::Http(p) => p.maybe_http_custom_headers().cloned(), + Intercept::System(system) => system + .get("http") + .and_then(|s| s.maybe_http_custom_headers().cloned()), + Intercept::Custom(custom) => custom + .call(uri) + .and_then(|s| s.maybe_http_custom_headers().cloned()), + Intercept::Https(_) => None, + } + } + pub(crate) fn intercept(&self, uri: &D) -> Option { let in_no_proxy = self .no_proxy @@ -578,6 +621,7 @@ impl ProxyScheme { Ok(ProxyScheme::Http { auth: None, host: host.parse().map_err(crate::error::builder)?, + misc: None, }) } @@ -586,6 +630,7 @@ impl ProxyScheme { Ok(ProxyScheme::Https { auth: None, host: host.parse().map_err(crate::error::builder)?, + misc: None, }) } @@ -697,6 +742,25 @@ impl ProxyScheme { } } + fn set_custom_headers(&mut self, headers: HeaderMap) { + match *self { + ProxyScheme::Http { ref mut misc, .. } => { + misc.get_or_insert_with(HeaderMap::new).extend(headers) + } + ProxyScheme::Https { ref mut misc, .. } => { + misc.get_or_insert_with(HeaderMap::new).extend(headers) + } + #[cfg(feature = "socks")] + ProxyScheme::Socks4 { .. } => { + panic!("Socks4 is not supported for this method") + } + #[cfg(feature = "socks")] + ProxyScheme::Socks5 { .. } => { + panic!("Socks5 is not supported for this method") + } + } + } + fn if_no_auth(mut self, update: &Option) -> Self { match self { ProxyScheme::Http { ref mut auth, .. } => { @@ -791,8 +855,16 @@ impl ProxyScheme { impl fmt::Debug for ProxyScheme { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - ProxyScheme::Http { auth: _auth, host } => write!(f, "http://{host}"), - ProxyScheme::Https { auth: _auth, host } => write!(f, "https://{host}"), + ProxyScheme::Http { + auth: _auth, + host, + misc: _misc, + } => write!(f, "http://{host}"), + ProxyScheme::Https { + auth: _auth, + host, + misc: _misc, + } => write!(f, "https://{host}"), #[cfg(feature = "socks")] ProxyScheme::Socks4 { addr, remote_dns } => { let h = if *remote_dns { "a" } else { "" }; @@ -847,6 +919,16 @@ impl Intercept { } } } + + fn set_custom_headers(&mut self, headers: HeaderMap) { + match self { + Intercept::All(ref mut s) + | Intercept::Http(ref mut s) + | Intercept::Https(ref mut s) => s.set_custom_headers(headers), + Intercept::System(_) => unimplemented!(), + Intercept::Custom(_) => unimplemented!(), + } + } } #[derive(Clone)] @@ -1277,9 +1359,10 @@ mod tests { let ps = "http://foo:bar@localhost:1239".into_proxy_scheme().unwrap(); match ps { - ProxyScheme::Http { auth, host } => { + ProxyScheme::Http { auth, host, misc } => { assert_eq!(auth.unwrap(), encode_basic_auth("foo", "bar")); assert_eq!(host, "localhost:1239"); + assert_eq!(format!("{:?}", misc), "None"); } other => panic!("unexpected: {other:?}"), } @@ -1290,9 +1373,10 @@ mod tests { let ps = "192.168.1.1:8888".into_proxy_scheme().unwrap(); match ps { - ProxyScheme::Http { auth, host } => { + ProxyScheme::Http { auth, host, misc } => { assert!(auth.is_none()); assert_eq!(host, "192.168.1.1:8888"); + assert_eq!(format!("{:?}", misc), "None"); } other => panic!("unexpected: {other:?}"), } @@ -1304,9 +1388,10 @@ mod tests { let ps = "foo:bar@localhost:1239".into_proxy_scheme().unwrap(); match ps { - ProxyScheme::Http { auth, host } => { + ProxyScheme::Http { auth, host, misc } => { assert_eq!(auth.unwrap(), encode_basic_auth("foo", "bar")); assert_eq!(host, "localhost:1239"); + assert_eq!(format!("{:?}", misc), "None"); } other => panic!("unexpected: {other:?}"), } @@ -1746,6 +1831,7 @@ mod tests { intercept: Intercept::Http(ProxyScheme::Http { auth: Some(HeaderValue::from_static("auth1")), host: http::uri::Authority::from_static("authority"), + misc: None, }), no_proxy: None, }; @@ -1759,6 +1845,7 @@ mod tests { intercept: Intercept::Http(ProxyScheme::Http { auth: None, host: http::uri::Authority::from_static("authority"), + misc: None, }), no_proxy: None, }; @@ -1772,6 +1859,7 @@ mod tests { intercept: Intercept::Http(ProxyScheme::Https { auth: Some(HeaderValue::from_static("auth2")), host: http::uri::Authority::from_static("authority"), + misc: None, }), no_proxy: None, }; @@ -1785,6 +1873,7 @@ mod tests { intercept: Intercept::All(ProxyScheme::Http { auth: Some(HeaderValue::from_static("auth3")), host: http::uri::Authority::from_static("authority"), + misc: None, }), no_proxy: None, }; @@ -1798,6 +1887,7 @@ mod tests { intercept: Intercept::All(ProxyScheme::Https { auth: Some(HeaderValue::from_static("auth4")), host: http::uri::Authority::from_static("authority"), + misc: None, }), no_proxy: None, }; @@ -1811,6 +1901,7 @@ mod tests { intercept: Intercept::All(ProxyScheme::Https { auth: None, host: http::uri::Authority::from_static("authority"), + misc: None, }), no_proxy: None, }; @@ -1828,6 +1919,7 @@ mod tests { ProxyScheme::Http { auth: Some(HeaderValue::from_static("auth5")), host: http::uri::Authority::from_static("authority"), + misc: None, }, ); m @@ -1848,6 +1940,7 @@ mod tests { ProxyScheme::Https { auth: Some(HeaderValue::from_static("auth6")), host: http::uri::Authority::from_static("authority"), + misc: None, }, ); m diff --git a/tests/proxy.rs b/tests/proxy.rs index cc32f04dd..0fe91d980 100644 --- a/tests/proxy.rs +++ b/tests/proxy.rs @@ -172,6 +172,41 @@ async fn test_no_proxy() { assert_eq!(res.status(), reqwest::StatusCode::OK); } +#[tokio::test] +async fn test_custom_headers() { + let url = "http://hyper.rs/prox"; + let server = server::http(move |req| { + assert_eq!(req.method(), "GET"); + assert_eq!(req.uri(), url); + assert_eq!(req.headers()["host"], "hyper.rs"); + assert_eq!( + req.headers()["proxy-authorization"], + "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==" + ); + async { http::Response::default() } + }); + + let proxy = format!("http://{}", server.addr()); + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + // reqwest::header::HeaderName::from_static("Proxy-Authorization"), + reqwest::header::PROXY_AUTHORIZATION, + "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==".parse().unwrap(), + ); + + let res = reqwest::Client::builder() + .proxy(reqwest::Proxy::http(&proxy).unwrap().headers(headers)) + .build() + .unwrap() + .get(url) + .send() + .await + .unwrap(); + + assert_eq!(res.url().as_str(), url); + assert_eq!(res.status(), reqwest::StatusCode::OK); +} + #[tokio::test] async fn test_using_system_proxy() { let url = "http://not.a.real.sub.hyper.rs/prox";