diff --git a/toolz/curried/__init__.py b/toolz/curried/__init__.py index 356eddbd..f68ff848 100644 --- a/toolz/curried/__init__.py +++ b/toolz/curried/__init__.py @@ -53,7 +53,7 @@ thread_first, thread_last, ) -from .exceptions import merge, merge_with +from .exceptions import merge, merge_with, intersect accumulate = toolz.curry(toolz.accumulate) assoc = toolz.curry(toolz.assoc) diff --git a/toolz/curried/exceptions.py b/toolz/curried/exceptions.py index 75a52bbb..641511a6 100644 --- a/toolz/curried/exceptions.py +++ b/toolz/curried/exceptions.py @@ -1,7 +1,7 @@ import toolz -__all__ = ['merge_with', 'merge'] +__all__ = ['merge_with', 'merge', 'intersect'] @toolz.curry @@ -14,5 +14,10 @@ def merge(d, *dicts, **kwargs): return toolz.merge(d, *dicts, **kwargs) +@toolz.curry +def intersect(d, *dicts, **kwargs): + return toolz.intersect(d, *dicts, **kwargs) + + merge_with.__doc__ = toolz.merge_with.__doc__ merge.__doc__ = toolz.merge.__doc__ diff --git a/toolz/dicttoolz.py b/toolz/dicttoolz.py index 35048d32..65f22ea9 100644 --- a/toolz/dicttoolz.py +++ b/toolz/dicttoolz.py @@ -1,10 +1,13 @@ import operator -from functools import reduce +from functools import reduce, partial from collections.abc import Mapping +from .itertoolz import get + __all__ = ('merge', 'merge_with', 'valmap', 'keymap', 'itemmap', 'valfilter', 'keyfilter', 'itemfilter', - 'assoc', 'dissoc', 'assoc_in', 'update_in', 'get_in') + 'assoc', 'dissoc', 'assoc_in', 'update_in', 'get_in', + 'intersect') def _get_factory(f, kwargs): @@ -335,3 +338,28 @@ def get_in(keys, coll, default=None, no_default=False): if no_default: raise return default + + +def intersect(*dicts, **kwargs): + """Compute the intersection of dictionaries based on their keys. + + The return is a mapping where the keys are common to all dictionaries + and the values are a tuple of the values from each dictionary in + the *given* order. + + >>> intersect({0: 1, 1: 2, 2: 3, 3: 4}, {0:2, 2:10}, {0: 20}) + {0: (1, 2, 20)} + """ + if len(dicts) == 1 and not isinstance(dicts[0], Mapping): + dicts = dicts[0] + factory = _get_factory(merge, kwargs) + + dict_keys = map(operator.methodcaller('keys'), sorted(dicts, key=len)) + intersected_keys = list(reduce(operator.and_, dict_keys)) + + # curry get since we can't use curried.get + curried_get = partial(get, intersected_keys) + rv = factory() + for i, values in zip(intersected_keys, zip(*map(curried_get, dicts))): + rv[i] = values + return rv diff --git a/toolz/tests/test_dicttoolz.py b/toolz/tests/test_dicttoolz.py index d45cd6cf..d7d5e268 100644 --- a/toolz/tests/test_dicttoolz.py +++ b/toolz/tests/test_dicttoolz.py @@ -3,7 +3,7 @@ import os from toolz.dicttoolz import (merge, merge_with, valmap, keymap, update_in, assoc, dissoc, keyfilter, valfilter, itemmap, - itemfilter, assoc_in) + itemfilter, assoc_in, intersect) from toolz.functoolz import identity from toolz.utils import raises @@ -152,6 +152,12 @@ def test_factory(self): factory=lambda: defaultdict(int)) == {1: 2, 2: 3}) assert raises(TypeError, lambda: merge(D({1: 2}), D({2: 3}), factoryy=dict)) + def test_intersect(self): + D, kw = self.D, self.kw + assert intersect({0:1, 1:2}, {0:2}) == {0: (1, 2)} + assert (intersect(D({'a': 1, 'b': 2, 'c': 3}), D({'b': 'a', 'c': 'z'}), **kw) == + D({'b': (2, 'a'), 'c': (3, 'z')})) + class defaultdict(_defaultdict): def __eq__(self, other):