@@ -20,6 +20,12 @@ import (
20
20
)
21
21
22
22
var (
23
+ // BufferPool for reuse of byte slices
24
+ BufferPool = sync.Pool {
25
+ New : func () interface {} {
26
+ return make ([]byte , 4096 ) // Adjust the size according to your needs
27
+ },
28
+ }
23
29
config * Config
24
30
limiter * rate.Limiter
25
31
)
@@ -50,7 +56,6 @@ func findValueByKeyContains(m map[string]string, substr string) (string, bool) {
50
56
return "" , false // Return empty string and false if no key contains the substring
51
57
}
52
58
53
- // processDNSQuery processes the DNS query and returns a response.
54
59
// processDNSQuery processes the DNS query and returns a response.
55
60
func processDNSQuery (query []byte ) ([]byte , error ) {
56
61
var msg dns.Msg
@@ -76,7 +81,32 @@ func processDNSQuery(query []byte) ([]byte, error) {
76
81
return nil , err
77
82
}
78
83
defer resp .Body .Close ()
79
- return io .ReadAll (resp .Body )
84
+
85
+ // Use a fixed-size buffer from the pool for the initial read
86
+ buffer := BufferPool .Get ().([]byte )
87
+ defer BufferPool .Put (buffer )
88
+
89
+ // Read the initial chunk of the response
90
+ n , err := resp .Body .Read (buffer )
91
+ if err != nil && err != io .EOF {
92
+ return nil , err
93
+ }
94
+
95
+ // If the buffer was large enough to hold the entire response, return it
96
+ if n < len (buffer ) {
97
+ return buffer [:n ], nil
98
+ }
99
+
100
+ // If the response is larger than our buffer, we need to read the rest
101
+ // and append to a dynamically-sized buffer
102
+ var dynamicBuffer bytes.Buffer
103
+ dynamicBuffer .Write (buffer [:n ])
104
+ _ , err = dynamicBuffer .ReadFrom (resp .Body )
105
+ if err != nil {
106
+ return nil , err
107
+ }
108
+
109
+ return dynamicBuffer .Bytes (), nil
80
110
}
81
111
82
112
return msg .Pack ()
@@ -87,31 +117,43 @@ func handleDoTConnection(conn net.Conn) {
87
117
defer conn .Close ()
88
118
89
119
if ! limiter .Allow () {
90
- // Log rate limit exceeded
120
+ log . Println ( " limit exceeded" )
91
121
return
92
122
}
93
123
124
+ // Use a fixed-size buffer from the pool for the initial read
125
+ poolBuffer := BufferPool .Get ().([]byte )
126
+ defer BufferPool .Put (poolBuffer )
127
+
94
128
// Read the first two bytes to determine the length of the DNS message
95
- lengthBuf := make ([]byte , 2 )
96
- _ , err := io .ReadFull (conn , lengthBuf )
129
+ _ , err := io .ReadFull (conn , poolBuffer [:2 ])
97
130
if err != nil {
98
131
log .Println (err )
99
132
return
100
133
}
101
134
102
135
// Parse the length of the DNS message
103
- dnsMessageLength := binary .BigEndian .Uint16 (lengthBuf )
136
+ dnsMessageLength := binary .BigEndian .Uint16 (poolBuffer [:2 ])
137
+
138
+ // Prepare a buffer to read the full DNS message
139
+ var buffer []byte
140
+ if int (dnsMessageLength ) > len (poolBuffer ) {
141
+ // If pool buffer is too small, allocate a new buffer
142
+ buffer = make ([]byte , dnsMessageLength )
143
+ } else {
144
+ // Use the pool buffer directly
145
+ buffer = poolBuffer [:dnsMessageLength ]
146
+ }
104
147
105
- // Allocate a buffer of the size indicated by the length and read the DNS message
106
- buffer := make ([]byte , dnsMessageLength )
148
+ // Read the DNS message
107
149
_ , err = io .ReadFull (conn , buffer )
108
150
if err != nil {
109
151
log .Println (err )
110
152
return
111
153
}
112
154
113
155
// Process the DNS query and generate a response
114
- response , err := processDNSQuery (buffer ) // Process the full message
156
+ response , err := processDNSQuery (buffer )
115
157
if err != nil {
116
158
log .Println (err )
117
159
return
0 commit comments