diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index e00cdbb8e..727ba384d 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -901,6 +901,8 @@ impl ClientBuilder { } let proxies_maybe_http_auth = proxies.iter().any(|p| p.maybe_has_http_auth()); + let proxies_maybe_http_custom_headers = + proxies.iter().any(|p| p.maybe_has_http_custom_headers()); Ok(Client { inner: Arc::new(ClientRef { @@ -924,6 +926,7 @@ impl ClientBuilder { request_timeout: RequestConfig::new(config.timeout), proxies, proxies_maybe_http_auth, + proxies_maybe_http_custom_headers, https_only: config.https_only, }), }) @@ -2375,6 +2378,7 @@ impl Client { }; self.proxy_auth(&uri, &mut headers); + self.proxy_custom_headers(&uri, &mut headers); let builder = hyper::Request::builder() .method(method.clone()) @@ -2454,6 +2458,26 @@ impl Client { break; } } + + fn proxy_custom_headers(&self, dst: &Uri, headers: &mut HeaderMap) { + if !self.inner.proxies_maybe_http_custom_headers { + return; + } + + if dst.scheme() != Some(&Scheme::HTTP) { + return; + } + + for proxy in self.inner.proxies.iter() { + if let Some(iter) = proxy.http_non_tunnel_custom_headers(dst) { + iter.iter().for_each(|(key, value)| { + headers.insert(key, value.clone()); + }); + } + + break; + } + } } impl fmt::Debug for Client { @@ -2643,6 +2667,7 @@ struct ClientRef { read_timeout: Option, proxies: Arc>, proxies_maybe_http_auth: bool, + proxies_maybe_http_custom_headers: bool, https_only: bool, } diff --git a/src/connect.rs b/src/connect.rs index decf26df3..bf97e111f 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -628,6 +628,9 @@ impl ConnectorService { #[cfg(feature = "__tls")] let auth = proxy.basic_auth().cloned(); + #[cfg(feature = "__tls")] + let misc = proxy.custom_headers().clone(); + match &self.inner { #[cfg(feature = "default-tls")] Inner::DefaultTls(http, tls) => { @@ -647,6 +650,10 @@ impl ConnectorService { headers.insert(http::header::USER_AGENT, ua); tunnel = tunnel.with_headers(headers); } + // Note that custom headers may override the user agent header. + if let Some(custom_headers) = misc { + tunnel = tunnel.with_headers(custom_headers.clone()); + } // We don't wrap this again in an HttpsConnector since that uses Maybe, // and we know this is definitely HTTPS. let tunneled = tunnel.call(dst.clone()).await?; @@ -683,6 +690,9 @@ impl ConnectorService { if let Some(auth) = auth { tunnel = tunnel.with_auth(auth); } + if let Some(custom_headers) = misc { + tunnel = tunnel.with_headers(custom_headers.clone()); + } if let Some(ua) = self.user_agent { let mut headers = http::HeaderMap::new(); headers.insert(http::header::USER_AGENT, ua); diff --git a/src/proxy.rs b/src/proxy.rs index a9126b162..a467436d1 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -2,7 +2,7 @@ use std::error::Error; use std::fmt; use std::sync::Arc; -use http::{header::HeaderValue, Uri}; +use http::{header::HeaderValue, HeaderMap, Uri}; use hyper_util::client::proxy::matcher; use crate::into_url::{IntoUrl, IntoUrlSealed}; @@ -67,6 +67,7 @@ pub struct NoProxy { #[derive(Clone)] struct Extra { auth: Option, + misc: Option, } // ===== Internal ===== @@ -75,6 +76,7 @@ pub(crate) struct Matcher { inner: Matcher_, extra: Extra, maybe_has_http_auth: bool, + maybe_has_http_custom_headers: bool, } enum Matcher_ { @@ -100,6 +102,14 @@ impl ProxyScheme { _ => None, } } + + fn maybe_http_custom_headers(&self) -> Option<&HeaderMap> { + match self { + ProxyScheme::Http { misc, .. } | ProxyScheme::Https { misc, .. } => misc.as_ref(), + #[cfg(feature = "socks")] + _ => None, + } + } } */ @@ -245,7 +255,10 @@ impl Proxy { fn new(intercept: Intercept) -> Proxy { Proxy { - extra: Extra { auth: None }, + extra: Extra { + auth: None, + misc: None, + }, intercept, no_proxy: None, } @@ -297,6 +310,32 @@ impl Proxy { self } + /// Adds a Custom Headers to Proxy + /// Adds custom headers to this Proxy + /// + /// # Example + /// ``` + /// # extern crate reqwest; + /// # use reqwest::header::*; + /// # fn run() -> Result<(), Box> { + /// let mut headers = HeaderMap::new(); + /// headers.insert(USER_AGENT, "reqwest".parse().unwrap()); + /// let proxy = reqwest::Proxy::https("http://localhost:1234")? + /// .headers(headers); + /// # Ok(()) + /// # } + /// # fn main() {} + /// ``` + pub fn headers(mut self, headers: HeaderMap) -> Proxy { + match self.intercept { + Intercept::All(_) | Intercept::Http(_) | Intercept::Https(_) | Intercept::Custom(_) => { + self.extra.misc = Some(headers); + } + } + + self + } + /// Adds a `No Proxy` exclusion list to this Proxy /// /// # Example @@ -323,10 +362,13 @@ impl Proxy { } = self; let maybe_has_http_auth; + let maybe_has_http_custom_headers; let inner = match intercept { Intercept::All(url) => { maybe_has_http_auth = cache_maybe_has_http_auth(&url, &extra.auth); + maybe_has_http_custom_headers = + cache_maybe_has_http_custom_headers(&url, &extra.misc); Matcher_::Util( matcher::Matcher::builder() .all(String::from(url)) @@ -336,6 +378,8 @@ impl Proxy { } Intercept::Http(url) => { maybe_has_http_auth = cache_maybe_has_http_auth(&url, &extra.auth); + maybe_has_http_custom_headers = + cache_maybe_has_http_custom_headers(&url, &extra.misc); Matcher_::Util( matcher::Matcher::builder() .http(String::from(url)) @@ -345,6 +389,8 @@ impl Proxy { } Intercept::Https(url) => { maybe_has_http_auth = cache_maybe_has_http_auth(&url, &extra.auth); + maybe_has_http_custom_headers = + cache_maybe_has_http_custom_headers(&url, &extra.misc); Matcher_::Util( matcher::Matcher::builder() .https(String::from(url)) @@ -354,6 +400,7 @@ impl Proxy { } Intercept::Custom(mut custom) => { maybe_has_http_auth = true; // never know + maybe_has_http_custom_headers = true; custom.no_proxy = no_proxy; Matcher_::Custom(custom) } @@ -363,6 +410,7 @@ impl Proxy { inner, extra, maybe_has_http_auth, + maybe_has_http_custom_headers, } } @@ -399,6 +447,10 @@ fn cache_maybe_has_http_auth(url: &Url, extra: &Option) -> bool { url.scheme() == "http" && (url.password().is_some() || extra.is_some()) } +fn cache_maybe_has_http_custom_headers(url: &Url, extra: &Option) -> bool { + url.scheme() == "http" && extra.is_some() +} + impl fmt::Debug for Proxy { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_tuple("Proxy") @@ -453,9 +505,13 @@ impl Matcher { pub(crate) fn system() -> Self { Self { inner: Matcher_::Util(matcher::Matcher::from_system()), - extra: Extra { auth: None }, + extra: Extra { + auth: None, + misc: None, + }, // maybe env vars have auth! maybe_has_http_auth: true, + maybe_has_http_custom_headers: true, } } @@ -493,6 +549,20 @@ impl Matcher { None } + + pub(crate) fn maybe_has_http_custom_headers(&self) -> bool { + self.maybe_has_http_custom_headers + } + + pub(crate) fn http_non_tunnel_custom_headers(&self, dst: &Uri) -> Option { + if let Some(proxy) = self.intercept(dst) { + if proxy.uri().scheme_str() == Some("http") { + return proxy.custom_headers().cloned(); + } + } + + None + } } impl fmt::Debug for Matcher { @@ -516,6 +586,13 @@ impl Intercepted { self.inner.basic_auth() } + pub(crate) fn custom_headers(&self) -> Option<&HeaderMap> { + if let Some(ref val) = self.extra.misc { + return Some(val); + } + None + } + #[cfg(feature = "socks")] pub(crate) fn raw_auth(&self) -> Option<(&str, &str)> { self.inner.raw_auth() @@ -580,6 +657,25 @@ impl ProxyScheme { } } + fn set_custom_headers(&mut self, headers: HeaderMap) { + match *self { + ProxyScheme::Http { ref mut misc, .. } => { + misc.get_or_insert_with(HeaderMap::new).extend(headers) + } + ProxyScheme::Https { ref mut misc, .. } => { + misc.get_or_insert_with(HeaderMap::new).extend(headers) + } + #[cfg(feature = "socks")] + ProxyScheme::Socks4 { .. } => { + panic!("Socks4 is not supported for this method") + } + #[cfg(feature = "socks")] + ProxyScheme::Socks5 { .. } => { + panic!("Socks5 is not supported for this method") + } + } + } + fn if_no_auth(mut self, update: &Option) -> Self { match self { ProxyScheme::Http { ref mut auth, .. } => { diff --git a/tests/proxy.rs b/tests/proxy.rs index d0cb1b478..d6f15851f 100644 --- a/tests/proxy.rs +++ b/tests/proxy.rs @@ -172,6 +172,41 @@ async fn test_no_proxy() { assert_eq!(res.status(), reqwest::StatusCode::OK); } +#[tokio::test] +async fn test_custom_headers() { + let url = "http://hyper.rs.local/prox"; + let server = server::http(move |req| { + assert_eq!(req.method(), "GET"); + assert_eq!(req.uri(), url); + assert_eq!(req.headers()["host"], "hyper.rs.local"); + assert_eq!( + req.headers()["proxy-authorization"], + "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==" + ); + async { http::Response::default() } + }); + + let proxy = format!("http://{}", server.addr()); + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + // reqwest::header::HeaderName::from_static("Proxy-Authorization"), + reqwest::header::PROXY_AUTHORIZATION, + "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==".parse().unwrap(), + ); + + let res = reqwest::Client::builder() + .proxy(reqwest::Proxy::http(&proxy).unwrap().headers(headers)) + .build() + .unwrap() + .get(url) + .send() + .await + .unwrap(); + + assert_eq!(res.url().as_str(), url); + assert_eq!(res.status(), reqwest::StatusCode::OK); +} + #[tokio::test] async fn test_using_system_proxy() { let url = "http://not.a.real.sub.hyper.rs.local/prox";