Skip to content

Commit e9a8b23

Browse files
committed
Refactor and add e2e test
Signed-off-by: Ryan O'Leary <[email protected]>
1 parent 7a5659a commit e9a8b23

File tree

7 files changed

+389
-132
lines changed

7 files changed

+389
-132
lines changed

ray-operator/controllers/ray/common/pod.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ func DefaultWorkerPodTemplate(ctx context.Context, instance rayv1.RayCluster, wo
317317
}
318318
podTemplate.Labels = labelPod(rayv1.WorkerNode, instance.Name, workerSpec.GroupName, workerSpec.Template.ObjectMeta.Labels)
319319
// Add additional labels for RayMultihostIndexing
320-
multihostIndexingEnabled := features.Enabled(features.RayMulithostIndexing) && workerSpec.NumOfHosts > 1
320+
multihostIndexingEnabled := features.Enabled(features.RayMultiHostIndexing) && workerSpec.NumOfHosts > 1
321321
if multihostIndexingEnabled {
322322
podTemplate.Labels = addMultihostIndexingPodLabels(podTemplate.Labels, replicaGrpName, numHostIndex)
323323
}

ray-operator/controllers/ray/common/pod_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1179,7 +1179,7 @@ func TestDeafultWorkerPodTemplateWithReplicaGrpAndIndex(t *testing.T) {
11791179
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
11801180
worker := cluster.Spec.WorkerGroupSpecs[0]
11811181

1182-
features.SetFeatureGateDuringTest(t, features.RayMulithostIndexing, true)
1182+
features.SetFeatureGateDuringTest(t, features.RayMultiHostIndexing, true)
11831183

11841184
worker.Template.ObjectMeta.Name = "ray-worker-test"
11851185
worker.NumOfHosts = 4

ray-operator/controllers/ray/raycluster_controller.go

Lines changed: 182 additions & 102 deletions
Large diffs are not rendered by default.

ray-operator/controllers/ray/raycluster_controller_test.go

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,15 @@ func rayClusterTemplate(name string, namespace string) *rayv1.RayCluster {
4949
maxReplicas int32 = 4
5050
replicas int32 = 3
5151
)
52+
sharedMemVolume := corev1.Volume{
53+
Name: "shared-mem",
54+
VolumeSource: corev1.VolumeSource{
55+
EmptyDir: &corev1.EmptyDirVolumeSource{
56+
Medium: corev1.StorageMediumMemory,
57+
SizeLimit: ptr.To(resource.MustParse("1Gi")),
58+
},
59+
},
60+
}
5261
return &rayv1.RayCluster{
5362
ObjectMeta: metav1.ObjectMeta{
5463
Name: name,
@@ -58,6 +67,7 @@ func rayClusterTemplate(name string, namespace string) *rayv1.RayCluster {
5867
HeadGroupSpec: rayv1.HeadGroupSpec{
5968
Template: corev1.PodTemplateSpec{
6069
Spec: corev1.PodSpec{
70+
Volumes: []corev1.Volume{sharedMemVolume},
6171
Containers: []corev1.Container{
6272
{
6373
Name: "ray-head",
@@ -75,6 +85,7 @@ func rayClusterTemplate(name string, namespace string) *rayv1.RayCluster {
7585
GroupName: "small-group",
7686
Template: corev1.PodTemplateSpec{
7787
Spec: corev1.PodSpec{
88+
Volumes: []corev1.Volume{sharedMemVolume},
7889
Containers: []corev1.Container{
7990
{
8091
Name: "ray-worker",
@@ -922,8 +933,9 @@ var _ = Context("Inside the default namespace", func() {
922933
numWorkerPods := 3 * int(numOfHosts)
923934
workerFilters := common.RayClusterGroupPodsAssociationOptions(rayCluster, rayCluster.Spec.WorkerGroupSpecs[0].GroupName).ToListOptions()
924935

925-
// Checks if the multi-host indexing is enabled
926-
multihostIndexingEnabled := features.Enabled(features.RayMulithostIndexing)
936+
BeforeEach(func() {
937+
features.SetFeatureGateDuringTest(GinkgoTB(), features.RayMultiHostIndexing, true)
938+
})
927939

928940
It("Verify RayCluster spec", func() {
929941
// These test are designed based on the following assumptions:
@@ -969,22 +981,20 @@ var _ = Context("Inside the default namespace", func() {
969981
})
970982

971983
It("All multi-host pods are properly labeled", func() {
972-
if multihostIndexingEnabled {
973-
workerGrpReplicaMap := make(map[string][]string)
974-
for _, pod := range workerPods.Items {
975-
hostIndex := pod.Labels[utils.RayHostIndexKey]
976-
hostGrpId := pod.Labels[utils.RayWorkerReplicaIndexKey]
977-
978-
grpReplicaIndexList, grpIdExists := workerGrpReplicaMap[hostGrpId]
979-
if grpIdExists {
980-
Expect(strconv.Atoi(hostIndex)).Should(BeNumerically("<", numOfHosts))
981-
Expect(strconv.Atoi(hostIndex)).Should(BeNumerically(">=", 0))
982-
Expect(slices.Contains(grpReplicaIndexList, hostIndex)).To(BeFalse())
983-
workerGrpReplicaMap[hostGrpId] = append(grpReplicaIndexList, hostIndex)
984-
} else {
985-
workerGrpReplicaMap[hostGrpId] = []string{}
986-
Expect(len(workerGrpReplicaMap)).Should(BeNumerically("<", replicas))
987-
}
984+
workerGrpReplicaMap := make(map[string][]string)
985+
for _, pod := range workerPods.Items {
986+
hostIndex := pod.Labels[utils.RayHostIndexKey]
987+
hostGrpId := pod.Labels[utils.RayWorkerReplicaIndexKey]
988+
989+
grpReplicaIndexList, grpIdExists := workerGrpReplicaMap[hostGrpId]
990+
if grpIdExists {
991+
Expect(strconv.Atoi(hostIndex)).Should(BeNumerically("<", numOfHosts))
992+
Expect(strconv.Atoi(hostIndex)).Should(BeNumerically(">=", 0))
993+
Expect(slices.Contains(grpReplicaIndexList, hostIndex)).To(BeFalse())
994+
workerGrpReplicaMap[hostGrpId] = append(grpReplicaIndexList, hostIndex)
995+
} else {
996+
workerGrpReplicaMap[hostGrpId] = []string{}
997+
Expect(len(workerGrpReplicaMap)).Should(BeNumerically("<=", int(replicas)))
988998
}
989999
}
9901000
})
@@ -1059,12 +1069,6 @@ var _ = Context("Inside the default namespace", func() {
10591069
pod := workerPods.Items[0]
10601070
err := k8sClient.Delete(ctx, &pod, &client.DeleteOptions{GracePeriodSeconds: ptr.To[int64](0)})
10611071
Expect(err).NotTo(HaveOccurred(), "Failed to delete a Pod")
1062-
if multihostIndexingEnabled {
1063-
// Number of pods should go down by num of hosts but then be re-created
1064-
Eventually(
1065-
listResourceFunc(ctx, &workerPods, workerFilters...),
1066-
time.Second*3, time.Millisecond*500).Should(Equal(numWorkerPods-int(numOfHosts)), fmt.Sprintf("workerGroup %v", workerPods.Items))
1067-
}
10681072
Eventually(
10691073
listResourceFunc(ctx, &workerPods, workerFilters...),
10701074
time.Second*3, time.Millisecond*500).Should(Equal(numWorkerPods), fmt.Sprintf("workerGroup %v", workerPods.Items))

ray-operator/controllers/ray/utils/util.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ func GenerateRayJobId(rayjob string) string {
331331
}
332332

333333
// GenerateRayWorkerReplicaGroupName generates a name for the replica group
334-
// currently used for RayMulithostIndexing
334+
// currently used for RayMultiHostIndexing
335335
func GenerateRayWorkerReplicaGroupName(workerGroupName string) string {
336336
return fmt.Sprintf("%s-%s", workerGroupName, rand.String(5))
337337
}

ray-operator/pkg/features/features.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ const (
2828
// owner: @aaronliang
2929
// rep: N/A
3030
// alpha: v1.0
31-
// Enables multihost worker indexing
32-
RayMulithostIndexing featuregate.Feature = "RayMultihostIndexing"
31+
// Enables multi-host worker indexing
32+
RayMultiHostIndexing featuregate.Feature = "RayMultiHostIndexing"
3333
)
3434

3535
func init() {
@@ -39,7 +39,7 @@ func init() {
3939
var defaultFeatureGates = map[featuregate.Feature]featuregate.FeatureSpec{
4040
RayClusterStatusConditions: {Default: true, PreRelease: featuregate.Beta},
4141
RayJobDeletionPolicy: {Default: false, PreRelease: featuregate.Alpha},
42-
RayMulithostIndexing: {Default: false, PreRelease: featuregate.Alpha},
42+
RayMultiHostIndexing: {Default: false, PreRelease: featuregate.Alpha},
4343
}
4444

4545
// SetFeatureGateDuringTest is a helper method to override feature gates in tests.
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
package e2e
2+
3+
import (
4+
"testing"
5+
6+
. "github.com/onsi/gomega"
7+
corev1 "k8s.io/api/core/v1"
8+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
9+
"k8s.io/apimachinery/pkg/api/resource"
10+
corev1ac "k8s.io/client-go/applyconfigurations/core/v1"
11+
12+
rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
13+
"github.com/ray-project/kuberay/ray-operator/controllers/ray/utils"
14+
rayv1ac "github.com/ray-project/kuberay/ray-operator/pkg/client/applyconfiguration/ray/v1"
15+
"github.com/ray-project/kuberay/ray-operator/pkg/features"
16+
. "github.com/ray-project/kuberay/ray-operator/test/support"
17+
)
18+
19+
func TestRayClusterMultiHost(t *testing.T) {
20+
test := With(t)
21+
g := NewWithT(t)
22+
23+
// Create a namespace
24+
namespace := test.NewTestNamespace()
25+
26+
features.SetFeatureGateDuringTest(t, features.RayMultiHostIndexing, true)
27+
28+
const (
29+
numOfHosts = 4
30+
initialReplicas = 2
31+
clusterName = "raycluster-multihost"
32+
)
33+
sharedMemVolumeAC := corev1ac.Volume().
34+
WithName("shared-mem").
35+
WithEmptyDir(corev1ac.EmptyDirVolumeSource().
36+
WithMedium(corev1.StorageMediumMemory).
37+
WithSizeLimit(resource.MustParse("1Gi")),
38+
)
39+
40+
// Define the RayCluster spec with a multi-host worker group.
41+
rayClusterAC := rayv1ac.RayCluster(clusterName, namespace.Name).
42+
WithSpec(rayv1ac.RayClusterSpec().
43+
WithRayVersion(GetRayVersion()).
44+
WithEnableInTreeAutoscaling(true).
45+
WithHeadGroupSpec(rayv1ac.HeadGroupSpec().
46+
WithRayStartParams(map[string]string{"dashboard-host": "0.0.0.0"}).
47+
WithTemplate(HeadPodTemplateApplyConfiguration().
48+
// All PodSpec configurations go inside WithSpec.
49+
WithSpec(corev1ac.PodSpec().
50+
WithVolumes(sharedMemVolumeAC).
51+
WithRestartPolicy(corev1.RestartPolicyNever).
52+
WithContainers(corev1ac.Container().
53+
WithName("ray-head").
54+
WithImage(GetRayImage()).
55+
WithEnv(corev1ac.EnvVar().WithName(utils.RAY_ENABLE_AUTOSCALER_V2).WithValue("1")).
56+
WithPorts(
57+
corev1ac.ContainerPort().WithName(utils.GcsServerPortName).WithContainerPort(utils.DefaultGcsServerPort),
58+
corev1ac.ContainerPort().WithName(utils.ServingPortName).WithContainerPort(utils.DefaultServingPort),
59+
corev1ac.ContainerPort().WithName(utils.DashboardPortName).WithContainerPort(utils.DefaultDashboardPort),
60+
corev1ac.ContainerPort().WithName(utils.ClientPortName).WithContainerPort(utils.DefaultClientPort),
61+
).
62+
WithResources(corev1ac.ResourceRequirements().
63+
WithRequests(corev1.ResourceList{
64+
corev1.ResourceCPU: resource.MustParse("2"),
65+
corev1.ResourceMemory: resource.MustParse("3Gi"),
66+
}).
67+
WithLimits(corev1.ResourceList{
68+
corev1.ResourceCPU: resource.MustParse("2"),
69+
corev1.ResourceMemory: resource.MustParse("3Gi"),
70+
})),
71+
),
72+
),
73+
),
74+
).
75+
WithWorkerGroupSpecs(rayv1ac.WorkerGroupSpec().
76+
WithGroupName("multi-host-group").
77+
WithReplicas(initialReplicas).
78+
WithMinReplicas(0).
79+
WithMaxReplicas(5).
80+
WithNumOfHosts(numOfHosts).
81+
WithTemplate(WorkerPodTemplateApplyConfiguration().
82+
// All PodSpec configurations go inside WithSpec here as well.
83+
WithSpec(corev1ac.PodSpec().
84+
WithVolumes(sharedMemVolumeAC).
85+
WithRestartPolicy(corev1.RestartPolicyNever).
86+
WithContainers(corev1ac.Container().
87+
WithName("ray-worker").
88+
WithImage(GetRayImage()).
89+
WithResources(corev1ac.ResourceRequirements().
90+
WithRequests(corev1.ResourceList{
91+
corev1.ResourceCPU: resource.MustParse("300m"),
92+
corev1.ResourceMemory: resource.MustParse("1G"),
93+
}).
94+
WithLimits(corev1.ResourceList{
95+
corev1.ResourceCPU: resource.MustParse("500m"),
96+
corev1.ResourceMemory: resource.MustParse("1G"),
97+
})),
98+
),
99+
),
100+
),
101+
),
102+
)
103+
104+
// Create the RayCluster.
105+
rayCluster, err := test.Client().Ray().RayV1().RayClusters(namespace.Name).Apply(test.Ctx(), rayClusterAC, TestApplyOptions)
106+
g.Expect(err).NotTo(HaveOccurred())
107+
LogWithTimestamp(test.T(), "Created RayCluster %s/%s successfully", rayCluster.Namespace, rayCluster.Name)
108+
109+
// Wait for the cluster to become Ready and verify the initial Pod count.
110+
LogWithTimestamp(test.T(), "Waiting for RayCluster %s/%s to become ready", rayCluster.Namespace, rayCluster.Name)
111+
g.Eventually(RayCluster(test, rayCluster.Namespace, rayCluster.Name), TestTimeoutLong).
112+
Should(WithTransform(RayClusterState, Equal(rayv1.Ready)))
113+
114+
expectedPodCount := initialReplicas * numOfHosts
115+
g.Eventually(func() ([]corev1.Pod, error) {
116+
return GetWorkerPods(test, rayCluster)
117+
}, TestTimeoutShort).Should(HaveLen(expectedPodCount))
118+
119+
// Verify that all pods are correctly labeled.
120+
LogWithTimestamp(test.T(), "Verifying labels on multi-host pods for %s/%s", rayCluster.Namespace, rayCluster.Name)
121+
workerPods, err := GetWorkerPods(test, rayCluster)
122+
g.Expect(err).NotTo(HaveOccurred())
123+
replicaMap := make(map[string][]string)
124+
for _, pod := range workerPods {
125+
replicaName, ok := pod.Labels[utils.RayWorkerReplicaIndexKey]
126+
g.Expect(ok).To(BeTrue(), "Pod %s should have a replica index label", pod.Name)
127+
hostIndex, ok := pod.Labels[utils.RayHostIndexKey]
128+
g.Expect(ok).To(BeTrue(), "Pod %s should have a host index label", pod.Name)
129+
replicaMap[replicaName] = append(replicaMap[replicaName], hostIndex)
130+
}
131+
g.Expect(replicaMap).To(HaveLen(initialReplicas), "Should have the correct number of replica groups")
132+
for replicaName, hostIndices := range replicaMap {
133+
g.Expect(hostIndices).To(HaveLen(numOfHosts), "Replica group %s should be complete", replicaName)
134+
}
135+
136+
// Scale down replicas from 2 to 1. Verify we scale by a multiple of NumOfHosts.
137+
LogWithTimestamp(test.T(), "Scaling down RayCluster %s/%s", rayCluster.Namespace, rayCluster.Name)
138+
rayClusterAC.Spec.WorkerGroupSpecs[0].WithReplicas(1)
139+
_, err = test.Client().Ray().RayV1().RayClusters(namespace.Name).Apply(test.Ctx(), rayClusterAC, TestApplyOptions)
140+
g.Expect(err).NotTo(HaveOccurred())
141+
142+
expectedPodCount = 1 * numOfHosts
143+
g.Eventually(func() ([]corev1.Pod, error) {
144+
return GetWorkerPods(test, rayCluster)
145+
}, TestTimeoutShort).Should(HaveLen(expectedPodCount), "Should scale down to 1 multi-host group (4 pods)")
146+
147+
// Test scale up: Increase replicas from 1 to 3.
148+
LogWithTimestamp(test.T(), "Scaling up RayCluster %s/%s", rayCluster.Namespace, rayCluster.Name)
149+
rayClusterAC.Spec.WorkerGroupSpecs[0].WithReplicas(3)
150+
_, err = test.Client().Ray().RayV1().RayClusters(namespace.Name).Apply(test.Ctx(), rayClusterAC, TestApplyOptions)
151+
g.Expect(err).NotTo(HaveOccurred())
152+
153+
expectedPodCount = 3 * numOfHosts
154+
g.Eventually(func() ([]corev1.Pod, error) {
155+
return GetWorkerPods(test, rayCluster)
156+
}, TestTimeoutShort).Should(HaveLen(expectedPodCount), "Should scale up to 3 multi-host groups (12 pods)")
157+
158+
// Manually delete a single pod and verify the controller atomically re-creates the slice.
159+
LogWithTimestamp(test.T(), "Testing atomic multi-host group recreation for RayCluster %s/%s", rayCluster.Namespace, rayCluster.Name)
160+
workerPods, err = GetWorkerPods(test, rayCluster)
161+
g.Expect(err).NotTo(HaveOccurred())
162+
podToDelete := workerPods[0]
163+
err = test.Client().Core().CoreV1().Pods(namespace.Name).Delete(test.Ctx(), podToDelete.Name, metav1.DeleteOptions{})
164+
g.Expect(err).NotTo(HaveOccurred())
165+
166+
// The controller should first clean up the broken multi-host group (-4 pods), and then re-scale it up (+4 pods).
167+
LogWithTimestamp(test.T(), "Waiting for controller to reconcile multi-host group.")
168+
// Reconcilation happens too quickly to catch the state where expectedPodCount-NumOfHosts, but we can test
169+
// that externally deleted Pods will be re-created to satisfy the expected number.
170+
g.Eventually(func() ([]corev1.Pod, error) {
171+
return GetWorkerPods(test, rayCluster)
172+
}, TestTimeoutShort).Should(HaveLen(expectedPodCount), "Controller restored cluster to the correct number of pods.")
173+
}

0 commit comments

Comments
 (0)