Skip to content

Commit 5bb7a54

Browse files
committed
Support generators for the excluded values
1 parent ac68d9d commit 5bb7a54

File tree

3 files changed

+75
-13
lines changed

3 files changed

+75
-13
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Compiled python modules.
22
*.pyc
3+
.venv
34

45
# Setuptools distribution folder.
56
/dist/
@@ -8,4 +9,4 @@
89
/build/
910

1011
# Python egg metadata, regenerated from source files by setuptools.
11-
/*.egg-info
12+
/*.egg-info

filtercascade/__init__.py

+25-10
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,28 @@ def __init__(self, filters, error_rates=[0.02, 0.5], growth_factor=1.1,
115115
self.version = version
116116

117117
def initialize(self, *, include, exclude):
118-
log.debug("{} include and {} exclude".format(
119-
len(include), len(exclude)))
118+
"""
119+
Arg "exclude" is potentially larger than main memory, so it should
120+
be assumed to be passed as a lazy-loading iterator. If it isn't,
121+
that's fine. The "include" arg must fit in memory and should be
122+
assumed to be a set.
123+
"""
124+
try:
125+
iter(exclude)
126+
except TypeError as te:
127+
raise TypeError("exclude is not iterable", te)
128+
try:
129+
len(include)
130+
except TypeError as te:
131+
raise TypeError("include is not a list", te)
132+
133+
include_len = len(include)
134+
120135
depth = 1
121136
maxSequentialGrowthLayers = 3
122137
sequentialGrowthLayers = 0
123138

124-
while len(include) > 0:
139+
while include_len > 0:
125140
starttime = datetime.datetime.utcnow()
126141
er = self.error_rates[-1]
127142
if depth < len(self.error_rates):
@@ -133,24 +148,23 @@ def initialize(self, *, include, exclude):
133148
# min_filter_length large. This is important for the deep layers near the end.
134149
Bloomer.filter_with_characteristics(
135150
max(
136-
int(len(include) * self.growth_factor),
151+
int(include_len * self.growth_factor),
137152
self.min_filter_length), er, depth))
138153
else:
139154
# Filter already created for this layer. Check size and resize if needed.
140155
required_size = Bloomer.calc_size(
141-
self.filters[depth - 1].nHashFuncs, len(include), er)
156+
self.filters[depth - 1].nHashFuncs, include_len, er)
142157
if self.filters[depth - 1].size < required_size:
143158
# Resize filter
144159
self.filters[depth -
145160
1] = Bloomer.filter_with_characteristics(
146-
int(len(include) * self.growth_factor),
161+
int(include_len * self.growth_factor),
147162
er, depth)
148163
log.info("Resized filter at {}-depth layer".format(depth))
149164
filter = self.filters[depth - 1]
150165
log.debug(
151-
"Initializing the {}-depth layer. err={} include={} exclude={} size={} hashes={}"
152-
.format(depth, er, len(include), len(exclude), filter.size,
153-
filter.nHashFuncs))
166+
"Initializing the {}-depth layer. err={} include_len={} size={} hashes={}"
167+
.format(depth, er, include_len, filter.size, filter.nHashFuncs))
154168
# loop over the elements that *should* be there. Add them to the filter.
155169
for elem in include:
156170
filter.add(elem)
@@ -188,7 +202,8 @@ def initialize(self, *, include, exclude):
188202
sequentialGrowthLayers = 0
189203

190204
include, exclude = false_positives, include
191-
if len(include) > 0:
205+
include_len = len(include)
206+
if include_len > 0:
192207
depth = depth + 1
193208
# Filter characteristics loaded from meta file may result in unused layers.
194209
# Remove them.

filtercascade/test.py

+48-2
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,48 @@
1-
import unittest
21
import filtercascade
2+
import hashlib
3+
import unittest
4+
from itertools import islice
5+
36

47
class MockFile(object):
58
def __init__(self):
69
self.data = b""
10+
711
def __len__(self):
812
return len(self.data)
13+
914
def __getitem__(self, idx):
1015
return self.data[idx]
1116

1217
def write(self, s):
1318
self.data = self.data + s
19+
1420
def read(self):
1521
return self.data
22+
1623
def flush(self):
1724
pass
1825

26+
1927
class SimpleToByteClass(object):
2028
def __init__(self, ordinal):
2129
self.o = ordinal
2230
self.method_called = False
31+
2332
def to_bytes(self):
2433
self.method_called = True
2534
return self.o.to_bytes(1, "little")
2635

36+
37+
def predictable_serial_gen(end):
38+
counter = 0
39+
while counter < end:
40+
counter += 1
41+
m = hashlib.sha256()
42+
m.update(counter.to_bytes(4, byteorder="big"))
43+
yield m.hexdigest()
44+
45+
2746
class TestFilterCascade(unittest.TestCase):
2847
def assertBloomerEqual(self, b1, b2):
2948
self.assertEqual(b1.nHashFuncs, b2.nHashFuncs)
@@ -76,6 +95,33 @@ def test_fc_input_formats(self):
7695
self.assertFilterCascadeEqual(f1, f2)
7796
self.assertFilterCascadeEqual(f1, f3)
7897

98+
def test_fc_include_not_list(self):
99+
f = filtercascade.FilterCascade([])
100+
with self.assertRaises(TypeError):
101+
f.initialize(include=predictable_serial_gen(1),
102+
exclude=predictable_serial_gen(1))
103+
104+
def test_fc_exclude_must_be_iterable(self):
105+
f = filtercascade.FilterCascade([])
106+
with self.assertRaises(TypeError):
107+
f.initialize(include=[], exclude=list(1))
108+
109+
def test_fc_iterable(self):
110+
f = filtercascade.FilterCascade([])
111+
112+
serials = predictable_serial_gen(500_000)
113+
# revocations must be disjoint from the main set, so
114+
# slice off a set and re-use the remainder
115+
revocations = set(islice(serials, 3_000))
116+
117+
f.initialize(include=revocations,
118+
exclude=serials)
119+
120+
self.assertEqual(len(f.filters), 3)
121+
self.assertEqual(f.filters[0].size, 81272)
122+
self.assertEqual(f.filters[1].size, 14400)
123+
self.assertEqual(f.filters[2].size, 14400)
124+
79125

80126
if __name__ == '__main__':
81-
unittest.main()
127+
unittest.main()

0 commit comments

Comments
 (0)