-
Notifications
You must be signed in to change notification settings - Fork 44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[replica-parallel] Add replica slices concept #1319
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!! Looks good overall, just some minor comments.
checkpoint/orbax/checkpoint/_src/serialization/replica_slices.py
Outdated
Show resolved
Hide resolved
sharding: jax.sharding.Sharding | ||
dtype: np.dtype | ||
# Whether the replica slices have been transferred and are ready as ndarrays | ||
transferred: bool |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this really necessary if we can just check replica_slices
to see whether it contains numpy arrays or not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There might be no replica_slices (e.g. because we are in single-replica mode and the current process has no shards with shard.replica_id == replica_id
)?
dtype: np.dtype | ||
# Whether the replica slices have been transferred and are ready as ndarrays | ||
transferred: bool | ||
replica_slices: list[ReplicaSliceOnDevice] | list[tuple[Index, np.ndarray]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like we could just have a ReplicaSlice object that can store data: jax.Array | np.ndarray
, with optional replica_id. That would simplify this typing a bit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved from ReplicaSliceOnDevice
to ReplicaSlice
(that may be either on-device or on-host) and added some invariants.
checkpoint/orbax/checkpoint/_src/serialization/replica_slices_test.py
Outdated
Show resolved
Hide resolved
assert num_devices >= 2 | ||
assert is_pow_of_two(num_devices) | ||
|
||
def test_get_replica_slices_single_replica(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be good to add new cases, or parameterize the existing ones, to include tests for arrays that are not fully replicated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a second variant of the test that operates on a partially-replicated array (first dim partitioned).
fcdeb8e
to
5f533d5
Compare
@cpgaffney1 Thanks for the review, PTAL! |
…chmid. PiperOrigin-RevId: 696250850
…chmid. PiperOrigin-RevId: 696250850
This is conflicting with a few internal behaviors - I have created a change that fixes most issues, just waiting for some advice on one particular issue. Likely by tomorrow the internal change should be ready to go, at which point I will merge this change and submit mine just after. |
Sounds good, thanks for helping shepherd this through! :-) |
…chmid. PiperOrigin-RevId: 696250850
…thanks to https://github.com/gspschmid. PiperOrigin-RevId: 696601983
…1319, as performance regressions have been observed. Note also that simply setting `use_replica_parallel=False` does not fix the issue. PiperOrigin-RevId: 720202277
…1319, as performance regressions have been observed. Note also that simply setting `use_replica_parallel=False` does not fix the issue. PiperOrigin-RevId: 720202277
…1319, as performance regressions have been observed. Note also that simply setting `use_replica_parallel=False` does not fix the issue. PiperOrigin-RevId: 720202277
…1319, as performance regressions have been observed. Note also that simply setting `use_replica_parallel=False` does not fix the issue. PiperOrigin-RevId: 720202277
…1319, as performance regressions have been observed. Note also that simply setting `use_replica_parallel=False` does not fix the issue. PiperOrigin-RevId: 720202277
…1319, as performance regressions have been observed. Note also that simply setting `use_replica_parallel=False` does not fix the issue. PiperOrigin-RevId: 720202277
…1319, as performance regressions have been observed. Note also that simply setting `use_replica_parallel=False` does not fix the issue. PiperOrigin-RevId: 720202277
…1319, as performance regressions have been observed. Note also that simply setting `use_replica_parallel=False` does not fix the issue. PiperOrigin-RevId: 720202277
…1319, as performance regressions have been observed. Note also that simply setting `use_replica_parallel=False` does not fix the issue. Also disable `enable_pinned_host_transfer` feature, as allowing this also results in poor performance. PiperOrigin-RevId: 720202277
…1319, as performance regressions have been observed. Note also that simply setting `use_replica_parallel=False` does not fix the issue. Also disable `enable_pinned_host_transfer` feature, as allowing this also results in poor performance. PiperOrigin-RevId: 720202277
…1319, as performance regressions have been observed. Note also that simply setting `use_replica_parallel=False` does not fix the issue. Also disable `enable_pinned_host_transfer` feature, as allowing this also results in poor performance. PiperOrigin-RevId: 720622279
(Followed by #1320)
Adds the concept of "replica slices", an explicit representation of which replica ids save which slices out of an array. The intent is for replica slices to allow us to generalize beyond Orbax's current restriction that requires each shard to be saved by exactly one replica.
Motivation: Depending on their sharding and replication, JAX arrays may consist of multiple shards. In case of replication each shard carries a distinct
replica_id
, distinguishing the copies of the same logical shard from one another. Orbax's current behavior is to save the samereplica_id
-copy for all shards of all arrays ("single-replica" saving). In the presence of replication this is suboptimal, since the work could be parallelized across all replicas.This PR is an initial step in the direction of "replica-parallel" saving: we make "replica slices" and related metadata explicit, but do not change any of Orbax's behavior.
Care is taken to compute the resulting
local_shape
(the shape of slices written by each replica) even when the local process does not end up saving any data. This seems necessary, since, to the best of my understanding, only one particular process may set tensorstore metadata, and tensorstore's chunk shape, in particular, is derived fromlocal_shape
.