Skip to content

Commit ef7164e

Browse files
author
uoosef
committed
memory improvements
1 parent 689b8af commit ef7164e

File tree

1 file changed

+51
-9
lines changed

1 file changed

+51
-9
lines changed

main.go

+51-9
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ import (
2020
)
2121

2222
var (
23+
// BufferPool for reuse of byte slices
24+
BufferPool = sync.Pool{
25+
New: func() interface{} {
26+
return make([]byte, 4096) // Adjust the size according to your needs
27+
},
28+
}
2329
config *Config
2430
limiter *rate.Limiter
2531
)
@@ -50,7 +56,6 @@ func findValueByKeyContains(m map[string]string, substr string) (string, bool) {
5056
return "", false // Return empty string and false if no key contains the substring
5157
}
5258

53-
// processDNSQuery processes the DNS query and returns a response.
5459
// processDNSQuery processes the DNS query and returns a response.
5560
func processDNSQuery(query []byte) ([]byte, error) {
5661
var msg dns.Msg
@@ -76,7 +81,32 @@ func processDNSQuery(query []byte) ([]byte, error) {
7681
return nil, err
7782
}
7883
defer resp.Body.Close()
79-
return io.ReadAll(resp.Body)
84+
85+
// Use a fixed-size buffer from the pool for the initial read
86+
buffer := BufferPool.Get().([]byte)
87+
defer BufferPool.Put(buffer)
88+
89+
// Read the initial chunk of the response
90+
n, err := resp.Body.Read(buffer)
91+
if err != nil && err != io.EOF {
92+
return nil, err
93+
}
94+
95+
// If the buffer was large enough to hold the entire response, return it
96+
if n < len(buffer) {
97+
return buffer[:n], nil
98+
}
99+
100+
// If the response is larger than our buffer, we need to read the rest
101+
// and append to a dynamically-sized buffer
102+
var dynamicBuffer bytes.Buffer
103+
dynamicBuffer.Write(buffer[:n])
104+
_, err = dynamicBuffer.ReadFrom(resp.Body)
105+
if err != nil {
106+
return nil, err
107+
}
108+
109+
return dynamicBuffer.Bytes(), nil
80110
}
81111

82112
return msg.Pack()
@@ -87,31 +117,43 @@ func handleDoTConnection(conn net.Conn) {
87117
defer conn.Close()
88118

89119
if !limiter.Allow() {
90-
// Log rate limit exceeded
120+
log.Println("limit exceeded")
91121
return
92122
}
93123

124+
// Use a fixed-size buffer from the pool for the initial read
125+
poolBuffer := BufferPool.Get().([]byte)
126+
defer BufferPool.Put(poolBuffer)
127+
94128
// Read the first two bytes to determine the length of the DNS message
95-
lengthBuf := make([]byte, 2)
96-
_, err := io.ReadFull(conn, lengthBuf)
129+
_, err := io.ReadFull(conn, poolBuffer[:2])
97130
if err != nil {
98131
log.Println(err)
99132
return
100133
}
101134

102135
// Parse the length of the DNS message
103-
dnsMessageLength := binary.BigEndian.Uint16(lengthBuf)
136+
dnsMessageLength := binary.BigEndian.Uint16(poolBuffer[:2])
137+
138+
// Prepare a buffer to read the full DNS message
139+
var buffer []byte
140+
if int(dnsMessageLength) > len(poolBuffer) {
141+
// If pool buffer is too small, allocate a new buffer
142+
buffer = make([]byte, dnsMessageLength)
143+
} else {
144+
// Use the pool buffer directly
145+
buffer = poolBuffer[:dnsMessageLength]
146+
}
104147

105-
// Allocate a buffer of the size indicated by the length and read the DNS message
106-
buffer := make([]byte, dnsMessageLength)
148+
// Read the DNS message
107149
_, err = io.ReadFull(conn, buffer)
108150
if err != nil {
109151
log.Println(err)
110152
return
111153
}
112154

113155
// Process the DNS query and generate a response
114-
response, err := processDNSQuery(buffer) // Process the full message
156+
response, err := processDNSQuery(buffer)
115157
if err != nil {
116158
log.Println(err)
117159
return

0 commit comments

Comments
 (0)