diff --git a/pkg/model-serving-controller/controller/model_serving_controller.go b/pkg/model-serving-controller/controller/model_serving_controller.go index ae34204c2..e8038a8fd 100644 --- a/pkg/model-serving-controller/controller/model_serving_controller.go +++ b/pkg/model-serving-controller/controller/model_serving_controller.go @@ -786,13 +786,13 @@ func (c *ModelServingController) manageRole(ctx context.Context, ms *workloadv1a return fmt.Errorf("cannot get ServingGroup of modelServing: %s from map: %v", ms.GetName(), err) } partition := c.getPartition(ms) - for index, servingGroup := range servingGroupList { + for _, servingGroup := range servingGroupList { if c.store.GetServingGroupStatus(utils.GetNamespaceName(ms), servingGroup.Name) == datastore.ServingGroupDeleting { // Deleting ServingGroup will be recreated after the deletion is complete, so there is no need to scale the roles continue } _, servingGroupOrdinal := utils.GetParentNameAndOrdinal(servingGroup.Name) - isPartitionProtected := partition > 0 && index < partition + isPartitionProtected := partition > 0 && servingGroupOrdinal >= 0 && servingGroupOrdinal < partition rolesToManage := ms.Spec.Template.Roles revisionToUse := newRevision diff --git a/test/e2e/controller-manager/model_serving_test.go b/test/e2e/controller-manager/model_serving_test.go index 6097ce1ae..c9d758e20 100644 --- a/test/e2e/controller-manager/model_serving_test.go +++ b/test/e2e/controller-manager/model_serving_test.go @@ -37,6 +37,7 @@ import ( clientset "github.com/volcano-sh/kthena/client-go/clientset/versioned" workload "github.com/volcano-sh/kthena/pkg/apis/workload/v1alpha1" + controllerutils "github.com/volcano-sh/kthena/pkg/model-serving-controller/utils" "github.com/volcano-sh/kthena/test/e2e/utils" ) @@ -1234,6 +1235,336 @@ func TestLWSAPIBasic(t *testing.T) { t.Log("LWS API basic test passed successfully") } +// TestModelServingPartitionBoundaryProtection verifies partition boundaries during rolling updates. +func TestModelServingPartitionBoundaryProtection(t *testing.T) { + ctx, kthenaClient, kubeClient := setupControllerManagerE2ETest(t) + + const ( + replicas = int32(5) + partition = int32(3) + ) + + modelServing := createPartitionedModelServing("test-partition-boundary", replicas, partition) + t.Logf("Creating ModelServing with %d replicas and partition=%d", replicas, partition) + createAndWaitForModelServing(t, ctx, kthenaClient, modelServing) + + initialMS, err := kthenaClient.WorkloadV1alpha1().ModelServings(testNamespace).Get(ctx, modelServing.Name, metav1.GetOptions{}) + require.NoError(t, err) + initialRevision := initialMS.Status.CurrentRevision + t.Logf("Initial CurrentRevision: %s", initialMS.Status.CurrentRevision) + require.NotEmpty(t, initialRevision, "Initial CurrentRevision should be set") + + updatedMS := initialMS.DeepCopy() + updatedMS.Spec.Template.Roles[0].EntryTemplate.Spec.Containers[0].Image = nginxAlpineImage + t.Logf("Updating image to %s", nginxAlpineImage) + + _, err = kthenaClient.WorkloadV1alpha1().ModelServings(testNamespace).Update(ctx, updatedMS, metav1.UpdateOptions{}) + require.NoError(t, err) + utils.WaitForModelServingReady(t, ctx, kthenaClient, testNamespace, modelServing.Name) + + updateRevision := waitForPartitionState(t, ctx, kthenaClient, kubeClient, modelServing.Name, partition, replicas, initialRevision) + assert.NotEqual(t, initialRevision, updateRevision) +} + +// TestModelServingPartitionDeletedGroupHistoricalRevision verifies deleted groups +// within partition are rebuilt using historical revision. +func TestModelServingPartitionDeletedGroupHistoricalRevision(t *testing.T) { + ctx, kthenaClient, kubeClient := setupControllerManagerE2ETest(t) + + const ( + replicas = int32(5) + partition = int32(3) + ) + + modelServing := createPartitionedModelServing("test-partition-historical", replicas, partition) + modelServing.Spec.RecoveryPolicy = workload.RoleRecreate + t.Logf("Creating ModelServing with %d replicas and partition=%d", replicas, partition) + createAndWaitForModelServing(t, ctx, kthenaClient, modelServing) + + initialMS, err := kthenaClient.WorkloadV1alpha1().ModelServings(testNamespace).Get(ctx, modelServing.Name, metav1.GetOptions{}) + require.NoError(t, err) + initialRevision := initialMS.Status.CurrentRevision + t.Logf("Initial CurrentRevision: %s", initialRevision) + require.NotEmpty(t, initialRevision, "Initial CurrentRevision should be set") + + updatedMS := initialMS.DeepCopy() + updatedMS.Spec.Template.Roles[0].EntryTemplate.Spec.Containers[0].Image = nginxAlpineImage + t.Logf("Updating image to %s", nginxAlpineImage) + + _, err = kthenaClient.WorkloadV1alpha1().ModelServings(testNamespace).Update(ctx, updatedMS, metav1.UpdateOptions{}) + require.NoError(t, err) + utils.WaitForModelServingReady(t, ctx, kthenaClient, testNamespace, modelServing.Name) + + updateRevision := waitForPartitionState(t, ctx, kthenaClient, kubeClient, modelServing.Name, partition, replicas, initialRevision) + t.Log("Partitioned update established") + + targetOrdinal := 1 + targetGroupName := fmt.Sprintf("%s-%d", modelServing.Name, targetOrdinal) + labelSelector := fmt.Sprintf("%s=%s", workload.GroupNameLabelKey, targetGroupName) + + pods, err := kubeClient.CoreV1().Pods(testNamespace).List(ctx, metav1.ListOptions{ + LabelSelector: labelSelector, + }) + require.NoError(t, err) + require.NotEmpty(t, pods.Items) + + podToDelete := pods.Items[0] + originalUID := string(podToDelete.UID) + t.Logf("Deleting pod %s (ordinal %d)", podToDelete.Name, targetOrdinal) + + err = kubeClient.CoreV1().Pods(testNamespace).Delete(ctx, podToDelete.Name, metav1.DeleteOptions{}) + require.NoError(t, err) + + utils.WaitForModelServingReady(t, ctx, kthenaClient, testNamespace, modelServing.Name) + + require.Eventually(t, func() bool { + ordinalStates, err := collectRunningServingGroupStates(ctx, kubeClient, modelServing.Name) + if err != nil { + t.Logf("Failed to collect serving group states: %v", err) + return false + } + state, ok := ordinalStates[int32(targetOrdinal)] + if !ok { + return false + } + t.Logf("Recreated protected ordinal %d => group=%s pod=%s revision=%s image=%s", targetOrdinal, state.GroupName, state.PodName, state.Revision, state.Image) + return state.PodUID != originalUID && + state.Revision == initialRevision && + state.Image == nginxImage + }, 3*time.Minute, 2*time.Second, "Recreated pod should use historical revision") + + finalMS, err := kthenaClient.WorkloadV1alpha1().ModelServings(testNamespace).Get(ctx, modelServing.Name, metav1.GetOptions{}) + require.NoError(t, err) + ordinalStates, err := collectRunningServingGroupStates(ctx, kubeClient, modelServing.Name) + require.NoError(t, err) + protectedCorrect, updatedCorrect := verifyPartitionState(t, ordinalStates, partition, replicas, initialRevision, updateRevision) + assert.Equal(t, int(partition), protectedCorrect) + assert.Equal(t, int(replicas-partition), updatedCorrect) + assert.Equal(t, initialRevision, finalMS.Status.CurrentRevision) + assert.Equal(t, updateRevision, finalMS.Status.UpdateRevision) +} + +// TestModelServingRollingUpdate verifies rolling updates without partition. +func TestModelServingRollingUpdate(t *testing.T) { + ctx, kthenaClient, kubeClient := setupControllerManagerE2ETest(t) + + const replicas = int32(3) + + modelServing := createBasicModelServing("test-rolling-update", replicas, 0) + t.Logf("Creating ModelServing with %d replicas", replicas) + createAndWaitForModelServing(t, ctx, kthenaClient, modelServing) + + initialMS, err := kthenaClient.WorkloadV1alpha1().ModelServings(testNamespace).Get(ctx, modelServing.Name, metav1.GetOptions{}) + require.NoError(t, err) + initialRevision := initialMS.Status.CurrentRevision + t.Logf("Initial CurrentRevision: %s", initialRevision) + + labelSelector := modelServingLabelSelector(modelServing.Name) + verifyAllPodsHaveImage(t, ctx, kubeClient, labelSelector, nginxImage, "before update") + + updatedMS := initialMS.DeepCopy() + updatedMS.Spec.Template.Roles[0].EntryTemplate.Spec.Containers[0].Image = nginxAlpineImage + t.Logf("Updating image to %s", nginxAlpineImage) + + _, err = kthenaClient.WorkloadV1alpha1().ModelServings(testNamespace).Update(ctx, updatedMS, metav1.UpdateOptions{}) + require.NoError(t, err) + utils.WaitForModelServingReady(t, ctx, kthenaClient, testNamespace, modelServing.Name) + + verifyAllPodsHaveImage(t, ctx, kubeClient, labelSelector, nginxAlpineImage, "after update") + + finalMS, err := kthenaClient.WorkloadV1alpha1().ModelServings(testNamespace).Get(ctx, modelServing.Name, metav1.GetOptions{}) + require.NoError(t, err) + require.NotEmpty(t, finalMS.Status.UpdateRevision, "UpdateRevision should be set after rollout") + + assert.Equal(t, finalMS.Status.CurrentRevision, finalMS.Status.UpdateRevision) + assert.NotEqual(t, initialRevision, finalMS.Status.UpdateRevision) + + ordinalStates, err := collectRunningServingGroupStates(ctx, kubeClient, modelServing.Name) + require.NoError(t, err) + require.Len(t, ordinalStates, int(replicas), "Expected one running group per replica after rollout") + for ordinal, state := range ordinalStates { + assert.Equalf(t, finalMS.Status.UpdateRevision, state.Revision, "Ordinal %d should use UpdateRevision without partition", ordinal) + assert.Equalf(t, nginxAlpineImage, state.Image, "Ordinal %d should run the updated image without partition", ordinal) + } + t.Logf("Rolling update completed - CurrentRevision: %s", finalMS.Status.CurrentRevision) +} + +func createPartitionedModelServing(name string, replicas, partition int32) *workload.ModelServing { + roleReplicas := int32(1) + return &workload.ModelServing{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: testNamespace, + }, + Spec: workload.ModelServingSpec{ + Replicas: &replicas, + RolloutStrategy: &workload.RolloutStrategy{ + Type: workload.ServingGroupRollingUpdate, + RollingUpdateConfiguration: &workload.RollingUpdateConfiguration{ + Partition: ptr.To(intstr.FromInt32(partition)), + MaxUnavailable: ptr.To(intstr.FromInt(int(replicas))), + }, + }, + Template: workload.ServingGroup{ + Roles: []workload.Role{ + { + Name: "prefill", + Replicas: &roleReplicas, + EntryTemplate: workload.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "test-container", + Image: nginxImage, + Ports: []corev1.ContainerPort{ + { + Name: "http", + ContainerPort: 80, + }, + }, + }, + }, + }, + }, + WorkerReplicas: 0, + }, + }, + }, + }, + } +} + +type servingGroupState struct { + GroupName string + PodName string + PodUID string + Ordinal int32 + Revision string + Image string +} + +func collectRunningServingGroupStates(ctx context.Context, kubeClient *kubernetes.Clientset, msName string) (map[int32]servingGroupState, error) { + pods, err := kubeClient.CoreV1().Pods(testNamespace).List(ctx, metav1.ListOptions{ + LabelSelector: modelServingLabelSelector(msName), + }) + if err != nil { + return nil, err + } + + states := make(map[int32]servingGroupState) + for _, pod := range pods.Items { + if pod.DeletionTimestamp != nil || pod.Status.Phase != corev1.PodRunning { + continue + } + groupName := pod.Labels[workload.GroupNameLabelKey] + if groupName == "" { + continue + } + parentName, ordinal := controllerutils.GetParentNameAndOrdinal(groupName) + if parentName != msName || ordinal < 0 { + continue + } + revision := pod.Labels[workload.RevisionLabelKey] + if revision == "" || len(pod.Spec.Containers) == 0 { + continue + } + + state := servingGroupState{ + GroupName: groupName, + PodName: pod.Name, + PodUID: string(pod.UID), + Ordinal: int32(ordinal), + Revision: revision, + Image: pod.Spec.Containers[0].Image, + } + if existing, ok := states[state.Ordinal]; !ok || state.PodName < existing.PodName { + states[state.Ordinal] = state + } + } + + return states, nil +} + +func waitForPartitionState(t *testing.T, ctx context.Context, kthenaClient *clientset.Clientset, + kubeClient *kubernetes.Clientset, msName string, partition, replicas int32, initialRevision string) string { + t.Helper() + + var updateRevision string + require.Eventually(t, func() bool { + ms, err := kthenaClient.WorkloadV1alpha1().ModelServings(testNamespace).Get(ctx, msName, metav1.GetOptions{}) + if err != nil { + return false + } + ordinalStates, err := collectRunningServingGroupStates(ctx, kubeClient, msName) + if err != nil { + t.Logf("Failed to collect serving group states: %v", err) + return false + } + if len(ordinalStates) != int(replicas) { + t.Logf("Running serving group count: %d (expecting %d)", len(ordinalStates), replicas) + return false + } + protectedCorrect, updatedCorrect := verifyPartitionState(t, ordinalStates, partition, replicas, initialRevision, ms.Status.UpdateRevision) + t.Logf("CurrentRevision: %s, UpdateRevision: %s, Protected: %d/%d, Updated: %d/%d", + ms.Status.CurrentRevision, ms.Status.UpdateRevision, protectedCorrect, partition, updatedCorrect, replicas-partition) + if ms.Status.CurrentRevision != initialRevision || + ms.Status.UpdateRevision == "" || + ms.Status.UpdateRevision == initialRevision || + protectedCorrect != int(partition) || + updatedCorrect != int(replicas-partition) { + return false + } + updateRevision = ms.Status.UpdateRevision + return true + }, 3*time.Minute, 2*time.Second, "Partition state did not converge") + + return updateRevision +} + +func verifyPartitionState(t *testing.T, ordinalStates map[int32]servingGroupState, + partition, replicas int32, currentRevision, updateRevision string) (protectedCorrect, updatedCorrect int) { + t.Helper() + for ordinal, state := range ordinalStates { + isProtected := partition > 0 && ordinal < partition + if isProtected && state.Revision == currentRevision && state.Image == nginxImage { + protectedCorrect++ + } else if !isProtected && state.Revision == updateRevision && state.Image == nginxAlpineImage { + updatedCorrect++ + } + } + return +} + +func verifyAllPodsHaveImage(t *testing.T, ctx context.Context, kubeClient *kubernetes.Clientset, + labelSelector, expectedImage, phase string) { + t.Helper() + require.Eventually(t, func() bool { + pods, err := kubeClient.CoreV1().Pods(testNamespace).List(ctx, metav1.ListOptions{ + LabelSelector: labelSelector, + }) + if err != nil || len(pods.Items) == 0 { + return false + } + + for _, pod := range pods.Items { + if pod.DeletionTimestamp != nil { + continue + } + if pod.Status.Phase != corev1.PodRunning { + return false + } + for _, container := range pod.Spec.Containers { + if container.Image != expectedImage { + return false + } + } + } + return true + }, 2*time.Minute, 1*time.Second, "All pods should have image %s %s", expectedImage, phase) + + t.Logf("Verified all pods have image %s %s", expectedImage, phase) +} + // TestModelServingControllerManagerRestart verifies that ModelServing pod creation // is successful even when the controller-manager restarts during reconciliation. // NOTE: This test must remain last among ModelServing tests because it restarts the