From a02c4ecf4ad76a0c1e8fa01f623e0f847201789d Mon Sep 17 00:00:00 2001 From: Niaz Khan Date: Tue, 28 Apr 2026 09:06:11 +0100 Subject: [PATCH] chore: add unit tests for controller authz failures --- internal/controllers/cmk/authz_test.go | 478 +++++++++++++++++++++++++ internal/daemon/mux.go | 5 + internal/testutils/api.go | 4 +- internal/testutils/authz.go | 135 +++++++ 4 files changed, 620 insertions(+), 2 deletions(-) create mode 100644 internal/controllers/cmk/authz_test.go diff --git a/internal/controllers/cmk/authz_test.go b/internal/controllers/cmk/authz_test.go new file mode 100644 index 00000000..93f76fa9 --- /dev/null +++ b/internal/controllers/cmk/authz_test.go @@ -0,0 +1,478 @@ +//go:build !unit + +package cmk_test + +import ( + "context" + "fmt" + "net/http" + "strings" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + + multitenancy "github.com/bartventer/gorm-multitenancy/v8" + + "github.com/openkcm/cmk/internal/authz" + "github.com/openkcm/cmk/internal/config" + "github.com/openkcm/cmk/internal/constants" + "github.com/openkcm/cmk/internal/daemon" + "github.com/openkcm/cmk/internal/repo" + "github.com/openkcm/cmk/internal/repo/sql" + "github.com/openkcm/cmk/internal/testutils" + cmkcontext "github.com/openkcm/cmk/utils/context" +) + +func startAPIAuthz(t *testing.T) (*multitenancy.DB, *daemon.ServeMux, string) { + t.Helper() + + db, tenants, dbCfg := testutils.NewTestDB(t, testutils.TestDBConfig{ + CreateDatabase: true, + }) + + sv := testutils.NewAPIServer(t, db, testutils.TestAPIServerConfig{ + Config: config.Config{Database: dbCfg}, + }) + + return db, sv, tenants[0] +} + +// authzEndpoints returns all restricted API endpoints for authz testing. +// Each entry maps to a restriction in authz.RestrictionsByAPI. +func authzEndpoints() []testutils.AuthzTestEndpoint { + keyID := uuid.New().String() + keyConfigID := uuid.New().String() + systemID := uuid.New().String() + workflowID := uuid.New().String() + groupID := uuid.New().String() + + return []testutils.AuthzTestEndpoint{ + // --- Keys --- + { + Method: http.MethodGet, + Endpoint: "/keys?keyConfigurationID=" + keyConfigID, + }, + { + Method: http.MethodPost, + Endpoint: "/keys", + Body: `{ + "name": "test-key", + "keyConfigurationID": "` + keyConfigID + `" + }`, + }, + { + Method: http.MethodGet, + Endpoint: "/keys/" + keyID, + }, + { + Method: http.MethodPatch, + Endpoint: "/keys/" + keyID, + Body: `{"name": "updated"}`, + }, + { + Method: http.MethodDelete, + Endpoint: "/keys/" + keyID, + }, + { + Method: http.MethodGet, + Endpoint: "/keys/" + keyID + "/importParams", + }, + { + Method: http.MethodPost, + Endpoint: "/keys/" + keyID + "/importKeyMaterial", + Body: `{"encryptedKeyMaterial": "dGVzdA==", "importToken": "dGVzdA=="}`, + }, + { + Method: http.MethodGet, + Endpoint: "/keys/" + keyID + "/versions", + }, + // NOTE: POST /keys/{keyID}/versions and GET /keys/{keyID}/versions/{version} + // are defined in the authz mapping but not registered as API routes. + + // --- Key Labels --- + { + Method: http.MethodGet, + Endpoint: "/key/" + keyID + "/labels", + }, + { + Method: http.MethodPost, + Endpoint: "/key/" + keyID + "/labels", + Body: `{"labels": {"env": "test"}}`, + }, + { + Method: http.MethodDelete, + Endpoint: "/key/" + keyID + "/label/testlabel", + }, + + // --- Key Configurations --- + { + Method: http.MethodGet, + Endpoint: "/keyConfigurations", + }, + { + Method: http.MethodPost, + Endpoint: "/keyConfigurations", + Body: `{ + "name": "test-kc", + "keyAlgorithm": "AES", + "provider": "TEST" + }`, + }, + { + Method: http.MethodGet, + Endpoint: "/keyConfigurations/" + keyConfigID, + }, + { + Method: http.MethodPatch, + Endpoint: "/keyConfigurations/" + keyConfigID, + Body: `{"name": "updated"}`, + }, + { + Method: http.MethodDelete, + Endpoint: "/keyConfigurations/" + keyConfigID, + }, + { + Method: http.MethodGet, + Endpoint: "/keyConfigurations/" + keyConfigID + "/tags", + }, + { + Method: http.MethodPut, + Endpoint: "/keyConfigurations/" + keyConfigID + "/tags", + Body: `{"tags": {"env": "test"}}`, + }, + { + Method: http.MethodGet, + Endpoint: "/keyConfigurations/" + keyConfigID + "/certificates", + }, + + // --- Systems --- + { + Method: http.MethodGet, + Endpoint: "/systems", + }, + { + Method: http.MethodGet, + Endpoint: "/systems/" + systemID, + }, + { + Method: http.MethodPatch, + Endpoint: "/systems/" + systemID + "/link", + Body: `{"keyConfigurationID": "` + keyConfigID + `"}`, + }, + { + Method: http.MethodDelete, + Endpoint: "/systems/" + systemID + "/link", + }, + { + Method: http.MethodPost, + Endpoint: "/systems/" + systemID + "/recoveryActions", + Body: `{"action": "RECOVER"}`, + }, + { + Method: http.MethodGet, + Endpoint: "/systems/" + systemID + "/recoveryActions", + }, + + // --- Workflows --- + { + Method: http.MethodPost, + Endpoint: "/workflows", + Body: `{ + "actionType": "UNLINK", + "artifactID": "` + systemID + `", + "artifactType": "SYSTEM" + }`, + }, + { + Method: http.MethodGet, + Endpoint: "/workflows", + }, + { + Method: http.MethodPost, + Endpoint: "/workflows/check", + Body: `{ + "actionType": "UNLINK", + "artifactID": "` + systemID + `", + "artifactType": "SYSTEM" + }`, + }, + { + Method: http.MethodGet, + Endpoint: "/workflows/" + workflowID, + }, + // NOTE: GET/POST /workflows/{workflowID}/approvers are defined in the + // authz mapping but not registered as API routes. + { + Method: http.MethodPost, + Endpoint: "/workflows/" + workflowID + "/state", + Body: `{"state": "APPROVED"}`, + }, + + // --- Groups --- + { + Method: http.MethodGet, + Endpoint: "/groups", + }, + { + Method: http.MethodPost, + Endpoint: "/groups", + Body: `{ + "name": "test-group", + "iamIdentifier": "test-iam-id", + "role": "KEY_ADMINISTRATOR" + }`, + }, + { + Method: http.MethodPost, + Endpoint: "/groups/iamCheck", + Body: `{"iamIdentifiers": ["test-id"]}`, + }, + { + Method: http.MethodGet, + Endpoint: "/groups/" + groupID, + }, + { + Method: http.MethodPatch, + Endpoint: "/groups/" + groupID, + Body: `{"name": "updated"}`, + }, + { + Method: http.MethodDelete, + Endpoint: "/groups/" + groupID, + }, + + // --- Tenant Configurations --- + { + Method: http.MethodGet, + Endpoint: "/tenantConfigurations/keystores", + }, + { + Method: http.MethodGet, + Endpoint: "/tenantConfigurations/workflow", + }, + { + Method: http.MethodPatch, + Endpoint: "/tenantConfigurations/workflow", + Body: `{"enabled": true}`, + }, + + // --- Tenant Info --- + { + Method: http.MethodGet, + Endpoint: "/tenantInfo", + }, + } +} + +// TestAuthzEndpointCoverage ensures that every registered API endpoint with an +// authz restriction is covered by authzEndpoints(). This fails when a new +// restricted endpoint is added to the server but not to the test fixture, +// preventing gaps in authz test coverage. +func TestAuthzEndpointCoverage(t *testing.T) { + _, sv, tenant := startAPIAuthz(t) + endpoints := authzEndpoints() + + // Build the set of patterns already covered by authzEndpoints. + covered := make(map[string]struct{}) + for _, ep := range endpoints { + req := testutils.NewHTTPRequest(t, testutils.RequestOptions{ + Method: ep.Method, + Endpoint: ep.Endpoint, + Tenant: tenant, + }) + + _, pattern := sv.Handler(req) + pattern = strings.Replace(pattern, sv.BaseURL, "", 1) + covered[pattern] = struct{}{} + } + + // For each authz restriction, check that its route is either unregistered + // (not a real API route) or already covered by the test fixture. + for apiKey := range authz.RestrictionsByAPI { + parts := strings.SplitN(apiKey, " ", 2) + method, path := parts[0], parts[1] + + // Build a concrete URL by replacing path params with dummy UUIDs. + concrete := testutils.SubstitutePathParams(path) + + req := testutils.NewHTTPRequest(t, testutils.RequestOptions{ + Method: method, + Endpoint: concrete, + Tenant: tenant, + }) + + _, pattern := sv.Handler(req) + pattern = strings.Replace(pattern, sv.BaseURL, "", 1) + + // If the resolved pattern doesn't match the restriction key the + // route is not registered on the ServeMux — skip it. + if pattern != apiKey { + continue + } + + assert.Contains(t, covered, apiKey, + "authz restriction %q is a registered route but has no entry in authzEndpoints(); "+ + "add a test entry to ensure authz is verified for this endpoint", apiKey) + } +} + +// TestAuthzBlocked verifies that each restricted endpoint returns 403 Forbidden +// when accessed by a role that does not have the required permission. +// The blocked roles are automatically derived from the authz policy data. +func TestAuthzBlocked(t *testing.T) { + db, sv, tenant := startAPIAuthz(t) + ctx := cmkcontext.CreateTenantContext(t.Context(), tenant) + r := sql.NewRepository(db) + + runAuthzBlockedTests(t, sv, tenant, r, ctx, authzEndpoints()) +} + +// runAuthzBlockedTests runs authorization failure tests for the provided +// endpoints. For each endpoint it uses the ServeMux to resolve the registered +// pattern, determines which roles should be blocked based on the policy data, +// and asserts that each blocked role receives 403 Forbidden. +func runAuthzBlockedTests( + t *testing.T, + sv *daemon.ServeMux, + tenant string, + r repo.Repo, + ctx context.Context, + endpoints []testutils.AuthzTestEndpoint, +) { + t.Helper() + + for _, ep := range endpoints { + req := testutils.NewHTTPRequest(t, testutils.RequestOptions{ //nolint:contextcheck + Method: ep.Method, + Endpoint: ep.Endpoint, + Tenant: tenant, + }) + + _, pattern := sv.Handler(req) + pattern = strings.Replace(pattern, sv.BaseURL, "", 1) + + restriction, exists := authz.RestrictionsByAPI[pattern] + if !exists { + t.Fatalf( + "no authz restriction found for pattern %q on %s %s", + pattern, ep.Method, ep.Endpoint, + ) + } + + blockedRoles := testutils.GetBlockedRoles( + restriction.APIResourceTypeName, restriction.APIAction, + ) + if len(blockedRoles) == 0 { + t.Logf("all roles are allowed for %q (%s:%s), skipping", + pattern, restriction.APIResourceTypeName, restriction.APIAction) + + continue + } + + for _, role := range blockedRoles { + testName := fmt.Sprintf( + "%s_%s_blocked_for_%s", ep.Method, testutils.CleanPath(ep.Endpoint), role, + ) + + t.Run(testName, func(t *testing.T) { + authClient := testutils.NewAuthClient(ctx, t, r, testutils.RoleAuthClientOpt(role)) + + w := testutils.MakeHTTPRequest(t, sv, testutils.RequestOptions{ //nolint:contextcheck + Method: ep.Method, + Endpoint: ep.Endpoint, + Tenant: tenant, + Body: testutils.WithBody(t, ep.Body), + AdditionalContext: authClient.GetClientMap(), + }) + + assert.Equal(t, http.StatusForbidden, w.Code, + "expected 403 for role %s on %s %s, got %d: %s", + role, ep.Method, ep.Endpoint, w.Code, w.Body.String()) + }) + } + } +} + +// TestAuthzAllowed is a sanity test that verifies each restricted endpoint does +// NOT return 403 Forbidden when accessed by an allowed role. This complements +// TestAuthzBlocked by confirming that authorization permits expected roles. +func TestAuthzAllowed(t *testing.T) { + db, sv, tenant := startAPIAuthz(t) + ctx := cmkcontext.CreateTenantContext(t.Context(), tenant) + r := sql.NewRepository(db) + + // Pre-create auth clients for all roles so that the authz loader + // picks them up on first tenant load (it caches per tenant). + authClients := map[constants.Role]testutils.AuthClientData{ + constants.KeyAdminRole: testutils.NewAuthClient(ctx, t, r, testutils.WithKeyAdminRole()), + constants.TenantAdminRole: testutils.NewAuthClient(ctx, t, r, testutils.WithTenantAdminRole()), + constants.TenantAuditorRole: testutils.NewAuthClient(ctx, t, r, testutils.WithAuditorRole()), + } + + runAuthzAllowedTests(t, sv, tenant, authzEndpoints(), authClients) +} + +// runAuthzAllowedTests is a sanity check: for each endpoint it picks the +// first allowed role and asserts the response is NOT 403 Forbidden. +func runAuthzAllowedTests( + t *testing.T, + sv *daemon.ServeMux, + tenant string, + endpoints []testutils.AuthzTestEndpoint, + authClients map[constants.Role]testutils.AuthClientData, +) { + t.Helper() + + for _, ep := range endpoints { + req := testutils.NewHTTPRequest(t, testutils.RequestOptions{ + Method: ep.Method, + Endpoint: ep.Endpoint, + Tenant: tenant, + }) + + _, pattern := sv.Handler(req) + pattern = strings.Replace(pattern, sv.BaseURL, "", 1) + + restriction, exists := authz.RestrictionsByAPI[pattern] + if !exists { + t.Fatalf( + "no authz restriction found for pattern %q on %s %s", + pattern, ep.Method, ep.Endpoint, + ) + } + + allowedRoles := testutils.GetAllowedRoles( + restriction.APIResourceTypeName, restriction.APIAction, + ) + if len(allowedRoles) == 0 { + continue + } + + // Sanity test: use the first allowed role and verify it is not blocked + role := allowedRoles[0] + testName := fmt.Sprintf( + "sanity_%s_%s_allowed_for_%s", ep.Method, testutils.CleanPath(ep.Endpoint), role, + ) + + t.Run(testName, func(t *testing.T) { + authClient := authClients[role] + + w := testutils.MakeHTTPRequest(t, sv, testutils.RequestOptions{ + Method: ep.Method, + Endpoint: ep.Endpoint, + Tenant: tenant, + Body: testutils.WithBody(t, ep.Body), + AdditionalContext: authClient.GetClientMap(), + }) + + // A non-403 means authz passed. A 403 with a non-FORBIDDEN code + // (e.g., ACTION_REQUIRE_WORKFLOW) is a business-logic denial, not authz. + if w.Code == http.StatusForbidden { + assert.NotContains(t, w.Body.String(), `"code":"FORBIDDEN"`, + "expected authz to allow role %s on %s %s, got authz 403: %s", + role, ep.Method, ep.Endpoint, w.Body.String()) + } + }) + } +} diff --git a/internal/daemon/mux.go b/internal/daemon/mux.go index b712ca8e..2ce3368c 100644 --- a/internal/daemon/mux.go +++ b/internal/daemon/mux.go @@ -23,6 +23,11 @@ func (m *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { m.httpServeMux.ServeHTTP(w, r) } +// Handler returns the handler and registered pattern that matches the request. +func (m *ServeMux) Handler(r *http.Request) (http.Handler, string) { + return m.httpServeMux.Handler(r) +} + func (m *ServeMux) HandleFunc( pattern string, handler func(http.ResponseWriter, *http.Request), diff --git a/internal/testutils/api.go b/internal/testutils/api.go index 3b77460c..890da4e0 100644 --- a/internal/testutils/api.go +++ b/internal/testutils/api.go @@ -50,7 +50,7 @@ func NewAPIServer( tb testing.TB, dbCon *multitenancy.DB, testCfg TestAPIServerConfig, -) cmkapi.ServeMux { +) *daemon.ServeMux { tb.Helper() cfg := testCfg.Config @@ -111,7 +111,7 @@ func NewAPIServer( func startAPIServer( tb testing.TB, controller *cmk.APIController, -) cmkapi.ServeMux { +) *daemon.ServeMux { tb.Helper() strictController := cmkapi.NewStrictHandlerWithOptions( diff --git a/internal/testutils/authz.go b/internal/testutils/authz.go index d077d36a..c4065248 100644 --- a/internal/testutils/authz.go +++ b/internal/testutils/authz.go @@ -2,11 +2,16 @@ package testutils import ( "context" + "fmt" + "io" + "slices" + "strings" "testing" "github.com/google/uuid" "github.com/openkcm/common-sdk/pkg/auth" + "github.com/openkcm/cmk/internal/authz" "github.com/openkcm/cmk/internal/constants" "github.com/openkcm/cmk/internal/model" "github.com/openkcm/cmk/internal/repo" @@ -172,3 +177,133 @@ func getClientData(identifier string, groupNames []string) *auth.ClientData { Groups: groupNames, } } + +// AuthzTestEndpoint defines an API endpoint to test for authorization failures. +type AuthzTestEndpoint struct { + // Method is the HTTP method (e.g., http.MethodGet) + Method string + // Endpoint is the URL path with any required path params filled in + // (e.g., "/keys?keyConfigurationID=xxx") + Endpoint string + // Body is an optional JSON request body for POST/PATCH/PUT requests + Body string +} + +// WithBody converts a JSON string to an io.Reader for use as a request body. +// Returns nil if the body is empty. +func WithBody(tb testing.TB, body string) io.Reader { + tb.Helper() + + if body == "" { + return nil + } + + return strings.NewReader(body) +} + +// allRoles returns all defined roles. +func allRoles() []constants.Role { + return []constants.Role{ + constants.KeyAdminRole, + constants.TenantAdminRole, + constants.TenantAuditorRole, + } +} + +// RoleAuthClientOpt maps a role to the corresponding AuthClientOpt. +func RoleAuthClientOpt(role constants.Role) AuthClientOpt { + switch role { + case constants.KeyAdminRole: + return WithKeyAdminRole() + case constants.TenantAdminRole: + return WithTenantAdminRole() + case constants.TenantAuditorRole: + return WithAuditorRole() + default: + panic(fmt.Sprintf("unsupported role: %s", role)) + } +} + +// GetAllowedRoles returns roles that have the given resource type + action +// based on the API policy data. +func GetAllowedRoles( + resourceType authz.APIResourceTypeName, + action authz.APIAction, +) []constants.Role { + allowed := make(map[constants.Role]struct{}) + + for _, policy := range authz.PolicyData.Policies { + for _, rt := range policy.ResourceTypes { + if rt.ID != resourceType { + continue + } + + if slices.Contains(rt.Actions, action) { + allowed[policy.Role] = struct{}{} + } + } + } + + var roles []constants.Role + for _, role := range allRoles() { + if _, ok := allowed[role]; ok { + roles = append(roles, role) + } + } + + return roles +} + +// GetBlockedRoles returns roles that do NOT have the given +// resource type + action. +func GetBlockedRoles( + resourceType authz.APIResourceTypeName, + action authz.APIAction, +) []constants.Role { + allowed := GetAllowedRoles(resourceType, action) + allowedSet := make(map[constants.Role]struct{}, len(allowed)) + for _, role := range allowed { + allowedSet[role] = struct{}{} + } + + var blocked []constants.Role + for _, role := range allRoles() { + if _, ok := allowedSet[role]; !ok { + blocked = append(blocked, role) + } + } + + return blocked +} + +// CleanPath returns a sanitized version of the path for use in test names. +func CleanPath(path string) string { + path = strings.ReplaceAll(path, "/", "_") + path = strings.ReplaceAll(path, "?", "_") + path = strings.ReplaceAll(path, "&", "_") + path = strings.ReplaceAll(path, "=", "_") + + if len(path) > 0 && path[0] == '_' { + path = path[1:] + } + + return path +} + +// SubstitutePathParams replaces path parameters (e.g., {keyID}) with dummy +// UUIDs so the path can be used in an HTTP request for route matching. +func SubstitutePathParams(path string) string { + result := path + for strings.Contains(result, "{") { + start := strings.Index(result, "{") + end := strings.Index(result, "}") + + if start == -1 || end == -1 { + break + } + + result = result[:start] + uuid.New().String() + result[end+1:] + } + + return result +}