Skip to content

Commit

Permalink
Pull request 610: [2.5] Add udp:// scheme
Browse files Browse the repository at this point in the history
Merge in ADGUARD-CORE-LIBS/dns-libs from feauture/AG-27295 to dev-2.5

Squashed commit of the following:

commit 6d0a2ba82a8b8f384a30e574c1ff2787b0265d71
Author: Zholboldu Emilbekuulu <[email protected]>
Date:   Thu Dec 14 14:10:45 2023 +0300

    AG-27295: Removed the should_try_tcp function as it’s no longer needed

commit c5aae3aedfed66a86bfcf6ef17728474a138d270
Merge: 54694880 fa31990
Author: Sergey Fionov <[email protected]>
Date:   Wed Dec 13 21:11:07 2023 +0200

    Merge branch 'dev-2.5' into feauture/AG-27295

commit 54694880f2a6259eb1843a1ac932a4ee866d61bc
Author: Zholboldu Emilbekuulu <[email protected]>
Date:   Wed Dec 13 21:33:24 2023 +0300

    AG-27295: Implemented receive_and_decode_dns_packet function

commit 6f0048d62b9cb5b6ae080f30fec9ea6d5767a1ac
Author: Zholboldu Emilbekuulu <[email protected]>
Date:   Mon Dec 11 17:15:32 2023 +0300

    AG-27295: Implementing DNS over UDP without fallback to TCP.

commit 9d4e752cae84f13786da865babe990bea7553f30
Author: Zholboldu Emilbekuulu <[email protected]>
Date:   Mon Dec 11 14:15:56 2023 +0300

    AG-27295: "udp:" replaced within UDP_SCHEME

commit 59717d8fd5377200850d0a2fb3baea242368b688
Author: Zholboldu Emilbekuulu <[email protected]>
Date:   Mon Dec 11 13:45:38 2023 +0300

    AG-27295: "udp:" replaced within UDP_SCHEME

commit afa3da2eab019378cee0bb63d9d47a84f865f344
Author: Zholboldu Emilbekuulu <[email protected]>
Date:   Mon Dec 11 12:46:33 2023 +0300

    AG-27295: Slash is removed from constants TCPSCHEME AND UDPSCHEME.

commit f3dcbba2e639b36faae9b38a4dbe3e2f07654e86
Author: Zholboldu Emilbekuulu <[email protected]>
Date:   Fri Dec 8 21:27:52 2023 +0300

    AG-27295:  DNS over UDP без фоллбэка на TCP

commit 49b8a26fc3d26980697ee713eb588f0657cf1469
Author: Zholboldu Emilbekuulu <[email protected]>
Date:   Fri Dec 8 20:47:07 2023 +0300

    AG-27295:  DNS over UDP без фоллбэка на TCP
  • Loading branch information
Zholboldu Emilbekuulu authored and sfionov committed Dec 14, 2023
1 parent e4bcb30 commit 85abf90
Show file tree
Hide file tree
Showing 11 changed files with 92 additions and 37 deletions.
4 changes: 3 additions & 1 deletion common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ find_package(pcre2 REQUIRED)
find_package(magic_enum REQUIRED)
find_package(native_libs_common REQUIRED)
find_package(libuv REQUIRED)
find_package(ldns REQUIRED)

set(SRCS
sys.cpp
Expand All @@ -31,8 +32,9 @@ endif ()
set_target_properties(dnslibs_common PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_include_directories(dnslibs_common PUBLIC include)
target_compile_definitions(dnslibs_common PUBLIC PCRE2_STATIC=1 PCRE2_CODE_UNIT_WIDTH=8 SPDLOG_NO_EXCEPTIONS=1)
target_link_libraries(dnslibs_common libevent::libevent pcre2::pcre2)
target_link_libraries(dnslibs_common libevent::libevent pcre2::pcre2 )
target_link_libraries(dnslibs_common magic_enum::magic_enum native_libs_common::native_libs_common libuv::libuv)
target_link_libraries(dnslibs_common ldns::ldns)

if (NOT MSVC)
target_compile_options(dnslibs_common PRIVATE -Wall -Wextra)
Expand Down
6 changes: 6 additions & 0 deletions common/include/dns/common/dns_defs.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
#pragma once

#include <ldns/packet.h>

#include "common/defs.h"
#include "common/error.h"
#include "common/utils.h"

namespace ag {
namespace dns {

using ldns_pkt_ptr = UniquePtr<ldns_pkt, &ldns_pkt_free>; // NOLINT(readability-identifier-naming)
using ldns_buffer_ptr = UniquePtr<ldns_buffer, &ldns_buffer_free>; // NOLINT(readability-identifier-naming)

/**
* Enum for errors than can happen during DNS exchange
*/
Expand Down
8 changes: 3 additions & 5 deletions dnscrypt/include/dns/dnscrypt/dns_crypt_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <string_view>
#include "common/defs.h"
#include "common/net_utils.h"
#include "dns/common/dns_defs.h"
#include <ldns/ldns.h>

namespace ag::dns::dnscrypt {
Expand All @@ -16,15 +17,12 @@ constexpr size_t KEY_SIZE = 32;
using KeyArray = Uint8Array<KEY_SIZE>;
using ClientMagicArray = Uint8Array<CLIENT_MAGIC_LEN>;

using ldns_pkt_ptr = UniquePtr<ldns_pkt, &ldns_pkt_free>;
using ldns_buffer_ptr = UniquePtr<ldns_buffer, &ldns_buffer_free>;

/**
* Crypto construction represents the encryption algorithm
*/
enum class CryptoConstruction : uint16_t {
UNDEFINED, /** UNDEFINED is the default value for empty cert_info only */
X_SALSA_20_POLY_1305 = 0x0001, /** X_SALSA_20_POLY_1305 encryption */
UNDEFINED, /** UNDEFINED is the default value for empty cert_info only */
X_SALSA_20_POLY_1305 = 0x0001, /** X_SALSA_20_POLY_1305 encryption */
X_CHACHA_20_POLY_1305 = 0x0002, /** X_CHACHA_20_POLY_1305 encryption */
};

Expand Down
13 changes: 13 additions & 0 deletions net/include/dns/net/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "common/coro.h"
#include "common/defs.h"
#include "common/error.h"
#include "dns/common/dns_defs.h"
#include "dns/net/aio_socket.h"
#include "dns/net/socket.h"

Expand Down Expand Up @@ -30,4 +31,16 @@ namespace ag::dns {
*/
coro::Task<Result<Uint8Vector, SocketError>> receive_dns_packet(AioSocket *self, std::optional<Micros> timeout);

/**
* Receive, decode, and validate a DNS packet from the peer.
* Blocks until either an error occurred, an invalid packet is received, or a valid packet is fully received and decoded.
* @param self AioSocket instance
* @param timeout operation timeout
* @param check_and_decode function to decode and validate the received packet
* @return The received and decoded packet if succeeded and valid, nullptr otherwise
*/
coro::Task<Result<ldns_pkt_ptr, SocketError>> receive_and_decode_dns_packet(
AioSocket *self,
std::optional<Micros> timeout,
std::function<ldns_pkt_ptr(Uint8Vector)> check_and_decode);
} // namespace ag::dns
42 changes: 42 additions & 0 deletions net/utils.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
#ifndef _WIN32
#include <ldns/net.h>
#include <netinet/in.h>

#else
#include <Winsock2.h>
#endif

#include "dns/common/dns_defs.h"
#include "dns/net/tcp_dns_buffer.h"
#include "dns/net/utils.h"

Expand Down Expand Up @@ -74,4 +77,43 @@ coro::Task<Result<Uint8Vector, dns::SocketError>> dns::receive_dns_packet(
co_return std::move(context.reply);
}

struct ReadContext {
utils::TransportProtocol protocol;
dns::TcpDnsBuffer tcp_buffer;
dns::ldns_pkt_ptr reply_pkt;
std::function<dns::ldns_pkt_ptr(Uint8Vector)> check_and_decode;
};

static bool on_read(void *arg, Uint8View data) {
auto *ctx = (ReadContext *) arg;
bool done = false;
switch (ctx->protocol) {
case utils::TransportProtocol::TP_TCP:
ctx->tcp_buffer.store(data);
if (auto p = ctx->tcp_buffer.extract_packet(); p.has_value()) {
ctx->reply_pkt = ctx->check_and_decode(std::move(p.value()));
done = ctx->reply_pkt != nullptr;
}
break;
case utils::TransportProtocol::TP_UDP:
ctx->reply_pkt = ctx->check_and_decode({data.begin(), data.end()});
done = ctx->reply_pkt != nullptr;
break;
}
return !done;
}

coro::Task<Result<dns::ldns_pkt_ptr, dns::SocketError>> dns::receive_and_decode_dns_packet(
AioSocket *self, std::optional<Micros> timeout, std::function<ldns_pkt_ptr(Uint8Vector)> check_and_decode) {

ReadContext context = {.protocol = self->get_underlying()->get_protocol(), .check_and_decode = check_and_decode};
auto on_read_handler = AioSocket::OnReadCallback{on_read, &context};

if (auto err = co_await self->receive(on_read_handler, timeout)) {
co_return err;
}

co_return std::move(context.reply_pkt);
}

} // namespace ag
2 changes: 0 additions & 2 deletions proxy/response_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@

namespace ag::dns {

using ldns_pkt_ptr = UniquePtr<ldns_pkt, &ldns_pkt_free>;

/**
* Response cache
*
Expand Down
2 changes: 0 additions & 2 deletions upstream/include/dns/upstream/upstream.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ namespace dns {
class Upstream;

using UpstreamPtr = std::shared_ptr<Upstream>;
using ldns_pkt_ptr = UniquePtr<ldns_pkt, &ldns_pkt_free>; // NOLINT(readability-identifier-naming)
using ldns_buffer_ptr = UniquePtr<ldns_buffer, &ldns_buffer_free>; // NOLINT(readability-identifier-naming)

/**
* Upstream factory configuration
Expand Down
1 change: 1 addition & 0 deletions upstream/test/test_upstream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ TEST_P(DnsTruncatedTest, TestDnsTruncated) {
INSTANTIATE_TEST_SUITE_P(DnsTruncatedTest, DnsTruncatedTest, testing::ValuesIn(truncated_test_data));

static const UpstreamTestData test_upstreams_data[]{
{"udp://1.1.1.1:53", {}},
{"tcp://8.8.8.8", {}},
{"8.8.8.8:53", {"8.8.8.8:53"}},
{"1.0.0.1", {}},
Expand Down
3 changes: 3 additions & 0 deletions upstream/upstream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ enum class Scheme : size_t {
SDNS,
DNS,
TCP,
UDP,
TLS,
HTTPS,
H3,
Expand All @@ -36,6 +37,7 @@ static constexpr std::string_view SCHEME_WITH_SUFFIX[] = {
"sdns://",
"dns://",
PlainUpstream::TCP_SCHEME,
PlainUpstream::UDP_SCHEME,
DotUpstream::SCHEME,
DohUpstream::SCHEME_HTTPS,
DohUpstream::SCHEME_H3,
Expand Down Expand Up @@ -177,6 +179,7 @@ UpstreamFactory::CreateResult UpstreamFactory::Impl::create_upstream(const Upstr
&create_upstream_sdns,
&create_upstream_plain,
&create_upstream_plain,
&create_upstream_plain,
&create_upstream_tls,
&create_upstream_https,
&create_upstream_https,
Expand Down
43 changes: 17 additions & 26 deletions upstream/upstream_plain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ Error<Upstream::InitError> PlainUpstream::init() {
}
m_address = ag::utils::str_to_socket_address(m_url.get_host());

if (m_url.get_protocol() == "tcp:") {
if (m_url.get_protocol() == UDP_SCHEME) {
m_prefer_udp = true;
} else if (m_url.get_protocol() == TCP_SCHEME) {
m_prefer_tcp = true;
} else {
return make_error(InitError::AE_INVALID_ADDRESS, AG_FMT("Invalid URL scheme: {}", m_url.get_protocol()));
Expand All @@ -47,24 +49,6 @@ Error<Upstream::InitError> PlainUpstream::init() {
return {};
}

static bool should_try_tcp(const ldns_pkt *request, const ldns_pkt *reply, const ldns_status status) {
if (status != LDNS_STATUS_OK) {
return true;
}

auto orig_id = ldns_pkt_id(request);
auto reply_id = ldns_pkt_id(reply);
if (reply_id != orig_id) {
return true;
}

if (ldns_pkt_tc(reply)) {
return true;
}

return false;
}

coro::Task<Upstream::ExchangeResult> PlainUpstream::exchange(const ldns_pkt *request_pkt, const DnsMessageInfo *info) {
std::weak_ptr<bool> guard = m_shutdown_guard;

Expand Down Expand Up @@ -107,7 +91,17 @@ coro::Task<Upstream::ExchangeResult> PlainUpstream::exchange(const ldns_pkt *req
co_return make_error(DnsError::AE_SOCKET_ERROR, err);
}

auto r = co_await receive_dns_packet(&socket, timeout);
auto r = co_await receive_and_decode_dns_packet(
&socket, timeout, [id = ldns_pkt_id(request_pkt)](Uint8Vector buf) {
ldns_pkt *reply_pkt = nullptr;
auto status = ldns_wire2pkt(&reply_pkt, buf.data(), buf.size());
// Skip incorrect packets or packets with invalid id
if (status != LDNS_STATUS_OK || ldns_pkt_id(reply_pkt) != id) {
return ldns_pkt_ptr{nullptr}; // Return nullptr wrapped in ldns_pkt_ptr
}
return ldns_pkt_ptr{reply_pkt};
});

if (guard.expired()) {
co_return make_error(DnsError::AE_SHUTTING_DOWN);
}
Expand All @@ -117,14 +111,11 @@ coro::Task<Upstream::ExchangeResult> PlainUpstream::exchange(const ldns_pkt *req
: make_error(DnsError::AE_SOCKET_ERROR, r.error());
}

auto &reply = r.value();
ldns_pkt *reply_pkt = nullptr;
status = ldns_wire2pkt(&reply_pkt, reply.data(), reply.size());
if (!should_try_tcp(request_pkt, reply_pkt, status)) {
co_return ldns_pkt_ptr{reply_pkt};
auto &reply_pkt = r.value();
if (m_prefer_udp || !ldns_pkt_tc(reply_pkt.get())) {
co_return std::move(reply_pkt);
}
tracelog_id(m_log, request_pkt, "Trying TCP request after UDP failure");
ldns_pkt_free(reply_pkt);
}

timeout -= timer.elapsed<decltype(timeout)>();
Expand Down
5 changes: 4 additions & 1 deletion upstream/upstream_plain.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ class PlainUpstream;
*/
class PlainUpstream : public Upstream {
public:
static constexpr std::string_view TCP_SCHEME = "tcp://";
static constexpr std::string_view TCP_SCHEME = "tcp:";
static constexpr std::string_view UDP_SCHEME = "udp:";

/**
* Create plain DNS upstream
Expand All @@ -39,6 +40,8 @@ class PlainUpstream : public Upstream {

/** Prefer TCP */
bool m_prefer_tcp;
/** Prefer UDP */
bool m_prefer_udp;
/** TCP connection pool */
ConnectionPoolPtr m_pool;
/** Socket address */
Expand Down

0 comments on commit 85abf90

Please sign in to comment.