-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathwriter.go
311 lines (279 loc) · 8.31 KB
/
writer.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
package mmdbmeld
import (
"fmt"
"net"
"net/netip"
"os"
"slices"
"strings"
"time"
"github.com/maxmind/mmdbwriter"
"github.com/maxmind/mmdbwriter/inserter"
"github.com/maxmind/mmdbwriter/mmdbtype"
"go4.org/netipx"
)
const reportSlotSize = 100_000
// WriteMMDB writes a mmdb file using given config and sources.
// Supply an updates channel to receive update messages about the progress.
func WriteMMDB(dbConfig DatabaseConfig, sources []Source, updates chan string) error {
// Init writer.
opts := mmdbwriter.Options{
DatabaseType: dbConfig.Name,
IncludeReservedNetworks: true,
DisableIPv4Aliasing: true,
IPVersion: dbConfig.MMDB.IPVersion,
RecordSize: dbConfig.MMDB.RecordSize,
}
writer, err := mmdbwriter.New(opts)
if err != nil {
return fmt.Errorf("failed to create mmdb writer for %s: %w", dbConfig.Name, err)
}
sendUpdate(updates, fmt.Sprintf(
"database options set: IPVersion=%d RecordSize=%d (IncludeReservedNetworks=%v DisableIPv4Aliasing=%v)",
opts.IPVersion,
opts.RecordSize,
opts.IncludeReservedNetworks,
opts.DisableIPv4Aliasing,
))
typeKeys := make([]string, 0, len(dbConfig.Types))
for k, v := range dbConfig.Types {
if v != "-" && v != "" {
typeKeys = append(typeKeys, k)
}
}
slices.Sort[[]string, string](typeKeys)
sendUpdate(updates, fmt.Sprintf(
"database types: %s",
strings.Join(typeKeys, ", "),
))
sendUpdate(updates, fmt.Sprintf(
"optimizations set: FloatDecimals=%d ForceIPVersion=%v MaxPrefix=%d",
dbConfig.Optimize.FloatDecimals,
dbConfig.Optimize.ForceIPVersionEnabled(),
dbConfig.Optimize.MaxPrefix,
))
sendUpdate(updates, fmt.Sprintf(
"merge config: AlwaysReplace=%v MergeArrays=%v ConditionalResets=%+v",
dbConfig.Merge.AlwaysReplace,
dbConfig.Merge.MergeArrays,
dbConfig.Merge.ConditionalResets,
))
// Close update channel when finished.
if updates != nil {
defer close(updates)
}
// Open output file to detect errors before processing.
outputFile, err := os.Create(dbConfig.Output)
if err != nil {
return fmt.Errorf("failed to open output file for %s: %w", dbConfig.Name, err)
}
// Process sources.
var (
totalInserts int
totalStartTime = time.Now()
slotStartTime = time.Now()
)
for _, source := range sources {
var inserted int
sendUpdate(updates, fmt.Sprintf("---\nprocessing %s...", source.Name()))
for {
entry, err := source.NextEntry()
if err != nil {
sendUpdate(updates, fmt.Sprintf("failed to parse entry: %s", err.Error()))
continue
}
if entry == nil {
break
}
mmdbMap, err := entry.ToMMDBMap(dbConfig.Optimize)
if err != nil {
sendUpdate(updates, fmt.Sprintf("failed to convert %+v to mmdb map: %s", entry, err.Error()))
continue
}
if entry.Net != nil {
// Handle Network/Prefix Format.
// Ignore entry if the IP version is forced and it does not match the mmdb DB.
if dbConfig.Optimize.ForceIPVersionEnabled() && ipVersion(entry.Net.IP) != opts.IPVersion {
continue
}
// Ignore entry if prefix is greater than the max prefix.
if dbConfig.Optimize.MaxPrefix > 0 {
prefixBits, _ := entry.Net.Mask.Size()
if prefixBits > dbConfig.Optimize.MaxPrefix {
continue
}
}
err = writer.InsertFunc(entry.Net, Inserter(mmdbMap, dbConfig.Merge))
if err != nil {
sendUpdate(updates, fmt.Sprintf("failed to insert %+v: %s", entry, err.Error()))
continue
}
} else {
// Handle From-To IP Format.
// Ignore entry if the IP version is forced and it does not match the mmdb DB.
if dbConfig.Optimize.ForceIPVersionEnabled() && ipVersion(entry.From) != opts.IPVersion {
continue
}
start, ok1 := netip.AddrFromSlice(entry.From)
end, ok2 := netip.AddrFromSlice(entry.To)
if !ok1 || !ok2 {
sendUpdate(updates, fmt.Sprintf("range with invalid IPs: %s - %s", entry.From, entry.To))
continue
}
r := netipx.IPRangeFrom(start, end)
if !r.IsValid() {
sendUpdate(updates, fmt.Sprintf("range is invalid: %s - %s", entry.From, entry.To))
continue
}
subnets := r.Prefixes()
for _, subnet := range subnets {
// Ignore entry if prefix is greater than the max prefix.
if dbConfig.Optimize.MaxPrefix > 0 && subnet.Bits() > dbConfig.Optimize.MaxPrefix {
continue
}
err = writer.InsertFunc(netipx.PrefixIPNet(subnet), Inserter(mmdbMap, dbConfig.Merge))
if err != nil {
sendUpdate(updates, fmt.Sprintf("failed to insert %+v: %s", entry, err.Error()))
continue
}
}
}
inserted++
totalInserts++
if inserted%reportSlotSize == 0 {
sendUpdate(updates, fmt.Sprintf(
"inserted %d entries - batch in %s (%s/op)",
inserted,
time.Since(slotStartTime).Round(time.Millisecond),
(time.Since(slotStartTime)/reportSlotSize).Round(time.Microsecond),
))
slotStartTime = time.Now()
}
}
if source.Err() != nil {
return fmt.Errorf("source %s failed: %w", source.Name(), source.Err())
}
sendUpdate(updates, fmt.Sprintf(
"inserted %d entries - batch in %s (%s/op)",
inserted,
time.Since(slotStartTime).Round(time.Millisecond),
(time.Since(slotStartTime)/reportSlotSize).Round(time.Microsecond),
))
}
// Write final db to file.
_, err = writer.WriteTo(outputFile)
if err != nil {
return fmt.Errorf("faild to write %s to output file: %w", dbConfig.Name, err)
}
// Send final upate.
var fileSize int64
stat, err := os.Stat(dbConfig.Output)
if err == nil {
fileSize = stat.Size()
}
sendUpdate(updates, fmt.Sprintf(
"---\n%s finished: inserted %d entries in %s, resulting in %.2f MB written to %s",
dbConfig.Name,
totalInserts,
time.Since(totalStartTime).Round(time.Second),
float64(fileSize)/1000000,
dbConfig.Output,
))
return nil
}
// Inserter is based on TopLevelMergeWith, but does addition processing based on config.
func Inserter(newValue mmdbtype.DataType, cfg MergeConfig) inserter.Func {
return func(existingValue mmdbtype.DataType) (mmdbtype.DataType, error) {
// Always fully replace.
if cfg.AlwaysReplace {
return newValue, nil
}
// Check if both values are maps before we start merging.
newMap, ok := newValue.(mmdbtype.Map)
if !ok {
return nil, fmt.Errorf(
"the new value is a %T, not a Map; ConditionalResetTopLevelMerge only works if both values are Map values",
newValue,
)
}
if existingValue == nil {
return newValue, nil
}
existingMap, ok := existingValue.(mmdbtype.Map)
if !ok {
return nil, fmt.Errorf(
"the existing value is a %T, not a Map; ConditionalResetTopLevelMerge only works if both values are Map values",
existingValue,
)
}
// Start merging.
// First, do a normal top-level merge.
returnMap := existingMap.Copy().(mmdbtype.Map) //nolint:forcetypeassert
for k, v := range newMap {
newValue := v.Copy()
// Check if we should merge an array type.
if cfg.MergeArrays {
if newArray, ok := newValue.(mmdbtype.Slice); ok {
if returnArray, ok := returnMap[k].(mmdbtype.Slice); ok {
returnMap[k] = append(returnArray, newArray...)
continue
}
}
}
// Simply assign new value if no special processing was needed.
returnMap[k] = newValue
}
// Then check which fields changed.
for _, c := range cfg.ConditionalResets {
var changed bool
for _, key := range c.IfChanged {
// Get existing value.
existingSubVal, ok := existingMap[mmdbtype.String(key)]
if !ok {
// There is no existing value of that key, so there is no change possible.
continue
}
// Get new value
newSubVal, ok := newMap[mmdbtype.String(key)]
if !ok {
// Value of that key is not being set, so there is no change possible.
continue
}
// Compare values if both are set.
if !newSubVal.Equal(existingSubVal) {
changed = true
break
}
}
// If any field changed, reset fields.
if changed {
for _, key := range c.Reset {
resetVal, ok := newMap[mmdbtype.String(key)]
if ok {
// Reset with new value.
returnMap[mmdbtype.String(key)] = resetVal
} else {
// Remove if no new value is present.
delete(returnMap, mmdbtype.String(key))
}
}
}
}
return returnMap, nil
}
}
func sendUpdate(to chan string, msg string) {
if to == nil {
return
}
select {
case to <- msg:
default:
}
}
func ipVersion(ip net.IP) int {
if ip.To4() != nil {
return 4
}
return 6
}