Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeviceArray: Improve support for copy, deepcopy, and pickle #10659

Merged
merged 1 commit into from
May 20, 2022

Conversation

jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented May 10, 2022

Currently deep-copying or pickling of jax DeviceArrays is implemented by forwarding arr.__reduce__ to the ndarray value, meaning that the copied/unpickled result is a normal numpy array, and calling copy.copy is a no-op.

This change implements custom __copy__, __deepcopy__, and __reduce__ methods in order to properly copy and/or serialize DeviceArray objects in both traced and untraced contexts.

For __reduce__ (i.e. pickle), deserialization is done on the default device – I'm not sure whether it would make sense to try to persist the device; in any case it's not straightforward because jaxlib.xla_extension.Device is not serializable. Further, because deserialization may happen in a different runtime than the one where serialization took place, it's not clear how the non-default device would be identified.

Addresses #2632

Note that pickling of bfloat16 arrays is still broken due to #8505

@jakevdp jakevdp force-pushed the devicearray-pickle branch 2 times, most recently from 32282d9 to 6d224ea Compare May 10, 2022 22:55
@jakevdp
Copy link
Collaborator Author

jakevdp commented May 10, 2022

pulling in for more testing - not sure whether there will be unintended side-effects.

@jakevdp jakevdp added the pull ready Ready for copybara import and testing label May 10, 2022
@jakevdp
Copy link
Collaborator Author

jakevdp commented May 11, 2022

Turns out there are a number of downstream users implicitly relying on pickling and/or deepcopy converting device arrays to numpy arrays. Perhaps we can start with a deprecation warning.

@jakevdp jakevdp force-pushed the devicearray-pickle branch 3 times, most recently from af62f1e to 36e0361 Compare May 11, 2022 17:43
@jakevdp jakevdp changed the title Support pickling of DeviceArray objects Support pickle/deepcopy of DeviceArray objects May 11, 2022
@jakevdp jakevdp force-pushed the devicearray-pickle branch from 36e0361 to db781ad Compare May 13, 2022 16:55
@jakevdp jakevdp changed the title Support pickle/deepcopy of DeviceArray objects DeviceArray: Improve support for copy, deepcopy, and pickle May 13, 2022
@jakevdp jakevdp force-pushed the devicearray-pickle branch 4 times, most recently from 360d993 to 8b34a3c Compare May 13, 2022 19:40
@jakevdp jakevdp requested a review from hawkinsp May 13, 2022 20:10
@jakevdp
Copy link
Collaborator Author

jakevdp commented May 13, 2022

I'm working on fixing internal breakages; probably will not be able to merge this for a week or two. But requesting review new because I think this will be close to the final state of the PR. Thanks!

@jakevdp jakevdp force-pushed the devicearray-pickle branch 5 times, most recently from 7565f8e to 029268b Compare May 16, 2022 20:37
@jakevdp
Copy link
Collaborator Author

jakevdp commented May 16, 2022

Ready for a review: PTAL

jax/_src/device_array.py Show resolved Hide resolved
@jakevdp jakevdp force-pushed the devicearray-pickle branch 2 times, most recently from a0dad89 to a57b646 Compare May 17, 2022 19:20
docs/jax.numpy.rst Outdated Show resolved Hide resolved

With the built-in :mod:`copy` module, when :func:`copy.copy` or :func:`copy.deepcopy`
encounder a :class:`~jax.numpy.DeviceArray`, it is equivalent to calling the
:meth:`~jaxlib.xla_extension.DeviceArray.copy` method, which will create a copy of
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A note, not necessarily for this change: I don't consider anything in jaxlib to be a public API. We should have public jax. names for anything we need to refer to.

Copy link
Collaborator Author

@jakevdp jakevdp May 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, unfortunately array methods are only defined on the xla extension version of the DeviceArray class, so we can only cross-reference them that way: https://jax.readthedocs.io/en/latest/jax.numpy.html#jax.numpy.DeviceArray

Maybe the pjit array unification thing will let us clean this up?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants