diff --git a/pkg/model-serving-controller/plugins/lws_labels_plugin.go b/pkg/model-serving-controller/plugins/lws_labels_plugin.go new file mode 100644 index 000000000..3d9011a98 --- /dev/null +++ b/pkg/model-serving-controller/plugins/lws_labels_plugin.go @@ -0,0 +1,110 @@ +/* +Copyright The Volcano Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package plugins + +import ( + "context" + "fmt" + "strconv" + + workloadv1alpha1 "github.com/volcano-sh/kthena/pkg/apis/workload/v1alpha1" + "github.com/volcano-sh/kthena/pkg/model-serving-controller/utils" +) + +const ( + // LWSLabelsPluginName is the registered name of the LWS labels plugin. + LWSLabelsPluginName = "lws-labels" + + // Standard LWS label keys as defined by the LeaderWorkerSet API. + LWSLabelName = "leaderworkerset.sigs.k8s.io/name" + LWSLabelGroupIndex = "leaderworkerset.sigs.k8s.io/group-index" + LWSLabelWorkerIndex = "leaderworkerset.sigs.k8s.io/worker-index" + LWSLabelGroupKey = "leaderworkerset.sigs.k8s.io/group-key" +) + +// LWSLabelsPlugin injects standard LeaderWorkerSet labels into pods +// created for LWS workloads, enabling compatibility with the LWS +// ecosystem (monitoring, logging, controllers). +type LWSLabelsPlugin struct { + name string +} + +func init() { + DefaultRegistry.Register(LWSLabelsPluginName, NewLWSLabelsPlugin) +} + +// NewLWSLabelsPlugin constructs the LWS labels plugin from a PluginSpec. +// This plugin does not require any configuration. +func NewLWSLabelsPlugin(spec workloadv1alpha1.PluginSpec) (Plugin, error) { + return &LWSLabelsPlugin{name: spec.Name}, nil +} + +func (p *LWSLabelsPlugin) Name() string { return p.name } + +// OnPodCreate injects the four standard LWS labels into the pod. +// Labels are merged safely: existing user-defined labels are never overwritten. +func (p *LWSLabelsPlugin) OnPodCreate(_ context.Context, req *HookRequest) error { + if req == nil || req.Pod == nil || req.ModelServing == nil { + return nil + } + + // Derive label values from the HookRequest context. + lwsName := req.ModelServing.Name + + // Extract group index from the serving group name (e.g. "my-lws-0" → "0"). + _, groupIndex := utils.GetParentNameAndOrdinal(req.ServingGroup) + if groupIndex < 0 { + return fmt.Errorf("cannot extract group index from serving group name %q", req.ServingGroup) + } + groupIndexStr := strconv.Itoa(groupIndex) + + // Extract worker index from the pod name (trailing ordinal, e.g. "my-lws-0-default-0-1" → "1"). + _, workerIndex := utils.GetParentNameAndOrdinal(req.Pod.Name) + if workerIndex < 0 { + return fmt.Errorf("cannot extract worker index from pod name %q", req.Pod.Name) + } + workerIndexStr := strconv.Itoa(workerIndex) + + // Group key uniquely identifies the group within the LWS. + groupKey := fmt.Sprintf("%s-%s", lwsName, groupIndexStr) + + // Ensure labels map is initialized. + if req.Pod.Labels == nil { + req.Pod.Labels = map[string]string{} + } + + // Merge safely: do not overwrite existing user-defined labels. + setIfAbsent(req.Pod.Labels, LWSLabelName, lwsName) + setIfAbsent(req.Pod.Labels, LWSLabelGroupIndex, groupIndexStr) + setIfAbsent(req.Pod.Labels, LWSLabelWorkerIndex, workerIndexStr) + setIfAbsent(req.Pod.Labels, LWSLabelGroupKey, groupKey) + + return nil +} + +// OnPodReady is a no-op for the LWS labels plugin. +func (p *LWSLabelsPlugin) OnPodReady(_ context.Context, _ *HookRequest) error { + return nil +} + +// setIfAbsent sets a label only if the key is not already present, +// preserving any user-defined value. +func setIfAbsent(labels map[string]string, key, value string) { + if _, exists := labels[key]; !exists { + labels[key] = value + } +} diff --git a/pkg/model-serving-controller/plugins/lws_labels_plugin_test.go b/pkg/model-serving-controller/plugins/lws_labels_plugin_test.go new file mode 100644 index 000000000..538f66d49 --- /dev/null +++ b/pkg/model-serving-controller/plugins/lws_labels_plugin_test.go @@ -0,0 +1,211 @@ +/* +Copyright The Volcano Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package plugins + +import ( + "context" + "testing" + + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + workloadv1alpha1 "github.com/volcano-sh/kthena/pkg/apis/workload/v1alpha1" +) + +func TestLWSLabelsPluginOnPodCreate(t *testing.T) { + tests := []struct { + name string + req *HookRequest + expectLabels map[string]string + expectError bool + expectNoChange bool // true if req/pod is nil and no mutation expected + }{ + { + name: "entry pod gets all four LWS labels", + req: &HookRequest{ + ModelServing: &workloadv1alpha1.ModelServing{ + ObjectMeta: metav1.ObjectMeta{Name: "my-lws"}, + }, + ServingGroup: "my-lws-0", + RoleName: "default", + RoleID: "default-0", + IsEntry: true, + Pod: &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "my-lws-0-default-0-0", + Labels: map[string]string{}, + }, + }, + }, + expectLabels: map[string]string{ + LWSLabelName: "my-lws", + LWSLabelGroupIndex: "0", + LWSLabelWorkerIndex: "0", + LWSLabelGroupKey: "my-lws-0", + }, + }, + { + name: "worker pod gets correct worker index", + req: &HookRequest{ + ModelServing: &workloadv1alpha1.ModelServing{ + ObjectMeta: metav1.ObjectMeta{Name: "my-lws"}, + }, + ServingGroup: "my-lws-2", + RoleName: "default", + RoleID: "default-0", + IsEntry: false, + Pod: &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "my-lws-2-default-0-3", + Labels: map[string]string{}, + }, + }, + }, + expectLabels: map[string]string{ + LWSLabelName: "my-lws", + LWSLabelGroupIndex: "2", + LWSLabelWorkerIndex: "3", + LWSLabelGroupKey: "my-lws-2", + }, + }, + { + name: "existing user labels are not overwritten", + req: &HookRequest{ + ModelServing: &workloadv1alpha1.ModelServing{ + ObjectMeta: metav1.ObjectMeta{Name: "my-lws"}, + }, + ServingGroup: "my-lws-0", + Pod: &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "my-lws-0-default-0-0", + Labels: map[string]string{ + LWSLabelName: "user-override", + LWSLabelGroupIndex: "99", + "custom-label": "keep-me", + }, + }, + }, + }, + expectLabels: map[string]string{ + LWSLabelName: "user-override", // preserved + LWSLabelGroupIndex: "99", // preserved + LWSLabelWorkerIndex: "0", // injected + LWSLabelGroupKey: "my-lws-0", // injected + "custom-label": "keep-me", // untouched + }, + }, + { + name: "nil labels map is initialized", + req: &HookRequest{ + ModelServing: &workloadv1alpha1.ModelServing{ + ObjectMeta: metav1.ObjectMeta{Name: "test"}, + }, + ServingGroup: "test-1", + Pod: &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-1-role-0-0", + }, + }, + }, + expectLabels: map[string]string{ + LWSLabelName: "test", + LWSLabelGroupIndex: "1", + LWSLabelWorkerIndex: "0", + LWSLabelGroupKey: "test-1", + }, + }, + { + name: "nil request is safe", + req: nil, + expectNoChange: true, + }, + { + name: "nil pod is safe", + req: &HookRequest{ + ModelServing: &workloadv1alpha1.ModelServing{ + ObjectMeta: metav1.ObjectMeta{Name: "x"}, + }, + Pod: nil, + }, + expectNoChange: true, + }, + { + name: "nil ModelServing is safe", + req: &HookRequest{ + ModelServing: nil, + Pod: &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{Name: "x-0-r-0-0"}, + }, + }, + expectNoChange: true, + }, + } + + plugin := &LWSLabelsPlugin{name: LWSLabelsPluginName} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := plugin.OnPodCreate(context.Background(), tt.req) + + if tt.expectError && err == nil { + t.Fatalf("expected error, got nil") + } + if !tt.expectError && err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tt.expectNoChange { + return + } + + pod := tt.req.Pod + for key, want := range tt.expectLabels { + got, ok := pod.Labels[key] + if !ok { + t.Errorf("label %s missing", key) + } else if got != want { + t.Errorf("label %s = %q, want %q", key, got, want) + } + } + }) + } +} + +func TestLWSLabelsPluginReadyNoop(t *testing.T) { + plugin := &LWSLabelsPlugin{name: LWSLabelsPluginName} + if err := plugin.OnPodReady(context.Background(), &HookRequest{}); err != nil { + t.Fatalf("expected no error, got %v", err) + } +} + +func TestLWSLabelsPluginRegistration(t *testing.T) { + factory, ok := DefaultRegistry.factories[LWSLabelsPluginName] + if !ok { + t.Fatalf("plugin %s not registered in DefaultRegistry", LWSLabelsPluginName) + } + + spec := workloadv1alpha1.PluginSpec{ + Name: LWSLabelsPluginName, + Type: workloadv1alpha1.PluginTypeBuiltIn, + } + p, err := factory(spec) + if err != nil { + t.Fatalf("factory returned error: %v", err) + } + if p.Name() != LWSLabelsPluginName { + t.Fatalf("plugin name = %q, want %q", p.Name(), LWSLabelsPluginName) + } +}