Skip to content

feat: implement component update #170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ vet: ## Run go vet against code.

.PHONY: test
test: manifests generate fmt vet envtest ## Run tests.
KUBEBUILDER_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) --bin-dir $(LOCALBIN) -p path)" go test $$(go list ./... | grep -v /e2e) -coverprofile cover.out
KUBEBUILDER_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) --bin-dir $(LOCALBIN) -p path)" go test $$(go list ./... | grep -v /e2e) -timeout 0 -coverprofile cover.out

# TODO(user): To use a different vendor for e2e tests, modify the setup under 'tests/e2e'.
# The default setup assumes Kind is pre-installed and builds/loads the Manager Docker image locally.
Expand Down
117 changes: 117 additions & 0 deletions internal/component/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package component

import (
"context"
"fmt"
"sort"

tfv1 "github.com/NexusGPU/tensor-fusion/api/v1"
"github.com/NexusGPU/tensor-fusion/internal/constants"
"github.com/NexusGPU/tensor-fusion/internal/utils"
corev1 "k8s.io/api/core/v1"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/log"
)

const (
ClientUpdateInProgressAnnotation = constants.Domain + "/client-update-in-progress"
ClientBatchUpdateLastTimeAnnotation = constants.Domain + "/client-batch-update-last-time"
)

type Client struct {
podsToUpdate []*corev1.Pod
}

func (c *Client) GetName() string {
return "client"
}

func (c *Client) DetectConfigChange(pool *tfv1.GPUPool, status *tfv1.PoolComponentStatus) (bool, string, string) {
oldHash := status.ClientVersion
changed, newHash := utils.CompareAndGetObjectHash(oldHash, pool.Spec.ComponentConfig.Client)
return changed, newHash, oldHash
}

func (c *Client) SetConfigHash(status *tfv1.PoolComponentStatus, hash string) {
status.ClientVersion = hash
}

func (c *Client) GetUpdateInProgressInfo(pool *tfv1.GPUPool) string {
return pool.Annotations[ClientUpdateInProgressAnnotation]
}

func (c *Client) SetUpdateInProgressInfo(pool *tfv1.GPUPool, hash string) {
pool.Annotations[ClientUpdateInProgressAnnotation] = hash
}

func (c *Client) GetBatchUpdateLastTimeInfo(pool *tfv1.GPUPool) string {
return pool.Annotations[ClientBatchUpdateLastTimeAnnotation]
}

func (c *Client) SetBatchUpdateLastTimeInfo(pool *tfv1.GPUPool, time string) {
pool.Annotations[ClientBatchUpdateLastTimeAnnotation] = time
}

func (c *Client) GetUpdateProgress(status *tfv1.PoolComponentStatus) int32 {
return status.ClientUpdateProgress
}

func (c *Client) SetUpdateProgress(status *tfv1.PoolComponentStatus, progress int32) {
status.ClientUpdateProgress = progress
status.ClientConfigSynced = false
if progress == 100 {
status.ClientConfigSynced = true
}
}

func (c *Client) GetResourcesInfo(r client.Client, ctx context.Context, pool *tfv1.GPUPool, configHash string) (int, int, bool, error) {
podList := &corev1.PodList{}
if err := r.List(ctx, podList,
client.MatchingLabels{
constants.TensorFusionEnabledLabelKey: constants.LabelValueTrue,
fmt.Sprintf(constants.GPUNodePoolIdentifierLabelFormat, pool.Name): constants.LabelValueTrue,
}); err != nil {
return 0, 0, false, fmt.Errorf("failed to list pods: %w", err)
}

total := len(podList.Items)

for _, pod := range podList.Items {
if !pod.DeletionTimestamp.IsZero() {
return 0, 0, true, nil
}

if pod.Labels[constants.LabelKeyPodTemplateHash] != configHash {
c.podsToUpdate = append(c.podsToUpdate, &pod)
}
}

sort.Sort(ClientPodsByCreationTimestamp(c.podsToUpdate))

return total, total - len(c.podsToUpdate), false, nil
}

func (c *Client) PerformBatchUpdate(r client.Client, ctx context.Context, pool *tfv1.GPUPool, delta int) (bool, error) {
log := log.FromContext(ctx)

log.Info("perform batch update", "component", c.GetName())
for i := range delta {
pod := c.podsToUpdate[i]
if err := r.Delete(ctx, pod); err != nil {
return false, fmt.Errorf("failed to delete pod: %w", err)
}
}

return true, nil
}

type ClientPodsByCreationTimestamp []*corev1.Pod

func (o ClientPodsByCreationTimestamp) Len() int { return len(o) }
func (o ClientPodsByCreationTimestamp) Swap(i, j int) { o[i], o[j] = o[j], o[i] }
func (o ClientPodsByCreationTimestamp) Less(i, j int) bool {
if o[i].CreationTimestamp.Equal(&o[j].CreationTimestamp) {
return o[i].Name < o[j].Name
}
return o[i].CreationTimestamp.Before(&o[j].CreationTimestamp)
}
173 changes: 173 additions & 0 deletions internal/component/component.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
package component

import (
"context"
"fmt"
"math"
"time"

tfv1 "github.com/NexusGPU/tensor-fusion/api/v1"
"github.com/NexusGPU/tensor-fusion/internal/constants"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/log"
)

type Interface interface {
GetName() string
DetectConfigChange(pool *tfv1.GPUPool, status *tfv1.PoolComponentStatus) (bool, string, string)
SetConfigHash(status *tfv1.PoolComponentStatus, hash string)
GetUpdateInProgressInfo(pool *tfv1.GPUPool) string
SetUpdateInProgressInfo(pool *tfv1.GPUPool, hash string)
GetBatchUpdateLastTimeInfo(pool *tfv1.GPUPool) string
SetBatchUpdateLastTimeInfo(pool *tfv1.GPUPool, time string)
GetUpdateProgress(status *tfv1.PoolComponentStatus) int32
SetUpdateProgress(status *tfv1.PoolComponentStatus, progress int32)
GetResourcesInfo(r client.Client, ctx context.Context, pool *tfv1.GPUPool, hash string) (int, int, bool, error)
PerformBatchUpdate(r client.Client, ctx context.Context, pool *tfv1.GPUPool, delta int) (bool, error)
}

func ManageUpdate(r client.Client, ctx context.Context, pool *tfv1.GPUPool, component Interface) (*ctrl.Result, error) {
log := log.FromContext(ctx)

autoUpdate, batchInterval := getUpdatePolicy(pool)
newStatus := pool.Status.ComponentStatus.DeepCopy()

changed, configHash, oldHash := component.DetectConfigChange(pool, newStatus)
if changed {
log.Info("component configuration changed", "component", component.GetName(), "old hash", oldHash, "new hash", configHash)
component.SetConfigHash(newStatus, configHash)
component.SetUpdateProgress(newStatus, 0)
if oldHash == "" || !autoUpdate {
return nil, patchComponentStatus(r, ctx, pool, newStatus)
}
if pool.Annotations == nil {
pool.Annotations = map[string]string{}
}
patch := client.MergeFrom(pool.DeepCopy())
component.SetUpdateInProgressInfo(pool, configHash)
component.SetBatchUpdateLastTimeInfo(pool, "")
if err := r.Patch(ctx, pool, patch); err != nil {
return nil, fmt.Errorf("failed to patch pool: %w", err)
}
} else {
if !autoUpdate || component.GetUpdateInProgressInfo(pool) != configHash {
return nil, nil
}
if timeInfo := component.GetBatchUpdateLastTimeInfo(pool); len(timeInfo) != 0 {
lastBatchUpdateTime, err := time.Parse(time.RFC3339, timeInfo)
if err != nil {
return nil, err
}
nextBatchUpdateTime := lastBatchUpdateTime.Add(batchInterval)
if now := time.Now(); now.Before(nextBatchUpdateTime) {
log.Info("next batch update time not yet reached", "now", now, "nextBatchUpdateTime", nextBatchUpdateTime)
return &ctrl.Result{RequeueAfter: nextBatchUpdateTime.Sub(now)}, nil
}
log.Info("next batch update time reached", "BatchUpdateTime", nextBatchUpdateTime)
}
}

totalSize, updatedSize, recheck, err := component.GetResourcesInfo(r, ctx, pool, configHash)
if err != nil {
return nil, err
} else if recheck {
return &ctrl.Result{RequeueAfter: constants.PendingRequeueDuration}, err
} else if totalSize <= 0 {
return nil, nil
}

batchPercentage := pool.Spec.NodeManagerConfig.NodePoolRollingUpdatePolicy.BatchPercentage
updateProgress := component.GetUpdateProgress(newStatus)
delta, newUpdateProgress, currentBatchIndex := calculateDesiredUpdatedDelta(totalSize, updatedSize, batchPercentage, updateProgress)
component.SetUpdateProgress(newStatus, newUpdateProgress)
log.Info("update in progress", "component", component.GetName(), "hash", configHash,
"updateProgress", newUpdateProgress, "totalSize", totalSize, "updatedSize", updatedSize,
"batchPercentage", batchPercentage, "currentBatchIndex", currentBatchIndex, "delta", delta)

var ctrlResult *ctrl.Result
if delta == 0 {
patch := client.MergeFrom(pool.DeepCopy())
newUpdateProgress = min((currentBatchIndex+1)*batchPercentage, 100)
component.SetUpdateProgress(newStatus, newUpdateProgress)
if newUpdateProgress != 100 {
component.SetBatchUpdateLastTimeInfo(pool, time.Now().Format(time.RFC3339))
interval := max(batchInterval, constants.PendingRequeueDuration)
ctrlResult = &ctrl.Result{RequeueAfter: interval}
log.Info("current batch update has completed", "progress", newUpdateProgress, "currentBatchIndex", currentBatchIndex, "nextUpdateTime", time.Now().Add(interval))
} else {
component.SetUpdateInProgressInfo(pool, "")
component.SetBatchUpdateLastTimeInfo(pool, "")
log.Info("all batch update has completed", "component", component.GetName(), "hash", configHash)
}
if err := r.Patch(ctx, pool, patch); err != nil {
return nil, fmt.Errorf("failed to patch pool: %w", err)
}
} else if delta > 0 {
recheck, err := component.PerformBatchUpdate(r, ctx, pool, int(delta))
if err != nil {
return nil, err
} else if recheck {
ctrlResult = &ctrl.Result{RequeueAfter: constants.PendingRequeueDuration}
}
}

return ctrlResult, patchComponentStatus(r, ctx, pool, newStatus)
}

func patchComponentStatus(r client.Client, ctx context.Context, pool *tfv1.GPUPool, newStatus *tfv1.PoolComponentStatus) error {
patch := client.MergeFrom(pool.DeepCopy())
pool.Status.ComponentStatus = *newStatus
if err := r.Status().Patch(ctx, pool, patch); err != nil {
return fmt.Errorf("failed to patch pool status: %w", err)
}
return nil
}

func getUpdatePolicy(pool *tfv1.GPUPool) (bool, time.Duration) {
autoUpdate := false
batchInterval := time.Duration(600) * time.Second

if pool.Spec.NodeManagerConfig != nil {
updatePolicy := pool.Spec.NodeManagerConfig.NodePoolRollingUpdatePolicy
if updatePolicy != nil {
if updatePolicy.AutoUpdate != nil {
autoUpdate = *updatePolicy.AutoUpdate
}

duration, err := time.ParseDuration(updatePolicy.BatchInterval)
if err == nil {
batchInterval = duration
}
}
}

return autoUpdate, batchInterval
}

func calculateDesiredUpdatedDelta(total int, updatedSize int, batchPercentage int32, updateProgress int32) (int32, int32, int32) {
batchSize := getValueFromPercent(int(batchPercentage), total, true)
var delta, desiredSize, currentBatchIndex int32
newUpdateProgress := updateProgress
for {
currentBatchIndex = newUpdateProgress / batchPercentage
desiredSize = min((currentBatchIndex+1)*int32(batchSize), int32(total))
delta = desiredSize - int32(updatedSize)
// if rolling udpate policy changed or new nodes were added during update, we need to update progress
if delta < 0 {
newUpdateProgress = min(newUpdateProgress+batchPercentage, 100)
} else {
break
}
}

return delta, newUpdateProgress, currentBatchIndex
}

func getValueFromPercent(percent int, total int, roundUp bool) int {
if roundUp {
return int(math.Ceil(float64(percent) * (float64(total)) / 100))
} else {
return int(math.Floor(float64(percent) * (float64(total)) / 100))
}
}
Loading