Skip to content

feat: add customizable headers in proxy mode #2600

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
15 changes: 15 additions & 0 deletions src/async_impl/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2225,6 +2225,7 @@ impl Client {
};

self.proxy_auth(&uri, &mut headers);
self.proxy_custom_headers(&uri, &mut headers);

let builder = hyper::Request::builder()
.method(method.clone())
Expand Down Expand Up @@ -2303,6 +2304,20 @@ impl Client {
}
}
}

fn proxy_custom_headers(&self, dst: &Uri, headers: &mut HeaderMap) {
for proxy in self.inner.proxies.iter() {
if proxy.is_match(dst) {
if let Some(iter) = proxy.http_custom_headers(dst) {
iter.iter().for_each(|(key, value)| {
headers.insert(key, value.clone());
});
}

break;
}
}
}
}

impl fmt::Debug for Client {
Expand Down
35 changes: 27 additions & 8 deletions src/connect.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#[cfg(feature = "__tls")]
use http::header::HeaderValue;
use http::uri::{Authority, Scheme};
#[cfg(feature = "__tls")]
use http::HeaderMap;
use http::Uri;
use hyper::rt::{Read, ReadBufCursor, Write};
use hyper_util::client::legacy::connect::{Connected, Connection};
Expand Down Expand Up @@ -500,9 +502,9 @@ impl ConnectorService {
) -> Result<Conn, BoxError> {
log::debug!("proxy({proxy_scheme:?}) intercepts '{dst:?}'");

let (proxy_dst, _auth) = match proxy_scheme {
ProxyScheme::Http { host, auth } => (into_uri(Scheme::HTTP, host), auth),
ProxyScheme::Https { host, auth } => (into_uri(Scheme::HTTPS, host), auth),
let (proxy_dst, _auth, _misc) = match proxy_scheme {
ProxyScheme::Http { host, auth, misc } => (into_uri(Scheme::HTTP, host), auth, misc),
ProxyScheme::Https { host, auth, misc } => (into_uri(Scheme::HTTPS, host), auth, misc),
#[cfg(feature = "socks")]
ProxyScheme::Socks4 { .. } => return self.connect_socks(dst, proxy_scheme).await,
#[cfg(feature = "socks")]
Expand All @@ -512,6 +514,9 @@ impl ConnectorService {
#[cfg(feature = "__tls")]
let auth = _auth;

#[cfg(feature = "__tls")]
let misc = _misc;

match &self.inner {
#[cfg(feature = "default-tls")]
Inner::DefaultTls(http, tls) => {
Expand All @@ -529,6 +534,7 @@ impl ConnectorService {
port,
self.user_agent.clone(),
auth,
misc,
)
.await?;
let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
Expand Down Expand Up @@ -564,7 +570,8 @@ impl ConnectorService {
log::trace!("tunneling HTTPS over proxy");
let maybe_server_name = ServerName::try_from(host.as_str().to_owned())
.map_err(|_| "Invalid Server Name");
let tunneled = tunnel(conn, host, port, self.user_agent.clone(), auth).await?;
let tunneled =
tunnel(conn, host, port, self.user_agent.clone(), auth, misc).await?;
let server_name = maybe_server_name?;
let io = RustlsConnector::from(tls)
.connect(server_name, TokioIo::new(tunneled))
Expand Down Expand Up @@ -851,6 +858,7 @@ async fn tunnel<T>(
port: u16,
user_agent: Option<HeaderValue>,
auth: Option<HeaderValue>,
misc: Option<HeaderMap>,
) -> Result<T, BoxError>
where
T: Read + Write + Unpin,
Expand Down Expand Up @@ -881,6 +889,16 @@ where
buf.extend_from_slice(b"\r\n");
}

if let Some(headers) = misc {
log::debug!("tunnel to {host}:{port} using headers");
headers.iter().for_each(|(key, value)| {
buf.extend_from_slice(key.as_str().as_bytes());
buf.extend_from_slice(b": ");
buf.extend_from_slice(value.as_bytes());
buf.extend_from_slice(b"\r\n");
});
}

// headers end
buf.extend_from_slice(b"\r\n");

Expand Down Expand Up @@ -1476,7 +1494,7 @@ mod tests {
let tcp = TokioIo::new(TcpStream::connect(&addr).await?);
let host = addr.ip().to_string();
let port = addr.port();
tunnel(tcp, host, port, ua(), None).await
tunnel(tcp, host, port, ua(), None, None).await
};

rt.block_on(f).unwrap();
Expand All @@ -1494,7 +1512,7 @@ mod tests {
let tcp = TokioIo::new(TcpStream::connect(&addr).await?);
let host = addr.ip().to_string();
let port = addr.port();
tunnel(tcp, host, port, ua(), None).await
tunnel(tcp, host, port, ua(), None, None).await
};

rt.block_on(f).unwrap_err();
Expand All @@ -1512,7 +1530,7 @@ mod tests {
let tcp = TokioIo::new(TcpStream::connect(&addr).await?);
let host = addr.ip().to_string();
let port = addr.port();
tunnel(tcp, host, port, ua(), None).await
tunnel(tcp, host, port, ua(), None, None).await
};

rt.block_on(f).unwrap_err();
Expand All @@ -1536,7 +1554,7 @@ mod tests {
let tcp = TokioIo::new(TcpStream::connect(&addr).await?);
let host = addr.ip().to_string();
let port = addr.port();
tunnel(tcp, host, port, ua(), None).await
tunnel(tcp, host, port, ua(), None, None).await
};

let error = rt.block_on(f).unwrap_err();
Expand Down Expand Up @@ -1564,6 +1582,7 @@ mod tests {
port,
ua(),
Some(proxy::encode_basic_auth("Aladdin", "open sesame")),
None,
)
.await
};
Expand Down
Loading
Loading