Skip to content

Commit

Permalink
Pull request 644: Fix using system resolver in bootstrap
Browse files Browse the repository at this point in the history
Merge in ADGUARD-CORE-LIBS/dns-libs from fix/system_upstream to master

Squashed commit of the following:

commit 425f1d3c9377222d6a639e0bdaeb25d32ef72720
Author: Sergey Fionov <[email protected]>
Date:   Wed Mar 27 18:52:08 2024 +0200

    skipci: remove unused class

commit 4c9785d779b27d9676d57a5db6f92026ba5c0bc5
Author: Sergey Fionov <[email protected]>
Date:   Wed Mar 27 18:48:29 2024 +0200

    Extract constant

commit ec16925427771fd11a49d01cc11fb2aabd24c8b8
Author: Sergey Fionov <[email protected]>
Date:   Wed Mar 27 18:45:42 2024 +0200

    SystemResolver to pimpl

commit 90441fbd32179a3636f6fdd7ac810e3547971595
Author: Sergey Fionov <[email protected]>
Date:   Wed Mar 27 18:08:42 2024 +0200

    Fix using system resolver in bootstrap
  • Loading branch information
sfionov committed Mar 27, 2024
1 parent 2bae3c8 commit f6df999
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 150 deletions.
3 changes: 2 additions & 1 deletion upstream/resolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "upstream_doh.h"
#include "upstream_dot.h"
#include "upstream_plain.h"
#include "upstream_system.h"

#define log_ip(l_, lvl_, ip_, fmt_, ...) lvl_##log(l_, "[{}] " fmt_, ip_, ##__VA_ARGS__)

Expand Down Expand Up @@ -86,7 +87,7 @@ static std::string get_server_address(const Logger &log, std::string_view addres
warnlog(log, "Failed to parse DNS stamp");
return "";
}
} else if (!check_ip_address(result)) {
} else if (!result.starts_with(SystemUpstream::SYSTEM_SCHEME) && !check_ip_address(result)) {
warnlog(log, "Resolver address must be a valid ip address");
return "";
}
Expand Down
263 changes: 151 additions & 112 deletions upstream/system_resolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,135 +4,174 @@

#include "system_resolver.h"
#include "dns/common/event_loop.h"
#include <dns_sd.h>

namespace ag::dns {

SystemResolver::SystemResolver(ConstructorAccess, EventLoop *loop, uint32_t if_index)
: m_loop(loop)
, m_if_index(if_index) {
DNSServiceRef service_ref = nullptr;
DNSServiceErrorType error_code = DNSServiceCreateConnection(&service_ref);
if (error_code != kDNSServiceErr_NoError) {
m_error_code = error_code;
static ag::Logger g_log{"SystemResolver"};

class SystemResolver::Impl {
public:
Impl(EventLoop *loop, uint32_t if_index)
: m_loop(loop)
, m_if_index(if_index) {
m_queue = dispatch_queue_create("SystemResolver", DISPATCH_QUEUE_SERIAL);
}
m_service_ref.reset(service_ref);
m_queue = dispatch_queue_create("SystemResolver", DISPATCH_QUEUE_SERIAL);
DNSServiceSetDispatchQueue(m_service_ref.get(), m_queue);
}

ag::Result<std::unique_ptr<SystemResolver>, SystemResolverError> SystemResolver::create(EventLoop *loop, uint32_t if_index) {
std::unique_ptr<SystemResolver> ret = std::make_unique<SystemResolver>(ConstructorAccess{}, loop, if_index);
if (ret && ret->m_error_code != 0) {
return make_error(SystemResolverError{ret->m_error_code});
auto resolve(std::string_view domain, ldns_rr_type rr_type) {
struct Awaitable {
Impl *parent;
EventLoop *loop;
std::mutex mutex;
std::string domain;
ldns_rr_type rr_type;
bool rr_type_received;
SystemResolver::LdnsRrListPtr rr_list{ldns_rr_list_new()};
Error<SystemResolverError> error;
ServiceRefPtr service;
bool done;
std::coroutine_handle<> caller;

auto await_ready() {
std::scoped_lock l{mutex};
DNSServiceRef service_ref;

auto error_code = DNSServiceQueryRecord(&service_ref,
kDNSServiceFlagsUnicastResponse | kDNSServiceFlagsReturnIntermediates
| kDNSServiceFlagAnsweredFromCache,
parent->m_if_index, domain.data(), rr_type, kDNSServiceClass_IN, handle_dns_service_query_record_reply, this);

if (error_code != kDNSServiceErr_NoError) {
error = make_error(SystemResolverError::AE_SYSTEM_RESOLVE_ERROR);
return true;
}
tracelog(g_log, "Requested domain {} rrtype {}",
domain, AllocatedPtr<char>{ldns_rr_type2str(ldns_rr_type(rr_type))}.get() ?: AG_FMT("TYPE{}", (int)rr_type));
DNSServiceSetDispatchQueue(service_ref, parent->m_queue);
service.reset(service_ref);

return false;
}

auto await_suspend(std::coroutine_handle<> h) {
std::scoped_lock l{mutex};
if (done) {
h();
} else {
caller = h;
}
}

Result<SystemResolver::LdnsRrListPtr, SystemResolverError> await_resume() {
if (error) {
return error;
} else {
return std::move(rr_list);
}
}
/**
* Handles the reply from a DNSServiceQueryRecord request.
* @param sdRef The DNSServiceRef initialized by DNSServiceQueryRecord.
* @param flags Possible values are kDNSServiceFlagsMoreComing and kDNSServiceFlagsAdd.
* @param interfaceIndex The interface on which the query was resolved.
* @param errorCode Indicates whether the operation succeeded.
* @param fullname The full domain name of the resource record.
* @param rrtype The type of the resource record.
* @param rrclass The class of the resource record.
* @param rdlen The length of the rdata.
* @param rdata The raw rdata of the resource record.
* @param ttl The time to live of the resource record.
* @param context A pointer to the user-defined context.
*/
static void handle_dns_service_query_record_reply(DNSServiceRef sdRef, DNSServiceFlags flags,
uint32_t interfaceIndex, DNSServiceErrorType errorCode, const char *fullname, uint16_t rrtype,
uint16_t rrclass, uint16_t rdlen, const void *rdata, uint32_t ttl, void *context) {
tracelog(g_log, "Reply: error {}, {} {} {} {} {}, flags {:x}",
errorCode, fullname ?: "(null)",
ttl,
AllocatedPtr<char>{ldns_rr_class2str(ldns_rr_class(rrclass))}.get() ?: AG_FMT("CLASS{}", rrclass),
AllocatedPtr<char>{ldns_rr_type2str(ldns_rr_type(rrtype))}.get() ?: AG_FMT("TYPE{}", rrtype),
utils::encode_to_hex(Uint8View{(const uint8_t *) rdata, rdlen}), flags);
auto *self = (Awaitable *) context;
std::scoped_lock l{self->mutex};
if (errorCode == kDNSServiceErr_NoError) {
if (rrtype == self->rr_type) {
self->rr_type_received = true;
}
ldns_rr *rr = ldns_rr_new();
ldns_rr_set_owner(rr, ldns_dname_new_frm_str(fullname));
ldns_rr_set_type(rr, static_cast<ldns_rr_type>(rrtype));
ldns_rr_set_class(rr, static_cast<ldns_rr_class>(rrclass));
ldns_rr_set_ttl(rr, ttl);
size_t pos = 0;
std::vector<uint8_t> r_vecdata;
r_vecdata.resize(rdlen + 2);
uint16_t rdlen_network = htons(rdlen);
memcpy(r_vecdata.data(), &rdlen_network, 2);
memcpy(r_vecdata.data() + 2, rdata, rdlen);
ldns_status status = ldns_wire2rdf(rr, r_vecdata.data(), r_vecdata.size(), &pos);
if (status != LDNS_STATUS_OK) {
self->error = make_error(SystemResolverError::AE_DECODE_ERROR, make_error(status));
}
ldns_rr_list_push_rr(self->rr_list.get(), rr);
} else {
if (errorCode == kDNSServiceErr_NoSuchRecord || errorCode == kDNSServiceErr_NoSuchName) {
self->error = make_error(SystemResolverError::AE_DOMAIN_NOT_FOUND);
} else {
self->error = make_error(SystemResolverError::AE_SYSTEM_RESOLVE_ERROR);
}
}

if ((flags & kDNSServiceFlagsMoreComing) == 0) {
if ((flags & kDNSServiceFlagAnsweredFromCache) && !self->rr_type_received
&& errorCode == kDNSServiceErr_NoError) {
tracelog(g_log, "Detected partial answer from cache, waiting more");
return;
}
tracelog(g_log, "Done");
self->done = true;
self->service.reset();
self->loop->submit(self->caller);
return;
}
tracelog(g_log, "More coming");
}
};
return Awaitable{
.parent = this,
.loop = m_loop,
.domain = std::string(domain),
.rr_type = rr_type,
};
}
return ret;
}

struct Context {
EventLoop *loop;
std::mutex mutex;
ldns_rr_type rr_type;
bool rr_type_received;
SystemResolver::LdnsRrListPtr rr_list{ldns_rr_list_new()};
Error<SystemResolverError> error;
SystemResolver::ServiceRefPtr service;
bool done;
std::coroutine_handle<> caller;
using ServiceRefPtr = ag::UniquePtr<std::remove_pointer_t<DNSServiceRef>, &DNSServiceRefDeallocate>;

EventLoop *m_loop = nullptr;
uint32_t m_if_index{}; ///< The network interface index.
DNSServiceErrorType m_error_code = 0;
ServiceRefPtr m_service_ref{};
dispatch_queue_t m_queue;
};

bool SystemResolver::ResolveAwaitable::await_ready() {
Context &ctx = *(Context *)context;
std::scoped_lock l{ctx.mutex};
return ctx.done;
SystemResolver::SystemResolver(ag::dns::SystemResolver::ConstructorAccess, ag::dns::EventLoop *loop, uint32_t if_index) {
m_pimpl = std::make_unique<Impl>(loop, if_index);
}

void SystemResolver::ResolveAwaitable::await_suspend(std::coroutine_handle<> h) {
Context &ctx = *(Context *)context;
std::scoped_lock l{ctx.mutex};
if (ctx.done) {
h();
} else {
ctx.caller = h;
}
}
SystemResolver::~SystemResolver() = default;

Result<SystemResolver::LdnsRrListPtr, SystemResolverError> SystemResolver::ResolveAwaitable::await_resume() {
Context &ctx = *(Context *)context;
Result<SystemResolver::LdnsRrListPtr, SystemResolverError> ret;
if (ctx.error) {
ret = ctx.error;
} else {
ret = std::move(ctx.rr_list);
ag::Result<std::unique_ptr<SystemResolver>, SystemResolverError> SystemResolver::create(EventLoop *loop, uint32_t if_index) {
std::unique_ptr<SystemResolver> ret = std::make_unique<SystemResolver>(ConstructorAccess{}, loop, if_index);
if (ret && ret->m_pimpl && ret->m_pimpl->m_error_code != 0) {
return make_error(SystemResolverError{ret->m_pimpl->m_error_code});
}
delete (Context *) context;
return ret;
}

SystemResolver::ResolveAwaitable SystemResolver::resolve(
std::string_view domain, ldns_rr_type rr_type) {
ResolveAwaitable awaitable{};
awaitable.context = new Context{};
Context &context = *(Context *)awaitable.context;
context.loop = m_loop;
context.rr_type = rr_type;
std::unique_lock l{context.mutex};
DNSServiceRef service_ref = m_service_ref.get();

auto error_code = DNSServiceQueryRecord(&service_ref,
kDNSServiceFlagsUnicastResponse | kDNSServiceFlagsReturnIntermediates | kDNSServiceFlagsShareConnection
| kDNSServiceFlagAnsweredFromCache,
m_if_index, domain.data(), rr_type, kDNSServiceClass_IN, handle_dns_service_query_record_reply, &context);

if (error_code != kDNSServiceErr_NoError) {
context.error = make_error(SystemResolverError::AE_SYSTEM_RESOLVE_ERROR);
context.done = true;
}
context.service.reset(service_ref);
return awaitable;
coro::Task<Result<SystemResolver::LdnsRrListPtr, SystemResolverError>>
SystemResolver::resolve(std::string_view domain, ldns_rr_type rr_type) {
co_return co_await m_pimpl->resolve(domain, rr_type);
}

void SystemResolver::handle_dns_service_query_record_reply(DNSServiceRef sdRef, DNSServiceFlags flags,
uint32_t interfaceIndex, DNSServiceErrorType errorCode, const char *fullname, uint16_t rrtype, uint16_t rrclass,
uint16_t rdlen, const void *rdata, uint32_t ttl, void *arg) {
auto *context = static_cast<Context *>(arg);
std::scoped_lock l{context->mutex};
if (errorCode == kDNSServiceErr_NoError) {
if (rrtype == context->rr_type) {
context->rr_type_received = true;
}
ldns_rr *rr = ldns_rr_new();
ldns_rr_set_owner(rr, ldns_dname_new_frm_str(fullname));
ldns_rr_set_type(rr, static_cast<ldns_rr_type>(rrtype));
ldns_rr_set_class(rr, static_cast<ldns_rr_class>(rrclass));
ldns_rr_set_ttl(rr, ttl);
size_t pos = 0;
std::vector<uint8_t> r_vecdata;
r_vecdata.resize(rdlen + 2);
uint16_t rdlen_network = htons(rdlen);
memcpy(r_vecdata.data(), &rdlen_network, 2);
memcpy(r_vecdata.data() + 2, rdata, rdlen);
ldns_status status = ldns_wire2rdf(rr, r_vecdata.data(), r_vecdata.size(), &pos);
if (status != LDNS_STATUS_OK) {
context->error = make_error(SystemResolverError::AE_DECODE_ERROR, make_error(status));
}
ldns_rr_list_push_rr(context->rr_list.get(), rr);
} else {
if (errorCode == kDNSServiceErr_NoSuchRecord || errorCode == kDNSServiceErr_NoSuchName) {
context->error = make_error(SystemResolverError::AE_DOMAIN_NOT_FOUND);
} else {
context->error = make_error(SystemResolverError::AE_SYSTEM_RESOLVE_ERROR);
}
}

if ((flags & kDNSServiceFlagsMoreComing) == 0) {
if ((flags & kDNSServiceFlagAnsweredFromCache) && !context->rr_type_received
&& errorCode == kDNSServiceErr_NoError) {
return;
}
context->done = true;
context->service.reset();
context->loop->submit(context->caller);
}
}

} // namespace ag::dns
39 changes: 5 additions & 34 deletions upstream/system_resolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include "dns/common/event_loop.h"

#include <cassert>
#include <dns_sd.h>
#include <ldns/ldns.h>
#include <string_view>

Expand Down Expand Up @@ -38,48 +37,20 @@ class SystemResolver {
*/
SystemResolver(ConstructorAccess, EventLoop *loop, uint32_t if_index = 0);
static Result<std::unique_ptr<SystemResolver>, SystemResolverError> create(EventLoop *loop, uint32_t if_index);
~SystemResolver();

struct ResolveAwaitable {
void *context;
bool await_ready();
void await_suspend(std::coroutine_handle<> h);
Result<LdnsRrListPtr, SystemResolverError> await_resume();
};
/**
* Resolves a domain to a list of resource records.
* @param domain Domain to resolve.
* @param rr_type Type of the resource record to resolve.
* @return A unique pointer to a list of resource records.
*/
ResolveAwaitable resolve(std::string_view domain, ldns_rr_type rr_type);

using ServiceRefPtr = ag::UniquePtr<std::remove_pointer_t<DNSServiceRef>, &DNSServiceRefDeallocate>;
coro::Task<Result<LdnsRrListPtr, SystemResolverError>>
resolve(std::string_view domain, ldns_rr_type rr_type);

private:
/**
* Handles the reply from a DNSServiceQueryRecord request.
* @param sdRef The DNSServiceRef initialized by DNSServiceQueryRecord.
* @param flags Possible values are kDNSServiceFlagsMoreComing and kDNSServiceFlagsAdd.
* @param interfaceIndex The interface on which the query was resolved.
* @param errorCode Indicates whether the operation succeeded.
* @param fullname The full domain name of the resource record.
* @param rrtype The type of the resource record.
* @param rrclass The class of the resource record.
* @param rdlen The length of the rdata.
* @param rdata The raw rdata of the resource record.
* @param ttl The time to live of the resource record.
* @param context A pointer to the user-defined context.
*/
static void handle_dns_service_query_record_reply(DNSServiceRef sdRef, DNSServiceFlags flags,
uint32_t interfaceIndex, DNSServiceErrorType errorCode, const char *fullname, uint16_t rrtype,
uint16_t rrclass, uint16_t rdlen, const void *rdata, uint32_t ttl, void *context);


EventLoop *m_loop = nullptr;
uint32_t m_if_index{}; ///< The network interface index.
DNSServiceErrorType m_error_code = 0;
ServiceRefPtr m_service_ref{};
dispatch_queue_t m_queue;
class Impl;
std::unique_ptr<Impl> m_pimpl;
};

} // namespace ag::dns
Expand Down
4 changes: 1 addition & 3 deletions upstream/upstream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
#include "upstream_doq.h"
#include "upstream_dot.h"
#include "upstream_plain.h"
#ifdef __APPLE__
#include "upstream_system.h"
#endif //_APPLE

namespace ag::dns {

Expand All @@ -45,7 +43,7 @@ static constexpr std::string_view SCHEME_WITH_SUFFIX[] = {
DohUpstream::SCHEME_HTTPS,
DohUpstream::SCHEME_H3,
DoqUpstream::SCHEME,
"system://",
SystemUpstream::SYSTEM_SCHEME,
};

static_assert(std::size(SCHEME_WITH_SUFFIX) + 1 == magic_enum::enum_count<Scheme>(),
Expand Down

0 comments on commit f6df999

Please sign in to comment.