-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DeviceArray: Improve support for copy, deepcopy, and pickle #10659
Conversation
32282d9
to
6d224ea
Compare
pulling in for more testing - not sure whether there will be unintended side-effects. |
Turns out there are a number of downstream users implicitly relying on pickling and/or deepcopy converting device arrays to numpy arrays. Perhaps we can start with a deprecation warning. |
af62f1e
to
36e0361
Compare
36e0361
to
db781ad
Compare
360d993
to
8b34a3c
Compare
I'm working on fixing internal breakages; probably will not be able to merge this for a week or two. But requesting review new because I think this will be close to the final state of the PR. Thanks! |
7565f8e
to
029268b
Compare
Ready for a review: PTAL |
a0dad89
to
a57b646
Compare
|
||
With the built-in :mod:`copy` module, when :func:`copy.copy` or :func:`copy.deepcopy` | ||
encounder a :class:`~jax.numpy.DeviceArray`, it is equivalent to calling the | ||
:meth:`~jaxlib.xla_extension.DeviceArray.copy` method, which will create a copy of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A note, not necessarily for this change: I don't consider anything in jaxlib
to be a public API. We should have public jax.
names for anything we need to refer to.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, unfortunately array methods are only defined on the xla extension version of the DeviceArray
class, so we can only cross-reference them that way: https://jax.readthedocs.io/en/latest/jax.numpy.html#jax.numpy.DeviceArray
Maybe the pjit array unification thing will let us clean this up?
a57b646
to
687fce9
Compare
687fce9
to
991ad72
Compare
Currently deep-copying or pickling of jax DeviceArrays is implemented by forwarding
arr.__reduce__
to the ndarray value, meaning that the copied/unpickled result is a normal numpy array, and callingcopy.copy
is a no-op.This change implements custom
__copy__
,__deepcopy__
, and__reduce__
methods in order to properly copy and/or serialize DeviceArray objects in both traced and untraced contexts.For
__reduce__
(i.e. pickle), deserialization is done on the default device – I'm not sure whether it would make sense to try to persist the device; in any case it's not straightforward becausejaxlib.xla_extension.Device
is not serializable. Further, because deserialization may happen in a different runtime than the one where serialization took place, it's not clear how the non-default device would be identified.Addresses #2632
Note that pickling of bfloat16 arrays is still broken due to #8505