[JAX] Make colocated Python's serialization pickle each shared object only once #32271
+190
−37
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
[JAX] Make colocated Python's serialization pickle each shared object only once
It is common for colocated Python's input spec/output spec/function serialization to use the same object across multiple serialized objects. For example, an input or output spec can contain many shardings as leaves, where all shardings are on the same mesh (even some shardings can be identical). It is slow to serialize all of these objects repeatedly, both in terms of compute and serialized size. This problem is most significant when colocated Python runs on a very large mesh because it is very expensive to pickle/unpickle
Mesh
andxc.DeviceList
repeatedly.This change inserts a logic detecting shared objects within input spec/output spec/function, and do outlining of these objects to make them pickled only once. Deserialization would unpickle the shared objects and then unpickle individual objects that can internally reference the shared objects.