Skip to content

Commit 23a8c85

Browse files
committed
feat: implement component update
1 parent 30d43f7 commit 23a8c85

15 files changed

+1448
-52
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ vet: ## Run go vet against code.
6262

6363
.PHONY: test
6464
test: manifests generate fmt vet envtest ## Run tests.
65-
KUBEBUILDER_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) --bin-dir $(LOCALBIN) -p path)" go test $$(go list ./... | grep -v /e2e) -coverprofile cover.out
65+
KUBEBUILDER_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) --bin-dir $(LOCALBIN) -p path)" go test $$(go list ./... | grep -v /e2e) -timeout 0 -coverprofile cover.out
6666

6767
# TODO(user): To use a different vendor for e2e tests, modify the setup under 'tests/e2e'.
6868
# The default setup assumes Kind is pre-installed and builds/loads the Manager Docker image locally.

internal/component/client.go

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
package component
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"sort"
7+
8+
tfv1 "github.com/NexusGPU/tensor-fusion/api/v1"
9+
"github.com/NexusGPU/tensor-fusion/internal/constants"
10+
"github.com/NexusGPU/tensor-fusion/internal/utils"
11+
corev1 "k8s.io/api/core/v1"
12+
"sigs.k8s.io/controller-runtime/pkg/client"
13+
"sigs.k8s.io/controller-runtime/pkg/log"
14+
)
15+
16+
const (
17+
ClientUpdateInProgressAnnotation = constants.Domain + "/client-update-in-progress"
18+
ClientBatchUpdateLastTimeAnnotation = constants.Domain + "/client-batch-update-last-time"
19+
)
20+
21+
type Client struct {
22+
podsToUpdate []*corev1.Pod
23+
}
24+
25+
func (c *Client) GetName() string {
26+
return "client"
27+
}
28+
29+
func (c *Client) DetectConfigChange(pool *tfv1.GPUPool, status *tfv1.PoolComponentStatus) (bool, string, string) {
30+
oldHash := status.ClientVersion
31+
changed, newHash := utils.CompareAndGetObjectHash(oldHash, pool.Spec.ComponentConfig.Client)
32+
return changed, newHash, oldHash
33+
}
34+
35+
func (c *Client) SetConfigHash(status *tfv1.PoolComponentStatus, hash string) {
36+
status.ClientVersion = hash
37+
}
38+
39+
func (c *Client) GetUpdateInProgressInfo(pool *tfv1.GPUPool) string {
40+
return pool.Annotations[ClientUpdateInProgressAnnotation]
41+
}
42+
43+
func (c *Client) SetUpdateInProgressInfo(pool *tfv1.GPUPool, hash string) {
44+
pool.Annotations[ClientUpdateInProgressAnnotation] = hash
45+
}
46+
47+
func (c *Client) GetBatchUpdateLastTimeInfo(pool *tfv1.GPUPool) string {
48+
return pool.Annotations[ClientBatchUpdateLastTimeAnnotation]
49+
}
50+
51+
func (c *Client) SetBatchUpdateLastTimeInfo(pool *tfv1.GPUPool, time string) {
52+
pool.Annotations[ClientBatchUpdateLastTimeAnnotation] = time
53+
}
54+
55+
func (c *Client) GetUpdateProgress(status *tfv1.PoolComponentStatus) int32 {
56+
return status.ClientUpdateProgress
57+
}
58+
59+
func (c *Client) SetUpdateProgress(status *tfv1.PoolComponentStatus, progress int32) {
60+
status.ClientUpdateProgress = progress
61+
status.ClientConfigSynced = false
62+
if progress == 100 {
63+
status.ClientConfigSynced = true
64+
}
65+
}
66+
67+
func (c *Client) GetResourcesInfo(r client.Client, ctx context.Context, pool *tfv1.GPUPool, configHash string) (int, int, bool, error) {
68+
podList := &corev1.PodList{}
69+
if err := r.List(ctx, podList,
70+
client.MatchingLabels{
71+
constants.TensorFusionEnabledLabelKey: constants.LabelValueTrue,
72+
fmt.Sprintf(constants.GPUNodePoolIdentifierLabelFormat, pool.Name): constants.LabelValueTrue,
73+
}); err != nil {
74+
return 0, 0, false, fmt.Errorf("failed to list pods: %w", err)
75+
}
76+
77+
total := len(podList.Items)
78+
79+
for _, pod := range podList.Items {
80+
if !pod.DeletionTimestamp.IsZero() {
81+
return 0, 0, true, nil
82+
}
83+
84+
if pod.Labels[constants.LabelKeyPodTemplateHash] != configHash {
85+
c.podsToUpdate = append(c.podsToUpdate, &pod)
86+
}
87+
}
88+
89+
sort.Sort(ClientPodsByCreationTimestamp(c.podsToUpdate))
90+
91+
return total, total - len(c.podsToUpdate), false, nil
92+
}
93+
94+
func (c *Client) PerformBatchUpdate(r client.Client, ctx context.Context, pool *tfv1.GPUPool, delta int) (bool, error) {
95+
log := log.FromContext(ctx)
96+
97+
log.Info("perform batch update", "component", c.GetName())
98+
for i := range int(delta) {
99+
pod := c.podsToUpdate[i]
100+
if err := r.Delete(ctx, pod); err != nil {
101+
return false, fmt.Errorf("failed to delete pod: %w", err)
102+
}
103+
}
104+
105+
return true, nil
106+
}
107+
108+
type ClientPodsByCreationTimestamp []*corev1.Pod
109+
110+
func (o ClientPodsByCreationTimestamp) Len() int { return len(o) }
111+
func (o ClientPodsByCreationTimestamp) Swap(i, j int) { o[i], o[j] = o[j], o[i] }
112+
func (o ClientPodsByCreationTimestamp) Less(i, j int) bool {
113+
if o[i].CreationTimestamp.Equal(&o[j].CreationTimestamp) {
114+
return o[i].Name < o[j].Name
115+
}
116+
return o[i].CreationTimestamp.Before(&o[j].CreationTimestamp)
117+
}

internal/component/component.go

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
package component
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"math"
7+
"time"
8+
9+
tfv1 "github.com/NexusGPU/tensor-fusion/api/v1"
10+
"github.com/NexusGPU/tensor-fusion/internal/constants"
11+
ctrl "sigs.k8s.io/controller-runtime"
12+
"sigs.k8s.io/controller-runtime/pkg/client"
13+
"sigs.k8s.io/controller-runtime/pkg/log"
14+
)
15+
16+
type Interface interface {
17+
GetName() string
18+
DetectConfigChange(pool *tfv1.GPUPool, status *tfv1.PoolComponentStatus) (bool, string, string)
19+
SetConfigHash(status *tfv1.PoolComponentStatus, hash string)
20+
GetUpdateInProgressInfo(pool *tfv1.GPUPool) string
21+
SetUpdateInProgressInfo(pool *tfv1.GPUPool, hash string)
22+
GetBatchUpdateLastTimeInfo(pool *tfv1.GPUPool) string
23+
SetBatchUpdateLastTimeInfo(pool *tfv1.GPUPool, time string)
24+
GetUpdateProgress(status *tfv1.PoolComponentStatus) int32
25+
SetUpdateProgress(status *tfv1.PoolComponentStatus, progress int32)
26+
GetResourcesInfo(r client.Client, ctx context.Context, pool *tfv1.GPUPool, hash string) (int, int, bool, error)
27+
PerformBatchUpdate(r client.Client, ctx context.Context, pool *tfv1.GPUPool, delta int) (bool, error)
28+
}
29+
30+
func ManageUpdate(r client.Client, ctx context.Context, pool *tfv1.GPUPool, component Interface) (*ctrl.Result, error) {
31+
log := log.FromContext(ctx)
32+
33+
autoUpdate, batchInterval := getUpdatePolicy(pool)
34+
newStatus := pool.Status.ComponentStatus.DeepCopy()
35+
36+
changed, configHash, oldHash := component.DetectConfigChange(pool, newStatus)
37+
if changed {
38+
log.Info("component configuration changed", "component", component.GetName(), "old hash", oldHash, "new hash", configHash)
39+
component.SetConfigHash(newStatus, configHash)
40+
component.SetUpdateProgress(newStatus, 0)
41+
if oldHash == "" || autoUpdate == false {
42+
return nil, patchComponentStatus(r, ctx, pool, newStatus)
43+
}
44+
if pool.Annotations == nil {
45+
pool.Annotations = map[string]string{}
46+
}
47+
patch := client.MergeFrom(pool.DeepCopy())
48+
component.SetUpdateInProgressInfo(pool, configHash)
49+
component.SetBatchUpdateLastTimeInfo(pool, "")
50+
if err := r.Patch(ctx, pool, patch); err != nil {
51+
return nil, fmt.Errorf("failed to patch pool: %w", err)
52+
}
53+
} else {
54+
if autoUpdate == false || component.GetUpdateInProgressInfo(pool) != configHash {
55+
return nil, nil
56+
}
57+
if timeInfo := component.GetBatchUpdateLastTimeInfo(pool); len(timeInfo) != 0 {
58+
lastBatchUpdateTime, err := time.Parse(time.RFC3339, timeInfo)
59+
if err != nil {
60+
return nil, err
61+
}
62+
nextBatchUpdateTime := lastBatchUpdateTime.Add(batchInterval)
63+
if now := time.Now(); now.Before(nextBatchUpdateTime) {
64+
log.Info("next batch update time not yet reached", "now", now, "nextBatchUpdateTime", nextBatchUpdateTime)
65+
return &ctrl.Result{RequeueAfter: nextBatchUpdateTime.Sub(now)}, nil
66+
}
67+
log.Info("next batch update time reached", "BatchUpdateTime", nextBatchUpdateTime)
68+
}
69+
}
70+
71+
totalSize, updatedSize, recheck, err := component.GetResourcesInfo(r, ctx, pool, configHash)
72+
if err != nil {
73+
return nil, err
74+
} else if recheck {
75+
return &ctrl.Result{RequeueAfter: constants.PendingRequeueDuration}, err
76+
} else if totalSize <= 0 {
77+
return nil, nil
78+
}
79+
80+
batchPercentage := pool.Spec.NodeManagerConfig.NodePoolRollingUpdatePolicy.BatchPercentage
81+
updateProgress := component.GetUpdateProgress(newStatus)
82+
delta, newUpdateProgress, currentBatchIndex := calculateDesiredUpdatedDelta(totalSize, updatedSize, batchPercentage, updateProgress)
83+
component.SetUpdateProgress(newStatus, newUpdateProgress)
84+
log.Info("update in progress", "component", component.GetName(), "hash", configHash,
85+
"updateProgress", newUpdateProgress, "totalSize", totalSize, "updatedSize", updatedSize,
86+
"batchPercentage", batchPercentage, "currentBatchIndex", currentBatchIndex, "delta", delta)
87+
88+
var ctrlResult *ctrl.Result
89+
if delta == 0 {
90+
patch := client.MergeFrom(pool.DeepCopy())
91+
newUpdateProgress = min((currentBatchIndex+1)*batchPercentage, 100)
92+
component.SetUpdateProgress(newStatus, newUpdateProgress)
93+
if newUpdateProgress != 100 {
94+
component.SetBatchUpdateLastTimeInfo(pool, time.Now().Format(time.RFC3339))
95+
interval := max(batchInterval, constants.PendingRequeueDuration)
96+
ctrlResult = &ctrl.Result{RequeueAfter: interval}
97+
log.Info("current batch update has completed", "progress", newUpdateProgress, "currentBatchIndex", currentBatchIndex, "nextUpdateTime", time.Now().Add(interval))
98+
} else {
99+
component.SetUpdateInProgressInfo(pool, "")
100+
component.SetBatchUpdateLastTimeInfo(pool, "")
101+
log.Info("all batch update has completed", "component", component.GetName(), "hash", configHash)
102+
}
103+
if err := r.Patch(ctx, pool, patch); err != nil {
104+
return nil, fmt.Errorf("failed to patch pool: %w", err)
105+
}
106+
} else if delta > 0 {
107+
recheck, err := component.PerformBatchUpdate(r, ctx, pool, int(delta))
108+
if err != nil {
109+
return nil, err
110+
} else if recheck {
111+
ctrlResult = &ctrl.Result{RequeueAfter: constants.PendingRequeueDuration}
112+
}
113+
}
114+
115+
return ctrlResult, patchComponentStatus(r, ctx, pool, newStatus)
116+
}
117+
118+
func patchComponentStatus(r client.Client, ctx context.Context, pool *tfv1.GPUPool, newStatus *tfv1.PoolComponentStatus) error {
119+
patch := client.MergeFrom(pool.DeepCopy())
120+
pool.Status.ComponentStatus = *newStatus
121+
if err := r.Status().Patch(ctx, pool, patch); err != nil {
122+
return fmt.Errorf("failed to patch pool status: %w", err)
123+
}
124+
return nil
125+
}
126+
127+
func getUpdatePolicy(pool *tfv1.GPUPool) (bool, time.Duration) {
128+
autoUpdate := false
129+
batchInterval := time.Duration(600) * time.Second
130+
131+
if pool.Spec.NodeManagerConfig != nil {
132+
updatePolicy := pool.Spec.NodeManagerConfig.NodePoolRollingUpdatePolicy
133+
if updatePolicy != nil {
134+
if updatePolicy.AutoUpdate != nil {
135+
autoUpdate = *updatePolicy.AutoUpdate
136+
}
137+
138+
duration, err := time.ParseDuration(updatePolicy.BatchInterval)
139+
if err == nil {
140+
batchInterval = duration
141+
}
142+
}
143+
}
144+
145+
return autoUpdate, batchInterval
146+
}
147+
148+
func calculateDesiredUpdatedDelta(total int, updatedSize int, batchPercentage int32, updateProgress int32) (int32, int32, int32) {
149+
batchSize := getValueFromPercent(int(batchPercentage), total, true)
150+
var delta, desiredSize, currentBatchIndex int32
151+
newUpdateProgress := updateProgress
152+
for {
153+
currentBatchIndex = newUpdateProgress / batchPercentage
154+
desiredSize = min((currentBatchIndex+1)*int32(batchSize), int32(total))
155+
delta = desiredSize - int32(updatedSize)
156+
// if rolling udpate policy changed or new nodes were added during update, we need to update progress
157+
if delta < 0 {
158+
newUpdateProgress = min(newUpdateProgress+batchPercentage, 100)
159+
} else {
160+
break
161+
}
162+
}
163+
164+
return delta, newUpdateProgress, currentBatchIndex
165+
}
166+
167+
func getValueFromPercent(percent int, total int, roundUp bool) int {
168+
if roundUp {
169+
return int(math.Ceil(float64(percent) * (float64(total)) / 100))
170+
} else {
171+
return int(math.Floor(float64(percent) * (float64(total)) / 100))
172+
}
173+
}

0 commit comments

Comments
 (0)