diff --git a/go.mod b/go.mod index 21ace7d39..7f9dda235 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,6 @@ module github.com/volcano-sh/kthena go 1.24.0 require ( - github.com/agiledragon/gomonkey/v2 v2.13.0 github.com/alicebob/miniredis/v2 v2.35.0 github.com/cespare/xxhash v1.1.0 github.com/gammazero/deque v1.0.0 diff --git a/go.sum b/go.sum index 4bc5a1abe..a25678eaa 100644 --- a/go.sum +++ b/go.sum @@ -12,8 +12,6 @@ github.com/Masterminds/sprig/v3 v3.3.0 h1:mQh0Yrg1XPo6vjYXgtf5OtijNAKJRNcTdOOGZe github.com/Masterminds/sprig/v3 v3.3.0/go.mod h1:Zy1iXRYNqNLUolqCpL4uhk6SHUMAOSCzdgBfDb35Lz0= github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= -github.com/agiledragon/gomonkey/v2 v2.13.0 h1:B24Jg6wBI1iB8EFR1c+/aoTg7QN/Cum7YffG8KMIyYo= -github.com/agiledragon/gomonkey/v2 v2.13.0/go.mod h1:ap1AmDzcVOAz1YpeJ3TCzIgstoaWLA6jbbgxfB4w2iY= github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21jeqDCONI= github.com/alicebob/miniredis/v2 v2.35.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= @@ -110,7 +108,6 @@ github.com/google/pprof v0.0.0-20250923004556-9e5a51aed1e8 h1:ZI8gCoCjGzPsum4L21 github.com/google/pprof v0.0.0-20250923004556-9e5a51aed1e8/go.mod h1:I6V7YzU0XDpsHqbsyrghnFZLO1gwK6NPTNvmetQIk9U= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 h1:JeSE6pjso5THxAzdVpqr6/geYxZytqFMBCOtn/ujyeo= github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674/go.mod h1:r4w70xmWCQKmi1ONH4KIaBptdivuRPyosB9RmPlGEwA= github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= @@ -129,7 +126,6 @@ github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8Hm github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= -github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= @@ -219,8 +215,6 @@ github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys= github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs= github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= -github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= -github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72 h1:qLC7fQah7D6K1B0ujays3HV9gkFtllcxhzImRR7ArPQ= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/spf13/cast v1.8.0 h1:gEN9K4b8Xws4EX0+a0reLmhq8moKn7ntRlQYgjPeCDk= @@ -279,7 +273,6 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA= golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -309,7 +302,6 @@ golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= golang.org/x/time v0.13.0 h1:eUlYslOIt32DgYD6utsuUeHs4d7AsEYLuIAdg7FlYgI= golang.org/x/time v0.13.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= diff --git a/licenses/github.com/agiledragon/gomonkey/v2/LICENSE b/licenses/github.com/agiledragon/gomonkey/v2/LICENSE deleted file mode 100644 index d75dc90e6..000000000 --- a/licenses/github.com/agiledragon/gomonkey/v2/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2018 Zhang Xiaolong - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/pkg/kthena-router/controller/modelserver_controller_test.go b/pkg/kthena-router/controller/modelserver_controller_test.go index 99bbf8c61..431674eaa 100644 --- a/pkg/kthena-router/controller/modelserver_controller_test.go +++ b/pkg/kthena-router/controller/modelserver_controller_test.go @@ -21,7 +21,6 @@ import ( "testing" "time" - "github.com/agiledragon/gomonkey/v2" dto "github.com/prometheus/client_model/go" "github.com/stretchr/testify/assert" corev1 "k8s.io/api/core/v1" @@ -35,11 +34,28 @@ import ( kthenafake "github.com/volcano-sh/kthena/client-go/clientset/versioned/fake" informersv1alpha1 "github.com/volcano-sh/kthena/client-go/informers/externalversions" aiv1alpha1 "github.com/volcano-sh/kthena/pkg/apis/networking/v1alpha1" - "github.com/volcano-sh/kthena/pkg/kthena-router/backend" "github.com/volcano-sh/kthena/pkg/kthena-router/datastore" "github.com/volcano-sh/kthena/pkg/kthena-router/utils" ) +type fakePodRuntimeInspector struct{} + +func (fakePodRuntimeInspector) GetPodMetrics(_ string, _ *corev1.Pod, _ map[string]*dto.Histogram) (map[string]float64, map[string]*dto.Histogram) { + return map[string]float64{ + utils.GPUCacheUsage: 0.5, + utils.RequestWaitingNum: 10, + utils.RequestRunningNum: 5, + }, nil +} + +func (fakePodRuntimeInspector) GetPodModels(_ string, _ *corev1.Pod) ([]string, error) { + return []string{"test-model"}, nil +} + +func newStoreWithMockBackend() datastore.Store { + return datastore.New(datastore.WithPodRuntimeInspector(fakePodRuntimeInspector{})) +} + func TestModelServerController_ModelServerLifecycle(t *testing.T) { // Create fake clients kubeClient := kubefake.NewSimpleClientset() @@ -50,7 +66,7 @@ func TestModelServerController_ModelServerLifecycle(t *testing.T) { kthenaInformerFactory := informersv1alpha1.NewSharedInformerFactory(kthenaClient, 0) // Create store - store := datastore.New() + store := newStoreWithMockBackend() // Create controller controller := NewModelServerController( @@ -259,9 +275,6 @@ func TestModelServerController_ModelServerLifecycle(t *testing.T) { } func TestModelServerController_PodLifecycle(t *testing.T) { - patch := setupMockBackend() - defer patch.Reset() - // Create fake clients kubeClient := kubefake.NewSimpleClientset() kthenaClient := kthenafake.NewSimpleClientset() @@ -290,7 +303,7 @@ func TestModelServerController_PodLifecycle(t *testing.T) { kthenaInformerFactory := informersv1alpha1.NewSharedInformerFactory(kthenaClient, 0) // Create store - store := datastore.New() + store := newStoreWithMockBackend() // Create controller controller := NewModelServerController( @@ -506,7 +519,7 @@ func TestModelServerController_ErrorHandling(t *testing.T) { kthenaInformerFactory := informersv1alpha1.NewSharedInformerFactory(kthenaClient, 0) // Create store - store := datastore.New() + store := newStoreWithMockBackend() // Create controller controller := NewModelServerController( @@ -550,7 +563,7 @@ func TestModelServerController_WorkQueueProcessing(t *testing.T) { kthenaInformerFactory := informersv1alpha1.NewSharedInformerFactory(kthenaClient, 0) // Create store - store := datastore.New() + store := newStoreWithMockBackend() // Create controller controller := NewModelServerController( @@ -625,7 +638,7 @@ func TestModelServerController_PodSelectionLogic(t *testing.T) { kthenaInformerFactory := informersv1alpha1.NewSharedInformerFactory(kthenaClient, 0) // Create store - store := datastore.New() + store := newStoreWithMockBackend() // Create controller controller := NewModelServerController( @@ -815,7 +828,7 @@ func TestModelServerController_ComprehensiveLifecycleTest(t *testing.T) { defer close(stopCh) // Create controller and store - store := datastore.New() + store := newStoreWithMockBackend() controller := NewModelServerController( kthenaInformerFactory, kubeInformerFactory, @@ -899,9 +912,6 @@ func TestModelServerController_ComprehensiveLifecycleTest(t *testing.T) { // 3. Then we sync the second modelserver (ms2) // 4. Verify that GetPodsByModelServer(ms2) returns all pods correctly func TestModelServerController_SharedPods(t *testing.T) { - patch := setupMockBackend() - defer patch.Reset() - // Create fake clients kubeClient := kubefake.NewSimpleClientset() kthenaClient := kthenafake.NewSimpleClientset() @@ -992,7 +1002,7 @@ func TestModelServerController_SharedPods(t *testing.T) { kthenaInformerFactory := informersv1alpha1.NewSharedInformerFactory(kthenaClient, 0) // Create store - store := datastore.New() + store := newStoreWithMockBackend() // Create controller controller := NewModelServerController( @@ -1121,19 +1131,3 @@ func waitForObjectInCache(t *testing.T, timeout time.Duration, checkFunc func() } } } - -// Helper function to setup mock for backend calls -func setupMockBackend() *gomonkey.Patches { - patch := gomonkey.NewPatches() - patch.ApplyFunc(backend.GetPodMetrics, func(backend string, pod *corev1.Pod, previousHistogram map[string]*dto.Histogram) (map[string]float64, map[string]*dto.Histogram) { - return map[string]float64{ - utils.GPUCacheUsage: 0.5, - utils.RequestWaitingNum: 10, - utils.RequestRunningNum: 5, - }, map[string]*dto.Histogram{} - }) - patch.ApplyFunc(backend.GetPodModels, func(backend string, pod *corev1.Pod) ([]string, error) { - return []string{"test-model"}, nil - }) - return patch -} diff --git a/pkg/kthena-router/datastore/ordering_test.go b/pkg/kthena-router/datastore/ordering_test.go index 5b8e2a192..d19173481 100644 --- a/pkg/kthena-router/datastore/ordering_test.go +++ b/pkg/kthena-router/datastore/ordering_test.go @@ -20,7 +20,6 @@ import ( "sync" "testing" - "github.com/agiledragon/gomonkey/v2" dto "github.com/prometheus/client_model/go" "github.com/stretchr/testify/assert" "istio.io/istio/pkg/util/sets" @@ -29,7 +28,6 @@ import ( "k8s.io/apimachinery/pkg/types" aiv1alpha1 "github.com/volcano-sh/kthena/pkg/apis/networking/v1alpha1" - "github.com/volcano-sh/kthena/pkg/kthena-router/backend" "github.com/volcano-sh/kthena/pkg/kthena-router/utils" ) @@ -59,28 +57,24 @@ func createTestModelServer(namespace, name string, engine aiv1alpha1.InferenceEn } } -// Helper function to setup mock for backend calls -func setupMockBackend() *gomonkey.Patches { - patch := gomonkey.NewPatches() - patch.ApplyFunc(backend.GetPodMetrics, func(backend string, pod *corev1.Pod, previousHistogram map[string]*dto.Histogram) (map[string]float64, map[string]*dto.Histogram) { - return map[string]float64{ - utils.GPUCacheUsage: 0.5, - utils.RequestWaitingNum: 10, - utils.RequestRunningNum: 5, - }, map[string]*dto.Histogram{} - }) - patch.ApplyFunc(backend.GetPodModels, func(backend string, pod *corev1.Pod) ([]string, error) { - return []string{"test-model"}, nil - }) - return patch +func newStoreWithMockBackend() *store { + return New(WithPodRuntimeInspector(&fakePodRuntimeInspector{ + metricsFn: func(_ string, _ *corev1.Pod, _ map[string]*dto.Histogram) (map[string]float64, map[string]*dto.Histogram) { + return map[string]float64{ + utils.GPUCacheUsage: 0.5, + utils.RequestWaitingNum: 10, + utils.RequestRunningNum: 5, + }, nil + }, + modelsFn: func(_ string, _ *corev1.Pod) ([]string, error) { + return []string{"test-model"}, nil + }, + })).(*store) } // Test Case 1: ModelServer added first, then Pod func TestStore_AddModelServerFirst_ThenPod(t *testing.T) { - patch := setupMockBackend() - defer patch.Reset() - - s := New().(*store) + s := newStoreWithMockBackend() // Step 1: Add ModelServer first ms := createTestModelServer("default", "model1", aiv1alpha1.VLLM) @@ -117,10 +111,7 @@ func TestStore_AddModelServerFirst_ThenPod(t *testing.T) { // Test Case 2: Pod added first, then ModelServer // Note: Current implementation expects ModelServer to exist before Pod func TestStore_AddPodFirst_ThenModelServer(t *testing.T) { - patch := setupMockBackend() - defer patch.Reset() - - s := New().(*store) + s := newStoreWithMockBackend() ms := createTestModelServer("default", "model1", aiv1alpha1.VLLM) pod := createTestPod("default", "pod1") @@ -155,10 +146,7 @@ func TestStore_AddPodFirst_ThenModelServer(t *testing.T) { // Test Case 3: Multiple Pods added with ModelServer func TestStore_MultiplePods_ThenModelServer(t *testing.T) { - patch := setupMockBackend() - defer patch.Reset() - - s := New().(*store) + s := newStoreWithMockBackend() ms := createTestModelServer("default", "model1", aiv1alpha1.VLLM) pod1 := createTestPod("default", "pod1") @@ -198,10 +186,7 @@ func TestStore_MultiplePods_ThenModelServer(t *testing.T) { // Test Case 4: ModelServer with multiple Pods added together func TestStore_ModelServerWithMultiplePods_AddedTogether(t *testing.T) { - patch := setupMockBackend() - defer patch.Reset() - - s := New().(*store) + s := newStoreWithMockBackend() ms := createTestModelServer("default", "model1", aiv1alpha1.VLLM) pod1 := createTestPod("default", "pod1") @@ -233,10 +218,7 @@ func TestStore_ModelServerWithMultiplePods_AddedTogether(t *testing.T) { // Test Case 5: Pod belongs to multiple ModelServers func TestStore_PodBelongsToMultipleModelServers(t *testing.T) { - patch := setupMockBackend() - defer patch.Reset() - - s := New().(*store) + s := newStoreWithMockBackend() ms1 := createTestModelServer("default", "model1", aiv1alpha1.VLLM) ms2 := createTestModelServer("default", "model2", aiv1alpha1.VLLM) @@ -272,10 +254,7 @@ func TestStore_PodBelongsToMultipleModelServers(t *testing.T) { // Test Case 6: Pod with multiple ModelServers func TestStore_PodWithMultipleModelServers_ThenAddModelServers(t *testing.T) { - patch := setupMockBackend() - defer patch.Reset() - - s := New().(*store) + s := newStoreWithMockBackend() ms1 := createTestModelServer("default", "model1", aiv1alpha1.VLLM) ms2 := createTestModelServer("default", "model2", aiv1alpha1.VLLM) @@ -318,10 +297,7 @@ func TestStore_PodWithMultipleModelServers_ThenAddModelServers(t *testing.T) { // Test Case 7: Update operations - changing Pod's ModelServer associations func TestStore_UpdatePodModelServerAssociations(t *testing.T) { - patch := setupMockBackend() - defer patch.Reset() - - s := New().(*store) + s := newStoreWithMockBackend() ms1 := createTestModelServer("default", "model1", aiv1alpha1.VLLM) ms2 := createTestModelServer("default", "model2", aiv1alpha1.VLLM) @@ -364,10 +340,7 @@ func TestStore_UpdatePodModelServerAssociations(t *testing.T) { // Test Case 8: Interleaved operations func TestStore_InterleavedOperations(t *testing.T) { - patch := setupMockBackend() - defer patch.Reset() - - s := New().(*store) + s := newStoreWithMockBackend() ms1 := createTestModelServer("default", "model1", aiv1alpha1.VLLM) ms2 := createTestModelServer("default", "model2", aiv1alpha1.VLLM) @@ -422,10 +395,7 @@ func TestStore_InterleavedOperations(t *testing.T) { // Test Case 9: Deletion scenarios func TestStore_DeletionScenarios(t *testing.T) { - patch := setupMockBackend() - defer patch.Reset() - - s := New().(*store) + s := newStoreWithMockBackend() ms1 := createTestModelServer("default", "model1", aiv1alpha1.VLLM) ms2 := createTestModelServer("default", "model2", aiv1alpha1.VLLM) @@ -515,10 +485,7 @@ func TestStore_EdgeCases(t *testing.T) { // Test Case 11: random operations (simulated) func TestStore_RandomOperations(t *testing.T) { - patch := setupMockBackend() - defer patch.Reset() - - s := New().(*store) + s := newStoreWithMockBackend() // Simulate rapid add/update operations that might happen concurrently ms := createTestModelServer("default", "model1", aiv1alpha1.VLLM) diff --git a/pkg/kthena-router/datastore/store.go b/pkg/kthena-router/datastore/store.go index 273222ece..32337467c 100644 --- a/pkg/kthena-router/datastore/store.go +++ b/pkg/kthena-router/datastore/store.go @@ -132,6 +132,32 @@ type EventData struct { // CallbackFunc is the type of function that can be registered as a callback type CallbackFunc func(data EventData) +// PodRuntimeInspector fetches runtime metrics and loaded models for a pod. +type PodRuntimeInspector interface { + GetPodMetrics(engine string, pod *corev1.Pod, previousHistogram map[string]*dto.Histogram) (map[string]float64, map[string]*dto.Histogram) + GetPodModels(engine string, pod *corev1.Pod) ([]string, error) +} + +type realPodRuntimeInspector struct{} + +func (realPodRuntimeInspector) GetPodMetrics(engine string, pod *corev1.Pod, previousHistogram map[string]*dto.Histogram) (map[string]float64, map[string]*dto.Histogram) { + return backend.GetPodMetrics(engine, pod, previousHistogram) +} + +func (realPodRuntimeInspector) GetPodModels(engine string, pod *corev1.Pod) ([]string, error) { + return backend.GetPodModels(engine, pod) +} + +type Option func(*store) + +func WithPodRuntimeInspector(inspector PodRuntimeInspector) Option { + return func(s *store) { + if inspector != nil { + s.podRuntimeInspector = inspector + } + } +} + // Store is an interface for storing and retrieving data type Store interface { // Add modelServer which are selected by modelServer.Spec.WorkloadSelector @@ -282,12 +308,13 @@ type store struct { // model -> RequestPriorityQueue requestWaitingQueue sync.Map tokenTracker TokenTracker + podRuntimeInspector PodRuntimeInspector rootCtx context.Context // Lifecycle context for queue goroutines, set by Run() fairnessQueueConfig FairnessQueueConfig } -func New() Store { - return &store{ +func New(opts ...Option) Store { + s := &store{ modelServer: sync.Map{}, pods: sync.Map{}, routeInfo: make(map[string]*modelRouteInfo), @@ -303,8 +330,22 @@ func New() Store { requestWaitingQueue: sync.Map{}, // Create token tracker with environment-based configuration tokenTracker: createTokenTracker(), + podRuntimeInspector: realPodRuntimeInspector{}, fairnessQueueConfig: createFairnessQueueConfig(), } + for _, opt := range opts { + if opt != nil { + opt(s) + } + } + return s +} + +func (s *store) getPodRuntimeInspector() PodRuntimeInspector { + if s.podRuntimeInspector == nil { + return realPodRuntimeInspector{} + } + return s.podRuntimeInspector } // createFairnessQueueConfig reads fairness queue configuration from environment variables. @@ -1131,9 +1172,13 @@ func (s *store) updatePodMetrics(pod *PodInfo) { } previousHistogram := getPreviousHistogram(pod) - gaugeMetrics, histogramMetrics := backend.GetPodMetrics(pod.engine, pod.Pod, previousHistogram) - updateGaugeMetricsInfo(pod, gaugeMetrics) - updateHistogramMetrics(pod, histogramMetrics) + gaugeMetrics, histogramMetrics := s.getPodRuntimeInspector().GetPodMetrics(pod.engine, pod.Pod, previousHistogram) + if gaugeMetrics != nil { + updateGaugeMetricsInfo(pod, gaugeMetrics) + } + if histogramMetrics != nil { + updateHistogramMetrics(pod, histogramMetrics) + } } func (s *store) updatePodModels(podInfo *PodInfo) { @@ -1142,7 +1187,7 @@ func (s *store) updatePodModels(podInfo *PodInfo) { return } - models, err := backend.GetPodModels(podInfo.engine, podInfo.Pod) + models, err := s.getPodRuntimeInspector().GetPodModels(podInfo.engine, podInfo.Pod) if err != nil { klog.V(4).Infof("failed to get models of pod %s/%s", podInfo.Pod.GetNamespace(), podInfo.Pod.GetName()) } diff --git a/pkg/kthena-router/datastore/store_test.go b/pkg/kthena-router/datastore/store_test.go index 38f46b822..d2d688c7d 100644 --- a/pkg/kthena-router/datastore/store_test.go +++ b/pkg/kthena-router/datastore/store_test.go @@ -27,11 +27,9 @@ import ( "testing" "time" - "github.com/agiledragon/gomonkey/v2" dto "github.com/prometheus/client_model/go" "github.com/stretchr/testify/assert" aiv1alpha1 "github.com/volcano-sh/kthena/pkg/apis/networking/v1alpha1" - "github.com/volcano-sh/kthena/pkg/kthena-router/backend" "github.com/volcano-sh/kthena/pkg/kthena-router/utils" "istio.io/istio/pkg/util/sets" corev1 "k8s.io/api/core/v1" @@ -193,6 +191,26 @@ func TestStoreUpdatePodMetrics(t *testing.T) { s := &store{ pods: sync.Map{}, modelServer: sync.Map{}, + podRuntimeInspector: &fakePodRuntimeInspector{ + metricsFn: func(_ string, _ *corev1.Pod, _ map[string]*dto.Histogram) (map[string]float64, map[string]*dto.Histogram) { + return map[string]float64{ + utils.GPUCacheUsage: 0.8, + utils.RequestWaitingNum: 15, + utils.RequestRunningNum: 10, + utils.TPOT: 120, + utils.TTFT: 210, + }, map[string]*dto.Histogram{ + utils.TPOT: { + SampleSum: &sum2, + SampleCount: &count2, + }, + utils.TTFT: { + SampleSum: &sum2, + SampleCount: &count2, + }, + } + }, + }, } podName := types.NamespacedName{ @@ -209,27 +227,6 @@ func TestStoreUpdatePodMetrics(t *testing.T) { pods: sets.New[types.NamespacedName](podName), }) - patch := gomonkey.NewPatches() - patch.ApplyFunc(backend.GetPodMetrics, func(backend string, pod *corev1.Pod, previousHistogram map[string]*dto.Histogram) (map[string]float64, map[string]*dto.Histogram) { - return map[string]float64{ - utils.GPUCacheUsage: 0.8, - utils.RequestWaitingNum: 15, - utils.RequestRunningNum: 10, - utils.TPOT: 120, - utils.TTFT: 210, - }, map[string]*dto.Histogram{ - utils.TPOT: { - SampleSum: &sum2, - SampleCount: &count2, - }, - utils.TTFT: { - SampleSum: &sum2, - SampleCount: &count2, - }, - } - }) - defer patch.Reset() - s.updatePodMetrics(&podinfo) name := types.NamespacedName{ @@ -1468,8 +1465,34 @@ func TestStoreMatchModelServer(t *testing.T) { } } -func newStore() *store { - return New().(*store) +type fakePodRuntimeInspector struct { + metricsFn func(string, *corev1.Pod, map[string]*dto.Histogram) (map[string]float64, map[string]*dto.Histogram) + modelsFn func(string, *corev1.Pod) ([]string, error) + metricsCalls int + modelsCalls int +} + +func (f *fakePodRuntimeInspector) GetPodMetrics(engine string, pod *corev1.Pod, previousHistogram map[string]*dto.Histogram) (map[string]float64, map[string]*dto.Histogram) { + f.metricsCalls++ + if f.metricsFn == nil { + return nil, nil + } + return f.metricsFn(engine, pod, previousHistogram) +} + +func (f *fakePodRuntimeInspector) GetPodModels(engine string, pod *corev1.Pod) ([]string, error) { + f.modelsCalls++ + if f.modelsFn == nil { + return nil, nil + } + return f.modelsFn(engine, pod) +} + +func newStore(inspector ...PodRuntimeInspector) *store { + if len(inspector) == 0 || inspector[0] == nil { + return New().(*store) + } + return New(WithPodRuntimeInspector(inspector[0])).(*store) } func TestAddOrUpdatePod_MetricsPreservedOnUpdate(t *testing.T) { @@ -1579,37 +1602,26 @@ func TestAddOrUpdatePod_MetricsPreservedOnUpdate(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - s := newStore() + inspector := &fakePodRuntimeInspector{ + metricsFn: func(_ string, _ *corev1.Pod, _ map[string]*dto.Histogram) (map[string]float64, map[string]*dto.Histogram) { + return tc.initialMetrics, tc.initialHist + }, + modelsFn: func(_ string, _ *corev1.Pod) ([]string, error) { + return tc.initialModels, nil + }, + } + s := newStore(inspector) ms := createTestModelServer("default", "ms1", aiv1alpha1.VLLM) s.AddOrUpdateModelServer(ms, sets.New[types.NamespacedName]()) - // Patch backend calls for the initial add - patch := gomonkey.NewPatches() - patch.ApplyFunc(backend.GetPodMetrics, func(_ string, _ *corev1.Pod, _ map[string]*dto.Histogram) (map[string]float64, map[string]*dto.Histogram) { - return tc.initialMetrics, tc.initialHist - }) - patch.ApplyFunc(backend.GetPodModels, func(_ string, _ *corev1.Pod) ([]string, error) { - return tc.initialModels, nil - }) - pod := createTestPod("default", "pod1") err := s.AddOrUpdatePod(pod, []*aiv1alpha1.ModelServer{ms}) assert.NoError(t, err) - - patch.Reset() - - // Backend should NOT be called during an update. If it is, fail loudly. - patch2 := gomonkey.NewPatches() - patch2.ApplyFunc(backend.GetPodMetrics, func(_ string, _ *corev1.Pod, _ map[string]*dto.Histogram) (map[string]float64, map[string]*dto.Histogram) { - t.Fatal("backend.GetPodMetrics must not be called on pod update") - return nil, nil - }) - patch2.ApplyFunc(backend.GetPodModels, func(_ string, _ *corev1.Pod) ([]string, error) { - t.Fatal("backend.GetPodModels must not be called on pod update") - return nil, nil - }) - defer patch2.Reset() + assert.Equal(t, 1, inspector.metricsCalls, "backend metrics should be fetched on initial pod add") + assert.Equal(t, 1, inspector.modelsCalls, "backend models should be fetched on initial pod add") + inspector.metricsCalls = 0 + inspector.modelsCalls = 0 // Simulate a pod update (e.g. label change) updatedPod := pod.DeepCopy() @@ -1619,6 +1631,8 @@ func TestAddOrUpdatePod_MetricsPreservedOnUpdate(t *testing.T) { err = s.AddOrUpdatePod(updatedPod, []*aiv1alpha1.ModelServer{ms}) assert.NoError(t, err) + assert.Equal(t, 0, inspector.metricsCalls, "backend.GetPodMetrics must not be called on pod update") + assert.Equal(t, 0, inspector.modelsCalls, "backend.GetPodModels must not be called on pod update") podInfo := s.GetPodInfo(utils.GetNamespaceName(updatedPod)) assert.NotNil(t, podInfo) @@ -1652,34 +1666,28 @@ func TestAddOrUpdatePod_MetricsPreservedOnUpdate(t *testing.T) { } func TestAddOrUpdatePod_NewPodStillFetchesMetrics(t *testing.T) { - s := newStore() + inspector := &fakePodRuntimeInspector{ + metricsFn: func(_ string, _ *corev1.Pod, _ map[string]*dto.Histogram) (map[string]float64, map[string]*dto.Histogram) { + return map[string]float64{ + utils.GPUCacheUsage: 0.3, + utils.RequestRunningNum: 2, + }, map[string]*dto.Histogram{} + }, + modelsFn: func(_ string, _ *corev1.Pod) ([]string, error) { + return []string{"base-model"}, nil + }, + } + s := newStore(inspector) ms := createTestModelServer("default", "ms1", aiv1alpha1.VLLM) s.AddOrUpdateModelServer(ms, sets.New[types.NamespacedName]()) - metricsCalled := false - modelsCalled := false - - patch := gomonkey.NewPatches() - patch.ApplyFunc(backend.GetPodMetrics, func(_ string, _ *corev1.Pod, _ map[string]*dto.Histogram) (map[string]float64, map[string]*dto.Histogram) { - metricsCalled = true - return map[string]float64{ - utils.GPUCacheUsage: 0.3, - utils.RequestRunningNum: 2, - }, map[string]*dto.Histogram{} - }) - patch.ApplyFunc(backend.GetPodModels, func(_ string, _ *corev1.Pod) ([]string, error) { - modelsCalled = true - return []string{"base-model"}, nil - }) - defer patch.Reset() - pod := createTestPod("default", "fresh-pod") err := s.AddOrUpdatePod(pod, []*aiv1alpha1.ModelServer{ms}) assert.NoError(t, err) - assert.True(t, metricsCalled, "backend.GetPodMetrics must be called for new pods") - assert.True(t, modelsCalled, "backend.GetPodModels must be called for new pods") + assert.Equal(t, 1, inspector.metricsCalls, "backend.GetPodMetrics must be called for new pods") + assert.Equal(t, 1, inspector.modelsCalls, "backend.GetPodModels must be called for new pods") podInfo := s.GetPodInfo(utils.GetNamespaceName(pod)) assert.InDelta(t, 0.3, podInfo.GetGPUCacheUsage(), 1e-9) @@ -1687,48 +1695,40 @@ func TestAddOrUpdatePod_NewPodStillFetchesMetrics(t *testing.T) { } func TestAddOrUpdatePod_ModelServerChangePreservesMetrics(t *testing.T) { - s := newStore() + inspector := &fakePodRuntimeInspector{ + metricsFn: func(_ string, _ *corev1.Pod, _ map[string]*dto.Histogram) (map[string]float64, map[string]*dto.Histogram) { + return map[string]float64{ + utils.GPUCacheUsage: 0.6, + utils.RequestWaitingNum: 5, + utils.RequestRunningNum: 10, + utils.TPOT: 0.04, + utils.TTFT: 0.2, + }, map[string]*dto.Histogram{} + }, + modelsFn: func(_ string, _ *corev1.Pod) ([]string, error) { + return []string{"model-a"}, nil + }, + } + s := newStore(inspector) ms1 := createTestModelServer("default", "ms1", aiv1alpha1.VLLM) ms2 := createTestModelServer("default", "ms2", aiv1alpha1.VLLM) s.AddOrUpdateModelServer(ms1, sets.New[types.NamespacedName]()) s.AddOrUpdateModelServer(ms2, sets.New[types.NamespacedName]()) - patch := gomonkey.NewPatches() - patch.ApplyFunc(backend.GetPodMetrics, func(_ string, _ *corev1.Pod, _ map[string]*dto.Histogram) (map[string]float64, map[string]*dto.Histogram) { - return map[string]float64{ - utils.GPUCacheUsage: 0.6, - utils.RequestWaitingNum: 5, - utils.RequestRunningNum: 10, - utils.TPOT: 0.04, - utils.TTFT: 0.2, - }, map[string]*dto.Histogram{} - }) - patch.ApplyFunc(backend.GetPodModels, func(_ string, _ *corev1.Pod) ([]string, error) { - return []string{"model-a"}, nil - }) - pod := createTestPod("default", "pod1") err := s.AddOrUpdatePod(pod, []*aiv1alpha1.ModelServer{ms1}) assert.NoError(t, err) - - patch.Reset() - - // Block backend calls during the reassignment update - patch2 := gomonkey.NewPatches() - patch2.ApplyFunc(backend.GetPodMetrics, func(_ string, _ *corev1.Pod, _ map[string]*dto.Histogram) (map[string]float64, map[string]*dto.Histogram) { - t.Fatal("backend.GetPodMetrics must not be called on pod update") - return nil, nil - }) - patch2.ApplyFunc(backend.GetPodModels, func(_ string, _ *corev1.Pod) ([]string, error) { - t.Fatal("backend.GetPodModels must not be called on pod update") - return nil, nil - }) - defer patch2.Reset() + assert.Equal(t, 1, inspector.metricsCalls, "backend metrics should be fetched on initial pod add") + assert.Equal(t, 1, inspector.modelsCalls, "backend models should be fetched on initial pod add") + inspector.metricsCalls = 0 + inspector.modelsCalls = 0 // Move pod from ms1 to ms2 err = s.AddOrUpdatePod(pod, []*aiv1alpha1.ModelServer{ms2}) assert.NoError(t, err) + assert.Equal(t, 0, inspector.metricsCalls, "backend.GetPodMetrics must not be called on pod update") + assert.Equal(t, 0, inspector.modelsCalls, "backend.GetPodModels must not be called on pod update") podInfo := s.GetPodInfo(utils.GetNamespaceName(pod)) assert.InDelta(t, 0.6, podInfo.GetGPUCacheUsage(), 1e-9, diff --git a/pkg/kthena-router/router/router_test.go b/pkg/kthena-router/router/router_test.go index b0030a263..cea1ed4c8 100644 --- a/pkg/kthena-router/router/router_test.go +++ b/pkg/kthena-router/router/router_test.go @@ -19,7 +19,6 @@ package router import ( "bytes" "encoding/json" - "errors" "flag" "fmt" "io" @@ -30,154 +29,44 @@ import ( "strconv" "testing" - "github.com/agiledragon/gomonkey/v2" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "istio.io/istio/pkg/util/sets" corev1 "k8s.io/api/core/v1" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" "k8s.io/klog/v2" aiv1alpha1 "github.com/volcano-sh/kthena/pkg/apis/networking/v1alpha1" "github.com/volcano-sh/kthena/pkg/kthena-router/accesslog" - "github.com/volcano-sh/kthena/pkg/kthena-router/common" "github.com/volcano-sh/kthena/pkg/kthena-router/connectors" "github.com/volcano-sh/kthena/pkg/kthena-router/datastore" - "github.com/volcano-sh/kthena/pkg/kthena-router/scheduler" - "github.com/volcano-sh/kthena/pkg/kthena-router/scheduler/framework" - "github.com/volcano-sh/kthena/pkg/kthena-router/scheduler/plugins/conf" ) func TestMain(m *testing.M) { gin.SetMode(gin.TestMode) klog.InitFlags(nil) - // Set klog verbosity level to 4 flag.Set("v", "4") - flag.Parse() // Parse flags to apply the klog level - routerConfig, _ := conf.ParseRouterConfig("../scheduler/testdata/configmap.yaml") - patch1 := gomonkey.ApplyFunc(conf.ParseRouterConfig, func(configMapPath string) (*conf.RouterConfiguration, error) { - return routerConfig, nil - }) - defer patch1.Reset() - - pluginsWeight, plugins, pluginConfig, _ := conf.LoadSchedulerConfig(&routerConfig.Scheduler) - patch2 := gomonkey.ApplyFunc(conf.LoadSchedulerConfig, func() (map[string]int, []string, map[string]runtime.RawExtension, error) { - return pluginsWeight, plugins, pluginConfig, nil - }) - defer patch2.Reset() - - // Run the tests + flag.Parse() exitCode := m.Run() - // Exit with the appropriate code os.Exit(exitCode) } -func buildPodInfo(name string, ip string) *datastore.PodInfo { - pod := &corev1.Pod{ - ObjectMeta: v1.ObjectMeta{ - Name: name, - }, - Status: corev1.PodStatus{ - PodIP: ip, - }, - } - - return &datastore.PodInfo{ - Pod: pod, - } -} - -func TestProxyModelEndpoint(t *testing.T) { - gin.SetMode(gin.TestMode) - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - req, _ := http.NewRequest("POST", "/", nil) - modelReq := ModelRequest{"model": "test"} - r := NewRouter(datastore.New(), "testdata/comfigmap.yaml") - hookPatch := gomonkey.ApplyMethod(r.scheduler, "RunPostHooks", func(s scheduler.Scheduler, ctx *framework.Context, index int) {}) - defer hookPatch.Reset() - - tests := []struct { - name string - ctx *framework.Context - proxyPatch func() *gomonkey.Patches - wantErr error - }{ - { - name: "BestPods are set, aggregated mode success", - ctx: &framework.Context{ - Model: "test", - Prompt: common.ChatMessage{Text: "test"}, - BestPods: []*datastore.PodInfo{buildPodInfo("decode1", "1.1.1.1")}, - }, - proxyPatch: func() *gomonkey.Patches { - patches := gomonkey.ApplyFunc(connectors.BuildDecodeRequest, func(c *gin.Context, req *http.Request, modelRequest ModelRequest) *http.Request { - return req - }) - patches.ApplyFunc(isStreaming, func(modelRequest ModelRequest) bool { - return false - }) - patches.ApplyFunc(proxyRequest, func(c *gin.Context, req *http.Request, podIP string, port int32, stream bool) error { - return nil - }) - return patches - }, - wantErr: nil, - }, - { - name: "BestPods proxy returns error", - ctx: &framework.Context{ - Model: "test", - Prompt: common.ChatMessage{Text: "test"}, - BestPods: []*datastore.PodInfo{buildPodInfo("decode1", "1.1.1.1")}, - }, - proxyPatch: func() *gomonkey.Patches { - patches := gomonkey.ApplyFunc(connectors.BuildDecodeRequest, func(c *gin.Context, req *http.Request, modelRequest ModelRequest) *http.Request { - return req - }) - patches.ApplyFunc(isStreaming, func(modelRequest ModelRequest) bool { - return false - }) - patches.ApplyFunc(proxyRequest, func(c *gin.Context, req *http.Request, podIP string, port int32, stream bool) error { - return errors.New("proxy error") - }) - return patches - }, - wantErr: errors.New("request to all pods failed"), - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var patch *gomonkey.Patches - if tt.proxyPatch != nil { - patch = tt.proxyPatch() - defer patch.Reset() - } - err := r.proxyModelEndpoint(c, req, tt.ctx, modelReq, int32(8080)) - if tt.wantErr != nil { - assert.Error(t, err) - assert.Equal(t, tt.wantErr.Error(), err.Error()) - } else { - assert.NoError(t, err) - } - }) - } -} - // setupTestRouter initializes a router and its dependencies for testing. -func setupTestRouter(backendHandler http.Handler) (*Router, datastore.Store, *httptest.Server) { +// It uses a mock HTTP server as the backend, following the community's recommendation +// to avoid hacky dependency injection. +func setupTestRouter(t *testing.T, backendHandler http.Handler) (*Router, datastore.Store, *httptest.Server) { gin.SetMode(gin.TestMode) backend := httptest.NewServer(backendHandler) store := datastore.New() - router := NewRouter(store, "") + router := NewRouter(store, "../scheduler/testdata/configmap.yaml") return router, store, backend } func TestRouter_HandlerFunc_AggregatedMode(t *testing.T) { - // 1. Setup backend mock + // 1. Setup backend mock server backendHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "/v1/chat/completions", r.URL.Path) body, _ := io.ReadAll(r.Body) @@ -187,7 +76,7 @@ func TestRouter_HandlerFunc_AggregatedMode(t *testing.T) { w.WriteHeader(http.StatusOK) fmt.Fprint(w, `{"id":"response-id"}`) }) - router, store, backend := setupTestRouter(backendHandler) + router, store, backend := setupTestRouter(t, backendHandler) defer backend.Close() backendURL, _ := url.Parse(backend.URL) @@ -243,7 +132,7 @@ func TestRouter_HandlerFunc_AggregatedMode(t *testing.T) { } func TestRouter_HandlerFunc_DisaggregatedMode(t *testing.T) { - // 1. Setup backend mock + // 1. Setup backend mock server prefillReqs := 0 decodeReqs := 0 backendHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -267,7 +156,7 @@ func TestRouter_HandlerFunc_DisaggregatedMode(t *testing.T) { fmt.Fprint(w, `data: {"id":"decode-resp"}`) } }) - router, store, backend := setupTestRouter(backendHandler) + router, store, backend := setupTestRouter(t, backendHandler) defer backend.Close() backendURL, _ := url.Parse(backend.URL) @@ -348,7 +237,7 @@ func TestRouter_HandlerFunc_DisaggregatedMode(t *testing.T) { } func TestRouter_HandlerFunc_ModelNotFound(t *testing.T) { - router, _, backend := setupTestRouter(nil) + router, _, backend := setupTestRouter(t, nil) defer backend.Close() w := httptest.NewRecorder() @@ -369,7 +258,7 @@ func TestRouter_HandlerFunc_ScheduleFailure(t *testing.T) { // This should not be called t.Error("backend should not be called on schedule failure") }) - router, store, backend := setupTestRouter(backendHandler) + router, store, backend := setupTestRouter(t, backendHandler) defer backend.Close() backendURL, _ := url.Parse(backend.URL) diff --git a/pkg/model-serving-controller/controller/model_serving_controller.go b/pkg/model-serving-controller/controller/model_serving_controller.go index ae34204c2..dee1b2e7d 100644 --- a/pkg/model-serving-controller/controller/model_serving_controller.go +++ b/pkg/model-serving-controller/controller/model_serving_controller.go @@ -66,12 +66,25 @@ const ( RoleIDKey = "RoleID" ) +// PodGroupManager is the interface for managing PodGroups. +// This interface allows for dependency injection in tests. +type PodGroupManager interface { + CreateOrUpdatePodGroup(ctx context.Context, ms *workloadv1alpha1.ModelServing, pgName string) (error, time.Duration) + DeletePodGroup(ctx context.Context, ms *workloadv1alpha1.ModelServing, servingGroupName string) error + CleanupPodGroups(ctx context.Context, ms *workloadv1alpha1.ModelServing) error + HasPodGroupCRD() bool + GetPodGroupInformer() cache.SharedIndexInformer + Run(parentCtx context.Context) error + GenerateTaskName(roleName string, roleIndex int) string + AnnotatePodWithPodGroup(pod *corev1.Pod, ms *workloadv1alpha1.ModelServing, groupName, taskName string) +} + type ModelServingController struct { kubeClientSet kubernetes.Interface modelServingClient clientset.Interface syncHandler func(ctx context.Context, msKey string) error - podGroupManager *podgroupmanager.Manager + podGroupManager PodGroupManager podsLister listerv1.PodLister podsInformer cache.SharedIndexInformer servicesLister listerv1.ServiceLister diff --git a/pkg/model-serving-controller/controller/model_serving_controller_harness_test.go b/pkg/model-serving-controller/controller/model_serving_controller_harness_test.go new file mode 100644 index 000000000..6e92487c6 --- /dev/null +++ b/pkg/model-serving-controller/controller/model_serving_controller_harness_test.go @@ -0,0 +1,123 @@ +/* +Copyright The Volcano Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package controller + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + apiextfake "k8s.io/apiextensions-apiserver/pkg/client/clientset/clientset/fake" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/client-go/kubernetes/fake" + "k8s.io/client-go/tools/cache" + volcanofake "volcano.sh/apis/pkg/client/clientset/versioned/fake" + + kthenafake "github.com/volcano-sh/kthena/client-go/clientset/versioned/fake" + workloadv1alpha1 "github.com/volcano-sh/kthena/pkg/apis/workload/v1alpha1" + testhelper "github.com/volcano-sh/kthena/pkg/model-serving-controller/utils/test" +) + +type testControllerHarness struct { + t *testing.T + ctx context.Context + cancel context.CancelFunc + controller *ModelServingController + kubeClient *fake.Clientset + kthenaClient *kthenafake.Clientset + volcanoClient *volcanofake.Clientset + apiextClient *apiextfake.Clientset +} + +func newTestController(t *testing.T, modelServings ...*workloadv1alpha1.ModelServing) *testControllerHarness { + t.Helper() + + objects := make([]runtime.Object, 0, len(modelServings)) + for _, ms := range modelServings { + objects = append(objects, ms.DeepCopy()) + } + + kubeClient := fake.NewSimpleClientset() + kthenaClient := kthenafake.NewSimpleClientset(objects...) + volcanoClient := volcanofake.NewSimpleClientset() + apiextClient := apiextfake.NewSimpleClientset(testhelper.CreatePodGroupCRD()) + + controller, err := NewModelServingController(kubeClient, kthenaClient, volcanoClient, apiextClient) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + harness := &testControllerHarness{ + t: t, + ctx: ctx, + cancel: cancel, + controller: controller, + kubeClient: kubeClient, + kthenaClient: kthenaClient, + volcanoClient: volcanoClient, + apiextClient: apiextClient, + } + + go controller.Run(ctx, 0) + + syncCtx, syncCancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(syncCancel) + if !cache.WaitForCacheSync(syncCtx.Done(), + controller.podsInformer.HasSynced, + controller.servicesInformer.HasSynced, + controller.modelServingsInformer.HasSynced, + ) { + cancel() + t.Fatalf("timed out waiting for informer caches to sync") + } + require.Eventually(t, func() bool { + return controller.initialSync + }, 5*time.Second, 10*time.Millisecond, "timed out waiting for initial sync") + + t.Cleanup(cancel) + return harness +} + +func namespacedKey(namespace, name string) string { + return namespace + "/" + name +} + +func (h *testControllerHarness) expectQueuedKey(key string) { + h.t.Helper() + + var seen []interface{} + defer func() { + for _, item := range seen { + h.controller.workqueue.Add(item) + } + }() + + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if h.controller.workqueue.Len() == 0 { + time.Sleep(10 * time.Millisecond) + continue + } + item, shutdown := h.controller.workqueue.Get() + require.False(h.t, shutdown, "workqueue shut down while waiting for %s", key) + h.controller.workqueue.Done(item) + h.controller.workqueue.Forget(item) + itemKey, ok := item.(string) + if ok && itemKey == key { + return + } + seen = append(seen, item) + } + h.t.Fatalf("timed out waiting for %s", key) +} diff --git a/pkg/model-serving-controller/controller/model_serving_controller_test.go b/pkg/model-serving-controller/controller/model_serving_controller_test.go index 5741deeb2..6c1a637d1 100644 --- a/pkg/model-serving-controller/controller/model_serving_controller_test.go +++ b/pkg/model-serving-controller/controller/model_serving_controller_test.go @@ -16,13 +16,12 @@ package controller import ( "context" "fmt" - "reflect" "sort" "testing" "time" - "github.com/agiledragon/gomonkey/v2" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/volcano-sh/kthena/pkg/model-serving-controller/podgroupmanager" testhelper "github.com/volcano-sh/kthena/pkg/model-serving-controller/utils/test" corev1 "k8s.io/api/core/v1" @@ -56,9 +55,159 @@ type resourceSpec struct { labels map[string]string } +type testQueue interface { + Len() int + Get() (item interface{}, shutdown bool) + Done(item interface{}) + Forget(item interface{}) +} + +func newModelServingForDeleteTest(namespace, name string) *workloadv1alpha1.ModelServing { + return &workloadv1alpha1.ModelServing{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: namespace, + UID: types.UID(fmt.Sprintf("%s-uid", name)), + }, + } +} + +func newPodGroupForDeleteTest(ms *workloadv1alpha1.ModelServing, groupName string, ownerUID types.UID) *schedulingv1beta1.PodGroup { + if ownerUID == "" { + ownerUID = ms.UID + } + return &schedulingv1beta1.PodGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: groupName, + Namespace: ms.Namespace, + Labels: map[string]string{ + workloadv1alpha1.ModelServingNameLabelKey: ms.Name, + workloadv1alpha1.GroupNameLabelKey: groupName, + }, + OwnerReferences: []metav1.OwnerReference{ + { + APIVersion: workloadv1alpha1.SchemeGroupVersion.String(), + Kind: workloadv1alpha1.ModelServingKind.Kind, + Name: ms.Name, + UID: ownerUID, + }, + }, + }, + } +} + +func drainWorkqueue(t *testing.T, queue testQueue) { + t.Helper() + for queue.Len() > 0 { + item, shutdown := queue.Get() + require.False(t, shutdown) + queue.Done(item) + queue.Forget(item) + } +} + +func assertQueueEmpty(t *testing.T, queue testQueue) { + t.Helper() + require.Equal(t, 0, queue.Len()) +} + +func assertQueuedKey(t *testing.T, queue testQueue, key string) { + t.Helper() + require.Greater(t, queue.Len(), 0, "expected %s to be queued", key) + + item, shutdown := queue.Get() + require.False(t, shutdown) + queue.Done(item) + queue.Forget(item) + + actualKey, ok := item.(string) + require.True(t, ok, "expected queued item to be a string key") + require.Equal(t, key, actualKey) +} + +func assertQueueStaysEmpty(t *testing.T, queue testQueue, duration time.Duration) { + t.Helper() + deadline := time.Now().Add(duration) + for time.Now().Before(deadline) { + require.Equal(t, 0, queue.Len()) + time.Sleep(10 * time.Millisecond) + } +} + +// fakePodGroupManager is a test double for PodGroupManager +type fakePodGroupManager struct { + createOrUpdateFunc func(ctx context.Context, ms *workloadv1alpha1.ModelServing, pgName string) (error, time.Duration) + deleteFunc func(ctx context.Context, ms *workloadv1alpha1.ModelServing, servingGroupName string) error + cleanupFunc func(ctx context.Context, ms *workloadv1alpha1.ModelServing) error + hasCRD bool +} + +func (f *fakePodGroupManager) CreateOrUpdatePodGroup(ctx context.Context, ms *workloadv1alpha1.ModelServing, pgName string) (error, time.Duration) { + if f.createOrUpdateFunc != nil { + return f.createOrUpdateFunc(ctx, ms, pgName) + } + return nil, 0 +} + +func (f *fakePodGroupManager) DeletePodGroup(ctx context.Context, ms *workloadv1alpha1.ModelServing, servingGroupName string) error { + if f.deleteFunc != nil { + return f.deleteFunc(ctx, ms, servingGroupName) + } + return nil +} + +func (f *fakePodGroupManager) CleanupPodGroups(ctx context.Context, ms *workloadv1alpha1.ModelServing) error { + if f.cleanupFunc != nil { + return f.cleanupFunc(ctx, ms) + } + return nil +} + +func (f *fakePodGroupManager) HasPodGroupCRD() bool { + return f.hasCRD +} + +func (f *fakePodGroupManager) GetPodGroupInformer() cache.SharedIndexInformer { + return nil +} + +func (f *fakePodGroupManager) Run(parentCtx context.Context) error { + return nil +} + +func (f *fakePodGroupManager) GenerateTaskName(roleName string, roleIndex int) string { + return fmt.Sprintf("%s-%d", roleName, roleIndex) +} + +func (f *fakePodGroupManager) AnnotatePodWithPodGroup(pod *corev1.Pod, ms *workloadv1alpha1.ModelServing, groupName, taskName string) { + if pod.Annotations == nil { + pod.Annotations = make(map[string]string) + } + pod.Annotations["scheduling.volcano.sh/task-name"] = taskName +} + +func TestNewTestController_HasSyncedQueueAndStores(t *testing.T) { + h := newTestController(t) + + require.NotNil(t, h.controller) + require.NotNil(t, h.kubeClient) + require.NotNil(t, h.kthenaClient) + require.NotNil(t, h.controller.workqueue) + require.NotNil(t, h.controller.store) + require.Equal(t, 0, h.controller.workqueue.Len()) + require.True(t, h.controller.podsInformer.HasSynced()) + require.True(t, h.controller.servicesInformer.HasSynced()) + require.True(t, h.controller.modelServingsInformer.HasSynced()) + require.True(t, h.controller.initialSync) +} + func TestCreateOrUpdatePodGroupByServingGroupRequeue(t *testing.T) { - controller := &ModelServingController{ - podGroupManager: &podgroupmanager.Manager{}, + h := newTestController(t) + controller := h.controller + controller.podGroupManager = &fakePodGroupManager{ + createOrUpdateFunc: func(_ context.Context, _ *workloadv1alpha1.ModelServing, _ string) (error, time.Duration) { + return fmt.Errorf("retry"), 50 * time.Millisecond + }, } ms := &workloadv1alpha1.ModelServing{ ObjectMeta: metav1.ObjectMeta{ @@ -67,40 +216,12 @@ func TestCreateOrUpdatePodGroupByServingGroupRequeue(t *testing.T) { }, } - called := false - var delay time.Duration - patches := gomonkey.NewPatches() - patches.ApplyMethod(reflect.TypeOf(controller.podGroupManager), "CreateOrUpdatePodGroup", func(_ *podgroupmanager.Manager, _ context.Context, _ *workloadv1alpha1.ModelServing, _ string) (error, time.Duration) { - return fmt.Errorf("retry"), 2 * time.Second - }) - patches.ApplyPrivateMethod(reflect.TypeOf(controller), "enqueueModelServingAfter", func(_ *ModelServingController, _ *workloadv1alpha1.ModelServing, duration time.Duration) { - called = true - delay = duration - }) - defer patches.Reset() - err := controller.createOrUpdatePodGroupByServingGroup(context.Background(), ms, "ms-0") assert.NoError(t, err) - assert.True(t, called) - assert.Equal(t, 2*time.Second, delay) + h.expectQueuedKey(namespacedKey(ms.Namespace, ms.Name)) } func TestCreatePodAlreadyExistsRequeues(t *testing.T) { - kubeClient := kubefake.NewSimpleClientset() - informerFactory := informers.NewSharedInformerFactory(kubeClient, 0) - podInformer := informerFactory.Core().V1().Pods() - - stopCh := make(chan struct{}) - defer close(stopCh) - informerFactory.Start(stopCh) - informerFactory.WaitForCacheSync(stopCh) - - controller := &ModelServingController{ - kubeClientSet: kubeClient, - podsLister: podInformer.Lister(), - podsInformer: podInformer.Informer(), - } - ms := &workloadv1alpha1.ModelServing{ ObjectMeta: metav1.ObjectMeta{ Name: "ms", @@ -109,10 +230,20 @@ func TestCreatePodAlreadyExistsRequeues(t *testing.T) { }, } + h := newTestController(t, ms) + controller := h.controller + existing := &corev1.Pod{ ObjectMeta: metav1.ObjectMeta{ Name: "ms-entry-0", Namespace: "default", + Labels: map[string]string{ + workloadv1alpha1.ModelServingNameLabelKey: ms.Name, + workloadv1alpha1.GroupNameLabelKey: "ms-0", + workloadv1alpha1.RoleLabelKey: "role", + workloadv1alpha1.RoleIDKey: "role-0", + workloadv1alpha1.EntryLabelKey: utils.Entry, + }, OwnerReferences: []metav1.OwnerReference{ { APIVersion: workloadv1alpha1.SchemeGroupVersion.String(), @@ -123,9 +254,24 @@ func TestCreatePodAlreadyExistsRequeues(t *testing.T) { }, } - _, err := kubeClient.CoreV1().Pods("default").Create(context.Background(), existing, metav1.CreateOptions{}) + _, err := h.kubeClient.CoreV1().Pods("default").Create(context.Background(), existing, metav1.CreateOptions{}) assert.NoError(t, err) - assert.NoError(t, podInformer.Informer().GetIndexer().Add(existing)) + require.Eventually(t, func() bool { + _, err := controller.podsLister.Pods("default").Get(existing.Name) + return err == nil + }, 2*time.Second, 10*time.Millisecond) + drainQueue := func() { + for controller.workqueue.Len() > 0 { + item, shutdown := controller.workqueue.Get() + require.False(t, shutdown) + controller.workqueue.Done(item) + controller.workqueue.Forget(item) + } + } + drainQueue() + require.Eventually(t, func() bool { + return controller.workqueue.Len() == 0 + }, 2*time.Second, 10*time.Millisecond) newPod := existing.DeepCopy() newPod.OwnerReferences = []metav1.OwnerReference{ @@ -136,56 +282,44 @@ func TestCreatePodAlreadyExistsRequeues(t *testing.T) { }, } - called := false - patches := gomonkey.NewPatches() - patches.ApplyPrivateMethod(reflect.TypeOf(controller), "enqueueModelServingAfter", func(_ *ModelServingController, _ *workloadv1alpha1.ModelServing, _ time.Duration) { - called = true - }) - defer patches.Reset() - err = controller.createPod(context.Background(), ms, "ms-0", "role", "role-0", newPod, true, nil, "entry") assert.NoError(t, err) - assert.True(t, called) + h.expectQueuedKey(namespacedKey(ms.Namespace, ms.Name)) } func TestDeletePodGroupEnqueues(t *testing.T) { - controller := &ModelServingController{} - ms := &workloadv1alpha1.ModelServing{ - ObjectMeta: metav1.ObjectMeta{ - Name: "ms", - Namespace: "default", - UID: types.UID("ms-uid"), - }, - } - podGroup := &schedulingv1beta1.PodGroup{ - ObjectMeta: metav1.ObjectMeta{ - Name: "ms-0", - Namespace: "default", - Labels: map[string]string{ - workloadv1alpha1.ModelServingNameLabelKey: ms.Name, - workloadv1alpha1.GroupNameLabelKey: "ms-0", - }, - }, - } + ms := newModelServingForDeleteTest("default", "ms") + h := newTestController(t, ms) + controller := h.controller + + require.Eventually(t, func() bool { + _, err := controller.modelServingLister.ModelServings(ms.Namespace).Get(ms.Name) + return err == nil + }, 2*time.Second, 10*time.Millisecond) + drainWorkqueue(t, controller.workqueue) + assertQueueEmpty(t, controller.workqueue) + + podGroup := newPodGroupForDeleteTest(ms, "ms-0", ms.UID) + controller.deletePodGroup(podGroup) + h.expectQueuedKey(namespacedKey(ms.Namespace, ms.Name)) +} - called := false - patches := gomonkey.NewPatches() - patches.ApplyPrivateMethod(reflect.TypeOf(controller), "getModelServingAndResourceDetails", func(_ *ModelServingController, _ metav1.Object) (*workloadv1alpha1.ModelServing, string, string, string) { - return ms, "ms-0", "", "" - }) - patches.ApplyPrivateMethod(reflect.TypeOf(controller), "shouldSkipHandling", func(_ *ModelServingController, _ *workloadv1alpha1.ModelServing, _ string, _ metav1.Object) bool { - return false - }) - patches.ApplyPrivateMethod(reflect.TypeOf(controller), "handleDeletionInProgress", func(_ *ModelServingController, _ *workloadv1alpha1.ModelServing, _ string, _ string, _ string) bool { - return false - }) - patches.ApplyPrivateMethod(reflect.TypeOf(controller), "enqueueModelServing", func(_ *ModelServingController, _ *workloadv1alpha1.ModelServing) { - called = true - }) - defer patches.Reset() +func TestDeletePodGroupOwnerMismatchDoesNotEnqueue(t *testing.T) { + ms := newModelServingForDeleteTest("default", "ms") + h := newTestController(t, ms) + controller := h.controller + + require.Eventually(t, func() bool { + _, err := controller.modelServingLister.ModelServings(ms.Namespace).Get(ms.Name) + return err == nil + }, 2*time.Second, 10*time.Millisecond) + drainWorkqueue(t, controller.workqueue) + assertQueueEmpty(t, controller.workqueue) + podGroup := newPodGroupForDeleteTest(ms, "ms-0", types.UID("other-uid")) controller.deletePodGroup(podGroup) - assert.True(t, called) + + assertQueueStaysEmpty(t, controller.workqueue, 200*time.Millisecond) } func TestIsServingGroupOutdated(t *testing.T) { @@ -4466,14 +4600,6 @@ func TestScaleDownRolesRunningStatusDeprioritized(t *testing.T) { assert.NoError(t, err) } - // Track which roles are deleted (targeted for deletion) - var deletedRoleIDs []string - patch := gomonkey.NewPatches() - patch.ApplyMethod(reflect.TypeOf(controller), "DeleteRole", func(_ *ModelServingController, ctx context.Context, ms *workloadv1alpha1.ModelServing, groupName, roleName, roleID string) { - deletedRoleIDs = append(deletedRoleIDs, roleID) - }) - defer patch.Reset() - // Target role (using first role's spec) targetRole := workloadv1alpha1.Role{ Name: "prefill", @@ -4492,20 +4618,45 @@ func TestScaleDownRolesRunningStatusDeprioritized(t *testing.T) { delete(allRoleIDs, remaining) } var expectedDeletedRoleIDs []string + var expectedDeleteSelectors []string for id := range allRoleIDs { expectedDeletedRoleIDs = append(expectedDeletedRoleIDs, id) + expectedDeleteSelectors = append(expectedDeleteSelectors, labels.SelectorFromSet(map[string]string{ + workloadv1alpha1.GroupNameLabelKey: groupName, + workloadv1alpha1.RoleLabelKey: "prefill", + workloadv1alpha1.RoleIDKey: id, + }).String()) } - // Verify correct number of deletions - numToDelete := len(tt.existingIndices) - tt.expectedCount - assert.Equal(t, numToDelete, len(deletedRoleIDs), - "Expected %d deletions, got %d", numToDelete, len(deletedRoleIDs)) + var actualDeletedRoleIDs []string + var actualDeleteSelectors []string + for _, action := range kubeClient.Actions() { + if !action.Matches("delete-collection", "pods") { + continue + } + deleteAction, ok := action.(kubetesting.DeleteCollectionAction) + require.True(t, ok) + actualDeleteSelectors = append(actualDeleteSelectors, deleteAction.GetListRestrictions().Labels.String()) + } + for _, idx := range tt.existingIndices { + roleID := fmt.Sprintf("prefill-%d", idx) + if controller.store.GetRoleStatus(nsn, groupName, "prefill", roleID) == datastore.RoleDeleting { + actualDeletedRoleIDs = append(actualDeletedRoleIDs, roleID) + } + } + + // Verify correct roles were marked deleting and targeted through pod delete actions. + assert.Len(t, actualDeletedRoleIDs, len(tt.existingIndices)-tt.expectedCount) - // Verify the correct roles were targeted for deletion - sort.Strings(deletedRoleIDs) + sort.Strings(actualDeletedRoleIDs) sort.Strings(expectedDeletedRoleIDs) - assert.Equal(t, expectedDeletedRoleIDs, deletedRoleIDs, - "%s: Expected deleted roles %v, got %v", tt.description, expectedDeletedRoleIDs, deletedRoleIDs) + assert.Equal(t, expectedDeletedRoleIDs, actualDeletedRoleIDs, + "%s: expected deleted roles %v, got %v", tt.description, expectedDeletedRoleIDs, actualDeletedRoleIDs) + + sort.Strings(actualDeleteSelectors) + sort.Strings(expectedDeleteSelectors) + assert.Equal(t, expectedDeleteSelectors, actualDeleteSelectors, + "%s: expected pod delete selectors %v, got %v", tt.description, expectedDeleteSelectors, actualDeleteSelectors) }) } } @@ -5727,30 +5878,10 @@ func TestDeleteRoleRollbackOnFailure(t *testing.T) { volcanoClient := volcanofake.NewSimpleClientset() apiextfake := apiextfake.NewSimpleClientset(testhelper.CreatePodGroupCRD()) - // Create informer factories - kubeInformerFactory := informers.NewSharedInformerFactory(client, 0) - kthenaInformerFactory := informersv1alpha1.NewSharedInformerFactory(kthenaClient, 0) - // Create controller controller, err := NewModelServingController(client, kthenaClient, volcanoClient, apiextfake) assert.NoError(t, err) - stop := make(chan struct{}) - defer close(stop) - - go controller.Run(context.Background(), 1) - - // Start informers - kthenaInformerFactory.Start(stop) - kubeInformerFactory.Start(stop) - - // Wait for cache sync - cache.WaitForCacheSync(stop, - controller.modelServingsInformer.HasSynced, - controller.podsInformer.HasSynced, - controller.servicesInformer.HasSynced, - ) - if tt.podDeletionError != nil { client.PrependReactor("delete-collection", "pods", func(action kubetesting.Action) (handled bool, ret runtime.Object, err error) { return true, nil, tt.podDeletionError @@ -5805,6 +5936,9 @@ func TestDeleteRoleRollbackOnFailure(t *testing.T) { }, } + drainWorkqueue(t, controller.workqueue) + assertQueueEmpty(t, controller.workqueue) + _, err = client.CoreV1().Pods("default").Create(context.TODO(), pod, metav1.CreateOptions{}) assert.NoError(t, err) @@ -5813,23 +5947,45 @@ func TestDeleteRoleRollbackOnFailure(t *testing.T) { err = controller.servicesInformer.GetIndexer().Add(service) assert.NoError(t, err) - queue := []string{} - patch := gomonkey.NewPatches() - patch.ApplyPrivateMethod(reflect.TypeOf(&ModelServingController{}), "enqueueModelServing", func(ms *workloadv1alpha1.ModelServing, duration time.Duration) { - queue = append(queue, ms.Name) - }) - defer patch.Reset() + startAction := len(client.Actions()) controller.DeleteRole(context.Background(), ms, groupName, roleName, roleID) finalStatus := controller.store.GetRoleStatus(nsn, groupName, roleName, roleID) assert.Equal(t, tt.expectedFinalStatus, finalStatus) - queueLen := len(queue) if tt.expectEnqueueCalled { - assert.True(t, queueLen > 0, "should enqueue") + assertQueuedKey(t, controller.workqueue, namespacedKey(ms.Namespace, ms.Name)) + assertQueueEmpty(t, controller.workqueue) + } else { + assertQueueStaysEmpty(t, controller.workqueue, 100*time.Millisecond) + } + + expectedDeleteSelector := labels.SelectorFromSet(map[string]string{ + workloadv1alpha1.GroupNameLabelKey: groupName, + workloadv1alpha1.RoleLabelKey: roleName, + workloadv1alpha1.RoleIDKey: roleID, + }).String() + var podDeleteSelectors []string + var serviceDeleteNames []string + for _, action := range client.Actions()[startAction:] { + switch { + case action.Matches("delete-collection", "pods"): + deleteAction, ok := action.(kubetesting.DeleteCollectionAction) + require.True(t, ok) + podDeleteSelectors = append(podDeleteSelectors, deleteAction.GetListRestrictions().Labels.String()) + case action.Matches("delete", "services"): + deleteAction, ok := action.(kubetesting.DeleteAction) + require.True(t, ok) + serviceDeleteNames = append(serviceDeleteNames, deleteAction.GetName()) + } + } + + assert.Equal(t, []string{expectedDeleteSelector}, podDeleteSelectors) + if tt.podDeletionError != nil { + assert.Empty(t, serviceDeleteNames) } else { - assert.Equal(t, 0, queueLen, "should not enqueue") + assert.Equal(t, []string{service.Name}, serviceDeleteNames) } }) } @@ -6170,6 +6326,7 @@ func TestDeleteServingGroupRollbackOnFailure(t *testing.T) { podDeletionError error serviceDeletionError error expectedFinalStatus datastore.ServingGroupStatus + expectError bool expectEnqueueCalled bool description string }{ @@ -6180,6 +6337,7 @@ func TestDeleteServingGroupRollbackOnFailure(t *testing.T) { podDeletionError: nil, serviceDeletionError: nil, expectedFinalStatus: datastore.ServingGroupRunning, + expectError: true, expectEnqueueCalled: true, description: "failed to delete pod group, should rollback to original status and re-enqueue", }, @@ -6190,6 +6348,7 @@ func TestDeleteServingGroupRollbackOnFailure(t *testing.T) { podDeletionError: fmt.Errorf("failed to delete pods"), serviceDeletionError: nil, expectedFinalStatus: datastore.ServingGroupCreating, + expectError: true, expectEnqueueCalled: true, description: "failed to delete pods, should rollback to original status and re-enqueue", }, @@ -6200,6 +6359,7 @@ func TestDeleteServingGroupRollbackOnFailure(t *testing.T) { podDeletionError: nil, serviceDeletionError: fmt.Errorf("failed to delete services"), expectedFinalStatus: datastore.ServingGroupRunning, + expectError: true, expectEnqueueCalled: true, description: "failed to delete services, should rollback to original status and re-enqueue", }, @@ -6210,6 +6370,7 @@ func TestDeleteServingGroupRollbackOnFailure(t *testing.T) { podDeletionError: nil, serviceDeletionError: nil, expectedFinalStatus: datastore.ServingGroupDeleting, + expectError: false, expectEnqueueCalled: false, description: "all deletions succeed, no rollback needed", }, @@ -6220,44 +6381,22 @@ func TestDeleteServingGroupRollbackOnFailure(t *testing.T) { client := kubefake.NewSimpleClientset() kthenaClient := kthenafake.NewSimpleClientset() volcanoClient := volcanofake.NewSimpleClientset() - apiextfake := apiextfake.NewSimpleClientset(testhelper.CreatePodGroupCRD()) - - // Create informer factories - kubeInformerFactory := informers.NewSharedInformerFactory(client, 0) - kthenaInformerFactory := informersv1alpha1.NewSharedInformerFactory(kthenaClient, 0) + apiextClient := apiextfake.NewSimpleClientset(testhelper.CreatePodGroupCRD()) // Create controller - controller, err := NewModelServingController(client, kthenaClient, volcanoClient, apiextfake) + controller, err := NewModelServingController(client, kthenaClient, volcanoClient, apiextClient) assert.NoError(t, err) - stop := make(chan struct{}) - defer close(stop) - - go controller.Run(context.Background(), 1) - - // Start informers - kthenaInformerFactory.Start(stop) - kubeInformerFactory.Start(stop) - - // Wait for cache sync - cache.WaitForCacheSync(stop, - controller.modelServingsInformer.HasSynced, - controller.podsInformer.HasSynced, - controller.servicesInformer.HasSynced, - ) + podGroupManager := podgroupmanager.NewManager(client, volcanoClient, apiextClient, nil) + controller.podGroupManager = &fakePodGroupManager{ + deleteFunc: podGroupManager.DeletePodGroup, + } - // Mock PodGroup deletion behavior using gomonkey - var patch *gomonkey.Patches if tt.podGroupDeletionError != nil { - patch = gomonkey.ApplyMethod(reflect.TypeOf(controller.podGroupManager), "DeletePodGroup", func(_ *podgroupmanager.Manager, ctx context.Context, ms *workloadv1alpha1.ModelServing, servingGroupName string) error { - return tt.podGroupDeletionError - }) - } else { - patch = gomonkey.ApplyMethod(reflect.TypeOf(controller.podGroupManager), "DeletePodGroup", func(_ *podgroupmanager.Manager, ctx context.Context, ms *workloadv1alpha1.ModelServing, servingGroupName string) error { - return nil + volcanoClient.PrependReactor("delete", "podgroups", func(action kubetesting.Action) (handled bool, ret runtime.Object, err error) { + return true, nil, tt.podGroupDeletionError }) } - defer patch.Reset() if tt.podDeletionError != nil { client.PrependReactor("delete-collection", "pods", func(action kubetesting.Action) (handled bool, ret runtime.Object, err error) { @@ -6307,31 +6446,79 @@ func TestDeleteServingGroupRollbackOnFailure(t *testing.T) { }, } + drainWorkqueue(t, controller.workqueue) + assertQueueEmpty(t, controller.workqueue) + _, err = client.CoreV1().Pods("default").Create(context.TODO(), pod, metav1.CreateOptions{}) assert.NoError(t, err) + err = controller.podsInformer.GetIndexer().Add(pod) + assert.NoError(t, err) _, err = client.CoreV1().Services("default").Create(context.TODO(), service, metav1.CreateOptions{}) assert.NoError(t, err) err = controller.servicesInformer.GetIndexer().Add(service) assert.NoError(t, err) - queue := []string{} - patchEnqueue := gomonkey.NewPatches() - patchEnqueue.ApplyPrivateMethod(reflect.TypeOf(&ModelServingController{}), "enqueueModelServing", func(ms *workloadv1alpha1.ModelServing, duration time.Duration) { - queue = append(queue, ms.Name) - }) - defer patchEnqueue.Reset() + startAction := len(client.Actions()) + startVolcanoAction := len(volcanoClient.Actions()) - _ = controller.deleteServingGroup(context.Background(), ms, sgName) + err = controller.deleteServingGroup(context.Background(), ms, sgName) + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } finalStatus := controller.store.GetServingGroupStatus(nsn, sgName) assert.Equal(t, tt.expectedFinalStatus, finalStatus, "final ServingGroup status should match expected") - queueLen := len(queue) if tt.expectEnqueueCalled { - assert.True(t, queueLen > 0, "should have enqueued for reconcile") + assertQueuedKey(t, controller.workqueue, namespacedKey(ms.Namespace, ms.Name)) + assertQueueEmpty(t, controller.workqueue) + } else { + assertQueueStaysEmpty(t, controller.workqueue, 100*time.Millisecond) + } + + expectedDeleteSelector := labels.SelectorFromSet(map[string]string{ + workloadv1alpha1.GroupNameLabelKey: sgName, + }).String() + var podDeleteSelectors []string + var serviceDeleteNames []string + for _, action := range client.Actions()[startAction:] { + switch { + case action.Matches("delete-collection", "pods"): + deleteAction, ok := action.(kubetesting.DeleteCollectionAction) + require.True(t, ok) + podDeleteSelectors = append(podDeleteSelectors, deleteAction.GetListRestrictions().Labels.String()) + case action.Matches("delete", "services"): + deleteAction, ok := action.(kubetesting.DeleteAction) + require.True(t, ok) + serviceDeleteNames = append(serviceDeleteNames, deleteAction.GetName()) + } + } + var podGroupDeleteNames []string + for _, action := range volcanoClient.Actions()[startVolcanoAction:] { + if !action.Matches("delete", "podgroups") { + continue + } + deleteAction, ok := action.(kubetesting.DeleteAction) + require.True(t, ok) + podGroupDeleteNames = append(podGroupDeleteNames, deleteAction.GetName()) + } + + assert.Equal(t, []string{sgName}, podGroupDeleteNames) + + if tt.podGroupDeletionError != nil { + assert.Empty(t, podDeleteSelectors) + assert.Empty(t, serviceDeleteNames) + return + } + + assert.Equal(t, []string{expectedDeleteSelector}, podDeleteSelectors) + if tt.podDeletionError != nil { + assert.Empty(t, serviceDeleteNames) } else { - assert.Equal(t, 0, queueLen, "should not have enqueued for reconcile") + assert.Equal(t, []string{service.Name}, serviceDeleteNames) } }) }