diff --git a/scheduler/pkg/kafka/gateway/infer.go b/scheduler/pkg/kafka/gateway/infer.go index 0134629715..f9d331ccd7 100644 --- a/scheduler/pkg/kafka/gateway/infer.go +++ b/scheduler/pkg/kafka/gateway/infer.go @@ -77,9 +77,25 @@ func GetIntConfigMapValue(configMap kafka.ConfigMap, key string, defaultValue in return defaultValue, nil } - value, err := strconv.Atoi(configMapValue.(string)) + if configMapValueInt, ok := configMapValue.(int); ok { + if configMapValueInt < 0 { + return -1, fmt.Errorf("%s: %d must not be negative", key, configMapValueInt) + } + return configMapValueInt, nil + } + + configMapValueStr, ok := configMapValue.(string) + if !ok { + return defaultValue, fmt.Errorf("%s key has wrong type: %T", key, configMapValue) + } + + value, err := strconv.Atoi(configMapValueStr) if err != nil { - return 0, err + return -1, fmt.Errorf("invalid value %s in %s with error: %v", configMapValueStr, key, err) + } + + if value < 0 { + return -1, fmt.Errorf("%s: %d must be bigger than 0", key, value) } return value, nil @@ -96,19 +112,19 @@ func NewInferKafkaHandler( ) (*InferKafkaHandler, error) { defaultReplicationFactor, err := util.GetIntEnvar(envDefaultReplicationFactor, defaultReplicationFactor) if err != nil { - return nil, err + return nil, fmt.Errorf("error getting default replication factor: %v", err) } replicationFactor, err := GetIntConfigMapValue(topicsConfigMap, replicationFactorKey, defaultReplicationFactor) if err != nil { - return nil, err + return nil, fmt.Errorf("invalid Kafka topic configuration: %v", err) } defaultNumPartitions, err := util.GetIntEnvar(envDefaultNumPartitions, defaultNumPartitions) if err != nil { - return nil, err + return nil, fmt.Errorf("error getting default number of partitions: %v", err) } numPartitions, err := GetIntConfigMapValue(topicsConfigMap, numPartitionsKey, defaultNumPartitions) if err != nil { - return nil, err + return nil, fmt.Errorf("invalid Kafka topic configuration: %w", err) } tlsClientOptions, err := util.CreateTLSClientOptions() if err != nil { diff --git a/scheduler/pkg/kafka/gateway/infer_test.go b/scheduler/pkg/kafka/gateway/infer_test.go new file mode 100644 index 0000000000..838115466d --- /dev/null +++ b/scheduler/pkg/kafka/gateway/infer_test.go @@ -0,0 +1,94 @@ +/* +Copyright (c) 2024 Seldon Technologies Ltd. + +Use of this software is governed by +(1) the license included in the LICENSE file or +(2) if the license included in the LICENSE file is the Business Source License 1.1, +the Change License after the Change Date as each is defined in accordance with the LICENSE file. +*/ + +package gateway + +import ( + "testing" + + "github.com/confluentinc/confluent-kafka-go/v2/kafka" + . "github.com/onsi/gomega" +) + +func TestGetIntConfigMapValue(t *testing.T) { + g := NewGomegaWithT(t) + + type test struct { + name string + configMap kafka.ConfigMap + key string + defaultValue int + wantValue int + wantError bool + } + + tests := []test{ + { + name: "success: string", + configMap: kafka.ConfigMap{"replicationFactor": "5"}, + key: replicationFactorKey, + defaultValue: 0, + wantValue: 5, + wantError: false, + }, + { + name: "fail: negative string", + configMap: kafka.ConfigMap{"replicationFactor": "-5"}, + key: replicationFactorKey, + defaultValue: 0, + wantValue: 0, + wantError: true, + }, + { + name: "fail: float string value", + configMap: kafka.ConfigMap{"replicationFactor": "5.0"}, + key: replicationFactorKey, + defaultValue: 0, + wantValue: 0, + wantError: true, + }, + { + name: "fail: string value", + configMap: kafka.ConfigMap{"replicationFactor": "---"}, + key: replicationFactorKey, + defaultValue: 0, + wantValue: 0, + wantError: true, + }, + { + name: "success: integer", + configMap: kafka.ConfigMap{"replicationFactor": 5}, + key: replicationFactorKey, + defaultValue: 0, + wantValue: 5, + wantError: false, + }, + { + name: "fail: negative integer", + configMap: kafka.ConfigMap{"replicationFactor": -5}, + key: replicationFactorKey, + defaultValue: 0, + wantValue: 0, + wantError: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + gotInt, err := GetIntConfigMapValue(test.configMap, test.key, test.defaultValue) + if test.wantError { + g.Expect(err).ToNot(BeNil()) + } else { + g.Expect(gotInt).To(Equal(test.wantValue)) + } + }) + + } + +}