Skip to content

Commit f3851e0

Browse files
committedMar 13, 2020
Add SHA256 and Salt support
1 parent 790443e commit f3851e0

File tree

3 files changed

+145
-24
lines changed

3 files changed

+145
-24
lines changed
 

‎filtercascade/__init__.py

+80-18
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,18 @@
77
import logging
88
import math
99
import mmh3
10+
import hashlib
11+
from deprecated import deprecated
1012
from struct import pack, unpack, calcsize
11-
from enum import IntEnum
13+
from enum import IntEnum, unique
1214

1315
log = logging.getLogger(__name__)
1416

1517

18+
@unique
1619
class HashAlgorithm(IntEnum):
1720
MURMUR3 = 1
21+
SHA256 = 2
1822

1923

2024
# A simple-as-possible bloom filter implementation making use of version 3 of the 32-bit murmur
@@ -23,15 +27,23 @@ class HashAlgorithm(IntEnum):
2327
class Bloomer:
2428
LAYER_FMT = b"<BIIB"
2529

26-
def __init__(self, *, size, nHashFuncs, level, hashAlg=HashAlgorithm.MURMUR3):
30+
def __init__(
31+
self, *, size, nHashFuncs, level, hashAlg=HashAlgorithm.MURMUR3, salt=None
32+
):
2733
self.nHashFuncs = nHashFuncs
2834
self.size = size
2935
self.level = level
30-
self.hashAlg = hashAlg
36+
self.hashAlg = HashAlgorithm(hashAlg)
37+
self.salt = salt
3138

3239
self.bitarray = bitarray.bitarray(self.size, endian="little")
3340
self.bitarray.setall(False)
3441

42+
if self.salt and not isinstance(self.salt, bytes):
43+
raise ValueError("salts must be passed as bytes")
44+
if self.salt and self.hashAlg == HashAlgorithm.MURMUR3:
45+
raise ValueError("salts not permitted for MurmurHash3")
46+
3547
def hash(self, seed, key):
3648
if not isinstance(key, bytes):
3749
to_bytes_op = getattr(key, "to_bytes", None)
@@ -42,12 +54,27 @@ def hash(self, seed, key):
4254
else:
4355
key = str(key).encode("utf-8")
4456

45-
if self.hashAlg != HashAlgorithm.MURMUR3:
46-
raise Exception(f"Unknown hash algorithm: {self.hashAlg}")
47-
4857
hash_seed = ((seed << 16) + self.level) & 0xFFFFFFFF
49-
h = (mmh3.hash(key, hash_seed) & 0xFFFFFFFF) % self.size
50-
return h
58+
59+
if self.hashAlg == HashAlgorithm.MURMUR3:
60+
if self.salt:
61+
raise ValueError("salts not permitted for MurmurHash3")
62+
h = (mmh3.hash(key, hash_seed) & 0xFFFFFFFF) % self.size
63+
return h
64+
65+
if self.hashAlg == HashAlgorithm.SHA256:
66+
m = hashlib.sha256()
67+
if self.salt:
68+
m.update(salt)
69+
m.update(hash_seed)
70+
m.update(key)
71+
h = (
72+
int.from_bytes(m.digest()[:4], byteorder="little", signed=False)
73+
% self.size
74+
)
75+
return h
76+
77+
raise Exception(f"Unknown hash algorithm: {self.hashAlg}")
5178

5279
def add(self, key):
5380
for i in range(self.nHashFuncs):
@@ -76,10 +103,20 @@ def tofile(self, f):
76103
self.bitarray.tofile(f)
77104

78105
@classmethod
79-
def filter_with_characteristics(cls, elements, falsePositiveRate, level=1):
106+
def filter_with_characteristics(
107+
cls,
108+
*,
109+
elements,
110+
falsePositiveRate,
111+
hashAlg=HashAlgorithm.MURMUR3,
112+
salt=None,
113+
level=1,
114+
):
80115
nHashFuncs = Bloomer.calc_n_hashes(falsePositiveRate)
81116
size = Bloomer.calc_size(nHashFuncs, elements, falsePositiveRate)
82-
return Bloomer(size=size, nHashFuncs=nHashFuncs, level=level)
117+
return Bloomer(
118+
size=size, nHashFuncs=nHashFuncs, level=level, hashAlg=hashAlg, salt=salt
119+
)
83120

84121
@classmethod
85122
def calc_n_hashes(cls, falsePositiveRate):
@@ -91,7 +128,7 @@ def calc_size(cls, nHashFuncs, elements, falsePositiveRate):
91128
return math.ceil(1.44 * elements * math.log(1 / falsePositiveRate, 2))
92129

93130
@classmethod
94-
def from_buf(cls, buf):
131+
def from_buf(cls, buf, salt=None):
95132
log.debug(len(buf))
96133
hashAlgInt, size, nHashFuncs, level = unpack(Bloomer.LAYER_FMT, buf[0:10])
97134
byte_count = math.ceil(size / 8)
@@ -102,6 +139,7 @@ def from_buf(cls, buf):
102139
nHashFuncs=nHashFuncs,
103140
level=level,
104141
hashAlg=HashAlgorithm(hashAlgInt),
142+
salt=salt,
105143
)
106144
bloomer.size = size
107145
log.debug(
@@ -123,12 +161,21 @@ def __init__(
123161
growth_factor=1.1,
124162
min_filter_length=10000,
125163
version=1,
164+
hashAlg=HashAlgorithm.MURMUR3,
165+
salt=None,
126166
):
127167
self.filters = filters
128168
self.error_rates = error_rates
129169
self.growth_factor = growth_factor
130170
self.min_filter_length = min_filter_length
131171
self.version = version
172+
self.hashAlg = hashAlg
173+
self.salt = salt
174+
175+
if self.salt and not isinstance(self.salt, bytes):
176+
raise ValueError("salts must be passed as byteas")
177+
if self.salt and self.hashAlg == HashAlgorithm.MURMUR3:
178+
raise ValueError("salts not permitted for MurmurHash3")
132179

133180
def initialize(self, *, include, exclude):
134181
"""
@@ -163,12 +210,12 @@ def initialize(self, *, include, exclude):
163210
# For growth-stability reasons, we force all layers to be at least
164211
# min_filter_length large. This is important for the deep layers near the end.
165212
Bloomer.filter_with_characteristics(
166-
max(
213+
elements=max(
167214
int(include_len * self.growth_factor),
168215
self.min_filter_length,
169216
),
170-
er,
171-
depth,
217+
falsePositiveRate=er,
218+
level=depth,
172219
)
173220
)
174221
else:
@@ -254,10 +301,17 @@ def __contains__(self, elem):
254301
else:
255302
return False != even
256303

304+
@deprecated(
305+
version="0.2.3",
306+
reason="Use the verify function which has the same semantics as initialize",
307+
)
257308
def check(self, *, entries, exclusions):
258-
for entry in entries:
309+
self.verify(include=entries, exclude=exclusions)
310+
311+
def verify(self, *, include, exclude):
312+
for entry in include:
259313
assert entry in self, "oops! false negative!"
260-
for entry in exclusions:
314+
for entry in exclude:
261315
assert not entry in self, "oops! false positive!"
262316

263317
def bitCount(self):
@@ -315,8 +369,16 @@ def loadDiffMeta(cls, f):
315369
return FilterCascade(filters)
316370

317371
@classmethod
318-
def cascade_with_characteristics(cls, capacity, error_rates, layer=0):
372+
def cascade_with_characteristics(
373+
cls, *, capacity, error_rates, hashAlg=HashAlgorithm.MURMUR3, salt=None, layer=0
374+
):
319375
return FilterCascade(
320-
[Bloomer.filter_with_characteristics(capacity, error_rates[0])],
376+
[
377+
Bloomer.filter_with_characteristics(
378+
elements=capacity, falsePositiveRate=error_rates[0]
379+
)
380+
],
321381
error_rates=error_rates,
382+
hashAlg=hashAlg,
383+
salt=salt,
322384
)

‎filtercascade/test.py

+64-6
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,14 @@ def predictable_serial_gen(end):
4343
yield m.hexdigest()
4444

4545

46+
def get_serial_sets(*, num_revoked, num_valid):
47+
valid = predictable_serial_gen(num_revoked + num_valid)
48+
# revocations must be disjoint from the main set, so
49+
# slice off a set and re-use the remainder
50+
revoked = set(islice(valid, num_revoked))
51+
return (valid, revoked)
52+
53+
4654
class TestFilterCascade(unittest.TestCase):
4755
def assertBloomerEqual(self, b1, b2):
4856
self.assertEqual(b1.nHashFuncs, b2.nHashFuncs)
@@ -110,18 +118,68 @@ def test_fc_exclude_must_be_iterable(self):
110118
def test_fc_iterable(self):
111119
f = filtercascade.FilterCascade([])
112120

113-
serials = predictable_serial_gen(500_000)
114-
# revocations must be disjoint from the main set, so
115-
# slice off a set and re-use the remainder
116-
revocations = set(islice(serials, 3_000))
117-
118-
f.initialize(include=revocations, exclude=serials)
121+
valid, revoked = get_serial_sets(num_valid=500_000, num_revoked=3_000)
122+
f.initialize(include=revoked, exclude=valid)
119123

120124
self.assertEqual(len(f.filters), 3)
121125
self.assertEqual(f.filters[0].size, 81272)
122126
self.assertEqual(f.filters[1].size, 14400)
123127
self.assertEqual(f.filters[2].size, 14400)
124128

125129

130+
class TestFilterCascadeSalts(unittest.TestCase):
131+
def test_non_byte_salt(self):
132+
with self.assertRaises(ValueError):
133+
filtercascade.FilterCascade(
134+
[], hashAlg=filtercascade.HashAlgorithm.SHA256, salt=64
135+
)
136+
137+
def test_murmur_with_salt(self):
138+
with self.assertRaises(ValueError):
139+
filtercascade.FilterCascade(
140+
[], hashAlg=filtercascade.HashAlgorithm.MURMUR3, salt=b"happiness"
141+
)
142+
143+
def test_sha256_with_salt(self):
144+
fc = filtercascade.FilterCascade(
145+
[], hashAlg=filtercascade.HashAlgorithm.SHA256, salt=b"happiness"
146+
)
147+
148+
valid, revoked = get_serial_sets(num_valid=10, num_revoked=1)
149+
fc.initialize(include=revoked, exclude=valid)
150+
151+
self.assertEqual(len(fc.filters), 1)
152+
self.assertEqual(fc.bitCount(), 81272)
153+
154+
f = MockFile()
155+
fc.tofile(f)
156+
self.assertEqual(len(f.data), 10171)
157+
158+
159+
class TestFilterCascadeAlgorithms(unittest.TestCase):
160+
def verify_minimum_sets(self, *, hashAlg):
161+
fc = filtercascade.FilterCascade([], hashAlg=hashAlg)
162+
163+
valid, revoked = get_serial_sets(num_valid=10, num_revoked=1)
164+
fc.initialize(include=revoked, exclude=valid)
165+
166+
self.assertEqual(len(fc.filters), 1)
167+
self.assertEqual(fc.bitCount(), 81272)
168+
169+
f = MockFile()
170+
fc.tofile(f)
171+
self.assertEqual(len(f.data), 10171)
172+
173+
fc2 = filtercascade.FilterCascade.from_buf(f)
174+
valid2, revoked2 = get_serial_sets(num_valid=10, num_revoked=1)
175+
fc2.verify(include=revoked2, exclude=valid2)
176+
177+
def test_murmurhash3(self):
178+
self.verify_minimum_sets(hashAlg=filtercascade.HashAlgorithm.MURMUR3)
179+
180+
def test_sha256(self):
181+
self.verify_minimum_sets(hashAlg=filtercascade.HashAlgorithm.SHA256)
182+
183+
126184
if __name__ == "__main__":
127185
unittest.main()

‎requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
bitarray >= 0.9.2
2+
Deprecated >= 1.2
23
mmh3 >= 2.5.1

0 commit comments

Comments
 (0)
Please sign in to comment.