Skip to content

Commit

Permalink
DeviceArray: Improve support for copy, deepcopy, and pickle
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed May 16, 2022
1 parent f73f03e commit 7565f8e
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 13 deletions.
10 changes: 9 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,15 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
-->

## jax 0.3.13 (Unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.12...main).
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.12...main).
* Changes
* `pickle`, `copy.copy`, and `copy.deepcopy` now have more complete support when
used on jax arrays ({jax-issue}`#10659`). In particular:
- `pickle` and `deepcopy` previously returned `np.ndarray` objects when used
on a `DeviceArray`; now `DeviceArray` objects are returned.
- Within function transformations (i.e. traced code), `deepcopy` and `copy`
previously were no-ops. Now they result in calls to `lax._array_copy`.
- Calling `pickle` on a traced array results in a `ConcretizationTypeError`.

## jaxlib 0.3.11 (Unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main).
Expand Down
22 changes: 18 additions & 4 deletions jax/_src/device_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import numpy as np

import jax
from jax import core
from jax._src.config import config
from jax._src import abstract_arrays
Expand Down Expand Up @@ -265,6 +266,14 @@ def __array__(self, dtype=None, context=None):

setattr(device_array, "__array__", __array__)

def __reduce__(self):
fun, args, arr_state = self._value.__reduce__()
aval_state = {'weak_type': self.aval.weak_type,
'named_shape': self.aval.named_shape}
return (reconstruct_device_array, (fun, args, arr_state, aval_state))

setattr(device_array, "__reduce__", __reduce__)

setattr(device_array, "__str__", partialmethod(_forward_to_value, str))
setattr(device_array, "__bool__", partialmethod(_forward_to_value, bool))
setattr(device_array, "__nonzero__", partialmethod(_forward_to_value, bool))
Expand All @@ -280,10 +289,6 @@ def __array__(self, dtype=None, context=None):
del to_bytes
setattr(device_array, "tolist", lambda self: self._value.tolist())

# pickle saves and loads just like an ndarray
setattr(device_array, "__reduce__",
partialmethod(_forward_to_value, operator.methodcaller("__reduce__")))

# explicitly set to be unhashable.
setattr(device_array, "__hash__", None)

Expand All @@ -298,6 +303,15 @@ def raise_not_implemented():
# pylint: enable=protected-access


def reconstruct_device_array(fun, args, arr_state, aval_state):
"""Method to reconstruct a device array from a serialized state."""
np_value = fun(*args)
np_value.__setstate__(arr_state)
jnp_value = jax.device_put(np_value)
jnp_value.aval = jnp_value.aval.update(**aval_state)
return jnp_value


class DeletedBuffer(object): pass
deleted_buffer = DeletedBuffer()

Expand Down
9 changes: 9 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4572,9 +4572,18 @@ def _operator_round(number, ndigits=None):
# If `ndigits` is None, for a builtin float round(7.5) returns an integer.
return out.astype(int) if ndigits is None else out

def _copy(self):
return self.copy()

def _deepcopy(self, memo):
del memo # unused
return self.copy()

_operators = {
"getitem": _rewriting_take,
"setitem": _unimplemented_setitem,
"copy": _copy,
"deepcopy": _deepcopy,
"neg": negative,
"pos": positive,
"eq": _defer_to_unrecognized_arg(equal),
Expand Down
14 changes: 8 additions & 6 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,14 @@ def __hex__(self): return self.aval._hex(self)
def __oct__(self): return self.aval._oct(self)
def __float__(self): return self.aval._float(self)
def __complex__(self): return self.aval._complex(self)
def __copy__(self): return self.aval._copy(self)
def __deepcopy__(self, memo): return self.aval._deepcopy(self, memo)

# raises a useful error on attempts to pickle a Tracer.
def __reduce__(self):
raise ConcretizationTypeError(
self, ("The error occurred in the __reduce__ method, which may "
"indicate an attempt to serialize/pickle a traced value."))

# raises the better error message from ShapedArray
def __setitem__(self, idx, val): return self.aval._setitem(self, idx, val)
Expand Down Expand Up @@ -651,12 +659,6 @@ def _contents(self):
except AttributeError:
return ()

def __copy__(self):
return self

def __deepcopy__(self, unused_memo):
return self

def _origin_msg(self) -> str:
return ""

Expand Down
10 changes: 8 additions & 2 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


import collections
import copy
import functools
from functools import partial
import inspect
Expand Down Expand Up @@ -3828,10 +3829,15 @@ def _check(obj, out_dtype, weak_type):
{"testcase_name": f"_dtype={np.dtype(dtype)}_func={func}",
"dtype": dtype, "func": func}
for dtype in all_dtypes
for func in ["array", "copy"]))
for func in ["array", "copy", "copy.copy", "copy.deepcopy"]))
def testArrayCopy(self, dtype, func):
x = jnp.ones(10, dtype=dtype)
copy_func = getattr(jnp, func)
if func == "copy.deepcopy":
copy_func = copy.deepcopy
elif func == "copy.copy":
copy_func = copy.copy
else:
copy_func = getattr(jnp, func)

x_view = jnp.asarray(x)
x_view_jit = jax.jit(jnp.asarray)(x)
Expand Down
38 changes: 38 additions & 0 deletions tests/pickle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
cloudpickle = None

import jax
from jax import core
from jax import numpy as jnp
from jax.config import config
from jax._src import test_util as jtu
Expand Down Expand Up @@ -73,5 +74,42 @@ def g(z):
self.assertEqual(expected, actual)


class PickleTest(jtu.JaxTestCase):

def testPickleOfDeviceArray(self):
x = jnp.arange(10.0)
s = pickle.dumps(x)
y = pickle.loads(s)
self.assertArraysEqual(x, y)
self.assertIsInstance(y, type(x))
self.assertEqual(x.aval, y.aval)

def testPickleOfDeviceArrayWeakType(self):
x = jnp.array(4.0)
self.assertEqual(x.aval.weak_type, True)
s = pickle.dumps(x)
y = pickle.loads(s)
self.assertArraysEqual(x, y)
self.assertIsInstance(y, type(x))
self.assertEqual(x.aval, y.aval)

def testPickleX64(self):
with jax.experimental.enable_x64():
x = jnp.array(4.0, dtype='float64')
s = pickle.dumps(x)

with jax.experimental.disable_x64():
y = pickle.loads(s)

self.assertEqual(x.dtype, jnp.float64)
self.assertArraysEqual(x, y, check_dtypes=False)
self.assertEqual(y.dtype, jnp.float32)
self.assertEqual(y.aval.dtype, jnp.float32)
self.assertIsInstance(y, type(x))

def testPickleTracerError(self):
with self.assertRaises(core.ConcretizationTypeError):
jax.jit(pickle.dumps)(0)

if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 7565f8e

Please sign in to comment.