[CherryPick][Feature Enhancement] Set ordered replica index label to support mult… #4171
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
** This PR is a cherry pick of the below PR / description to the release v1.5 branch **
Why are these changes needed?
This is a feature enhancement to the
RayMultiHostIndexingfeature that adds replica id and host-index labels to Pods created by KubeRay that was implemented in this PR: #3998.To support the multi-host use-case for TPUs and GPUs, we already set a replica id label (a unique string based on the group name) and a host-index (a unique int from 0-NumOfHosts-1) when this feature is enabled and
NumOfHosts > 1. These labels add value to users through observability when running SPMD workloads, and through atomic creation and deletion/re-creation of multi-host groups.However, these labels do not support the use-case for multi-slice workloads where it is important to know the ordered index of the replica within the multi-slice set. Frameworks like JAX require the slice ID, which is an int between 0 and the # of slices, to be set to configure multi-slice workloads (source).
To solve this issue, this PR adds a label for the ordered replica index, an int value between 0 and replicas-1 for each worker group when this feature is enabled. This label will greatly simplify the process of setting environment variables like
MEGASCALE_SLICE_IDfor multi-slice workloads that use JAX. We can then check these KubeRay labels from the KubeRay TPU webhook when injecting environment vars. Before the TPU webhook change, these labels are still useful because users can pass them to the Pod environment using downward API.Related issue number
#3902
Checks