Skip to content

Commit dc37ce8

Browse files
authored
feat: refactor lookup logic (#161)
This commit does the following: * support lookup of more than NS from local storage * add convenience function for converting from our storage format to dns.RR * remove NS lookup logic for fallback servers and pass along query verbatim if not in local storage * remove (now) unneeded state helper function Fixes #107 and #159
1 parent 2fcce96 commit dc37ce8

File tree

2 files changed

+50
-111
lines changed

2 files changed

+50
-111
lines changed

internal/dns/dns.go

+50-73
Original file line numberDiff line numberDiff line change
@@ -68,41 +68,25 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
6868
}
6969
}
7070

71-
// Check for known record for domain nameserver
72-
records, err := state.GetState().LookupNameserverRecord(
71+
// Check for known record from local storage
72+
records, err := state.GetState().LookupRecords(
73+
[]string{dns.Type(r.Question[0].Qtype).String()},
7374
strings.TrimSuffix(r.Question[0].Name, "."),
7475
)
7576
if err != nil {
76-
logger.Errorf("failed to lookup record in state: %s", err)
77+
logger.Errorf("failed to lookup records in state: %s", err)
7778
return
7879
}
7980
if records != nil {
8081
// Assemble response
8182
m.SetReply(r)
82-
for k, v := range records {
83-
k = dns.Fqdn(k)
84-
address := net.ParseIP(v)
85-
// A or AAAA record
86-
if address.To4() != nil {
87-
// IPv4
88-
a := &dns.A{
89-
Hdr: dns.RR_Header{
90-
Name: k,
91-
Rrtype: dns.TypeA,
92-
Class: dns.ClassINET,
93-
Ttl: 999,
94-
},
95-
A: address,
96-
}
97-
m.Answer = append(m.Answer, a)
98-
} else {
99-
// IPv6
100-
aaaa := &dns.AAAA{
101-
Hdr: dns.RR_Header{Name: k, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 999},
102-
AAAA: address,
103-
}
104-
m.Answer = append(m.Answer, aaaa)
83+
for _, tmpRecord := range records {
84+
tmpRR, err := stateRecordToDnsRR(tmpRecord)
85+
if err != nil {
86+
logger.Errorf("failed to convert state record to dns.RR: %s", err)
87+
return
10588
}
89+
m.Answer = append(m.Answer, tmpRR)
10690
}
10791
// Send response
10892
if err := w.WriteMsg(m); err != nil {
@@ -112,6 +96,7 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
11296
return
11397
}
11498

99+
// Check for any NS records for parent domains from local storage
115100
nameserverDomain, nameservers, err := findNameserversForDomain(
116101
r.Question[0].Name,
117102
)
@@ -182,13 +167,52 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
182167
return
183168
}
184169

170+
// Query fallback servers, if configured
171+
if len(cfg.Dns.FallbackServers) > 0 {
172+
// Pick random fallback server
173+
fallbackServer := randomFallbackServer()
174+
// Pass along query to chosen fallback server
175+
resp, err := doQuery(r, fallbackServer, false)
176+
if err != nil {
177+
// Send failure response
178+
m.SetRcode(r, dns.RcodeServerFailure)
179+
if err := w.WriteMsg(m); err != nil {
180+
logger.Errorf("failed to write response: %s", err)
181+
}
182+
logger.Errorf("failed to query domain nameserver: %s", err)
183+
return
184+
} else {
185+
copyResponse(r, resp, m)
186+
// Send response
187+
if err := w.WriteMsg(m); err != nil {
188+
logger.Errorf("failed to write response: %s", err)
189+
}
190+
return
191+
}
192+
}
193+
185194
// Return NXDOMAIN if we have no information about the requested domain or any of its parents
186195
m.SetRcode(r, dns.RcodeNameError)
187196
if err := w.WriteMsg(m); err != nil {
188197
logger.Errorf("failed to write response: %s", err)
189198
}
190199
}
191200

201+
func stateRecordToDnsRR(record state.DomainRecord) (dns.RR, error) {
202+
tmpTtl := ""
203+
if record.Ttl > 0 {
204+
tmpTtl = fmt.Sprintf("%d", record.Ttl)
205+
}
206+
tmpRR := fmt.Sprintf(
207+
"%s %s IN %s %s",
208+
record.Lhs,
209+
tmpTtl,
210+
record.Type,
211+
record.Rhs,
212+
)
213+
return dns.NewRR(tmpRR)
214+
}
215+
192216
func copyResponse(req *dns.Msg, srcResp *dns.Msg, destResp *dns.Msg) {
193217
// Copy relevant data from original request and source response into destination response
194218
destResp.SetRcode(req, srcResp.MsgHdr.Rcode)
@@ -279,8 +303,6 @@ func doQuery(msg *dns.Msg, address string, recursive bool) (*dns.Msg, error) {
279303
func findNameserversForDomain(
280304
recordName string,
281305
) (string, map[string][]net.IP, error) {
282-
cfg := config.GetConfig()
283-
284306
// Split record name into labels and lookup each domain and parent until we get a hit
285307
queryLabels := dns.SplitDomainName(recordName)
286308

@@ -314,51 +336,6 @@ func findNameserversForDomain(
314336
}
315337
}
316338

317-
// Query fallback servers, if configured
318-
if len(cfg.Dns.FallbackServers) > 0 {
319-
// Pick random fallback server
320-
fallbackServer := randomFallbackServer()
321-
for startLabelIdx := 0; startLabelIdx < len(queryLabels); startLabelIdx++ {
322-
lookupDomainName := dns.Fqdn(
323-
strings.Join(queryLabels[startLabelIdx:], "."),
324-
)
325-
m := createQuery(lookupDomainName, dns.TypeNS)
326-
in, err := doQuery(m, fallbackServer, false)
327-
if err != nil {
328-
return "", nil, err
329-
}
330-
if in.Rcode == dns.RcodeSuccess {
331-
if len(in.Answer) > 0 {
332-
ret := map[string][]net.IP{}
333-
for _, answer := range in.Answer {
334-
switch v := answer.(type) {
335-
case *dns.NS:
336-
ns := v.Ns
337-
ret[ns] = make([]net.IP, 0)
338-
// Query for matching A/AAAA records
339-
m2 := createQuery(ns, dns.TypeA)
340-
in2, err := doQuery(m2, fallbackServer, false)
341-
if err != nil {
342-
return "", nil, err
343-
}
344-
for _, answer2 := range in2.Answer {
345-
switch v := answer2.(type) {
346-
case *dns.A:
347-
ret[ns] = append(ret[ns], v.A)
348-
case *dns.AAAA:
349-
ret[ns] = append(ret[ns], v.AAAA)
350-
}
351-
}
352-
}
353-
}
354-
if len(ret) > 0 {
355-
return lookupDomainName, ret, nil
356-
}
357-
}
358-
}
359-
}
360-
}
361-
362339
return "", nil, nil
363340
}
364341

internal/state/state.go

-38
Original file line numberDiff line numberDiff line change
@@ -259,44 +259,6 @@ func (s *State) LookupRecords(recordTypes []string, recordName string) ([]Domain
259259
return ret, nil
260260
}
261261

262-
// LookupNameserverRecord searches the domain nameserver entries for one matching the requested record
263-
func (s *State) LookupNameserverRecord(
264-
recordName string,
265-
) (map[string]string, error) {
266-
ret := map[string]string{}
267-
err := s.db.View(func(txn *badger.Txn) error {
268-
opts := badger.DefaultIteratorOptions
269-
// Makes key scans faster
270-
opts.PrefetchValues = false
271-
it := txn.NewIterator(opts)
272-
defer it.Close()
273-
for it.Rewind(); it.Valid(); it.Next() {
274-
item := it.Item()
275-
k := item.Key()
276-
if strings.HasSuffix(
277-
string(k),
278-
fmt.Sprintf("_nameserver_%s", recordName),
279-
) {
280-
err := item.Value(func(v []byte) error {
281-
ret[recordName] = string(v)
282-
return nil
283-
})
284-
if err != nil {
285-
return err
286-
}
287-
}
288-
}
289-
return nil
290-
})
291-
if err != nil {
292-
return nil, err
293-
}
294-
if len(ret) == 0 {
295-
return nil, nil
296-
}
297-
return ret, nil
298-
}
299-
300262
func GetState() *State {
301263
return globalState
302264
}

0 commit comments

Comments
 (0)