diff --git a/Makefile b/Makefile index 1b6e354..3b3fdea 100644 --- a/Makefile +++ b/Makefile @@ -62,7 +62,7 @@ vet: ## Run go vet against code. .PHONY: test test: manifests generate fmt vet envtest ## Run tests. - KUBEBUILDER_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) --bin-dir $(LOCALBIN) -p path)" go test $$(go list ./... | grep -v /e2e) -coverprofile cover.out + KUBEBUILDER_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) --bin-dir $(LOCALBIN) -p path)" go test $$(go list ./... | grep -v /e2e) -timeout 0 -coverprofile cover.out # TODO(user): To use a different vendor for e2e tests, modify the setup under 'tests/e2e'. # The default setup assumes Kind is pre-installed and builds/loads the Manager Docker image locally. diff --git a/internal/component/client.go b/internal/component/client.go new file mode 100644 index 0000000..da8d39a --- /dev/null +++ b/internal/component/client.go @@ -0,0 +1,117 @@ +package component + +import ( + "context" + "fmt" + "sort" + + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" + "github.com/NexusGPU/tensor-fusion/internal/constants" + "github.com/NexusGPU/tensor-fusion/internal/utils" + corev1 "k8s.io/api/core/v1" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/log" +) + +const ( + ClientUpdateInProgressAnnotation = constants.Domain + "/client-update-in-progress" + ClientBatchUpdateLastTimeAnnotation = constants.Domain + "/client-batch-update-last-time" +) + +type Client struct { + podsToUpdate []*corev1.Pod +} + +func (c *Client) GetName() string { + return "client" +} + +func (c *Client) DetectConfigChange(pool *tfv1.GPUPool, status *tfv1.PoolComponentStatus) (bool, string, string) { + oldHash := status.ClientVersion + changed, newHash := utils.CompareAndGetObjectHash(oldHash, pool.Spec.ComponentConfig.Client) + return changed, newHash, oldHash +} + +func (c *Client) SetConfigHash(status *tfv1.PoolComponentStatus, hash string) { + status.ClientVersion = hash +} + +func (c *Client) GetUpdateInProgressInfo(pool *tfv1.GPUPool) string { + return pool.Annotations[ClientUpdateInProgressAnnotation] +} + +func (c *Client) SetUpdateInProgressInfo(pool *tfv1.GPUPool, hash string) { + pool.Annotations[ClientUpdateInProgressAnnotation] = hash +} + +func (c *Client) GetBatchUpdateLastTimeInfo(pool *tfv1.GPUPool) string { + return pool.Annotations[ClientBatchUpdateLastTimeAnnotation] +} + +func (c *Client) SetBatchUpdateLastTimeInfo(pool *tfv1.GPUPool, time string) { + pool.Annotations[ClientBatchUpdateLastTimeAnnotation] = time +} + +func (c *Client) GetUpdateProgress(status *tfv1.PoolComponentStatus) int32 { + return status.ClientUpdateProgress +} + +func (c *Client) SetUpdateProgress(status *tfv1.PoolComponentStatus, progress int32) { + status.ClientUpdateProgress = progress + status.ClientConfigSynced = false + if progress == 100 { + status.ClientConfigSynced = true + } +} + +func (c *Client) GetResourcesInfo(r client.Client, ctx context.Context, pool *tfv1.GPUPool, configHash string) (int, int, bool, error) { + podList := &corev1.PodList{} + if err := r.List(ctx, podList, + client.MatchingLabels{ + constants.TensorFusionEnabledLabelKey: constants.LabelValueTrue, + fmt.Sprintf(constants.GPUNodePoolIdentifierLabelFormat, pool.Name): constants.LabelValueTrue, + }); err != nil { + return 0, 0, false, fmt.Errorf("failed to list pods: %w", err) + } + + total := len(podList.Items) + + for _, pod := range podList.Items { + if !pod.DeletionTimestamp.IsZero() { + return 0, 0, true, nil + } + + if pod.Labels[constants.LabelKeyPodTemplateHash] != configHash { + c.podsToUpdate = append(c.podsToUpdate, &pod) + } + } + + sort.Sort(ClientPodsByCreationTimestamp(c.podsToUpdate)) + + return total, total - len(c.podsToUpdate), false, nil +} + +func (c *Client) PerformBatchUpdate(r client.Client, ctx context.Context, pool *tfv1.GPUPool, delta int) (bool, error) { + log := log.FromContext(ctx) + + log.Info("perform batch update", "component", c.GetName()) + for i := range delta { + pod := c.podsToUpdate[i] + if err := r.Delete(ctx, pod); err != nil { + return false, fmt.Errorf("failed to delete pod: %w", err) + } + } + + return true, nil +} + +type ClientPodsByCreationTimestamp []*corev1.Pod + +func (o ClientPodsByCreationTimestamp) Len() int { return len(o) } +func (o ClientPodsByCreationTimestamp) Swap(i, j int) { o[i], o[j] = o[j], o[i] } +func (o ClientPodsByCreationTimestamp) Less(i, j int) bool { + if o[i].CreationTimestamp.Equal(&o[j].CreationTimestamp) { + return o[i].Name < o[j].Name + } + return o[i].CreationTimestamp.Before(&o[j].CreationTimestamp) +} diff --git a/internal/component/component.go b/internal/component/component.go new file mode 100644 index 0000000..c03e291 --- /dev/null +++ b/internal/component/component.go @@ -0,0 +1,173 @@ +package component + +import ( + "context" + "fmt" + "math" + "time" + + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" + "github.com/NexusGPU/tensor-fusion/internal/constants" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/log" +) + +type Interface interface { + GetName() string + DetectConfigChange(pool *tfv1.GPUPool, status *tfv1.PoolComponentStatus) (bool, string, string) + SetConfigHash(status *tfv1.PoolComponentStatus, hash string) + GetUpdateInProgressInfo(pool *tfv1.GPUPool) string + SetUpdateInProgressInfo(pool *tfv1.GPUPool, hash string) + GetBatchUpdateLastTimeInfo(pool *tfv1.GPUPool) string + SetBatchUpdateLastTimeInfo(pool *tfv1.GPUPool, time string) + GetUpdateProgress(status *tfv1.PoolComponentStatus) int32 + SetUpdateProgress(status *tfv1.PoolComponentStatus, progress int32) + GetResourcesInfo(r client.Client, ctx context.Context, pool *tfv1.GPUPool, hash string) (int, int, bool, error) + PerformBatchUpdate(r client.Client, ctx context.Context, pool *tfv1.GPUPool, delta int) (bool, error) +} + +func ManageUpdate(r client.Client, ctx context.Context, pool *tfv1.GPUPool, component Interface) (*ctrl.Result, error) { + log := log.FromContext(ctx) + + autoUpdate, batchInterval := getUpdatePolicy(pool) + newStatus := pool.Status.ComponentStatus.DeepCopy() + + changed, configHash, oldHash := component.DetectConfigChange(pool, newStatus) + if changed { + log.Info("component configuration changed", "component", component.GetName(), "old hash", oldHash, "new hash", configHash) + component.SetConfigHash(newStatus, configHash) + component.SetUpdateProgress(newStatus, 0) + if oldHash == "" || !autoUpdate { + return nil, patchComponentStatus(r, ctx, pool, newStatus) + } + if pool.Annotations == nil { + pool.Annotations = map[string]string{} + } + patch := client.MergeFrom(pool.DeepCopy()) + component.SetUpdateInProgressInfo(pool, configHash) + component.SetBatchUpdateLastTimeInfo(pool, "") + if err := r.Patch(ctx, pool, patch); err != nil { + return nil, fmt.Errorf("failed to patch pool: %w", err) + } + } else { + if !autoUpdate || component.GetUpdateInProgressInfo(pool) != configHash { + return nil, nil + } + if timeInfo := component.GetBatchUpdateLastTimeInfo(pool); len(timeInfo) != 0 { + lastBatchUpdateTime, err := time.Parse(time.RFC3339, timeInfo) + if err != nil { + return nil, err + } + nextBatchUpdateTime := lastBatchUpdateTime.Add(batchInterval) + if now := time.Now(); now.Before(nextBatchUpdateTime) { + log.Info("next batch update time not yet reached", "now", now, "nextBatchUpdateTime", nextBatchUpdateTime) + return &ctrl.Result{RequeueAfter: nextBatchUpdateTime.Sub(now)}, nil + } + log.Info("next batch update time reached", "BatchUpdateTime", nextBatchUpdateTime) + } + } + + totalSize, updatedSize, recheck, err := component.GetResourcesInfo(r, ctx, pool, configHash) + if err != nil { + return nil, err + } else if recheck { + return &ctrl.Result{RequeueAfter: constants.PendingRequeueDuration}, err + } else if totalSize <= 0 { + return nil, nil + } + + batchPercentage := pool.Spec.NodeManagerConfig.NodePoolRollingUpdatePolicy.BatchPercentage + updateProgress := component.GetUpdateProgress(newStatus) + delta, newUpdateProgress, currentBatchIndex := calculateDesiredUpdatedDelta(totalSize, updatedSize, batchPercentage, updateProgress) + component.SetUpdateProgress(newStatus, newUpdateProgress) + log.Info("update in progress", "component", component.GetName(), "hash", configHash, + "updateProgress", newUpdateProgress, "totalSize", totalSize, "updatedSize", updatedSize, + "batchPercentage", batchPercentage, "currentBatchIndex", currentBatchIndex, "delta", delta) + + var ctrlResult *ctrl.Result + if delta == 0 { + patch := client.MergeFrom(pool.DeepCopy()) + newUpdateProgress = min((currentBatchIndex+1)*batchPercentage, 100) + component.SetUpdateProgress(newStatus, newUpdateProgress) + if newUpdateProgress != 100 { + component.SetBatchUpdateLastTimeInfo(pool, time.Now().Format(time.RFC3339)) + interval := max(batchInterval, constants.PendingRequeueDuration) + ctrlResult = &ctrl.Result{RequeueAfter: interval} + log.Info("current batch update has completed", "progress", newUpdateProgress, "currentBatchIndex", currentBatchIndex, "nextUpdateTime", time.Now().Add(interval)) + } else { + component.SetUpdateInProgressInfo(pool, "") + component.SetBatchUpdateLastTimeInfo(pool, "") + log.Info("all batch update has completed", "component", component.GetName(), "hash", configHash) + } + if err := r.Patch(ctx, pool, patch); err != nil { + return nil, fmt.Errorf("failed to patch pool: %w", err) + } + } else if delta > 0 { + recheck, err := component.PerformBatchUpdate(r, ctx, pool, int(delta)) + if err != nil { + return nil, err + } else if recheck { + ctrlResult = &ctrl.Result{RequeueAfter: constants.PendingRequeueDuration} + } + } + + return ctrlResult, patchComponentStatus(r, ctx, pool, newStatus) +} + +func patchComponentStatus(r client.Client, ctx context.Context, pool *tfv1.GPUPool, newStatus *tfv1.PoolComponentStatus) error { + patch := client.MergeFrom(pool.DeepCopy()) + pool.Status.ComponentStatus = *newStatus + if err := r.Status().Patch(ctx, pool, patch); err != nil { + return fmt.Errorf("failed to patch pool status: %w", err) + } + return nil +} + +func getUpdatePolicy(pool *tfv1.GPUPool) (bool, time.Duration) { + autoUpdate := false + batchInterval := time.Duration(600) * time.Second + + if pool.Spec.NodeManagerConfig != nil { + updatePolicy := pool.Spec.NodeManagerConfig.NodePoolRollingUpdatePolicy + if updatePolicy != nil { + if updatePolicy.AutoUpdate != nil { + autoUpdate = *updatePolicy.AutoUpdate + } + + duration, err := time.ParseDuration(updatePolicy.BatchInterval) + if err == nil { + batchInterval = duration + } + } + } + + return autoUpdate, batchInterval +} + +func calculateDesiredUpdatedDelta(total int, updatedSize int, batchPercentage int32, updateProgress int32) (int32, int32, int32) { + batchSize := getValueFromPercent(int(batchPercentage), total, true) + var delta, desiredSize, currentBatchIndex int32 + newUpdateProgress := updateProgress + for { + currentBatchIndex = newUpdateProgress / batchPercentage + desiredSize = min((currentBatchIndex+1)*int32(batchSize), int32(total)) + delta = desiredSize - int32(updatedSize) + // if rolling udpate policy changed or new nodes were added during update, we need to update progress + if delta < 0 { + newUpdateProgress = min(newUpdateProgress+batchPercentage, 100) + } else { + break + } + } + + return delta, newUpdateProgress, currentBatchIndex +} + +func getValueFromPercent(percent int, total int, roundUp bool) int { + if roundUp { + return int(math.Ceil(float64(percent) * (float64(total)) / 100)) + } else { + return int(math.Floor(float64(percent) * (float64(total)) / 100)) + } +} diff --git a/internal/component/hypervisor.go b/internal/component/hypervisor.go new file mode 100644 index 0000000..b33d03c --- /dev/null +++ b/internal/component/hypervisor.go @@ -0,0 +1,134 @@ +package component + +import ( + "context" + "fmt" + "sort" + + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" + "github.com/NexusGPU/tensor-fusion/internal/constants" + "github.com/NexusGPU/tensor-fusion/internal/utils" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/log" +) + +const ( + HypervisorUpdateInProgressAnnotation = constants.Domain + "/hypervisor-update-in-progress" + HypervisorBatchUpdateLastTimeAnnotation = constants.Domain + "/hypervisor-batch-update-last-time" +) + +type Hypervisor struct { + nodesToUpdate []*tfv1.GPUNode +} + +func (h *Hypervisor) GetName() string { + return "hypervisor" +} + +func (h *Hypervisor) DetectConfigChange(pool *tfv1.GPUPool, status *tfv1.PoolComponentStatus) (bool, string, string) { + oldHash := status.HypervisorVersion + changed, newHash := utils.CompareAndGetObjectHash(oldHash, pool.Spec.ComponentConfig.Hypervisor) + return changed, newHash, oldHash +} + +func (h *Hypervisor) SetConfigHash(status *tfv1.PoolComponentStatus, hash string) { + status.HypervisorVersion = hash +} + +func (h *Hypervisor) GetUpdateInProgressInfo(pool *tfv1.GPUPool) string { + return pool.Annotations[HypervisorUpdateInProgressAnnotation] +} + +func (h *Hypervisor) SetUpdateInProgressInfo(pool *tfv1.GPUPool, hash string) { + pool.Annotations[HypervisorUpdateInProgressAnnotation] = hash +} + +func (h *Hypervisor) GetBatchUpdateLastTimeInfo(pool *tfv1.GPUPool) string { + return pool.Annotations[HypervisorBatchUpdateLastTimeAnnotation] +} + +func (h *Hypervisor) SetBatchUpdateLastTimeInfo(pool *tfv1.GPUPool, time string) { + pool.Annotations[HypervisorBatchUpdateLastTimeAnnotation] = time +} + +func (h *Hypervisor) GetUpdateProgress(status *tfv1.PoolComponentStatus) int32 { + return status.HyperVisorUpdateProgress +} + +func (h *Hypervisor) SetUpdateProgress(status *tfv1.PoolComponentStatus, progress int32) { + status.HyperVisorUpdateProgress = progress + status.HypervisorConfigSynced = false + if progress == 100 { + status.HypervisorConfigSynced = true + } +} + +func (h *Hypervisor) GetResourcesInfo(r client.Client, ctx context.Context, pool *tfv1.GPUPool, configHash string) (int, int, bool, error) { + log := log.FromContext(ctx) + + nodeList := &tfv1.GPUNodeList{} + if err := r.List(ctx, nodeList, client.MatchingLabels(map[string]string{ + fmt.Sprintf(constants.GPUNodePoolIdentifierLabelFormat, pool.Name): "true", + })); err != nil { + return 0, 0, false, fmt.Errorf("failed to list nodes: %w", err) + } + + total := len(nodeList.Items) + + for _, node := range nodeList.Items { + if !node.DeletionTimestamp.IsZero() { + total-- + continue + } + if node.Status.Phase == tfv1.TensorFusionGPUNodePhasePending { + log.Info("node in pending status", "name", node.Name) + return 0, 0, true, nil + } + key := client.ObjectKey{ + Namespace: utils.CurrentNamespace(), + Name: fmt.Sprintf("hypervisor-%s", node.Name), + } + pod := &corev1.Pod{} + err := r.Get(ctx, key, pod) + if errors.IsNotFound(err) || + pod.Labels[constants.LabelKeyPodTemplateHash] != configHash { + h.nodesToUpdate = append(h.nodesToUpdate, &node) + } + } + + // TODO: sort by creation time desc, need to adjust test + sort.Sort(GPUNodeByCreationTimestamp(h.nodesToUpdate)) + + return total, total - len(h.nodesToUpdate), false, nil +} + +func (h *Hypervisor) PerformBatchUpdate(r client.Client, ctx context.Context, pool *tfv1.GPUPool, delta int) (bool, error) { + log := log.FromContext(ctx) + + log.Info("perform batch update", "component", h.GetName()) + for i := range delta { + node := h.nodesToUpdate[i] + if node.Status.Phase != tfv1.TensorFusionGPUNodePhasePending { + node.Status.Phase = tfv1.TensorFusionGPUNodePhasePending + if err := r.Status().Update(ctx, node); err != nil { + return false, fmt.Errorf("failed to update node status : %w", err) + } + log.Info("node phase has been updated to pending", "node", node.Name) + } + } + + return false, nil +} + +type GPUNodeByCreationTimestamp []*tfv1.GPUNode + +func (o GPUNodeByCreationTimestamp) Len() int { return len(o) } +func (o GPUNodeByCreationTimestamp) Swap(i, j int) { o[i], o[j] = o[j], o[i] } +func (o GPUNodeByCreationTimestamp) Less(i, j int) bool { + if o[i].CreationTimestamp.Equal(&o[j].CreationTimestamp) { + return o[i].Name < o[j].Name + } + return o[i].CreationTimestamp.Before(&o[j].CreationTimestamp) +} diff --git a/internal/component/worker.go b/internal/component/worker.go new file mode 100644 index 0000000..2ac7b3b --- /dev/null +++ b/internal/component/worker.go @@ -0,0 +1,127 @@ +package component + +import ( + "context" + "fmt" + "sort" + + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" + "github.com/NexusGPU/tensor-fusion/internal/constants" + "github.com/NexusGPU/tensor-fusion/internal/utils" + "github.com/NexusGPU/tensor-fusion/internal/worker" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/log" +) + +const ( + WorkerUpdateInProgressAnnotation = constants.Domain + "/worker-update-in-progress" + WorkerBatchUpdateLastTimeAnnotation = constants.Domain + "/worker-batch-update-last-time" +) + +type Worker struct { + workloadsToUpdate []*tfv1.TensorFusionWorkload +} + +func (w *Worker) GetName() string { + return "worker" +} + +func (w *Worker) DetectConfigChange(pool *tfv1.GPUPool, status *tfv1.PoolComponentStatus) (bool, string, string) { + oldHash := status.WorkerVersion + changed, newHash := utils.CompareAndGetObjectHash(oldHash, pool.Spec.ComponentConfig.Worker) + return changed, newHash, oldHash +} + +func (w *Worker) SetConfigHash(status *tfv1.PoolComponentStatus, hash string) { + status.WorkerVersion = hash +} + +func (w *Worker) GetUpdateInProgressInfo(pool *tfv1.GPUPool) string { + return pool.Annotations[WorkerUpdateInProgressAnnotation] +} + +func (w *Worker) SetUpdateInProgressInfo(pool *tfv1.GPUPool, hash string) { + pool.Annotations[WorkerUpdateInProgressAnnotation] = hash +} + +func (w *Worker) GetBatchUpdateLastTimeInfo(pool *tfv1.GPUPool) string { + return pool.Annotations[WorkerBatchUpdateLastTimeAnnotation] +} + +func (w *Worker) SetBatchUpdateLastTimeInfo(pool *tfv1.GPUPool, time string) { + pool.Annotations[WorkerBatchUpdateLastTimeAnnotation] = time +} + +func (w *Worker) GetUpdateProgress(status *tfv1.PoolComponentStatus) int32 { + return status.WorkerUpdateProgress +} + +func (w *Worker) SetUpdateProgress(status *tfv1.PoolComponentStatus, progress int32) { + status.WorkerUpdateProgress = progress + status.WorkerConfigSynced = false + if progress == 100 { + status.WorkerConfigSynced = true + } +} + +func (w *Worker) GetResourcesInfo(r client.Client, ctx context.Context, pool *tfv1.GPUPool, configHash string) (int, int, bool, error) { + log := log.FromContext(ctx) + workloadList := &tfv1.TensorFusionWorkloadList{} + if err := r.List(ctx, workloadList, client.MatchingLabels(map[string]string{ + constants.LabelKeyOwner: pool.Name, + })); err != nil { + return 0, 0, false, fmt.Errorf("failed to list workloads : %w", err) + } + + total := len(workloadList.Items) + + workerGenerator := &worker.WorkerGenerator{WorkerConfig: pool.Spec.ComponentConfig.Worker} + for _, workload := range workloadList.Items { + if !workload.DeletionTimestamp.IsZero() { + total-- + continue + } + if workload.Status.PodTemplateHash == "" { + log.Info("workload in pending status", "name", workload.Name) + return 0, 0, true, nil + } + podTemplateHash, err := workerGenerator.PodTemplateHash(workload.Spec.Resources.Limits) + if err != nil { + return 0, 0, false, fmt.Errorf("failed to get pod template hash: %w", err) + } + if workload.Status.PodTemplateHash != podTemplateHash { + w.workloadsToUpdate = append(w.workloadsToUpdate, &workload) + } + } + + sort.Sort(TensorFusionWorkloadByCreationTimestamp(w.workloadsToUpdate)) + + return total, total - len(w.workloadsToUpdate), false, nil +} + +func (w *Worker) PerformBatchUpdate(r client.Client, ctx context.Context, pool *tfv1.GPUPool, delta int) (bool, error) { + log := log.FromContext(ctx) + log.Info("perform batch update", "component", w.GetName()) + + for i := range delta { + workload := w.workloadsToUpdate[i] + workload.Status.PodTemplateHash = "" + if err := r.Status().Update(ctx, workload); err != nil { + return false, fmt.Errorf("failed to update workload status : %w", err) + } + log.Info("workload pod template hash in status has changed", "workload", workload.Name) + } + + return true, nil +} + +type TensorFusionWorkloadByCreationTimestamp []*tfv1.TensorFusionWorkload + +func (o TensorFusionWorkloadByCreationTimestamp) Len() int { return len(o) } +func (o TensorFusionWorkloadByCreationTimestamp) Swap(i, j int) { o[i], o[j] = o[j], o[i] } +func (o TensorFusionWorkloadByCreationTimestamp) Less(i, j int) bool { + if o[i].CreationTimestamp.Equal(&o[j].CreationTimestamp) { + return o[i].Name < o[j].Name + } + return o[i].CreationTimestamp.Before(&o[j].CreationTimestamp) +} diff --git a/internal/config/deployment_mock.go b/internal/config/deployment_mock.go new file mode 100644 index 0000000..6e0d686 --- /dev/null +++ b/internal/config/deployment_mock.go @@ -0,0 +1,55 @@ +package config + +import ( + appsv1 "k8s.io/api/apps/v1" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/utils/ptr" +) + +var MockDeployment = &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pytorch-example", + Namespace: "tensor-fusion", + Labels: map[string]string{ + "app": "pytorch-example", + "tensor-fusion.ai/enabled": "true", + }, + }, + Spec: appsv1.DeploymentSpec{ + Replicas: ptr.To[int32](1), + Selector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + "app": "pytorch-example", + }, + }, + Template: corev1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{ + "app": "pytorch-example", + "tensor-fusion.ai/enabled": "true", + }, + Annotations: map[string]string{ + "tensor-fusion.ai/generate-workload": "true", + "tensor-fusion.ai/gpupool": "mock", + "tensor-fusion.ai/inject-container": "python", + "tensor-fusion.ai/replicas": "1", + "tensor-fusion.ai/tflops-limit": "10", + "tensor-fusion.ai/tflops-request": "10", + "tensor-fusion.ai/vram-limit": "1Gi", + "tensor-fusion.ai/vram-request": "1Gi", + "tensor-fusion.ai/workload": "pytorch-example", + }, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "python", + Image: "pytorch/pytorch:2.4.1-cuda12.1-cudnn9-runtime", + Command: []string{"sh", "-c", "sleep", "1d"}, + }, + }, + }, + }, + }, +} diff --git a/internal/config/gpupool_mock.go b/internal/config/gpupool_mock.go index aa3f402..63ca4ff 100644 --- a/internal/config/gpupool_mock.go +++ b/internal/config/gpupool_mock.go @@ -31,6 +31,13 @@ var MockGPUPoolSpec = &tfv1.GPUPoolSpec{ }, }, }, + NodePoolRollingUpdatePolicy: &tfv1.NodeRollingUpdatePolicy{ + AutoUpdate: ptr.To(false), + BatchPercentage: 25, + BatchInterval: "10m", + MaxDuration: "10m", + MaintenanceWindow: tfv1.MaintenanceWindow{}, + }, }, ComponentConfig: &tfv1.ComponentConfig{ Hypervisor: &tfv1.HypervisorConfig{ diff --git a/internal/controller/gpu_controller.go b/internal/controller/gpu_controller.go index 9014838..1c11df6 100644 --- a/internal/controller/gpu_controller.go +++ b/internal/controller/gpu_controller.go @@ -88,14 +88,14 @@ func (r *GPUReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.R return ctrl.Result{}, fmt.Errorf("node %s is not assigned to any pool", gpunode.Name) } + patch := client.MergeFrom(gpu.DeepCopy()) if gpu.Labels == nil { gpu.Labels = make(map[string]string) } gpu.Labels[constants.GpuPoolKey] = poolName - // update gpu - if err := r.Update(ctx, gpu); err != nil { - return ctrl.Result{}, fmt.Errorf("update gpu %s: %w", gpu.Name, err) + if err := r.Patch(ctx, gpu, patch); err != nil { + return ctrl.Result{}, fmt.Errorf("patch gpu %s: %w", gpu.Name, err) } return ctrl.Result{}, nil } diff --git a/internal/controller/gpunode_controller.go b/internal/controller/gpunode_controller.go index 43e55be..0cc54a1 100644 --- a/internal/controller/gpunode_controller.go +++ b/internal/controller/gpunode_controller.go @@ -20,6 +20,7 @@ import ( "context" "encoding/json" "fmt" + "reflect" "strings" "time" @@ -205,10 +206,12 @@ func (r *GPUNodeReconciler) checkStatusAndUpdateVirtualCapacity(ctx context.Cont // Reconcile GPUNode status with hypervisor pod status, when changed if pod.Status.Phase != corev1.PodRunning || !utils.IsPodConditionTrue(pod.Status.Conditions, corev1.PodReady) { - node.Status.Phase = tfv1.TensorFusionGPUNodePhasePending - err := r.Status().Update(ctx, node) - if err != nil { - return true, fmt.Errorf("failed to update GPU node status: %w", err) + if node.Status.Phase != tfv1.TensorFusionGPUNodePhasePending { + node.Status.Phase = tfv1.TensorFusionGPUNodePhasePending + err := r.Status().Update(ctx, node) + if err != nil { + return true, fmt.Errorf("failed to update GPU node status: %w", err) + } } // Update all GPU devices status to Pending @@ -228,6 +231,8 @@ func (r *GPUNodeReconciler) checkStatusAndUpdateVirtualCapacity(ctx context.Cont return true, nil } + statusCopy := node.Status.DeepCopy() + node.Status.AvailableVRAM = resource.Quantity{} node.Status.AvailableTFlops = resource.Quantity{} node.Status.TotalTFlops = resource.Quantity{} @@ -245,9 +250,12 @@ func (r *GPUNodeReconciler) checkStatusAndUpdateVirtualCapacity(ctx context.Cont node.Status.VirtualVRAM = virtualVRAM node.Status.Phase = tfv1.TensorFusionGPUNodePhaseRunning - err = r.Status().Update(ctx, node) - if err != nil { - return true, fmt.Errorf("failed to update GPU node status: %w", err) + + if !reflect.DeepEqual(node.Status, statusCopy) { + err = r.Status().Update(ctx, node) + if err != nil { + return true, fmt.Errorf("failed to update GPU node status: %w", err) + } } err = r.syncStatusToGPUDevices(ctx, node, tfv1.TensorFusionGPUPhaseRunning) @@ -266,9 +274,10 @@ func (r *GPUNodeReconciler) syncStatusToGPUDevices(ctx context.Context, node *tf for _, gpu := range gpuList { if gpu.Status.Phase != state { + patch := client.MergeFrom(gpu.DeepCopy()) gpu.Status.Phase = state - if err := r.Status().Update(ctx, &gpu); err != nil { - return fmt.Errorf("failed to update GPU device status: %w", err) + if err := r.Status().Patch(ctx, &gpu, patch); err != nil { + return fmt.Errorf("failed to patch GPU device status: %w", err) } } } @@ -371,8 +380,6 @@ func (r *GPUNodeReconciler) reconcileNodeDiscoveryJob( func (r *GPUNodeReconciler) reconcileHypervisorPod(ctx context.Context, node *tfv1.GPUNode, pool *tfv1.GPUPool) (string, error) { log := log.FromContext(ctx) - log.Info("reconciling hypervisor pod") - if pool.Spec.ComponentConfig == nil || pool.Spec.ComponentConfig.Hypervisor == nil { return "", fmt.Errorf("missing hypervisor config") } @@ -388,6 +395,10 @@ func (r *GPUNodeReconciler) reconcileHypervisorPod(ctx context.Context, node *tf return "", fmt.Errorf("failed to get current hypervisor pod: %w", err) } } else { + if node.Status.Phase == tfv1.TensorFusionGPUNodePhaseRunning { + return key.Name, nil + } + if !currentPod.DeletionTimestamp.IsZero() { log.Info("hypervisor pod is being deleted", "name", key.Name) return key.Name, nil @@ -575,6 +586,7 @@ func (r *GPUNodeReconciler) CalculateVirtualCapacity(node *tfv1.GPUNode, pool *t ramSize, _ := node.Status.NodeInfo.RAMSize.AsInt64() virtualVRAM := node.Status.TotalVRAM.DeepCopy() + // TODO: panic if not set TFlopsOversellRatio vTFlops := node.Status.TotalTFlops.AsApproximateFloat64() * (float64(pool.Spec.CapacityConfig.Oversubscription.TFlopsOversellRatio) / 100.0) virtualVRAM.Add(*resource.NewQuantity( diff --git a/internal/controller/gpunode_controller_test.go b/internal/controller/gpunode_controller_test.go index b1a1ee4..a9d6ed2 100644 --- a/internal/controller/gpunode_controller_test.go +++ b/internal/controller/gpunode_controller_test.go @@ -73,35 +73,6 @@ var _ = Describe("GPUNode Controller", func() { tfEnv.Cleanup() - // By("checking that it will recreate terminated hypervisor pod") - // Expect(k8sClient.Delete(ctx, pod)).Should(Succeed()) - // Eventually(func() error { - // return k8sClient.Get(ctx, types.NamespacedName{ - // Name: fmt.Sprintf("hypervisor-%s", gpuNode.Name), - // Namespace: utils.CurrentNamespace(), - // }, pod) - // }, timeout, interval).Should(Succeed()) - - // TODO: make this test pass when implement rolling udpate - // By("checking that the hypervisor config changed") - // tfc := getMockCluster(ctx) - // hypervisor := tfc.Spec.GPUPools[0].SpecTemplate.ComponentConfig.Hypervisor - // podTmpl := &corev1.PodTemplate{} - // err := json.Unmarshal(hypervisor.PodTemplate.Raw, podTmpl) - // Expect(err).NotTo(HaveOccurred()) - // podTmpl.Template.Spec.Containers[0].Name = "foo" - // hypervisor.PodTemplate.Raw = lo.Must(json.Marshal(podTmpl)) - // Expect(k8sClient.Update(ctx, tfc)).To(Succeed()) - // Eventually(func() string { - // pod := &corev1.Pod{} - // if err = k8sClient.Get(ctx, types.NamespacedName{ - // Name: fmt.Sprintf("hypervisor-%s", gpuNode.Name), - // Namespace: utils.CurrentNamespace(), - // }, pod); err != nil { - // return "" - // } - // return pod.Spec.Containers[0].Name - // }, timeout, interval).Should(Equal("foo")) }) }) }) diff --git a/internal/controller/gpupool_controller.go b/internal/controller/gpupool_controller.go index 65c43f2..efe4521 100644 --- a/internal/controller/gpupool_controller.go +++ b/internal/controller/gpupool_controller.go @@ -19,14 +19,18 @@ package controller import ( "context" "fmt" + "sort" "sync" + "time" tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" + "github.com/NexusGPU/tensor-fusion/internal/component" "github.com/NexusGPU/tensor-fusion/internal/constants" utils "github.com/NexusGPU/tensor-fusion/internal/utils" "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/api/resource" "k8s.io/apimachinery/pkg/runtime" + utilerrors "k8s.io/apimachinery/pkg/util/errors" "k8s.io/client-go/tools/record" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/builder" @@ -124,8 +128,11 @@ func (r *GPUPoolReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ct } } - // TODO, when componentConfig changed, it should notify corresponding resource to upgrade - // eg. when hypervisor changed, should change all owned GPUNode's status.phase to Updating + if ctrlResult, err := r.reconcilePoolComponents(ctx, pool); err != nil { + return ctrl.Result{}, err + } else if ctrlResult != nil { + return *ctrlResult, nil + } return ctrl.Result{}, nil } @@ -218,6 +225,47 @@ func (r *GPUPoolReconciler) reconcilePoolCurrentCapacityAndReadiness(ctx context return nil } +func (r *GPUPoolReconciler) reconcilePoolComponents(ctx context.Context, pool *tfv1.GPUPool) (*ctrl.Result, error) { + if pool.Spec.ComponentConfig == nil { + return nil, fmt.Errorf(`missing componentconfig in pool spec`) + } + + log := log.FromContext(ctx) + startTime := time.Now() + log.Info("Started reconciling components", "startTime", startTime) + defer func() { + log.Info("Finished reconciling components", "duration", time.Since(startTime)) + }() + + components := []component.Interface{ + &component.Hypervisor{}, + &component.Worker{}, + &component.Client{}, + } + + errs := []error{} + ctrlResults := []*ctrl.Result{} + for _, c := range components { + ctrlResult, err := component.ManageUpdate(r.Client, ctx, pool, c) + if err != nil { + errs = append(errs, err) + } + if ctrlResult != nil { + ctrlResults = append(ctrlResults, ctrlResult) + } + } + + var ctrlResult *ctrl.Result + if len(ctrlResults) > 0 { + sort.Slice(ctrlResults, func(i, j int) bool { + return ctrlResults[i].RequeueAfter < ctrlResults[j].RequeueAfter + }) + ctrlResult = ctrlResults[0] + } + + return ctrlResult, utilerrors.NewAggregate(errs) +} + // SetupWithManager sets up the controller with the Manager. func (r *GPUPoolReconciler) SetupWithManager(mgr ctrl.Manager) error { return ctrl.NewControllerManagedBy(mgr). diff --git a/internal/controller/gpupool_controller_test.go b/internal/controller/gpupool_controller_test.go index 5a33e5b..03c492a 100644 --- a/internal/controller/gpupool_controller_test.go +++ b/internal/controller/gpupool_controller_test.go @@ -17,9 +17,19 @@ limitations under the License. package controller import ( + "encoding/json" + "fmt" + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" + "github.com/NexusGPU/tensor-fusion/internal/constants" + "github.com/NexusGPU/tensor-fusion/internal/utils" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "github.com/samber/lo" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/utils/ptr" + "sigs.k8s.io/controller-runtime/pkg/client" ) var _ = Describe("GPUPool Controller", func() { @@ -33,13 +43,743 @@ var _ = Describe("GPUPool Controller", func() { pool := tfEnv.GetGPUPool(0) g.Expect(pool.Status.Phase).Should(Equal(tfv1.TensorFusionPoolPhaseRunning)) }, timeout, interval).Should(Succeed()) + tfEnv.Cleanup() }) }) - Context("When pool hypervisor config changed", func() { - It("Should trigger reconciliation for all gpunodes in the pool", func() { - By("changing pool hypervisor config") + Context("When reconciling hypervisor", func() { + It("Should update hypervisor status upon configuration changes", func() { + tfEnv := NewTensorFusionEnvBuilder().AddPoolWithNodeCount(0).Build() + By("verifying hypervisor status should be initialized when the gpu pool is created") + pool := tfEnv.GetGPUPool(0) + oldHash := utils.GetObjectHash(pool.Spec.ComponentConfig.Hypervisor) + Eventually(func(g Gomega) { + pool := tfEnv.GetGPUPool(0) + g.Expect(pool.Status.ComponentStatus.HypervisorVersion).To(Equal(oldHash)) + g.Expect(pool.Status.ComponentStatus.HyperVisorUpdateProgress).To(BeZero()) + g.Expect(pool.Status.ComponentStatus.HypervisorConfigSynced).To(BeFalse()) + }, timeout, interval).Should(Succeed()) + + By("verifying hypervisor version should be updated upon configuration changes") + updateHypervisorConfig(tfEnv) + Eventually(func(g Gomega) { + pool := tfEnv.GetGPUPool(0) + newHash := utils.GetObjectHash(pool.Spec.ComponentConfig.Hypervisor) + g.Expect(newHash).ShouldNot(Equal(oldHash)) + g.Expect(pool.Status.ComponentStatus.HypervisorVersion).To(Equal(newHash)) + g.Expect(pool.Status.ComponentStatus.HyperVisorUpdateProgress).To(BeZero()) + g.Expect(pool.Status.ComponentStatus.HypervisorConfigSynced).To(BeFalse()) + }, timeout, interval).Should(Succeed()) + + tfEnv.Cleanup() + }) + + It("Should not update anything if AutoUpdate is false", func() { + tfEnv := NewTensorFusionEnvBuilder(). + AddPoolWithNodeCount(1). + SetGpuCountPerNode(1). + Build() + updateRollingUpdatePolicy(tfEnv, false, 100, "3s") + _, oldHash := triggerHypervisorUpdate(tfEnv) + verifyAllHypervisorPodHashConsistently(tfEnv, oldHash) + tfEnv.Cleanup() + }) + + It("Should update according to batch interval", func() { + tfEnv := NewTensorFusionEnvBuilder(). + AddPoolWithNodeCount(2). + SetGpuCountPerNode(1). + Build() + + By("configuring a large enougth batch inteval to prevent next update batch") + updateRollingUpdatePolicy(tfEnv, true, 50, "10m") + newHash, oldHash := triggerHypervisorUpdate(tfEnv) + verifyHypervisorPodHash(tfEnv.GetGPUNode(0, 0), newHash) + verifyHypervisorPodHashConsistently(tfEnv.GetGPUNode(0, 1), oldHash) + verifyHypervisorUpdateProgressConsistently(tfEnv, 50) + + By("changing the batch inteval to trigger next update batch") + updateRollingUpdatePolicy(tfEnv, true, 50, "3s") + verifyHypervisorPodHash(tfEnv.GetGPUNode(0, 1), newHash) + verifyHypervisorUpdateProgress(tfEnv, 100) + + tfEnv.Cleanup() + }) + + It("Should pause the update according to batch interval", func() { + tfEnv := NewTensorFusionEnvBuilder(). + AddPoolWithNodeCount(2). + SetGpuCountPerNode(1). + Build() + + By("configuring a large enougth batch inteval to prevent next update batch") + updateRollingUpdatePolicy(tfEnv, true, 50, "10m") + newHash, oldHash := triggerHypervisorUpdate(tfEnv) + verifyHypervisorPodHash(tfEnv.GetGPUNode(0, 0), newHash) + verifyHypervisorUpdateProgress(tfEnv, 50) + verifyHypervisorPodHashConsistently(tfEnv.GetGPUNode(0, 1), oldHash) + verifyHypervisorUpdateProgressConsistently(tfEnv, 50) + tfEnv.Cleanup() + }) + + It("Should perform update according to batch percentage", func() { + tfEnv := NewTensorFusionEnvBuilder(). + AddPoolWithNodeCount(2). + SetGpuCountPerNode(1). + Build() + updateRollingUpdatePolicy(tfEnv, true, 50, "3s") + newHash, _ := triggerHypervisorUpdate(tfEnv) + verifyAllHypervisorPodHash(tfEnv, newHash) + verifyHypervisorUpdateProgress(tfEnv, 100) + tfEnv.Cleanup() + }) + + // It("Should perform update according to non-divisible batch percentage", func() { + // tfEnv := NewTensorFusionEnvBuilder(). + // AddPoolWithNodeCount(3). + // SetGpuCountPerNode(1). + // Build() + // updateRollingUpdatePolicy(tfEnv, true, 66, "3s") + // newHash, _ := triggerHypervisorUpdate(tfEnv) + // verifyAllHypervisorPodHash(tfEnv, newHash) + // verifyHypervisorUpdateProgress(tfEnv, 100) + // tfEnv.Cleanup() + // }) + + It("Should update all nodes at once if BatchPercentage is 100", func() { + tfEnv := NewTensorFusionEnvBuilder(). + AddPoolWithNodeCount(3). + SetGpuCountPerNode(1). + Build() + updateRollingUpdatePolicy(tfEnv, true, 100, "3s") + newHash, _ := triggerHypervisorUpdate(tfEnv) + verifyAllHypervisorPodHash(tfEnv, newHash) + verifyHypervisorUpdateProgress(tfEnv, 100) + tfEnv.Cleanup() + }) + }) + + Context("When reconciling worker", func() { + It("Should update worker status upon configuration changes", func() { + tfEnv := NewTensorFusionEnvBuilder().AddPoolWithNodeCount(0).Build() + By("verifying worker status should be initialized when the gpu pool is created") + pool := tfEnv.GetGPUPool(0) + oldHash := utils.GetObjectHash(pool.Spec.ComponentConfig.Worker) + Eventually(func(g Gomega) { + pool := tfEnv.GetGPUPool(0) + g.Expect(pool.Status.ComponentStatus.WorkerVersion).To(Equal(oldHash)) + g.Expect(pool.Status.ComponentStatus.WorkerUpdateProgress).To(BeZero()) + g.Expect(pool.Status.ComponentStatus.WorkerConfigSynced).To(BeFalse()) + }, timeout, interval).Should(Succeed()) + + By("verifying worker version should be updated upon configuration changes") + updateWorkerConfig(tfEnv) + Eventually(func(g Gomega) { + pool := tfEnv.GetGPUPool(0) + newHash := utils.GetObjectHash(pool.Spec.ComponentConfig.Worker) + g.Expect(newHash).ShouldNot(Equal(oldHash)) + g.Expect(pool.Status.ComponentStatus.WorkerVersion).To(Equal(newHash)) + g.Expect(pool.Status.ComponentStatus.WorkerUpdateProgress).To(BeZero()) + g.Expect(pool.Status.ComponentStatus.WorkerConfigSynced).To(BeFalse()) + }, timeout, interval).Should(Succeed()) + + tfEnv.Cleanup() + }) + + It("Should update according to batch interval", func() { + tfEnv := NewTensorFusionEnvBuilder(). + AddPoolWithNodeCount(2). + SetGpuCountPerNode(1). + Build() + + By("configuring a large enougth batch inteval to prevent next update batch") + updateRollingUpdatePolicy(tfEnv, true, 50, "10m") + createWorkloads(tfEnv, 2) + triggerWorkerUpdate(tfEnv) + verifyWorkerPodContainerNameConsistently(1, "tensorfusion-worker") + verifyWorkerUpdateProgressConsistently(tfEnv, 50) + + By("changing the batch inteval to trigger next update batch") + updateRollingUpdatePolicy(tfEnv, true, 50, "3s") + verifyAllWorkerPodContainerName(tfEnv, "updated-name") + verifyWorkerUpdateProgress(tfEnv, 100) + + deleteWorkloads(2) + tfEnv.Cleanup() + }) + + It("Should update according to batch percentage", func() { + tfEnv := NewTensorFusionEnvBuilder(). + AddPoolWithNodeCount(1). + SetGpuCountPerNode(2). + Build() + updateRollingUpdatePolicy(tfEnv, true, 50, "3s") + createWorkloads(tfEnv, 2) + triggerWorkerUpdate(tfEnv) + verifyAllWorkerPodContainerName(tfEnv, "updated-name") + verifyWorkerUpdateProgress(tfEnv, 100) + deleteWorkloads(2) + tfEnv.Cleanup() + }) + + It("Should update all workload at once if BatchPercentage is 100", func() { + tfEnv := NewTensorFusionEnvBuilder(). + AddPoolWithNodeCount(1). + SetGpuCountPerNode(2). + Build() + updateRollingUpdatePolicy(tfEnv, true, 100, "3s") + createWorkloads(tfEnv, 2) + triggerWorkerUpdate(tfEnv) + verifyAllWorkerPodContainerName(tfEnv, "updated-name") + verifyWorkerUpdateProgress(tfEnv, 100) + }) + }) + + Context("When reconciling client", func() { + It("Should update client status upon configuration changes", func() { + tfEnv := NewTensorFusionEnvBuilder().AddPoolWithNodeCount(0).Build() + By("verifying client status should be initialized when the gpu pool is created") + pool := tfEnv.GetGPUPool(0) + oldHash := utils.GetObjectHash(pool.Spec.ComponentConfig.Client) + Eventually(func(g Gomega) { + pool := tfEnv.GetGPUPool(0) + g.Expect(pool.Status.ComponentStatus.ClientVersion).To(Equal(oldHash)) + g.Expect(pool.Status.ComponentStatus.ClientUpdateProgress).To(BeZero()) + g.Expect(pool.Status.ComponentStatus.ClientConfigSynced).To(BeFalse()) + }, timeout, interval).Should(Succeed()) + + By("verifying client version should be updated upon configuration changes") + updateClientConfig(tfEnv) + Eventually(func(g Gomega) { + pool := tfEnv.GetGPUPool(0) + newHash := utils.GetObjectHash(pool.Spec.ComponentConfig.Client) + g.Expect(newHash).ShouldNot(Equal(oldHash)) + g.Expect(pool.Status.ComponentStatus.ClientVersion).To(Equal(newHash)) + g.Expect(pool.Status.ComponentStatus.ClientUpdateProgress).To(BeZero()) + g.Expect(pool.Status.ComponentStatus.ClientConfigSynced).To(BeFalse()) + }, timeout, interval).Should(Succeed()) + + tfEnv.Cleanup() + }) + + It("Should update according to batch interval", func() { + tfEnv := NewTensorFusionEnvBuilder(). + AddPoolWithNodeCount(2). + SetGpuCountPerNode(1). + Build() + ensureGpuPoolIsRunning(tfEnv) + + createClientPods(tfEnv, 2) + + By("configuring a large enougth batch inteval to prevent next update batch") + updateRollingUpdatePolicy(tfEnv, true, 50, "10m") + newHash, oldHash := triggerClientUpdate(tfEnv) + + verifyClientPodWasDeleted(0) + createClientPodByIndex(tfEnv, 0) + verifyClientPodHash(0, newHash) + + verifyClientPodHashConsistently(1, oldHash) + verifyClientUpdateProgressConsistently(tfEnv, 50) + + By("changing the batch inteval to trigger next update batch") + updateRollingUpdatePolicy(tfEnv, true, 50, "3s") + verifyClientPodWasDeleted(1) + createClientPodByIndex(tfEnv, 1) + verifyClientPodHash(1, newHash) + verifyClientUpdateProgress(tfEnv, 100) + + cleanupClientPods() + tfEnv.Cleanup() + }) + + It("Should update all client pods at once if BatchPercentage is 100", func() { + tfEnv := NewTensorFusionEnvBuilder(). + AddPoolWithNodeCount(1). + SetGpuCountPerNode(1). + Build() + ensureGpuPoolIsRunning(tfEnv) + updateRollingUpdatePolicy(tfEnv, true, 100, "3s") + replicas := 2 + createClientPods(tfEnv, replicas) + updateClientConfig(tfEnv) + verifyClientPodWasDeleted(0) + verifyClientPodWasDeleted(1) + createClientPods(tfEnv, replicas) + verifyClientUpdateProgress(tfEnv, 100) + + cleanupClientPods() + tfEnv.Cleanup() }) }) }) + +func triggerHypervisorUpdate(tfEnv *TensorFusionEnv) (string, string) { + GinkgoHelper() + ensureGpuPoolIsRunning(tfEnv) + oldHash := verifyGpuPoolHypervisorHash(tfEnv, "") + updateHypervisorConfig(tfEnv) + newHash := verifyGpuPoolHypervisorHash(tfEnv, oldHash) + Expect(newHash).ShouldNot(Equal(oldHash)) + return newHash, oldHash +} + +func updateHypervisorConfig(tfEnv *TensorFusionEnv) { + GinkgoHelper() + tfc := tfEnv.GetCluster() + hypervisor := tfc.Spec.GPUPools[0].SpecTemplate.ComponentConfig.Hypervisor + podTmpl := &corev1.PodTemplate{} + Expect(json.Unmarshal(hypervisor.PodTemplate.Raw, podTmpl)).Should(Succeed()) + podTmpl.Template.Spec.Containers[0].Name = "updated-name" + hypervisor.PodTemplate.Raw = lo.Must(json.Marshal(podTmpl)) + tfEnv.UpdateCluster(tfc) +} + +func updateClientConfig(tfEnv *TensorFusionEnv) { + GinkgoHelper() + tfc := tfEnv.GetCluster() + client := tfc.Spec.GPUPools[0].SpecTemplate.ComponentConfig.Client + client.OperatorEndpoint = "http://localhost:8081" + tfEnv.UpdateCluster(tfc) +} + +func triggerClientUpdate(tfEnv *TensorFusionEnv) (string, string) { + GinkgoHelper() + ensureGpuPoolIsRunning(tfEnv) + oldHash := verifyGpuPoolClientHash(tfEnv, "") + updateClientConfig(tfEnv) + newHash := verifyGpuPoolClientHash(tfEnv, oldHash) + Expect(newHash).ShouldNot(Equal(oldHash)) + return newHash, oldHash +} + +func triggerWorkerUpdate(tfEnv *TensorFusionEnv) { + GinkgoHelper() + ensureGpuPoolIsRunning(tfEnv) + oldHash := verifyGpuPoolWorkerHash(tfEnv, "") + updateWorkerConfig(tfEnv) + newHash := verifyGpuPoolWorkerHash(tfEnv, oldHash) + Expect(newHash).ShouldNot(Equal(oldHash)) +} + +func updateWorkerConfig(tfEnv *TensorFusionEnv) { + GinkgoHelper() + tfc := tfEnv.GetCluster() + worker := tfc.Spec.GPUPools[0].SpecTemplate.ComponentConfig.Worker + podTmpl := &corev1.PodTemplate{} + Expect(json.Unmarshal(worker.PodTemplate.Raw, podTmpl)).Should(Succeed()) + podTmpl.Template.Spec.Containers[0].Name = "updated-name" + worker.PodTemplate.Raw = lo.Must(json.Marshal(podTmpl)) + tfEnv.UpdateCluster(tfc) +} + +func updateRollingUpdatePolicy(tfEnv *TensorFusionEnv, autoUpdate bool, batchPercentage int32, batchInterval string) { + GinkgoHelper() + tfc := tfEnv.GetCluster() + policy := tfc.Spec.GPUPools[0].SpecTemplate.NodeManagerConfig.NodePoolRollingUpdatePolicy + policy.AutoUpdate = ptr.To(autoUpdate) + policy.BatchPercentage = batchPercentage + policy.BatchInterval = batchInterval + tfEnv.UpdateCluster(tfc) + Eventually(func(g Gomega) { + pool := tfEnv.GetGPUPool(0) + newPolicy := pool.Spec.NodeManagerConfig.NodePoolRollingUpdatePolicy + g.Expect(newPolicy.AutoUpdate).Should(Equal(policy.AutoUpdate)) + g.Expect(newPolicy.BatchPercentage).Should(Equal(policy.BatchPercentage)) + g.Expect(newPolicy.BatchInterval).Should(Equal(policy.BatchInterval)) + }, timeout, interval).Should(Succeed()) +} + +func verifyGpuPoolClientHash(tfEnv *TensorFusionEnv, oldHash string) string { + GinkgoHelper() + pool := &tfv1.GPUPool{} + Eventually(func(g Gomega) { + pool = tfEnv.GetGPUPool(0) + newHash := utils.GetObjectHash(pool.Spec.ComponentConfig.Client) + g.Expect(newHash).ShouldNot(Equal(oldHash)) + g.Expect(pool.Status.ComponentStatus.ClientVersion).To(Equal(newHash)) + }, timeout, interval).Should(Succeed()) + + return pool.Status.ComponentStatus.ClientVersion +} + +func verifyGpuPoolHypervisorHash(tfEnv *TensorFusionEnv, oldHash string) string { + GinkgoHelper() + pool := &tfv1.GPUPool{} + Eventually(func(g Gomega) { + pool = tfEnv.GetGPUPool(0) + newHash := utils.GetObjectHash(pool.Spec.ComponentConfig.Hypervisor) + g.Expect(newHash).ShouldNot(Equal(oldHash)) + g.Expect(pool.Status.ComponentStatus.HypervisorVersion).To(Equal(newHash)) + }, timeout, interval).Should(Succeed()) + + return pool.Status.ComponentStatus.HypervisorVersion +} + +func verifyGpuPoolWorkerHash(tfEnv *TensorFusionEnv, oldHash string) string { + GinkgoHelper() + pool := &tfv1.GPUPool{} + Eventually(func(g Gomega) { + pool = tfEnv.GetGPUPool(0) + newHash := utils.GetObjectHash(pool.Spec.ComponentConfig.Worker) + g.Expect(newHash).ShouldNot(Equal(oldHash)) + g.Expect(pool.Status.ComponentStatus.WorkerVersion).To(Equal(newHash)) + }, timeout, interval).Should(Succeed()) + + return pool.Status.ComponentStatus.WorkerVersion +} + +func verifyHypervisorPodHash(gpuNode *tfv1.GPUNode, hash string) { + GinkgoHelper() + Eventually(func(g Gomega) { + pod := &corev1.Pod{} + g.Expect(k8sClient.Get(ctx, client.ObjectKey{ + Name: fmt.Sprintf("hypervisor-%s", gpuNode.Name), + Namespace: utils.CurrentNamespace(), + }, pod)).Should(Succeed()) + g.Expect(pod.Labels[constants.LabelKeyPodTemplateHash]).Should(Equal(hash)) + updatePodPhaseToRunning(pod, hash) + }, timeout, interval).Should(Succeed()) +} + +func verifyClientPodHash(index int, hash string) { + GinkgoHelper() + Eventually(func(g Gomega) { + pod := &corev1.Pod{} + key := client.ObjectKey{Namespace: utils.CurrentNamespace(), Name: getClientPodName(index)} + g.Expect(k8sClient.Get(ctx, key, pod)).Should(Succeed()) + g.Expect(pod.Labels[constants.LabelKeyPodTemplateHash]).Should(Equal(hash)) + }, timeout, interval).Should(Succeed()) +} + +func verifyClientPodHashConsistently(index int, hash string) { + GinkgoHelper() + Consistently(func(g Gomega) { + pod := &corev1.Pod{} + key := client.ObjectKey{Namespace: utils.CurrentNamespace(), Name: getClientPodName(index)} + g.Expect(k8sClient.Get(ctx, key, pod)).Should(Succeed()) + g.Expect(pod.Labels[constants.LabelKeyPodTemplateHash]).Should(Equal(hash)) + }, duration, interval).Should(Succeed()) +} + +func verifyHypervisorPodHashConsistently(gpuNode *tfv1.GPUNode, hash string) { + GinkgoHelper() + Consistently(func(g Gomega) { + pod := &corev1.Pod{} + g.Expect(k8sClient.Get(ctx, client.ObjectKey{ + Name: fmt.Sprintf("hypervisor-%s", gpuNode.Name), + Namespace: utils.CurrentNamespace(), + }, pod)).Should(Succeed()) + g.Expect(pod.Labels[constants.LabelKeyPodTemplateHash]).Should(Equal(hash)) + updatePodPhaseToRunning(pod, hash) + }, duration, interval).Should(Succeed()) +} + +func verifyClientPodWasDeleted(index int) { + Eventually(func(g Gomega) { + pod := &corev1.Pod{} + key := client.ObjectKey{Namespace: utils.CurrentNamespace(), Name: getClientPodName(index)} + g.Expect(k8sClient.Get(ctx, key, pod)).ShouldNot(Succeed()) + }, timeout, interval).Should(Succeed()) +} + +func verifyAllHypervisorPodHash(tfEnv *TensorFusionEnv, hash string) { + GinkgoHelper() + Eventually(func(g Gomega) { + nodeList := tfEnv.GetGPUNodeList(0) + for _, gpuNode := range nodeList.Items { + pod := &corev1.Pod{} + g.Expect(k8sClient.Get(ctx, client.ObjectKey{ + Name: fmt.Sprintf("hypervisor-%s", gpuNode.Name), + Namespace: utils.CurrentNamespace(), + }, pod)).Should(Succeed()) + g.Expect(pod.Spec.Containers[0].Name).Should(Equal("updated-name")) + g.Expect(pod.Labels[constants.LabelKeyPodTemplateHash]).Should(Equal(hash)) + updatePodPhaseToRunning(pod, hash) + } + }, timeout, interval).Should(Succeed()) +} + +// func verifyWorkerPodContainerName(workloadIndex int, name string) { +// GinkgoHelper() +// Eventually(func(g Gomega) { +// podList := &corev1.PodList{} +// g.Expect(k8sClient.List(ctx, podList, +// client.InNamespace("default"), +// client.MatchingLabels{constants.WorkloadKey: getWorkloadName(workloadIndex)})).Should(Succeed()) +// g.Expect(podList.Items).Should(HaveLen(1)) +// for _, pod := range podList.Items { +// g.Expect(pod.Spec.Containers[0].Name).Should(Equal(name)) +// } +// }, timeout, interval).Should(Succeed()) +// } + +func verifyWorkerPodContainerNameConsistently(workloadIndex int, name string) { + GinkgoHelper() + Consistently(func(g Gomega) { + podList := &corev1.PodList{} + g.Expect(k8sClient.List(ctx, podList, + client.InNamespace("default"), + client.MatchingLabels{constants.WorkloadKey: getWorkloadName(workloadIndex)})).Should(Succeed()) + g.Expect(podList.Items).Should(HaveLen(1)) + for _, pod := range podList.Items { + g.Expect(pod.Spec.Containers[0].Name).Should(Equal(name)) + } + }, duration, interval).Should(Succeed()) +} + +func verifyAllWorkerPodContainerName(tfEnv *TensorFusionEnv, name string) { + GinkgoHelper() + pool := tfEnv.GetGPUPool(0) + Eventually(func(g Gomega) { + workloadList := &tfv1.TensorFusionWorkloadList{} + g.Expect(k8sClient.List(ctx, workloadList, client.MatchingLabels(map[string]string{ + constants.LabelKeyOwner: pool.Name, + }))).Should(Succeed()) + for _, workload := range workloadList.Items { + podList := &corev1.PodList{} + g.Expect(k8sClient.List(ctx, podList, + client.InNamespace(workload.Namespace), + client.MatchingLabels{constants.WorkloadKey: workload.Name})).Should(Succeed()) + g.Expect(podList.Items).Should(HaveLen(int(*workload.Spec.Replicas))) + for _, pod := range podList.Items { + g.Expect(pod.Spec.Containers[0].Name).Should(Equal(name)) + } + } + + }, timeout, interval).Should(Succeed()) +} + +func verifyAllHypervisorPodHashConsistently(tfEnv *TensorFusionEnv, hash string) { + GinkgoHelper() + Consistently(func(g Gomega) { + nodeList := tfEnv.GetGPUNodeList(0) + for _, gpuNode := range nodeList.Items { + pod := &corev1.Pod{} + g.Expect(k8sClient.Get(ctx, client.ObjectKey{ + Name: fmt.Sprintf("hypervisor-%s", gpuNode.Name), + Namespace: utils.CurrentNamespace(), + }, pod)).Should(Succeed()) + g.Expect(pod.Labels[constants.LabelKeyPodTemplateHash]).Should(Equal(hash)) + updatePodPhaseToRunning(pod, hash) + } + }, duration, interval).Should(Succeed()) +} + +// func verifyAllWorkerPodContainerNameConsistently(tfEnv *TensorFusionEnv, name string) { +// GinkgoHelper() +// pool := tfEnv.GetGPUPool(0) +// Consistently(func(g Gomega) { +// workloadList := &tfv1.TensorFusionWorkloadList{} +// g.Expect(k8sClient.List(ctx, workloadList, client.MatchingLabels(map[string]string{ +// constants.LabelKeyOwner: pool.Name, +// }))).Should(Succeed()) +// for _, workload := range workloadList.Items { +// podList := &corev1.PodList{} +// g.Expect(k8sClient.List(ctx, podList, +// client.InNamespace(workload.Namespace), +// client.MatchingLabels{constants.WorkloadKey: workload.Name})).Should(Succeed()) +// g.Expect(podList.Items).Should(HaveLen(int(*workload.Spec.Replicas))) +// for _, pod := range podList.Items { +// g.Expect(pod.Spec.Containers[0].Name).Should(Equal(name)) +// } +// } + +// }, duration, interval).Should(Succeed()) +// } + +func verifyHypervisorUpdateProgress(tfEnv *TensorFusionEnv, progress int32) { + GinkgoHelper() + Eventually(func(g Gomega) { + pool := tfEnv.GetGPUPool(0) + g.Expect(pool.Status.ComponentStatus.HyperVisorUpdateProgress).To(Equal(progress)) + if progress == 100 { + g.Expect(pool.Status.ComponentStatus.HypervisorConfigSynced).To(BeTrue()) + } else { + g.Expect(pool.Status.ComponentStatus.HypervisorConfigSynced).To(BeFalse()) + } + }, timeout, interval).Should(Succeed()) +} + +func verifyWorkerUpdateProgress(tfEnv *TensorFusionEnv, progress int32) { + GinkgoHelper() + Eventually(func(g Gomega) { + pool := tfEnv.GetGPUPool(0) + g.Expect(pool.Status.ComponentStatus.WorkerUpdateProgress).To(Equal(progress)) + if progress == 100 { + g.Expect(pool.Status.ComponentStatus.WorkerConfigSynced).To(BeTrue()) + } else { + g.Expect(pool.Status.ComponentStatus.WorkerConfigSynced).To(BeFalse()) + } + }, timeout, interval).Should(Succeed()) +} + +func verifyClientUpdateProgress(tfEnv *TensorFusionEnv, progress int32) { + GinkgoHelper() + Eventually(func(g Gomega) { + pool := tfEnv.GetGPUPool(0) + g.Expect(pool.Status.ComponentStatus.ClientUpdateProgress).To(Equal(progress)) + if progress == 100 { + g.Expect(pool.Status.ComponentStatus.ClientConfigSynced).To(BeTrue()) + } else { + g.Expect(pool.Status.ComponentStatus.ClientConfigSynced).To(BeFalse()) + } + }, timeout, interval).Should(Succeed()) +} + +func verifyClientUpdateProgressConsistently(tfEnv *TensorFusionEnv, progress int32) { + GinkgoHelper() + Consistently(func(g Gomega) { + pool := tfEnv.GetGPUPool(0) + g.Expect(pool.Status.ComponentStatus.ClientUpdateProgress).To(Equal(progress)) + if progress == 100 { + g.Expect(pool.Status.ComponentStatus.ClientConfigSynced).To(BeTrue()) + } else { + g.Expect(pool.Status.ComponentStatus.ClientConfigSynced).To(BeFalse()) + } + }, duration, interval).Should(Succeed()) +} + +func verifyHypervisorUpdateProgressConsistently(tfEnv *TensorFusionEnv, progress int32) { + GinkgoHelper() + Eventually(func(g Gomega) { + pool := tfEnv.GetGPUPool(0) + g.Expect(pool.Status.ComponentStatus.HyperVisorUpdateProgress).To(Equal(progress)) + if progress == 100 { + g.Expect(pool.Status.ComponentStatus.HypervisorConfigSynced).To(BeTrue()) + } else { + g.Expect(pool.Status.ComponentStatus.HypervisorConfigSynced).To(BeFalse()) + } + }, timeout, interval).Should(Succeed()) +} + +func verifyWorkerUpdateProgressConsistently(tfEnv *TensorFusionEnv, progress int32) { + GinkgoHelper() + Eventually(func(g Gomega) { + pool := tfEnv.GetGPUPool(0) + g.Expect(pool.Status.ComponentStatus.WorkerUpdateProgress).To(Equal(progress)) + if progress == 100 { + g.Expect(pool.Status.ComponentStatus.WorkerConfigSynced).To(BeTrue()) + } else { + g.Expect(pool.Status.ComponentStatus.WorkerConfigSynced).To(BeFalse()) + } + }, duration, interval).Should(Succeed()) +} + +// no pod controller in EnvTest, need to manually update pod status +func updatePodPhaseToRunning(pod *corev1.Pod, hash string) { + GinkgoHelper() + if pod.Labels[constants.LabelKeyPodTemplateHash] == hash && pod.Status.Phase != corev1.PodRunning { + patch := client.MergeFrom(pod.DeepCopy()) + pod.Status.Phase = corev1.PodRunning + pod.Status.Conditions = append(pod.Status.Conditions, corev1.PodCondition{Type: corev1.PodReady, Status: corev1.ConditionTrue}) + Expect(k8sClient.Status().Patch(ctx, pod, patch)).Should(Succeed()) + } +} + +func ensureGpuPoolIsRunning(tfEnv *TensorFusionEnv) { + GinkgoHelper() + Eventually(func(g Gomega) { + pool := tfEnv.GetGPUPool(0) + g.Expect(pool.Status.Phase).Should(Equal(tfv1.TensorFusionPoolPhaseRunning)) + }, timeout, interval).Should(Succeed()) +} + +// no RepliaSet like controller in EnvTest, need to create by ourself +func createClientPodByIndex(tfEnv *TensorFusionEnv, index int) { + GinkgoHelper() + pool := tfEnv.GetGPUPool(0) + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: getClientPodName(index), + Namespace: utils.CurrentNamespace(), + Labels: map[string]string{ + constants.TensorFusionEnabledLabelKey: constants.LabelValueTrue, + fmt.Sprintf(constants.GPUNodePoolIdentifierLabelFormat, pool.Name): constants.LabelValueTrue, + constants.LabelKeyPodTemplateHash: utils.GetObjectHash(pool.Spec.ComponentConfig.Client), + }, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "mock", + Image: "mock", + }, + }, + }, + } + Expect(k8sClient.Create(ctx, pod)).Should(Succeed()) + + Eventually(func(g Gomega) { + pod := &corev1.Pod{} + key := client.ObjectKey{Namespace: utils.CurrentNamespace(), Name: getClientPodName(index)} + g.Expect(k8sClient.Get(ctx, key, pod)).Should(Succeed()) + }, timeout, interval).Should(Succeed()) +} + +func createClientPods(tfEnv *TensorFusionEnv, count int) { + GinkgoHelper() + pool := tfEnv.GetGPUPool(0) + for i := range count { + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: getClientPodName(i), + Namespace: utils.CurrentNamespace(), + Labels: map[string]string{ + constants.TensorFusionEnabledLabelKey: constants.LabelValueTrue, + fmt.Sprintf(constants.GPUNodePoolIdentifierLabelFormat, pool.Name): constants.LabelValueTrue, + constants.LabelKeyPodTemplateHash: utils.GetObjectHash(pool.Spec.ComponentConfig.Client), + }, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "mock", + Image: "mock", + }, + }, + }, + } + Expect(k8sClient.Create(ctx, pod)).Should(Succeed()) + } + + Eventually(func(g Gomega) { + for i := range count { + pod := &corev1.Pod{} + key := client.ObjectKey{Namespace: utils.CurrentNamespace(), Name: getClientPodName(i)} + g.Expect(k8sClient.Get(ctx, key, pod)).Should(Succeed()) + } + }, timeout, interval).Should(Succeed()) +} + +func cleanupClientPods() { + Expect(k8sClient.DeleteAllOf(ctx, &corev1.Pod{}, client.InNamespace(utils.CurrentNamespace()))).Should(Succeed()) +} + +func createWorkloads(tfEnv *TensorFusionEnv, count int) { + GinkgoHelper() + pool := tfEnv.GetGPUPool(0) + for workloadIndex := range count { + key := client.ObjectKey{Name: getWorkloadName(workloadIndex), Namespace: "default"} + replicas := 1 + workload := createTensorFusionWorkload(pool.Name, key, replicas) + checkWorkerPodCount(workload) + } +} + +func deleteWorkloads(count int) { + GinkgoHelper() + for workloadIndex := range count { + key := client.ObjectKey{Name: getWorkloadName(workloadIndex), Namespace: "default"} + cleanupWorkload(key) + } +} + +func getWorkloadName(index int) string { + return fmt.Sprintf("workload-%d", index) +} + +func getClientPodName(index int) string { + return fmt.Sprintf("client-%d", index) +} diff --git a/internal/controller/suite_test.go b/internal/controller/suite_test.go index 11d48c8..5fcb9da 100644 --- a/internal/controller/suite_test.go +++ b/internal/controller/suite_test.go @@ -62,6 +62,7 @@ var cancel context.CancelFunc const ( timeout = time.Second * 10 + duration = time.Second * 5 interval = time.Millisecond * 100 ) @@ -224,6 +225,7 @@ type TensorFusionEnv struct { } func (c *TensorFusionEnv) GetCluster() *tfv1.TensorFusionCluster { + GinkgoHelper() tfc := &tfv1.TensorFusionCluster{} Eventually(func(g Gomega) { g.Expect(k8sClient.Get(ctx, c.clusterKey, tfc)).Should(Succeed()) @@ -232,10 +234,12 @@ func (c *TensorFusionEnv) GetCluster() *tfv1.TensorFusionCluster { } func (c *TensorFusionEnv) UpdateCluster(tfc *tfv1.TensorFusionCluster) { + GinkgoHelper() Expect(k8sClient.Update(ctx, tfc)).Should(Succeed()) } func (c *TensorFusionEnv) Cleanup() { + GinkgoHelper() for poolIndex, nodeGpuMap := range c.poolNodeMap { for nodeIndex := range nodeGpuMap { c.DeleteGPUNode(poolIndex, nodeIndex) @@ -264,6 +268,7 @@ func (c *TensorFusionEnv) Cleanup() { } func (c *TensorFusionEnv) GetGPUPoolList() *tfv1.GPUPoolList { + GinkgoHelper() poolList := &tfv1.GPUPoolList{} Eventually(func(g Gomega) { g.Expect(k8sClient.List(ctx, poolList, client.MatchingLabels(map[string]string{ @@ -275,14 +280,16 @@ func (c *TensorFusionEnv) GetGPUPoolList() *tfv1.GPUPoolList { } func (c *TensorFusionEnv) GetGPUPool(poolIndex int) *tfv1.GPUPool { + GinkgoHelper() pool := &tfv1.GPUPool{} Eventually(func(g Gomega) { g.Expect(k8sClient.Get(ctx, client.ObjectKey{Name: c.getPoolName(poolIndex)}, pool)).Should(Succeed()) - }).Should(Succeed()) + }, timeout, interval).Should(Succeed()) return pool } func (c *TensorFusionEnv) GetGPUNodeList(poolIndex int) *tfv1.GPUNodeList { + GinkgoHelper() nodeList := &tfv1.GPUNodeList{} Eventually(func(g Gomega) { g.Expect(k8sClient.List(ctx, nodeList, client.MatchingLabels(map[string]string{ @@ -294,6 +301,7 @@ func (c *TensorFusionEnv) GetGPUNodeList(poolIndex int) *tfv1.GPUNodeList { } func (c *TensorFusionEnv) GetGPUNode(poolIndex int, nodeIndex int) *tfv1.GPUNode { + GinkgoHelper() node := &tfv1.GPUNode{} Eventually(func(g Gomega) { g.Expect(k8sClient.Get(ctx, client.ObjectKey{Name: c.getNodeName(poolIndex, nodeIndex)}, node)).Should(Succeed()) @@ -302,6 +310,7 @@ func (c *TensorFusionEnv) GetGPUNode(poolIndex int, nodeIndex int) *tfv1.GPUNode } func (c *TensorFusionEnv) DeleteGPUNode(poolIndex int, nodeIndex int) { + GinkgoHelper() c.DeleteNodeGpuList(poolIndex, nodeIndex) node := c.GetGPUNode(poolIndex, nodeIndex) Expect(k8sClient.Delete(ctx, node)).Should(Succeed()) @@ -312,6 +321,7 @@ func (c *TensorFusionEnv) DeleteGPUNode(poolIndex int, nodeIndex int) { } func (c *TensorFusionEnv) GetNodeGpuList(poolIndex int, nodeIndex int) *tfv1.GPUList { + GinkgoHelper() gpuList := &tfv1.GPUList{} Eventually(func(g Gomega) { g.Expect(k8sClient.List(ctx, gpuList, client.MatchingLabels(map[string]string{ @@ -323,12 +333,14 @@ func (c *TensorFusionEnv) GetNodeGpuList(poolIndex int, nodeIndex int) *tfv1.GPU } func (c *TensorFusionEnv) DeleteNodeGpuList(poolIndex int, nodeIndex int) { + GinkgoHelper() Expect(k8sClient.DeleteAllOf(ctx, &tfv1.GPU{}, client.MatchingLabels{constants.LabelKeyOwner: c.getNodeName(poolIndex, nodeIndex)}, )).Should(Succeed()) } func (c *TensorFusionEnv) GetPoolGpuList(poolIndex int) *tfv1.GPUList { + GinkgoHelper() gpuList := &tfv1.GPUList{} poolGpuCount := 0 for _, gpuCount := range c.poolNodeMap[poolIndex] { @@ -348,6 +360,7 @@ func (c *TensorFusionEnv) GetPoolGpuList(poolIndex int) *tfv1.GPUList { // So the checkStatusAndUpdateVirtualCapacity in gpunode_controller.go checking pod status always pending and the gpunode status can't change to running // When using an existing cluster, the test speed go a lot faster, may change later? func (c *TensorFusionEnv) UpdateHypervisorStatus() { + GinkgoHelper() if os.Getenv("USE_EXISTING_CLUSTER") != "true" { for poolIndex := range c.poolNodeMap { podList := &corev1.PodList{} @@ -422,6 +435,7 @@ func (b *TensorFusionEnvBuilder) SetGpuCountForNode(nodeIndex int, gpuCount int) var testEnvId int = 0 func (b *TensorFusionEnvBuilder) Build() *TensorFusionEnv { + GinkgoHelper() b.clusterKey = client.ObjectKey{ Name: fmt.Sprintf("cluster-%d", testEnvId), Namespace: "default", @@ -458,7 +472,7 @@ func (b *TensorFusionEnvBuilder) Build() *TensorFusionEnv { constants.LabelKeyOwner: tfc.Name, }))).Should(Succeed()) g.Expect(gpuPoolList.Items).Should(HaveLen(b.poolCount)) - }, timeout*1000, interval).Should(Succeed()) + }, timeout, interval).Should(Succeed()) // generate nodes selectors := strings.Split(constants.InitialGPUNodeSelector, "=") diff --git a/internal/utils/reconcile.go b/internal/utils/reconcile.go index fab149e..ec6f063 100644 --- a/internal/utils/reconcile.go +++ b/internal/utils/reconcile.go @@ -137,12 +137,17 @@ func GetObjectHash(objs ...any) string { return fmt.Sprintf("%x", hasher.Sum(nil)) } +func CompareAndGetObjectHash(hash string, obj ...any) (bool, string) { + newHash := GetObjectHash(obj...) + return hash != newHash, newHash +} + const DebounceKeySuffix = ":in_queue" func DebouncedReconcileCheck(ctx context.Context, lastProcessedItems *sync.Map, name types.NamespacedName) (runNow bool, alreadyQueued bool, waitTime time.Duration) { const ( // Minimum time between reconciliations for the same object - debounceInterval = 5 * time.Second + debounceInterval = 3 * time.Second ) now := time.Now() key := name.String() diff --git a/internal/webhook/v1/pod_webhook.go b/internal/webhook/v1/pod_webhook.go index 8fd9f36..575326d 100644 --- a/internal/webhook/v1/pod_webhook.go +++ b/internal/webhook/v1/pod_webhook.go @@ -124,6 +124,12 @@ func (m *TensorFusionPodMutator) Handle(ctx context.Context, req admission.Reque nodeSelector = workloadStatus.NodeSelector } + if pod.Labels == nil { + pod.Labels = map[string]string{} + } + pod.Labels[constants.LabelKeyPodTemplateHash] = utils.GetObjectHash(pool.Spec.ComponentConfig) + pod.Labels[fmt.Sprintf(constants.GPUNodePoolIdentifierLabelFormat, pool.Name)] = constants.LabelValueTrue + // Inject initContainer and env variables patches, err := m.patchTFClient(pod, pool.Spec.ComponentConfig.Client, tfInfo.ContainerNames, nodeSelector) if err != nil { @@ -173,6 +179,9 @@ func (m *TensorFusionPodMutator) createOrUpdateWorkload(ctx context.Context, pod ObjectMeta: metav1.ObjectMeta{ Name: tfInfo.WorkloadName, Namespace: pod.Namespace, + Labels: map[string]string{ + constants.LabelKeyOwner: tfInfo.Profile.PoolName, + }, }, Spec: tfv1.TensorFusionWorkloadSpec{ Replicas: &replicas,