Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions pkg/model-serving-controller/plugins/lws_labels_plugin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*

Check failure on line 1 in pkg/model-serving-controller/plugins/lws_labels_plugin.go

View workflow job for this annotation

GitHub Actions / build

File is not properly formatted (gofmt)
Copyright The Volcano Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package plugins

import (
"context"
"fmt"
"strconv"

workloadv1alpha1 "github.com/volcano-sh/kthena/pkg/apis/workload/v1alpha1"
"github.com/volcano-sh/kthena/pkg/model-serving-controller/utils"
)

const (
// LWSLabelsPluginName is the registered name of the LWS labels plugin.
LWSLabelsPluginName = "lws-labels"

// Standard LWS label keys as defined by the LeaderWorkerSet API.
LWSLabelName = "leaderworkerset.sigs.k8s.io/name"
LWSLabelGroupIndex = "leaderworkerset.sigs.k8s.io/group-index"
LWSLabelWorkerIndex = "leaderworkerset.sigs.k8s.io/worker-index"
LWSLabelGroupKey = "leaderworkerset.sigs.k8s.io/group-key"
)

// LWSLabelsPlugin injects standard LeaderWorkerSet labels into pods
// created for LWS workloads, enabling compatibility with the LWS
// ecosystem (monitoring, logging, controllers).
type LWSLabelsPlugin struct {
name string
}

func init() {
DefaultRegistry.Register(LWSLabelsPluginName, NewLWSLabelsPlugin)
}

// NewLWSLabelsPlugin constructs the LWS labels plugin from a PluginSpec.
// This plugin does not require any configuration.
func NewLWSLabelsPlugin(spec workloadv1alpha1.PluginSpec) (Plugin, error) {
return &LWSLabelsPlugin{name: spec.Name}, nil
}
Comment on lines +46 to +54
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As implemented, this plugin only runs when it is explicitly listed in ModelServing.spec.plugins. The current LWS->ModelServing translation (pkg/model-serving-controller/controller/lws_controller.go:constructModelServing) does not add this plugin, so LWS-originated pods will still miss the leaderworkerset.sigs.k8s.io/* labels by default. Consider wiring this in (e.g., have the LWS controller inject a built-in PluginSpec{Name: "lws-labels", Type: BuiltIn} or add defaulting based on the LWS ownerRef), otherwise the PR won't actually fix #759 as described.

Copilot uses AI. Check for mistakes.

func (p *LWSLabelsPlugin) Name() string { return p.name }

// OnPodCreate injects the four standard LWS labels into the pod.
// Labels are merged safely: existing user-defined labels are never overwritten.
func (p *LWSLabelsPlugin) OnPodCreate(_ context.Context, req *HookRequest) error {
if req == nil || req.Pod == nil || req.ModelServing == nil {
return nil
}

// Derive label values from the HookRequest context.
lwsName := req.ModelServing.Name

// Extract group index from the serving group name (e.g. "my-lws-0" → "0").
_, groupIndex := utils.GetParentNameAndOrdinal(req.ServingGroup)
if groupIndex < 0 {
return fmt.Errorf("cannot extract group index from serving group name %q", req.ServingGroup)
}
groupIndexStr := strconv.Itoa(groupIndex)

// Extract worker index from the pod name (trailing ordinal, e.g. "my-lws-0-default-0-1" → "1").
_, workerIndex := utils.GetParentNameAndOrdinal(req.Pod.Name)
if workerIndex < 0 {
return fmt.Errorf("cannot extract worker index from pod name %q", req.Pod.Name)
}
Comment on lines +60 to +79
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OnPodCreate returns an error if it cannot parse req.ServingGroup or req.Pod.Name. Because plugin errors abort pod creation (createPod returns the error), this makes a best-effort label-injection feature capable of breaking workloads. Prefer treating parsing failures as a no-op (or gating on a known LWS marker/ownerRef) and only injecting labels when the expected naming pattern is present.

Copilot uses AI. Check for mistakes.
workerIndexStr := strconv.Itoa(workerIndex)

// Group key uniquely identifies the group within the LWS.
groupKey := fmt.Sprintf("%s-%s", lwsName, groupIndexStr)
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

groupKey is recomputed from ModelServing.Name + parsed group index, but the serving group name is already available as req.ServingGroup (and is what pods/services are keyed off elsewhere). Using req.ServingGroup directly would be more robust (e.g., if the ModelServing name ever diverges from the serving group prefix) and avoids constructing a potentially inconsistent key.

Suggested change
groupKey := fmt.Sprintf("%s-%s", lwsName, groupIndexStr)
groupKey := req.ServingGroup

Copilot uses AI. Check for mistakes.

// Ensure labels map is initialized.
if req.Pod.Labels == nil {
req.Pod.Labels = map[string]string{}
}

// Merge safely: do not overwrite existing user-defined labels.
setIfAbsent(req.Pod.Labels, LWSLabelName, lwsName)
setIfAbsent(req.Pod.Labels, LWSLabelGroupIndex, groupIndexStr)
setIfAbsent(req.Pod.Labels, LWSLabelWorkerIndex, workerIndexStr)
setIfAbsent(req.Pod.Labels, LWSLabelGroupKey, groupKey)

return nil
}

// OnPodReady is a no-op for the LWS labels plugin.
func (p *LWSLabelsPlugin) OnPodReady(_ context.Context, _ *HookRequest) error {
return nil
}

// setIfAbsent sets a label only if the key is not already present,
// preserving any user-defined value.
func setIfAbsent(labels map[string]string, key, value string) {
if _, exists := labels[key]; !exists {
labels[key] = value
}
}
211 changes: 211 additions & 0 deletions pkg/model-serving-controller/plugins/lws_labels_plugin_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
/*

Check failure on line 1 in pkg/model-serving-controller/plugins/lws_labels_plugin_test.go

View workflow job for this annotation

GitHub Actions / build

File is not properly formatted (gofmt)
Copyright The Volcano Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package plugins

import (
"context"
"testing"

corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

workloadv1alpha1 "github.com/volcano-sh/kthena/pkg/apis/workload/v1alpha1"
)

func TestLWSLabelsPluginOnPodCreate(t *testing.T) {
tests := []struct {
name string
req *HookRequest
expectLabels map[string]string
expectError bool
expectNoChange bool // true if req/pod is nil and no mutation expected
}{
{
name: "entry pod gets all four LWS labels",
req: &HookRequest{
ModelServing: &workloadv1alpha1.ModelServing{
ObjectMeta: metav1.ObjectMeta{Name: "my-lws"},
},
ServingGroup: "my-lws-0",
RoleName: "default",
RoleID: "default-0",
IsEntry: true,
Pod: &corev1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: "my-lws-0-default-0-0",
Labels: map[string]string{},
},
},
},
expectLabels: map[string]string{
LWSLabelName: "my-lws",
LWSLabelGroupIndex: "0",
LWSLabelWorkerIndex: "0",
LWSLabelGroupKey: "my-lws-0",
},
},
{
name: "worker pod gets correct worker index",
req: &HookRequest{
ModelServing: &workloadv1alpha1.ModelServing{
ObjectMeta: metav1.ObjectMeta{Name: "my-lws"},
},
ServingGroup: "my-lws-2",
RoleName: "default",
RoleID: "default-0",
IsEntry: false,
Pod: &corev1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: "my-lws-2-default-0-3",
Labels: map[string]string{},
},
},
},
expectLabels: map[string]string{
LWSLabelName: "my-lws",
LWSLabelGroupIndex: "2",
LWSLabelWorkerIndex: "3",
LWSLabelGroupKey: "my-lws-2",
},
},
{
name: "existing user labels are not overwritten",
req: &HookRequest{
ModelServing: &workloadv1alpha1.ModelServing{
ObjectMeta: metav1.ObjectMeta{Name: "my-lws"},
},
ServingGroup: "my-lws-0",
Pod: &corev1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: "my-lws-0-default-0-0",
Labels: map[string]string{
LWSLabelName: "user-override",
LWSLabelGroupIndex: "99",
"custom-label": "keep-me",
},
},
},
},
expectLabels: map[string]string{
LWSLabelName: "user-override", // preserved
LWSLabelGroupIndex: "99", // preserved
LWSLabelWorkerIndex: "0", // injected
LWSLabelGroupKey: "my-lws-0", // injected
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test case asserts that group-index can be user-overridden ("99") while the plugin still injects group-key based on the parsed serving group name ("my-lws-0"). That produces an internally inconsistent label set (group-key no longer corresponds to group-index), which can break LWS tooling. Either (a) derive group-key from the effective group-index value when present, or (b) skip injecting group-key when group-index is already set to something else, and adjust the expectation accordingly.

Suggested change
LWSLabelGroupKey: "my-lws-0", // injected

Copilot uses AI. Check for mistakes.
"custom-label": "keep-me", // untouched
},
},
{
name: "nil labels map is initialized",
req: &HookRequest{
ModelServing: &workloadv1alpha1.ModelServing{
ObjectMeta: metav1.ObjectMeta{Name: "test"},
},
ServingGroup: "test-1",
Pod: &corev1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: "test-1-role-0-0",
},
},
},
expectLabels: map[string]string{
LWSLabelName: "test",
LWSLabelGroupIndex: "1",
LWSLabelWorkerIndex: "0",
LWSLabelGroupKey: "test-1",
},
},
{
name: "nil request is safe",
req: nil,
expectNoChange: true,
},
{
name: "nil pod is safe",
req: &HookRequest{
ModelServing: &workloadv1alpha1.ModelServing{
ObjectMeta: metav1.ObjectMeta{Name: "x"},
},
Pod: nil,
},
expectNoChange: true,
},
{
name: "nil ModelServing is safe",
req: &HookRequest{
ModelServing: nil,
Pod: &corev1.Pod{
ObjectMeta: metav1.ObjectMeta{Name: "x-0-r-0-0"},
},
},
expectNoChange: true,
},
}

plugin := &LWSLabelsPlugin{name: LWSLabelsPluginName}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := plugin.OnPodCreate(context.Background(), tt.req)

if tt.expectError && err == nil {
t.Fatalf("expected error, got nil")
}
if !tt.expectError && err != nil {
t.Fatalf("unexpected error: %v", err)
}
if tt.expectNoChange {
return
}

pod := tt.req.Pod
for key, want := range tt.expectLabels {
got, ok := pod.Labels[key]
if !ok {
t.Errorf("label %s missing", key)
} else if got != want {
t.Errorf("label %s = %q, want %q", key, got, want)
}
}
Comment on lines +175 to +182
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For more robust test assertions, it's better to compare the entire labels map using reflect.DeepEqual. This ensures that no unexpected labels are added and that the final state of the labels is exactly as expected. The current loop-based check only verifies that a subset of expected labels exist, but it wouldn't catch any extra, erroneously added labels.

Note: You will need to import the reflect package to use this function.

if !reflect.DeepEqual(pod.Labels, tt.expectLabels) {
	t.Errorf("labels mismatch.\nGot:  %v\nWant: %v", pod.Labels, tt.expectLabels)
}

})
}
}

func TestLWSLabelsPluginReadyNoop(t *testing.T) {
plugin := &LWSLabelsPlugin{name: LWSLabelsPluginName}
if err := plugin.OnPodReady(context.Background(), &HookRequest{}); err != nil {
t.Fatalf("expected no error, got %v", err)
}
}

func TestLWSLabelsPluginRegistration(t *testing.T) {
factory, ok := DefaultRegistry.factories[LWSLabelsPluginName]
if !ok {
t.Fatalf("plugin %s not registered in DefaultRegistry", LWSLabelsPluginName)
}

spec := workloadv1alpha1.PluginSpec{
Name: LWSLabelsPluginName,
Type: workloadv1alpha1.PluginTypeBuiltIn,
}
p, err := factory(spec)
if err != nil {
t.Fatalf("factory returned error: %v", err)
}
if p.Name() != LWSLabelsPluginName {
t.Fatalf("plugin name = %q, want %q", p.Name(), LWSLabelsPluginName)
}
}
Loading