Skip to content

Commit 689b8af

Browse files
author
uoosef
committed
fix dot
1 parent 9fbbc72 commit 689b8af

File tree

1 file changed

+22
-51
lines changed

1 file changed

+22
-51
lines changed

main.go

+22-51
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,6 @@ import (
2020
)
2121

2222
var (
23-
// BufferPool for reuse of byte slices
24-
BufferPool = sync.Pool{
25-
New: func() interface{} {
26-
return make([]byte, 4096) // Adjust the size according to your needs
27-
},
28-
}
2923
config *Config
3024
limiter *rate.Limiter
3125
)
@@ -56,6 +50,7 @@ func findValueByKeyContains(m map[string]string, substr string) (string, bool) {
5650
return "", false // Return empty string and false if no key contains the substring
5751
}
5852

53+
// processDNSQuery processes the DNS query and returns a response.
5954
// processDNSQuery processes the DNS query and returns a response.
6055
func processDNSQuery(query []byte) ([]byte, error) {
6156
var msg dns.Msg
@@ -81,32 +76,7 @@ func processDNSQuery(query []byte) ([]byte, error) {
8176
return nil, err
8277
}
8378
defer resp.Body.Close()
84-
85-
// Use a fixed-size buffer from the pool for the initial read
86-
buffer := BufferPool.Get().([]byte)
87-
defer BufferPool.Put(buffer)
88-
89-
// Read the initial chunk of the response
90-
n, err := resp.Body.Read(buffer)
91-
if err != nil && err != io.EOF {
92-
return nil, err
93-
}
94-
95-
// If the buffer was large enough to hold the entire response, return it
96-
if n < len(buffer) {
97-
return buffer[:n], nil
98-
}
99-
100-
// If the response is larger than our buffer, we need to read the rest
101-
// and append to a dynamically-sized buffer
102-
var dynamicBuffer bytes.Buffer
103-
dynamicBuffer.Write(buffer[:n])
104-
_, err = dynamicBuffer.ReadFrom(resp.Body)
105-
if err != nil {
106-
return nil, err
107-
}
108-
109-
return dynamicBuffer.Bytes(), nil
79+
return io.ReadAll(resp.Body)
11080
}
11181

11282
return msg.Pack()
@@ -117,48 +87,49 @@ func handleDoTConnection(conn net.Conn) {
11787
defer conn.Close()
11888

11989
if !limiter.Allow() {
120-
log.Println("limit exceeded")
90+
// Log rate limit exceeded
12191
return
12292
}
12393

124-
// Get a buffer from the pool and put it back after use
125-
buffer := BufferPool.Get().([]byte)
126-
defer BufferPool.Put(buffer)
127-
12894
// Read the first two bytes to determine the length of the DNS message
129-
_, err := io.ReadFull(conn, buffer[:2])
95+
lengthBuf := make([]byte, 2)
96+
_, err := io.ReadFull(conn, lengthBuf)
13097
if err != nil {
13198
log.Println(err)
13299
return
133100
}
134101

135102
// Parse the length of the DNS message
136-
dnsMessageLength := binary.BigEndian.Uint16(buffer[:2])
137-
138-
// Check if the buffer is large enough to hold the DNS message, otherwise get a larger one
139-
if int(dnsMessageLength) > cap(buffer) {
140-
buffer = make([]byte, dnsMessageLength)
141-
defer BufferPool.Put(buffer[:4096]) // Put back the original buffer size to the pool
142-
} else {
143-
buffer = buffer[:dnsMessageLength]
144-
}
103+
dnsMessageLength := binary.BigEndian.Uint16(lengthBuf)
145104

146-
// Read the DNS message
105+
// Allocate a buffer of the size indicated by the length and read the DNS message
106+
buffer := make([]byte, dnsMessageLength)
147107
_, err = io.ReadFull(conn, buffer)
148108
if err != nil {
149109
log.Println(err)
150110
return
151111
}
152112

153113
// Process the DNS query and generate a response
154-
response, err := processDNSQuery(buffer)
114+
response, err := processDNSQuery(buffer) // Process the full message
115+
if err != nil {
116+
log.Println(err)
117+
return
118+
}
119+
120+
// Prepare the response with the length header
121+
responseLength := make([]byte, 2)
122+
binary.BigEndian.PutUint16(responseLength, uint16(len(response)))
123+
124+
// Write the length of the response followed by the response itself
125+
_, err = conn.Write(responseLength)
155126
if err != nil {
156127
log.Println(err)
157128
return
158129
}
159130

160-
// Write response
161-
if _, err := conn.Write(response); err != nil {
131+
_, err = conn.Write(response)
132+
if err != nil {
162133
log.Println(err)
163134
return
164135
}

0 commit comments

Comments
 (0)