diff --git a/upstream/CMakeLists.txt b/upstream/CMakeLists.txt index 6cd36d95..4284a6db 100644 --- a/upstream/CMakeLists.txt +++ b/upstream/CMakeLists.txt @@ -86,3 +86,4 @@ if (APPLE) endif () add_unit_test(test_dot_invalid_address "${TEST_DIR}" "${DNSLIBS_DIR}/upstream/src" TRUE TRUE) add_unit_test(test_bootstrapper "${TEST_DIR}" "${DNSLIBS_DIR}/upstream/src" TRUE TRUE) +add_unit_test(test_doh_credentials "${TEST_DIR}" "${DNSLIBS_DIR}/upstream/src" TRUE TRUE) diff --git a/upstream/test/test_doh_credentials.cpp b/upstream/test/test_doh_credentials.cpp new file mode 100644 index 00000000..9bcbb484 --- /dev/null +++ b/upstream/test/test_doh_credentials.cpp @@ -0,0 +1,75 @@ +#include +#include +#include + +#include "common/gtest_coro.h" +#include "common/base64.h" +#include "common/utils.h" +#include "dns/common/event_loop.h" +#include "dns/net/socket.h" +#include "dns/upstream/upstream.h" + +#include "../upstream_doh.h" + +namespace ag::dns::upstream::test { + +struct DohCredentialTestParam { + std::string url; + std::string expected_username; + std::string expected_password; +}; + +class DohUpstreamParamTest : public ::testing::TestWithParam { +public: + void SetUp() override { + m_loop = EventLoop::create(); + m_loop->start(); + } + + void TearDown() override { + m_loop->stop(); + m_loop->join(); + } + + EventLoopPtr m_loop; +}; + +TEST_P(DohUpstreamParamTest, ParsesCredentialsCorrectly) { + co_await m_loop->co_submit(); + + const auto ¶m = GetParam(); + + SocketFactory sf{{.loop = *m_loop}}; + UpstreamFactory factory({.loop = *m_loop, .socket_factory = &sf}); + + UpstreamOptions options; + options.address = param.url; + options.bootstrap = {"8.8.8.8"}; + + auto upstream_res = factory.create_upstream(options); + ASSERT_FALSE(upstream_res.has_error()) << "Failed to create Upstream for " << param.url; + + auto doh_upstream = std::dynamic_pointer_cast(upstream_res.value()); + ASSERT_NE(doh_upstream, nullptr) << "Failed to cast to DohUpstream"; + + ASSERT_TRUE(doh_upstream->get_request_template().headers().get("Authorization").has_value()); + std::string test_header(doh_upstream->get_request_template().headers().get("Authorization").value()); + auto creds_expected = AG_FMT("{}:{}", param.expected_username, param.expected_password); + auto creds_expected_base64 = ag::encode_to_base64(as_u8v(creds_expected), false); + EXPECT_EQ(test_header, AG_FMT("Basic {}", creds_expected_base64)); +} + +static const DohCredentialTestParam doh_credential_test_cases[] = { + {"https://username:password@dns.google/dns-query", "username", "password"}, + {"https://user%20name:password@dns.google/dns-query", "user name", "password"}, + {"https://username:pass%7Eword@dns.google/dns-query", "username", "pass~word"}, + {"https://user%7Cname:pass~word@dns.google/dns-query", "user|name", "pass~word"}, + {"https://username:pass%word@dns.google/dns-query", "username", "pass%word"}, + {"https://username:%7E%%7Cpassword@dns.google/dns-query", "username", "~%|password"}, + {"https://username:%00password@dns.google/dns-query", "username", std::string("\0password", 9)}, +}; + +INSTANTIATE_TEST_SUITE_P(DohUpstreamParamTest, DohUpstreamParamTest, + ::testing::ValuesIn(doh_credential_test_cases)); + +} // namespace ag::dns::upstream::test diff --git a/upstream/upstream_doh.cpp b/upstream/upstream_doh.cpp index 9dee905b..8e013b5d 100644 --- a/upstream/upstream_doh.cpp +++ b/upstream/upstream_doh.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -329,7 +330,9 @@ ag::Error ag::dns::DohUpstream::init() { m_request_template.authority(std::string(m_url.get_hostname())); if (!m_url.get_username().empty() && !m_url.get_password().empty()) { - auto creds_fmt = AG_FMT("{}:{}", m_url.get_username(), m_url.get_password()); + auto decode_username = ada::unicode::percent_decode(m_url.get_username(), m_url.get_username().find('%')); + auto decode_password = ada::unicode::percent_decode(m_url.get_password(), m_url.get_password().find('%')); + auto creds_fmt = AG_FMT("{}:{}", decode_username, decode_password); auto creds_base64 = ag::encode_to_base64(as_u8v(creds_fmt), false); m_request_template.headers().put("Authorization", AG_FMT("Basic {}", creds_base64)); } diff --git a/upstream/upstream_doh.h b/upstream/upstream_doh.h index e37248aa..312a052d 100644 --- a/upstream/upstream_doh.h +++ b/upstream/upstream_doh.h @@ -37,6 +37,10 @@ class DohUpstream : public Upstream { std::vector fingerprints); ~DohUpstream() override; + const http::Request get_request_template() const { + return m_request_template; + } + DohUpstream() = delete; DohUpstream(const DohUpstream &) = delete; DohUpstream &operator=(const DohUpstream &) = delete;