From 336dcfaaadc38b71b7724d7041cf3e635d7823fb Mon Sep 17 00:00:00 2001 From: Ryan Gang Date: Thu, 26 Sep 2024 21:25:44 +0530 Subject: [PATCH] feat: add ApiVersionsAssertion file --- .../apiversions_response_assertion.go | 71 +++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 internal/assertions/apiversions_response_assertion.go diff --git a/internal/assertions/apiversions_response_assertion.go b/internal/assertions/apiversions_response_assertion.go new file mode 100644 index 0000000..a351c94 --- /dev/null +++ b/internal/assertions/apiversions_response_assertion.go @@ -0,0 +1,71 @@ +package assertions + +import ( + "fmt" + + kafkaapi "github.com/codecrafters-io/kafka-tester/protocol/api" + "github.com/codecrafters-io/tester-utils/logger" +) + +type ApiVersionsResponseAssertion struct { + ActualValue kafkaapi.ApiVersionsResponse + ExpectedValue kafkaapi.ApiVersionsResponse +} + +func NewApiVersionsResponseAssertion(actualValue kafkaapi.ApiVersionsResponse, expectedValue kafkaapi.ApiVersionsResponse) ApiVersionsResponseAssertion { + return ApiVersionsResponseAssertion{ActualValue: actualValue, ExpectedValue: expectedValue} +} + +var apiKeyNames = map[int16]string{ + 1: "FETCH", + 18: "API_VERSIONS", + 75: "DESCRIBE_TOPIC_PARTITIONS", +} + +var errorCodes = map[int]string{ + 0: "NO_ERROR", +} + +func (a ApiVersionsResponseAssertion) Evaluate(fields []string, AssertApiVersionsResponseKey bool, logger *logger.Logger) error { + if Contains(fields, "ErrorCode") { + if a.ActualValue.ErrorCode != a.ExpectedValue.ErrorCode { + return fmt.Errorf("Expected %s to be %d, got %d", "ErrorCode", a.ExpectedValue.ErrorCode, a.ActualValue.ErrorCode) + } + + errorCodeName, ok := errorCodes[int(a.ActualValue.ErrorCode)] + if !ok { + errorCodeName = "UNKNOWN" + } + logger.Successf("✓ Error code: %d (%s)", a.ActualValue.ErrorCode, errorCodeName) + } + + if AssertApiVersionsResponseKey { + if len(a.ActualValue.ApiKeys) < len(a.ExpectedValue.ApiKeys) { + return fmt.Errorf("Expected API keys array to include atleast %d keys, got %d", len(a.ExpectedValue.ApiKeys), len(a.ActualValue.ApiKeys)) + } + logger.Successf("✓ API keys array length: %d", len(a.ActualValue.ApiKeys)) + + for _, expectedApiVersionKey := range a.ExpectedValue.ApiKeys { + found := false + for _, actualApiVersionKey := range a.ActualValue.ApiKeys { + if actualApiVersionKey.ApiKey == expectedApiVersionKey.ApiKey { + found = true + if actualApiVersionKey.MaxVersion < expectedApiVersionKey.MaxVersion { + return fmt.Errorf("Expected API version %v to be supported for %s, got %v", expectedApiVersionKey.MaxVersion, apiKeyNames[expectedApiVersionKey.ApiKey], actualApiVersionKey.MaxVersion) + } + logger.Successf("✓ API version %v is supported for %s", actualApiVersionKey.MaxVersion, apiKeyNames[expectedApiVersionKey.ApiKey]) + + if actualApiVersionKey.MinVersion < expectedApiVersionKey.MinVersion { + return fmt.Errorf("Expected API version %v to be supported for %s, got %v", expectedApiVersionKey.MinVersion, apiKeyNames[expectedApiVersionKey.ApiKey], actualApiVersionKey.MinVersion) + } + logger.Successf("✓ API version %v is supported for %s", actualApiVersionKey.MinVersion, apiKeyNames[expectedApiVersionKey.ApiKey]) + } + } + if !found { + return fmt.Errorf("Expected APIVersionsResponseKey array to include API key %d (%s)", expectedApiVersionKey.ApiKey, apiKeyNames[expectedApiVersionKey.ApiKey]) + } + } + } + + return nil +}