diff --git a/cloudpickle/cloudpickle.py b/cloudpickle/cloudpickle.py index 1d4c85cbe..772f17f62 100644 --- a/cloudpickle/cloudpickle.py +++ b/cloudpickle/cloudpickle.py @@ -43,7 +43,7 @@ from __future__ import print_function import dis -from functools import partial +import functools import io import itertools import logging @@ -1067,6 +1067,28 @@ def save_mappingproxy(self, obj): dispatch[types.MappingProxyType] = save_mappingproxy + # In CPython, functions decorated with functools.lru_cache are actually + # instances of a non-serializable built-in type. We pickle them by pickling + # the underlying function, along with the size of the lru cache. We do + # **not** attempt to pickle the contents of the function's cache. + if hasattr(functools, 'lru_cache'): # pragma: no branch + _lru_cache_instance = functools.lru_cache()(lambda: None) + + # PyPy's lru_cache returns a regular function object that closes over + # the cache state. We can't easily treat this specially beacuse + # Pickle's dispatching is purely type-based. + if not isinstance(_lru_cache_instance, types.FunctionType): + # Assume CPython native LRU Cache. + def save_lru_cached_function(self, obj): + self.save_reduce( + _rebuild_lru_cached_function, + (obj.cache_info().maxsize, obj.__wrapped__), + obj=obj, + ) + + dispatch[type(_lru_cache_instance)] = save_lru_cached_function + del _lru_cache_instance # Remove from class namespace. + """Special functions for Add-on libraries""" def inject_addons(self): """Plug in system. Register additional pickling functions if modules already loaded""" @@ -1395,3 +1417,11 @@ def _is_dynamic(module): except ImportError: return True return False + + +def _rebuild_lru_cached_function(maxsize, func): + """Reconstruct a function that was decorated with functools.lru_cache. + + The rebuilt function will have an empty cache. + """ + return functools.lru_cache(maxsize)(func) diff --git a/tests/cloudpickle_test.py b/tests/cloudpickle_test.py index 8bd210abe..9a604a3dd 100644 --- a/tests/cloudpickle_test.py +++ b/tests/cloudpickle_test.py @@ -50,6 +50,8 @@ _TEST_GLOBAL_VARIABLE = "default_value" +IS_PYPY = platform.python_implementation() == 'PyPy' + class RaiserOnPickle(object): @@ -368,7 +370,7 @@ def test_partial(self): partial_clone = pickle_depickle(partial_obj, protocol=self.protocol) self.assertEqual(partial_clone(4), 1) - @pytest.mark.skipif(platform.python_implementation() == 'PyPy', + @pytest.mark.skipif( reason="Skip numpy and scipy tests on PyPy") def test_ufunc(self): # test a numpy ufunc (universal function), which is a C-based function @@ -476,7 +478,7 @@ def method(self, x): self.assertEqual(mod.f(5), mod2.f(5)) self.assertEqual(mod.Foo().method(5), mod2.Foo().method(5)) - if platform.python_implementation() != 'PyPy': + if not IS_PYPY: # XXX: this fails with excessive recursion on PyPy. mod3 = subprocess_pickle_echo(mod, protocol=self.protocol) self.assertEqual(mod.x, mod3.x) @@ -639,7 +641,7 @@ def test_is_dynamic_module(self): dynamic_module = types.ModuleType('dynamic_module') assert _is_dynamic(dynamic_module) - if platform.python_implementation() == 'PyPy': + if IS_PYPY: import _codecs assert not _is_dynamic(_codecs) @@ -674,8 +676,7 @@ def test_builtin_function(self): # builtin function from a "regular" module assert pickle_depickle(mkdir, protocol=self.protocol) is mkdir - @pytest.mark.skipif(platform.python_implementation() == 'PyPy' and - sys.version_info[:2] == (3, 5), + @pytest.mark.skipif(IS_PYPY and sys.version_info[:2] == (3, 5), reason="bug of pypy3.5 in builtin-type constructors") def test_builtin_type_constructor(self): # Due to a bug in pypy3.5, cloudpickling builtin-type constructors @@ -750,7 +751,7 @@ def test_builtin_classmethod(self): # Roundtripping a classmethod_descriptor results in a # builtin_function_or_method (CPython upstream issue). assert depickled_clsdict_meth(arg) == clsdict_clsmethod(float, arg) - if platform.python_implementation() == 'PyPy': + if IS_PYPY: # builtin-classmethods are simple classmethod in PyPy (not # callable). We test equality of types and the functionality of the # __func__ attribute instead. We do not test the the identity of @@ -781,7 +782,7 @@ def test_builtin_slotmethod(self): assert depickled_clsdict_meth is clsdict_slotmethod @pytest.mark.skipif( - platform.python_implementation() == "PyPy" or + IS_PYPY or sys.version_info[:1] < (3,), reason="No known staticmethod example in the python 2 / pypy stdlib") def test_builtin_staticmethod(self): @@ -1499,7 +1500,7 @@ class A: """.format(protocol=self.protocol) assert_run_python_script(code) - @pytest.mark.skipif(platform.python_implementation() == 'PyPy', + @pytest.mark.skipif(IS_PYPY, reason="Skip PyPy because memory grows too much") def test_interactive_remote_function_calls_no_memory_leak(self): code = """if __name__ == "__main__": @@ -1876,6 +1877,59 @@ def __getattr__(self, name): with pytest.raises(pickle.PicklingError, match='recursion'): cloudpickle.dumps(a) + @unittest.skipIf(IS_PYPY or not hasattr(functools, "lru_cache"), + "Old versions of Python do not have lru_cache. " + "PyPy's lru_cache is a regular function.") + def test_pickle_lru_cached_function(self): + + for maxsize in None, 1, 2: + + @functools.lru_cache(maxsize=maxsize) + def func(x, y): + return x + y + + # Populate original function's cache. + func(1, 2) + + new_func = pickle_depickle(func, protocol=self.protocol) + assert type(new_func) == type(func) + + # We don't attempt to pickle the original function's cache, so the + # new function should have an empty cache. + self._expect_cache_info( + new_func.cache_info(), + hits=0, + misses=0, + maxsize=maxsize, + currsize=0, + ) + + assert new_func(1, 2) == 3 + + self._expect_cache_info( + new_func.cache_info(), + hits=0, + misses=1, + maxsize=maxsize, + currsize=1, + ) + + assert new_func(1, 2) == 3 + + self._expect_cache_info( + new_func.cache_info(), + hits=1, + misses=1, + maxsize=maxsize, + currsize=1, + ) + + def _expect_cache_info(self, cache_info, hits, misses, maxsize, currsize): + assert cache_info.hits == hits + assert cache_info.misses == misses + assert cache_info.maxsize == maxsize + assert cache_info.currsize == currsize + class Protocol2CloudPickleTest(CloudPickleTest):