diff --git a/upstream/resolver.cpp b/upstream/resolver.cpp index ff201390..3a7f80f5 100644 --- a/upstream/resolver.cpp +++ b/upstream/resolver.cpp @@ -8,6 +8,7 @@ #include "upstream_doh.h" #include "upstream_dot.h" #include "upstream_plain.h" +#include "upstream_system.h" #define log_ip(l_, lvl_, ip_, fmt_, ...) lvl_##log(l_, "[{}] " fmt_, ip_, ##__VA_ARGS__) @@ -86,7 +87,7 @@ static std::string get_server_address(const Logger &log, std::string_view addres warnlog(log, "Failed to parse DNS stamp"); return ""; } - } else if (!check_ip_address(result)) { + } else if (!result.starts_with(SystemUpstream::SYSTEM_SCHEME) && !check_ip_address(result)) { warnlog(log, "Resolver address must be a valid ip address"); return ""; } diff --git a/upstream/system_resolver.cpp b/upstream/system_resolver.cpp index cb9110da..20393449 100644 --- a/upstream/system_resolver.cpp +++ b/upstream/system_resolver.cpp @@ -4,135 +4,174 @@ #include "system_resolver.h" #include "dns/common/event_loop.h" +#include namespace ag::dns { -SystemResolver::SystemResolver(ConstructorAccess, EventLoop *loop, uint32_t if_index) - : m_loop(loop) - , m_if_index(if_index) { - DNSServiceRef service_ref = nullptr; - DNSServiceErrorType error_code = DNSServiceCreateConnection(&service_ref); - if (error_code != kDNSServiceErr_NoError) { - m_error_code = error_code; +static ag::Logger g_log{"SystemResolver"}; + +class SystemResolver::Impl { +public: + Impl(EventLoop *loop, uint32_t if_index) + : m_loop(loop) + , m_if_index(if_index) { + m_queue = dispatch_queue_create("SystemResolver", DISPATCH_QUEUE_SERIAL); } - m_service_ref.reset(service_ref); - m_queue = dispatch_queue_create("SystemResolver", DISPATCH_QUEUE_SERIAL); - DNSServiceSetDispatchQueue(m_service_ref.get(), m_queue); -} -ag::Result, SystemResolverError> SystemResolver::create(EventLoop *loop, uint32_t if_index) { - std::unique_ptr ret = std::make_unique(ConstructorAccess{}, loop, if_index); - if (ret && ret->m_error_code != 0) { - return make_error(SystemResolverError{ret->m_error_code}); + auto resolve(std::string_view domain, ldns_rr_type rr_type) { + struct Awaitable { + Impl *parent; + EventLoop *loop; + std::mutex mutex; + std::string domain; + ldns_rr_type rr_type; + bool rr_type_received; + SystemResolver::LdnsRrListPtr rr_list{ldns_rr_list_new()}; + Error error; + ServiceRefPtr service; + bool done; + std::coroutine_handle<> caller; + + auto await_ready() { + std::scoped_lock l{mutex}; + DNSServiceRef service_ref; + + auto error_code = DNSServiceQueryRecord(&service_ref, + kDNSServiceFlagsUnicastResponse | kDNSServiceFlagsReturnIntermediates + | kDNSServiceFlagAnsweredFromCache, + parent->m_if_index, domain.data(), rr_type, kDNSServiceClass_IN, handle_dns_service_query_record_reply, this); + + if (error_code != kDNSServiceErr_NoError) { + error = make_error(SystemResolverError::AE_SYSTEM_RESOLVE_ERROR); + return true; + } + tracelog(g_log, "Requested domain {} rrtype {}", + domain, AllocatedPtr{ldns_rr_type2str(ldns_rr_type(rr_type))}.get() ?: AG_FMT("TYPE{}", (int)rr_type)); + DNSServiceSetDispatchQueue(service_ref, parent->m_queue); + service.reset(service_ref); + + return false; + } + + auto await_suspend(std::coroutine_handle<> h) { + std::scoped_lock l{mutex}; + if (done) { + h(); + } else { + caller = h; + } + } + + Result await_resume() { + if (error) { + return error; + } else { + return std::move(rr_list); + } + } + /** + * Handles the reply from a DNSServiceQueryRecord request. + * @param sdRef The DNSServiceRef initialized by DNSServiceQueryRecord. + * @param flags Possible values are kDNSServiceFlagsMoreComing and kDNSServiceFlagsAdd. + * @param interfaceIndex The interface on which the query was resolved. + * @param errorCode Indicates whether the operation succeeded. + * @param fullname The full domain name of the resource record. + * @param rrtype The type of the resource record. + * @param rrclass The class of the resource record. + * @param rdlen The length of the rdata. + * @param rdata The raw rdata of the resource record. + * @param ttl The time to live of the resource record. + * @param context A pointer to the user-defined context. + */ + static void handle_dns_service_query_record_reply(DNSServiceRef sdRef, DNSServiceFlags flags, + uint32_t interfaceIndex, DNSServiceErrorType errorCode, const char *fullname, uint16_t rrtype, + uint16_t rrclass, uint16_t rdlen, const void *rdata, uint32_t ttl, void *context) { + tracelog(g_log, "Reply: error {}, {} {} {} {} {}, flags {:x}", + errorCode, fullname ?: "(null)", + ttl, + AllocatedPtr{ldns_rr_class2str(ldns_rr_class(rrclass))}.get() ?: AG_FMT("CLASS{}", rrclass), + AllocatedPtr{ldns_rr_type2str(ldns_rr_type(rrtype))}.get() ?: AG_FMT("TYPE{}", rrtype), + utils::encode_to_hex(Uint8View{(const uint8_t *) rdata, rdlen}), flags); + auto *self = (Awaitable *) context; + std::scoped_lock l{self->mutex}; + if (errorCode == kDNSServiceErr_NoError) { + if (rrtype == self->rr_type) { + self->rr_type_received = true; + } + ldns_rr *rr = ldns_rr_new(); + ldns_rr_set_owner(rr, ldns_dname_new_frm_str(fullname)); + ldns_rr_set_type(rr, static_cast(rrtype)); + ldns_rr_set_class(rr, static_cast(rrclass)); + ldns_rr_set_ttl(rr, ttl); + size_t pos = 0; + std::vector r_vecdata; + r_vecdata.resize(rdlen + 2); + uint16_t rdlen_network = htons(rdlen); + memcpy(r_vecdata.data(), &rdlen_network, 2); + memcpy(r_vecdata.data() + 2, rdata, rdlen); + ldns_status status = ldns_wire2rdf(rr, r_vecdata.data(), r_vecdata.size(), &pos); + if (status != LDNS_STATUS_OK) { + self->error = make_error(SystemResolverError::AE_DECODE_ERROR, make_error(status)); + } + ldns_rr_list_push_rr(self->rr_list.get(), rr); + } else { + if (errorCode == kDNSServiceErr_NoSuchRecord || errorCode == kDNSServiceErr_NoSuchName) { + self->error = make_error(SystemResolverError::AE_DOMAIN_NOT_FOUND); + } else { + self->error = make_error(SystemResolverError::AE_SYSTEM_RESOLVE_ERROR); + } + } + + if ((flags & kDNSServiceFlagsMoreComing) == 0) { + if ((flags & kDNSServiceFlagAnsweredFromCache) && !self->rr_type_received + && errorCode == kDNSServiceErr_NoError) { + tracelog(g_log, "Detected partial answer from cache, waiting more"); + return; + } + tracelog(g_log, "Done"); + self->done = true; + self->service.reset(); + self->loop->submit(self->caller); + return; + } + tracelog(g_log, "More coming"); + } + }; + return Awaitable{ + .parent = this, + .loop = m_loop, + .domain = std::string(domain), + .rr_type = rr_type, + }; } - return ret; -} -struct Context { - EventLoop *loop; - std::mutex mutex; - ldns_rr_type rr_type; - bool rr_type_received; - SystemResolver::LdnsRrListPtr rr_list{ldns_rr_list_new()}; - Error error; - SystemResolver::ServiceRefPtr service; - bool done; - std::coroutine_handle<> caller; + using ServiceRefPtr = ag::UniquePtr, &DNSServiceRefDeallocate>; + + EventLoop *m_loop = nullptr; + uint32_t m_if_index{}; ///< The network interface index. + DNSServiceErrorType m_error_code = 0; + ServiceRefPtr m_service_ref{}; + dispatch_queue_t m_queue; }; -bool SystemResolver::ResolveAwaitable::await_ready() { - Context &ctx = *(Context *)context; - std::scoped_lock l{ctx.mutex}; - return ctx.done; +SystemResolver::SystemResolver(ag::dns::SystemResolver::ConstructorAccess, ag::dns::EventLoop *loop, uint32_t if_index) { + m_pimpl = std::make_unique(loop, if_index); } -void SystemResolver::ResolveAwaitable::await_suspend(std::coroutine_handle<> h) { - Context &ctx = *(Context *)context; - std::scoped_lock l{ctx.mutex}; - if (ctx.done) { - h(); - } else { - ctx.caller = h; - } -} +SystemResolver::~SystemResolver() = default; -Result SystemResolver::ResolveAwaitable::await_resume() { - Context &ctx = *(Context *)context; - Result ret; - if (ctx.error) { - ret = ctx.error; - } else { - ret = std::move(ctx.rr_list); +ag::Result, SystemResolverError> SystemResolver::create(EventLoop *loop, uint32_t if_index) { + std::unique_ptr ret = std::make_unique(ConstructorAccess{}, loop, if_index); + if (ret && ret->m_pimpl && ret->m_pimpl->m_error_code != 0) { + return make_error(SystemResolverError{ret->m_pimpl->m_error_code}); } - delete (Context *) context; return ret; } -SystemResolver::ResolveAwaitable SystemResolver::resolve( - std::string_view domain, ldns_rr_type rr_type) { - ResolveAwaitable awaitable{}; - awaitable.context = new Context{}; - Context &context = *(Context *)awaitable.context; - context.loop = m_loop; - context.rr_type = rr_type; - std::unique_lock l{context.mutex}; - DNSServiceRef service_ref = m_service_ref.get(); - - auto error_code = DNSServiceQueryRecord(&service_ref, - kDNSServiceFlagsUnicastResponse | kDNSServiceFlagsReturnIntermediates | kDNSServiceFlagsShareConnection - | kDNSServiceFlagAnsweredFromCache, - m_if_index, domain.data(), rr_type, kDNSServiceClass_IN, handle_dns_service_query_record_reply, &context); - - if (error_code != kDNSServiceErr_NoError) { - context.error = make_error(SystemResolverError::AE_SYSTEM_RESOLVE_ERROR); - context.done = true; - } - context.service.reset(service_ref); - return awaitable; +coro::Task> +SystemResolver::resolve(std::string_view domain, ldns_rr_type rr_type) { + co_return co_await m_pimpl->resolve(domain, rr_type); } -void SystemResolver::handle_dns_service_query_record_reply(DNSServiceRef sdRef, DNSServiceFlags flags, - uint32_t interfaceIndex, DNSServiceErrorType errorCode, const char *fullname, uint16_t rrtype, uint16_t rrclass, - uint16_t rdlen, const void *rdata, uint32_t ttl, void *arg) { - auto *context = static_cast(arg); - std::scoped_lock l{context->mutex}; - if (errorCode == kDNSServiceErr_NoError) { - if (rrtype == context->rr_type) { - context->rr_type_received = true; - } - ldns_rr *rr = ldns_rr_new(); - ldns_rr_set_owner(rr, ldns_dname_new_frm_str(fullname)); - ldns_rr_set_type(rr, static_cast(rrtype)); - ldns_rr_set_class(rr, static_cast(rrclass)); - ldns_rr_set_ttl(rr, ttl); - size_t pos = 0; - std::vector r_vecdata; - r_vecdata.resize(rdlen + 2); - uint16_t rdlen_network = htons(rdlen); - memcpy(r_vecdata.data(), &rdlen_network, 2); - memcpy(r_vecdata.data() + 2, rdata, rdlen); - ldns_status status = ldns_wire2rdf(rr, r_vecdata.data(), r_vecdata.size(), &pos); - if (status != LDNS_STATUS_OK) { - context->error = make_error(SystemResolverError::AE_DECODE_ERROR, make_error(status)); - } - ldns_rr_list_push_rr(context->rr_list.get(), rr); - } else { - if (errorCode == kDNSServiceErr_NoSuchRecord || errorCode == kDNSServiceErr_NoSuchName) { - context->error = make_error(SystemResolverError::AE_DOMAIN_NOT_FOUND); - } else { - context->error = make_error(SystemResolverError::AE_SYSTEM_RESOLVE_ERROR); - } - } - - if ((flags & kDNSServiceFlagsMoreComing) == 0) { - if ((flags & kDNSServiceFlagAnsweredFromCache) && !context->rr_type_received - && errorCode == kDNSServiceErr_NoError) { - return; - } - context->done = true; - context->service.reset(); - context->loop->submit(context->caller); - } -} } // namespace ag::dns diff --git a/upstream/system_resolver.h b/upstream/system_resolver.h index 0a32b696..7c05c1a6 100644 --- a/upstream/system_resolver.h +++ b/upstream/system_resolver.h @@ -6,7 +6,6 @@ #include "dns/common/event_loop.h" #include -#include #include #include @@ -38,48 +37,20 @@ class SystemResolver { */ SystemResolver(ConstructorAccess, EventLoop *loop, uint32_t if_index = 0); static Result, SystemResolverError> create(EventLoop *loop, uint32_t if_index); + ~SystemResolver(); - struct ResolveAwaitable { - void *context; - bool await_ready(); - void await_suspend(std::coroutine_handle<> h); - Result await_resume(); - }; /** * Resolves a domain to a list of resource records. * @param domain Domain to resolve. * @param rr_type Type of the resource record to resolve. * @return A unique pointer to a list of resource records. */ - ResolveAwaitable resolve(std::string_view domain, ldns_rr_type rr_type); - - using ServiceRefPtr = ag::UniquePtr, &DNSServiceRefDeallocate>; + coro::Task> + resolve(std::string_view domain, ldns_rr_type rr_type); private: - /** - * Handles the reply from a DNSServiceQueryRecord request. - * @param sdRef The DNSServiceRef initialized by DNSServiceQueryRecord. - * @param flags Possible values are kDNSServiceFlagsMoreComing and kDNSServiceFlagsAdd. - * @param interfaceIndex The interface on which the query was resolved. - * @param errorCode Indicates whether the operation succeeded. - * @param fullname The full domain name of the resource record. - * @param rrtype The type of the resource record. - * @param rrclass The class of the resource record. - * @param rdlen The length of the rdata. - * @param rdata The raw rdata of the resource record. - * @param ttl The time to live of the resource record. - * @param context A pointer to the user-defined context. - */ - static void handle_dns_service_query_record_reply(DNSServiceRef sdRef, DNSServiceFlags flags, - uint32_t interfaceIndex, DNSServiceErrorType errorCode, const char *fullname, uint16_t rrtype, - uint16_t rrclass, uint16_t rdlen, const void *rdata, uint32_t ttl, void *context); - - - EventLoop *m_loop = nullptr; - uint32_t m_if_index{}; ///< The network interface index. - DNSServiceErrorType m_error_code = 0; - ServiceRefPtr m_service_ref{}; - dispatch_queue_t m_queue; + class Impl; + std::unique_ptr m_pimpl; }; } // namespace ag::dns diff --git a/upstream/upstream.cpp b/upstream/upstream.cpp index ea06eb93..a22adb53 100644 --- a/upstream/upstream.cpp +++ b/upstream/upstream.cpp @@ -17,9 +17,7 @@ #include "upstream_doq.h" #include "upstream_dot.h" #include "upstream_plain.h" -#ifdef __APPLE__ #include "upstream_system.h" -#endif //_APPLE namespace ag::dns { @@ -45,7 +43,7 @@ static constexpr std::string_view SCHEME_WITH_SUFFIX[] = { DohUpstream::SCHEME_HTTPS, DohUpstream::SCHEME_H3, DoqUpstream::SCHEME, - "system://", + SystemUpstream::SYSTEM_SCHEME, }; static_assert(std::size(SCHEME_WITH_SUFFIX) + 1 == magic_enum::enum_count(),