diff --git a/datasketch/lsh.py b/datasketch/lsh.py index f77e36e3..dbaa3ec9 100644 --- a/datasketch/lsh.py +++ b/datasketch/lsh.py @@ -226,6 +226,29 @@ def insert( """ self._insert(key, minhash, check_duplication=check_duplication, buffer=False) + def merge( + self, + other: MinHashLSH, + check_overlap: bool = False + ): + """Merge the other MinHashLSH with this one, making this one the union + of both. + + Note: + Only num_perm, number of bands and sizes of each band is checked for equivalency of two MinHashLSH indexes. + Other initialization parameters threshold, weights, storage_config, prepickle and hash_func are not checked. + + Args: + other (MinHashLSH): The other MinHashLSH. + check_overlap (bool): Check if there are any overlapping keys before merging and raise if there are any. + (`default=False`) + + Raises: + ValueError: If the two MinHashLSH have different initialization + parameters, or if `check_overlap` is `True` and there are overlapping keys. + """ + self._merge(other, check_overlap=check_overlap, buffer=False) + def insertion_session(self, buffer_size: int = 50000) -> MinHashLSHInsertionSession: """ Create a context manager for fast insertion into this index. @@ -282,6 +305,38 @@ def _insert( for H, hashtable in zip(Hs, self.hashtables): hashtable.insert(H, key, buffer=buffer) + def __equivalent(self, other:MinHashLSH) -> bool: + """ + Returns: + bool: If the two MinHashLSH have equal num_perm, number of bands, size of each band then two are equivalent. + """ + return ( + type(self) is type(other) and + self.h == other.h and + self.b == other.b and + self.r == other.r + ) + + def _merge( + self, + other: MinHashLSH, + check_overlap: bool = False, + buffer: bool = False + ) -> MinHashLSH: + if self.__equivalent(other): + if check_overlap and set(self.keys).intersection(set(other.keys)): + raise ValueError("The keys are overlapping, duplicate key exists.") + for key in other.keys: + Hs = other.keys.get(key) + self.keys.insert(key, *Hs, buffer=buffer) + for H, hashtable in zip(Hs, self.hashtables): + hashtable.insert(H, key, buffer=buffer) + else: + if type(self) is not type(other): + raise ValueError(f"Cannot merge type MinHashLSH and type {type(other).__name__}.") + raise ValueError( + "Cannot merge MinHashLSH with different initialization parameters.") + def query(self, minhash) -> List[Hashable]: """ Giving the MinHash of the query set, retrieve diff --git a/docs/lsh.rst b/docs/lsh.rst index 9df92e82..dcd0d47a 100644 --- a/docs/lsh.rst +++ b/docs/lsh.rst @@ -77,6 +77,14 @@ plotting code. .. figure:: /_static/lsh_benchmark.png :alt: MinHashLSH Benchmark +You can merge two MinHashLSH indexes to create a union index using the ``merge`` method. This +makes MinHashLSH useful in parallel processing. + +.. code:: python + + # This merges the lsh1 with lsh2. + lsh1.merge(lsh2) + There are other optional parameters that can be used to tune the index. See the documentation of :class:`datasketch.MinHashLSH` for details. diff --git a/examples/lsh_examples.py b/examples/lsh_examples.py index b16edf4f..007e1399 100644 --- a/examples/lsh_examples.py +++ b/examples/lsh_examples.py @@ -37,6 +37,19 @@ def eg1(): result = lsh.query(m1) print("Approximate neighbours with Jaccard similarity > 0.5", result) + # Merge two LSH index + lsh1 = MinHashLSH(threshold=0.5, num_perm=128) + lsh1.insert("m2", m2) + lsh1.insert("m3", m3) + + lsh2 = MinHashLSH(threshold=0.5, num_perm=128) + lsh2.insert("m1", m1) + + lsh1.merge(lsh2) + print("Does m1 exist in the lsh1...", "m1" in lsh1.keys) + # if check_overlap flag is set to True then it will check the overlapping of the keys in the two MinHashLSH + lsh1.merge(lsh2,check_overlap=True) + def eg2(): mg = WeightedMinHashGenerator(10, 5) m1 = mg.minhash(v1) diff --git a/test/test_lsh.py b/test/test_lsh.py index 38f8844f..a2893753 100644 --- a/test/test_lsh.py +++ b/test/test_lsh.py @@ -240,6 +240,117 @@ def test_get_counts(self): for table in counts: self.assertEqual(sum(table.values()), 2) + def test_merge(self): + lsh1 = MinHashLSH(threshold=0.5, num_perm=16) + m1 = MinHash(16) + m1.update("a".encode("utf-8")) + m2 = MinHash(16) + m2.update("b".encode("utf-8")) + lsh1.insert("a",m1) + lsh1.insert("b",m2) + + lsh2 = MinHashLSH(threshold=0.5, num_perm=16) + m3 = MinHash(16) + m3.update("c".encode("utf-8")) + m4 = MinHash(16) + m4.update("d".encode("utf-8")) + lsh2.insert("c",m1) + lsh2.insert("d",m2) + + lsh1.merge(lsh2) + for t in lsh1.hashtables: + self.assertTrue(len(t) >= 1) + items = [] + for H in t: + items.extend(t[H]) + self.assertTrue("c" in items) + self.assertTrue("d" in items) + self.assertTrue("a" in lsh1) + self.assertTrue("b" in lsh1) + self.assertTrue("c" in lsh1) + self.assertTrue("d" in lsh1) + for i, H in enumerate(lsh1.keys["c"]): + self.assertTrue("c" in lsh1.hashtables[i][H]) + + self.assertTrue(lsh1.merge, lsh2) + self.assertRaises(ValueError, lsh1.merge, lsh2, check_overlap=True) + + m5 = MinHash(16) + m5.update("e".encode("utf-8")) + lsh3 = MinHashLSH(threshold=0.5, num_perm=16) + lsh3.insert("a",m5) + + self.assertRaises(ValueError, lsh1.merge, lsh3, check_overlap=True) + + lsh1.merge(lsh3) + + m6 = MinHash(16) + m6.update("e".encode("utf-8")) + lsh4 = MinHashLSH(threshold=0.5, num_perm=16) + lsh4.insert("a",m6) + + lsh1.merge(lsh4, check_overlap=False) + + + def test_merge_redis(self): + with patch('redis.Redis', fake_redis) as mock_redis: + lsh1 = MinHashLSH(threshold=0.5, num_perm=16, storage_config={ + 'type': 'redis', 'redis': {'host': 'localhost', 'port': 6379} + }) + lsh2 = MinHashLSH(threshold=0.5, num_perm=16, storage_config={ + 'type': 'redis', 'redis': {'host': 'localhost', 'port': 6379} + }) + + m1 = MinHash(16) + m1.update("a".encode("utf8")) + m2 = MinHash(16) + m2.update("b".encode("utf8")) + lsh1.insert("a", m1) + lsh1.insert("b", m2) + + m3 = MinHash(16) + m3.update("c".encode("utf8")) + m4 = MinHash(16) + m4.update("d".encode("utf8")) + lsh2.insert("c", m3) + lsh2.insert("d", m4) + + lsh1.merge(lsh2) + for t in lsh1.hashtables: + self.assertTrue(len(t) >= 1) + items = [] + for H in t: + items.extend(t[H]) + self.assertTrue(pickle.dumps("c") in items) + self.assertTrue(pickle.dumps("d") in items) + self.assertTrue("a" in lsh1) + self.assertTrue("b" in lsh1) + self.assertTrue("c" in lsh1) + self.assertTrue("d" in lsh1) + for i, H in enumerate(lsh1.keys[pickle.dumps("c")]): + self.assertTrue(pickle.dumps("c") in lsh1.hashtables[i][H]) + + self.assertTrue(lsh1.merge, lsh2) + self.assertRaises(ValueError, lsh1.merge, lsh2, check_overlap=True) + + m5 = MinHash(16) + m5.update("e".encode("utf-8")) + lsh3 = MinHashLSH(threshold=0.5, num_perm=16, storage_config={ + 'type': 'redis', 'redis': {'host': 'localhost', 'port': 6379} + }) + lsh3.insert("a",m5) + + self.assertRaises(ValueError, lsh1.merge, lsh3, check_overlap=True) + + m6 = MinHash(16) + m6.update("e".encode("utf-8")) + lsh4 = MinHashLSH(threshold=0.5, num_perm=16, storage_config={ + 'type': 'redis', 'redis': {'host': 'localhost', 'port': 6379} + }) + lsh4.insert("a",m6) + + lsh1.merge(lsh4, check_overlap=False) + class TestWeightedMinHashLSH(unittest.TestCase):