From 7565f8eb5911761f68b6a2540dc20288174c9942 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 16 May 2022 10:12:25 -0700 Subject: [PATCH] DeviceArray: Improve support for copy, deepcopy, and pickle --- CHANGELOG.md | 10 +++++++++- jax/_src/device_array.py | 22 +++++++++++++++++---- jax/_src/numpy/lax_numpy.py | 9 +++++++++ jax/core.py | 14 ++++++++------ tests/lax_numpy_test.py | 10 ++++++++-- tests/pickle_test.py | 38 +++++++++++++++++++++++++++++++++++++ 6 files changed, 90 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9cac1254fa32..4253a5166050 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,15 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. --> ## jax 0.3.13 (Unreleased) -* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.12...main). +* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.12...main). +* Changes + * `pickle`, `copy.copy`, and `copy.deepcopy` now have more complete support when + used on jax arrays ({jax-issue}`#10659`). In particular: + - `pickle` and `deepcopy` previously returned `np.ndarray` objects when used + on a `DeviceArray`; now `DeviceArray` objects are returned. + - Within function transformations (i.e. traced code), `deepcopy` and `copy` + previously were no-ops. Now they result in calls to `lax._array_copy`. + - Calling `pickle` on a traced array results in a `ConcretizationTypeError`. ## jaxlib 0.3.11 (Unreleased) * [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main). diff --git a/jax/_src/device_array.py b/jax/_src/device_array.py index a926db3f22ca..92246f5e93cf 100644 --- a/jax/_src/device_array.py +++ b/jax/_src/device_array.py @@ -21,6 +21,7 @@ import numpy as np +import jax from jax import core from jax._src.config import config from jax._src import abstract_arrays @@ -265,6 +266,14 @@ def __array__(self, dtype=None, context=None): setattr(device_array, "__array__", __array__) + def __reduce__(self): + fun, args, arr_state = self._value.__reduce__() + aval_state = {'weak_type': self.aval.weak_type, + 'named_shape': self.aval.named_shape} + return (reconstruct_device_array, (fun, args, arr_state, aval_state)) + + setattr(device_array, "__reduce__", __reduce__) + setattr(device_array, "__str__", partialmethod(_forward_to_value, str)) setattr(device_array, "__bool__", partialmethod(_forward_to_value, bool)) setattr(device_array, "__nonzero__", partialmethod(_forward_to_value, bool)) @@ -280,10 +289,6 @@ def __array__(self, dtype=None, context=None): del to_bytes setattr(device_array, "tolist", lambda self: self._value.tolist()) - # pickle saves and loads just like an ndarray - setattr(device_array, "__reduce__", - partialmethod(_forward_to_value, operator.methodcaller("__reduce__"))) - # explicitly set to be unhashable. setattr(device_array, "__hash__", None) @@ -298,6 +303,15 @@ def raise_not_implemented(): # pylint: enable=protected-access +def reconstruct_device_array(fun, args, arr_state, aval_state): + """Method to reconstruct a device array from a serialized state.""" + np_value = fun(*args) + np_value.__setstate__(arr_state) + jnp_value = jax.device_put(np_value) + jnp_value.aval = jnp_value.aval.update(**aval_state) + return jnp_value + + class DeletedBuffer(object): pass deleted_buffer = DeletedBuffer() diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 7b10b17b1d07..c6b99135123a 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -4572,9 +4572,18 @@ def _operator_round(number, ndigits=None): # If `ndigits` is None, for a builtin float round(7.5) returns an integer. return out.astype(int) if ndigits is None else out +def _copy(self): + return self.copy() + +def _deepcopy(self, memo): + del memo # unused + return self.copy() + _operators = { "getitem": _rewriting_take, "setitem": _unimplemented_setitem, + "copy": _copy, + "deepcopy": _deepcopy, "neg": negative, "pos": positive, "eq": _defer_to_unrecognized_arg(equal), diff --git a/jax/core.py b/jax/core.py index 9f67c1f6b8be..dce40ad5d874 100644 --- a/jax/core.py +++ b/jax/core.py @@ -604,6 +604,14 @@ def __hex__(self): return self.aval._hex(self) def __oct__(self): return self.aval._oct(self) def __float__(self): return self.aval._float(self) def __complex__(self): return self.aval._complex(self) + def __copy__(self): return self.aval._copy(self) + def __deepcopy__(self, memo): return self.aval._deepcopy(self, memo) + + # raises a useful error on attempts to pickle a Tracer. + def __reduce__(self): + raise ConcretizationTypeError( + self, ("The error occurred in the __reduce__ method, which may " + "indicate an attempt to serialize/pickle a traced value.")) # raises the better error message from ShapedArray def __setitem__(self, idx, val): return self.aval._setitem(self, idx, val) @@ -651,12 +659,6 @@ def _contents(self): except AttributeError: return () - def __copy__(self): - return self - - def __deepcopy__(self, unused_memo): - return self - def _origin_msg(self) -> str: return "" diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index efdd34d516bd..afc1de0d9fd7 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -14,6 +14,7 @@ import collections +import copy import functools from functools import partial import inspect @@ -3828,10 +3829,15 @@ def _check(obj, out_dtype, weak_type): {"testcase_name": f"_dtype={np.dtype(dtype)}_func={func}", "dtype": dtype, "func": func} for dtype in all_dtypes - for func in ["array", "copy"])) + for func in ["array", "copy", "copy.copy", "copy.deepcopy"])) def testArrayCopy(self, dtype, func): x = jnp.ones(10, dtype=dtype) - copy_func = getattr(jnp, func) + if func == "copy.deepcopy": + copy_func = copy.deepcopy + elif func == "copy.copy": + copy_func = copy.copy + else: + copy_func = getattr(jnp, func) x_view = jnp.asarray(x) x_view_jit = jax.jit(jnp.asarray)(x) diff --git a/tests/pickle_test.py b/tests/pickle_test.py index 9a51813dd047..424b4cda93e5 100644 --- a/tests/pickle_test.py +++ b/tests/pickle_test.py @@ -24,6 +24,7 @@ cloudpickle = None import jax +from jax import core from jax import numpy as jnp from jax.config import config from jax._src import test_util as jtu @@ -73,5 +74,42 @@ def g(z): self.assertEqual(expected, actual) +class PickleTest(jtu.JaxTestCase): + + def testPickleOfDeviceArray(self): + x = jnp.arange(10.0) + s = pickle.dumps(x) + y = pickle.loads(s) + self.assertArraysEqual(x, y) + self.assertIsInstance(y, type(x)) + self.assertEqual(x.aval, y.aval) + + def testPickleOfDeviceArrayWeakType(self): + x = jnp.array(4.0) + self.assertEqual(x.aval.weak_type, True) + s = pickle.dumps(x) + y = pickle.loads(s) + self.assertArraysEqual(x, y) + self.assertIsInstance(y, type(x)) + self.assertEqual(x.aval, y.aval) + + def testPickleX64(self): + with jax.experimental.enable_x64(): + x = jnp.array(4.0, dtype='float64') + s = pickle.dumps(x) + + with jax.experimental.disable_x64(): + y = pickle.loads(s) + + self.assertEqual(x.dtype, jnp.float64) + self.assertArraysEqual(x, y, check_dtypes=False) + self.assertEqual(y.dtype, jnp.float32) + self.assertEqual(y.aval.dtype, jnp.float32) + self.assertIsInstance(y, type(x)) + + def testPickleTracerError(self): + with self.assertRaises(core.ConcretizationTypeError): + jax.jit(pickle.dumps)(0) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())