Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[replica-parallel] Add replica slices concept #1319

Merged
merged 1 commit into from
Nov 14, 2024

Conversation

gspschmid
Copy link
Contributor

@gspschmid gspschmid commented Nov 11, 2024

(Followed by #1320)

Adds the concept of "replica slices", an explicit representation of which replica ids save which slices out of an array. The intent is for replica slices to allow us to generalize beyond Orbax's current restriction that requires each shard to be saved by exactly one replica.

Motivation: Depending on their sharding and replication, JAX arrays may consist of multiple shards. In case of replication each shard carries a distinct replica_id, distinguishing the copies of the same logical shard from one another. Orbax's current behavior is to save the same replica_id-copy for all shards of all arrays ("single-replica" saving). In the presence of replication this is suboptimal, since the work could be parallelized across all replicas.

This PR is an initial step in the direction of "replica-parallel" saving: we make "replica slices" and related metadata explicit, but do not change any of Orbax's behavior.

Care is taken to compute the resulting local_shape (the shape of slices written by each replica) even when the local process does not end up saving any data. This seems necessary, since, to the best of my understanding, only one particular process may set tensorstore metadata, and tensorstore's chunk shape, in particular, is derived from local_shape.

@gspschmid
Copy link
Contributor Author

@cpgaffney1

Copy link
Collaborator

@cpgaffney1 cpgaffney1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!! Looks good overall, just some minor comments.

sharding: jax.sharding.Sharding
dtype: np.dtype
# Whether the replica slices have been transferred and are ready as ndarrays
transferred: bool
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really necessary if we can just check replica_slices to see whether it contains numpy arrays or not?

Copy link
Contributor Author

@gspschmid gspschmid Nov 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There might be no replica_slices (e.g. because we are in single-replica mode and the current process has no shards with shard.replica_id == replica_id)?

dtype: np.dtype
# Whether the replica slices have been transferred and are ready as ndarrays
transferred: bool
replica_slices: list[ReplicaSliceOnDevice] | list[tuple[Index, np.ndarray]]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like we could just have a ReplicaSlice object that can store data: jax.Array | np.ndarray, with optional replica_id. That would simplify this typing a bit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved from ReplicaSliceOnDevice to ReplicaSlice (that may be either on-device or on-host) and added some invariants.

assert num_devices >= 2
assert is_pow_of_two(num_devices)

def test_get_replica_slices_single_replica(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to add new cases, or parameterize the existing ones, to include tests for arrays that are not fully replicated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a second variant of the test that operates on a partially-replicated array (first dim partitioned).

@gspschmid gspschmid force-pushed the gschmid/replica_parallel_0 branch from fcdeb8e to 5f533d5 Compare November 13, 2024 16:58
@gspschmid
Copy link
Contributor Author

@cpgaffney1 Thanks for the review, PTAL!

@cpgaffney1
Copy link
Collaborator

This is conflicting with a few internal behaviors - I have created a change that fixes most issues, just waiting for some advice on one particular issue. Likely by tomorrow the internal change should be ready to go, at which point I will merge this change and submit mine just after.

@gspschmid
Copy link
Contributor Author

Sounds good, thanks for helping shepherd this through! :-)

copybara-service bot pushed a commit that referenced this pull request Nov 14, 2024
@cpgaffney1 cpgaffney1 merged commit acd7869 into google:main Nov 14, 2024
1 check passed
copybara-service bot pushed a commit that referenced this pull request Nov 14, 2024
@gspschmid gspschmid deleted the gschmid/replica_parallel_0 branch November 15, 2024 09:57
copybara-service bot pushed a commit that referenced this pull request Jan 27, 2025
…1319, as performance regressions have been observed. Note also that simply setting `use_replica_parallel=False` does not fix the issue.

PiperOrigin-RevId: 720202277
copybara-service bot pushed a commit that referenced this pull request Jan 27, 2025
…1319, as performance regressions have been observed. Note also that simply setting `use_replica_parallel=False` does not fix the issue.

PiperOrigin-RevId: 720202277
copybara-service bot pushed a commit that referenced this pull request Jan 27, 2025
…1319, as performance regressions have been observed. Note also that simply setting `use_replica_parallel=False` does not fix the issue.

PiperOrigin-RevId: 720202277
copybara-service bot pushed a commit that referenced this pull request Jan 27, 2025
…1319, as performance regressions have been observed. Note also that simply setting `use_replica_parallel=False` does not fix the issue.

PiperOrigin-RevId: 720202277
copybara-service bot pushed a commit that referenced this pull request Jan 27, 2025
…1319, as performance regressions have been observed. Note also that simply setting `use_replica_parallel=False` does not fix the issue.

PiperOrigin-RevId: 720202277
copybara-service bot pushed a commit that referenced this pull request Jan 27, 2025
…1319, as performance regressions have been observed. Note also that simply setting `use_replica_parallel=False` does not fix the issue.

PiperOrigin-RevId: 720202277
copybara-service bot pushed a commit that referenced this pull request Jan 28, 2025
…1319, as performance regressions have been observed. Note also that simply setting `use_replica_parallel=False` does not fix the issue.

PiperOrigin-RevId: 720202277
copybara-service bot pushed a commit that referenced this pull request Jan 28, 2025
…1319, as performance regressions have been observed. Note also that simply setting `use_replica_parallel=False` does not fix the issue.

PiperOrigin-RevId: 720202277
copybara-service bot pushed a commit that referenced this pull request Jan 28, 2025
…1319, as performance regressions have been observed. Note also that simply setting `use_replica_parallel=False` does not fix the issue. Also disable `enable_pinned_host_transfer` feature, as allowing this also results in poor performance.

PiperOrigin-RevId: 720202277
copybara-service bot pushed a commit that referenced this pull request Jan 28, 2025
…1319, as performance regressions have been observed. Note also that simply setting `use_replica_parallel=False` does not fix the issue. Also disable `enable_pinned_host_transfer` feature, as allowing this also results in poor performance.

PiperOrigin-RevId: 720202277
copybara-service bot pushed a commit that referenced this pull request Jan 28, 2025
…1319, as performance regressions have been observed. Note also that simply setting `use_replica_parallel=False` does not fix the issue. Also disable `enable_pinned_host_transfer` feature, as allowing this also results in poor performance.

PiperOrigin-RevId: 720622279
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants