-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathdns_framed.h
133 lines (120 loc) · 4.81 KB
/
dns_framed.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#pragma once
#include <mutex>
#include "common/logger.h"
#if defined _WIN32 && !defined __clang__
#pragma optimize( "", off )
#endif
#include "common/parallel.h"
#if defined _WIN32 && !defined __clang__
#pragma optimize( "", on )
#endif
#include "dns/common/event_loop.h"
#include "dns/net/tcp_dns_buffer.h"
#include "dns/upstream/upstream.h"
#include "connection.h"
namespace ag::dns {
class DnsFramedConnection;
using DnsFramedConnectionPtr = std::shared_ptr<DnsFramedConnection>;
/**
* DNS framed connection class.
* It uses format specified in DNS RFC for TCP connections:
* - Header is 2 bytes - length of payload
* - Payload is DNS packet content as sent by UDP
*
* Note that this class is extendable (for DoT) but already inherited from `enable_shared_from_this`.
*/
class DnsFramedConnection : public Connection, public std::enable_shared_from_this<DnsFramedConnection> {
public:
DnsFramedConnection(const ConstructorAccess &access, EventLoop &loop, const ConnectionPoolPtr &pool,
const std::string &address_str);
static DnsFramedConnectionPtr create(
EventLoop &loop, const ConnectionPoolPtr &pool, const std::string &address_str) {
return std::make_shared<DnsFramedConnection>(ConstructorAccess{}, loop, pool, address_str);
}
virtual void connect();
~DnsFramedConnection() override;
/** Logger */
Logger m_log;
/** Connection id */
uint32_t m_id{};
/** Address */
AddressVariant m_address;
/** Input buffer */
TcpDnsBuffer m_input_buffer;
/** Connection handle */
SocketFactory::SocketPtr m_stream;
/** Idle timeout */
std::chrono::milliseconds m_idle_timeout{};
/** Map of requests */
HashMap<int, Request *> m_requests;
/** Next request id */
static std::atomic<uint16_t> m_next_request_id;
coro::Task<Reply> perform_request(Uint8View packet, Millis timeout) override;
virtual void resume_request(uint16_t request_id);
virtual void finish_request(uint16_t request_id, Reply &&reply);
static void on_connected(void *arg);
static void on_read(void *arg, Uint8View data);
static void on_close(void *arg, Error<SocketError> error);
void on_close(Error<DnsError> dns_error);
std::string address_str() {
if (const auto *saddr = std::get_if<SocketAddress>(&m_address); saddr && saddr->valid()) {
return AG_FMT("{}({})", m_address_str, saddr->str());
}
return AG_FMT("{}()", m_address_str);
}
auto ensure_connected(Request *request) {
struct Awaitable {
DnsFramedConnection *self;
Request *req;
bool await_ready() {
if (self->m_state == Connection::Status::CLOSED) {
req->reply = Reply(make_error(DnsError::AE_CONNECTION_CLOSED));
return true;
}
return self->m_state == Connection::Status::ACTIVE;
}
void await_suspend(std::coroutine_handle<> h) {
self->m_requests[req->request_id] = req;
req->caller = h;
if (self->m_state == Connection::Status::IDLE) {
self->connect();
}
}
void await_resume() {}
};
auto wait_timeout = [](EventLoop &loop, std::weak_ptr<DnsFramedConnection> conn, Millis timeout, uint16_t request_id) -> coro::Task<void> {
co_await loop.co_sleep(timeout);
if (DnsFramedConnection *self = conn.lock().get()) {
self->finish_request(request_id, Reply{make_error(DnsError::AE_TIMED_OUT)});
}
};
coro::run_detached(
wait_timeout(m_loop, weak_from_this(), request->timeout, request->request_id));
return Awaitable{.self = this, .req = request};
}
auto wait_response(Request *request) {
struct Awaitable {
DnsFramedConnection *self;
Request *req;
bool await_ready() {
return false;
}
void await_suspend(std::coroutine_handle<> h) {
dbglog(self->m_log, "Waiting response...");
self->m_requests[req->request_id] = req;
req->caller = h;
}
void await_resume() {}
};
auto wait_timeout = [](EventLoop &loop, std::weak_ptr<DnsFramedConnection> conn, Millis timeout, uint16_t request_id) -> coro::Task<void> {
co_await loop.co_sleep(timeout);
if (DnsFramedConnection *self = conn.lock().get()) {
self->finish_request(request_id, Reply{make_error(DnsError::AE_TIMED_OUT)});
}
};
coro::run_detached(
wait_timeout(m_loop, weak_from_this(), request->timeout, request->request_id));
return Awaitable{.self = this, .req = request};
}
};
} // namespace ag::dns