diff --git a/pkg/kthena-router/datastore/store.go b/pkg/kthena-router/datastore/store.go index 852ec52dd..97beab116 100644 --- a/pkg/kthena-router/datastore/store.go +++ b/pkg/kthena-router/datastore/store.go @@ -981,6 +981,11 @@ func selectFromWeightedSlice(weights []uint32) int { totalWeight += int(weight) } + // Guard against divide-by-zero when all weights are zero + if totalWeight == 0 { + return 0 + } + randomNum := rng.Intn(totalWeight) for i, weight := range weights { diff --git a/pkg/kthena-router/datastore/store_test.go b/pkg/kthena-router/datastore/store_test.go index 20f1ca088..a22f6c6bb 100644 --- a/pkg/kthena-router/datastore/store_test.go +++ b/pkg/kthena-router/datastore/store_test.go @@ -1217,3 +1217,69 @@ func TestStoreMatchModelServer(t *testing.T) { }) } } + +func TestSelectFromWeightedSlice(t *testing.T) { + tests := []struct { + name string + weights []uint32 + expectPanic bool + expectedResult int // only checked when not expecting panic and result is deterministic + checkRange bool + minResult int + maxResult int + }{ + { + name: "all weights zero does not panic", + weights: []uint32{0, 0, 0}, + expectPanic: false, + expectedResult: 0, + }, + { + name: "single zero weight does not panic", + weights: []uint32{0}, + expectPanic: false, + expectedResult: 0, + }, + { + name: "normal weighted selection returns valid index", + weights: []uint32{1, 2, 3}, + expectPanic: false, + checkRange: true, + minResult: 0, + maxResult: 2, + }, + { + name: "single non-zero weight returns index 0", + weights: []uint32{5}, + expectPanic: false, + expectedResult: 0, + }, + { + name: "empty weights returns 0", + weights: []uint32{}, + expectPanic: false, + expectedResult: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.expectPanic { + assert.Panics(t, func() { + selectFromWeightedSlice(tt.weights) + }) + return + } + + assert.NotPanics(t, func() { + result := selectFromWeightedSlice(tt.weights) + if tt.checkRange { + assert.GreaterOrEqual(t, result, tt.minResult) + assert.LessOrEqual(t, result, tt.maxResult) + } else { + assert.Equal(t, tt.expectedResult, result) + } + }) + }) + } +}