diff --git a/src/client.rs b/src/client.rs index 0df1d357..0726f570 100644 --- a/src/client.rs +++ b/src/client.rs @@ -12,6 +12,7 @@ use crate::{ }; use futures_io::AsyncRead; use futures_util::{future::BoxFuture, pin_mut}; +use http::header::{HeaderName, HeaderValue}; use http::{Request, Response}; use lazy_static::lazy_static; use std::{ @@ -58,6 +59,8 @@ pub struct HttpClientBuilder { agent_builder: AgentBuilder, defaults: http::Extensions, middleware: Vec>, + default_headers: http::HeaderMap, + error: Option, } impl Default for HttpClientBuilder { @@ -84,6 +87,8 @@ impl HttpClientBuilder { agent_builder: AgentBuilder::default(), defaults, middleware: Vec::new(), + default_headers: http::HeaderMap::new(), + error: None, } } @@ -240,14 +245,72 @@ impl HttpClientBuilder { self.configure(map) } + /// Set a default header to be passed with every request + /// + /// NOTE: In case there is an error in parsing the HeaderName or HeaderValue + /// the tuple is silently discarded. + /// + /// # Examples + /// + /// ``` + /// # use isahc::prelude::*; + /// # + /// let client = HttpClient::builder() + /// .default_header("some-header", "some-value") + /// .build()?; + /// # Ok::<(), Box>(()) + /// ``` + pub fn default_header(mut self, key: K, value: V) -> Self + where + HeaderName: TryFrom, + HeaderValue: TryFrom, + >::Error: Into, + >::Error: Into, + { + match HeaderName::try_from(key) { + Ok(key) => match HeaderValue::try_from(value) { + Ok(value) => { + self.default_headers.append(key, value); + } + Err(e) => { + self.error = Some(e.into()); + } + }, + Err(e) => { + self.error = Some(e.into()); + } + } + self + } + + /// Get the underlying HeaderMap from the current client-builder. + /// + /// # Examples + /// + /// ``` + /// # use isahc::prelude::*; + /// # + /// let mut builder = HttpClient::builder() + /// .default_header("some-header", "some-value"); + /// let header_opt = builder.default_headers_mut(); + /// # Ok::<(), Box>(()) + /// ``` + pub fn default_headers_mut(&mut self) -> &mut http::HeaderMap { + &mut self.default_headers + } + /// Build an [`HttpClient`] using the configured options. /// /// If the client fails to initialize, an error will be returned. pub fn build(self) -> Result { + if let Some(err) = self.error { + return Err(err); + } Ok(HttpClient { agent: Arc::new(self.agent_builder.spawn()?), defaults: self.defaults, middleware: self.middleware, + default_headers: self.default_headers, }) } } @@ -343,6 +406,8 @@ pub struct HttpClient { defaults: http::Extensions, /// Any middleware implementations that requests should pass through. middleware: Vec>, + + default_headers: http::HeaderMap, } impl HttpClient { @@ -827,7 +892,17 @@ impl HttpClient { } } - // Set custom request headers. + // We are checking here if header already contains the key, simply ignore it. + // In case the key wasn't present in parts.headers ensure that + // we have all the headers from default headers. + for name in self.default_headers.keys() { + if !parts.headers.contains_key(name) { + for v in self.default_headers.get_all(name).iter() { + parts.headers.append(name, v.clone()); + } + } + } + parts.headers.set_opt(&mut easy)?; Ok((easy, future)) @@ -889,4 +964,33 @@ mod tests { static_assertions::assert_impl_all!(HttpClient: Send, Sync); static_assertions::assert_impl_all!(HttpClientBuilder: Send); + + #[test] + fn test_default_header() { + let client = HttpClientBuilder::new() + .default_header("some-key", "some-value") + .build(); + match client { + Ok(_) => assert!(true), + Err(_) => assert!(false), + } + } + + #[test] + fn test_default_headers_mut() { + let mut builder = HttpClientBuilder::new().default_header("some-key", "some-value"); + let headers_map = builder.default_headers_mut(); + assert!(headers_map.len() == 1); + + let mut builder = HttpClientBuilder::new() + .default_header("some-key", "some-value1") + .default_header("some-key", "some-value2"); + let headers_map = builder.default_headers_mut(); + + assert!(headers_map.len() == 2); + + let mut builder = HttpClientBuilder::new(); + let header_map = builder.default_headers_mut(); + assert!(header_map.is_empty()) + } } diff --git a/src/error.rs b/src/error.rs index ff614c7b..cb206a1d 100644 --- a/src/error.rs +++ b/src/error.rs @@ -147,6 +147,20 @@ impl From for Error { } } +#[doc(hidden)] +impl From for Error { + fn from(error: http::header::InvalidHeaderName) -> Error { + Error::InvalidHttpFormat(error.into()) + } +} + +#[doc(hidden)] +impl From for Error { + fn from(error: http::header::InvalidHeaderValue) -> Error { + Error::InvalidHttpFormat(error.into()) + } +} + #[doc(hidden)] impl From for Error { fn from(error: io::Error) -> Error { diff --git a/tests/headers.rs b/tests/headers.rs index 2c680d5e..bea0251e 100644 --- a/tests/headers.rs +++ b/tests/headers.rs @@ -1,3 +1,4 @@ +use isahc::prelude::*; use mockito::{mock, server_url, Matcher}; speculate::speculate! { @@ -25,4 +26,120 @@ speculate::speculate! { m.assert(); } + + test "header can be inserted in HttpClient::builder" { + + let host_header = server_url().replace("http://", ""); + let m = mock("GET", "/") + .match_header("host", host_header.as_ref()) + .match_header("accept", "*/*") + .match_header("accept-encoding", "deflate, gzip") + // .match_header("user-agent", Matcher::Regex(r"^curl/\S+ isahc/\S+$".into())) + .match_header("user-agent", Matcher::Any) + .match_header("X-header", "some-value1") + .create(); + + let client = HttpClient::builder() + .default_header("X-header", "some-value1") + .build() + .unwrap(); + + let request = Request::builder() + .method("GET") + .uri(server_url()) + .body(()) + .unwrap(); + + let _ = client.send(request).unwrap(); + m.assert(); + } + + test "headers in Request::builder must override headers in HttpClient::builder" { + + let host_header = server_url().replace("http://", ""); + let m = mock("GET", "/") + .match_header("host", host_header.as_ref()) + .match_header("accept", "*/*") + .match_header("accept-encoding", "deflate, gzip") + // .match_header("user-agent", Matcher::Regex(r"^curl/\S+ isahc/\S+$".into())) + .match_header("user-agent", Matcher::Any) + .match_header("X-header", "some-value2") + .create(); + + let client = HttpClient::builder() + .default_header("X-header", "some-value1") + .build() + .unwrap(); + + let request = Request::builder() + .method("GET") + .header("X-header", "some-value2") + .uri(server_url()) + .body(()) + .unwrap(); + + let _ = client.send(request).unwrap(); + m.assert(); + } + + // test "multiple headers with same key can be inserted in HttpClient::builder" { + + // let host_header = server_url().replace("http://", ""); + // let m = mock("GET", "/") + // .match_header("host", host_header.as_ref()) + // .match_header("accept", "*/*") + // .match_header("accept-encoding", "deflate, gzip") + // // .match_header("user-agent", Matcher::Regex(r"^curl/\S+ isahc/\S+$".into())) + // .match_header("user-agent", Matcher::Any) + // .match_header("X-header", "some-value1") + // .match_header("X-header", "some-value2") + // .create(); + + // let client = HttpClient::builder() + // .default_header("X-header", "some-value1") + // .default_header("X-header", "some-value2") + // .build() + // .unwrap(); + + // let request = Request::builder() + // .method("GET") + // .uri(server_url()) + // .body(()) + // .unwrap(); + + // let _ = client.send(request).unwrap(); + // m.assert(); + // } + + + + + test "headers in Request::builder must override multiple headers in HttpClient::builder" { + + let host_header = server_url().replace("http://", ""); + let m = mock("GET", "/") + .match_header("host", host_header.as_ref()) + .match_header("accept", "*/*") + .match_header("accept-encoding", "deflate, gzip") + // .match_header("user-agent", Matcher::Regex(r"^curl/\S+ isahc/\S+$".into())) + .match_header("user-agent", Matcher::Any) + .match_header("X-header", "some-value3") + .create(); + + let client = HttpClient::builder() + .default_header("X-header", "some-value1") + .default_header("X-header", "some-value2") + .build() + .unwrap(); + + let request = Request::builder() + .method("GET") + .header("X-header", "some-value3") + .uri(server_url()) + .body(()) + .unwrap(); + + let _ = client.send(request).unwrap(); + m.assert(); + } }