diff --git a/internal/controllers/cmk/tenant_controller_test.go b/internal/controllers/cmk/tenant_controller_test.go index 4f66dc6c..be0b36f4 100644 --- a/internal/controllers/cmk/tenant_controller_test.go +++ b/internal/controllers/cmk/tenant_controller_test.go @@ -5,13 +5,14 @@ import ( "strings" "testing" - "github.com/google/uuid" + "github.com/openkcm/common-sdk/pkg/auth" "github.com/stretchr/testify/assert" multitenancy "github.com/bartventer/gorm-multitenancy/v8" "github.com/openkcm/cmk/internal/api/cmkapi" "github.com/openkcm/cmk/internal/config" + "github.com/openkcm/cmk/internal/constants" "github.com/openkcm/cmk/internal/model" "github.com/openkcm/cmk/internal/repo" "github.com/openkcm/cmk/internal/repo/sql" @@ -19,23 +20,27 @@ import ( cmkContext "github.com/openkcm/cmk/utils/context" ) -func startAPITenant(t *testing.T) (*multitenancy.DB, cmkapi.ServeMux) { +func startAPITenant(t *testing.T) (*multitenancy.DB, cmkapi.ServeMux, *testutils.TestSigningKeyStorage) { t.Helper() db, _, dbCfg := testutils.NewTestDB(t, testutils.TestDBConfig{ CreateDatabase: true, }, testutils.WithGenerateTenants(10)) + keyStorage := testutils.NewTestSigningKeyStorage(t) return db, testutils.NewAPIServer(t, db, testutils.TestAPIServerConfig{ - Config: config.Config{Database: dbCfg}, - }) + Config: config.Config{Database: dbCfg}, + EnableClientDataMW: true, + SigningKeyStorage: keyStorage, + }), keyStorage } func TestGetTenants(t *testing.T) { - db, sv := startAPITenant(t) + db, sv, keyStorage := startAPITenant(t) r := sql.NewRepository(db) var tenants []model.Tenant + var headers http.Header err := r.List(t.Context(), model.Tenant{}, &tenants, *repo.NewQuery()) assert.NoError(t, err) @@ -55,13 +60,26 @@ func TestGetTenants(t *testing.T) { assert.NoError(t, err) } + clientData := &auth.ClientData{ + Identifier: "user-123", + Email: "bob@example.com", + GivenName: "Bob", + FamilyName: "Builder", + Groups: []string{"sysadmin", "some-other-group"}, + } + // Get private key for signing test requests + privateKey, ok := keyStorage.GetPrivateKey(0) + assert.True(t, ok, "test key should exist") + + headers = testutils.NewSignedClientDataHeadersFromStruct(t, clientData, privateKey, 0) + assert.NotEmpty(t, headers) + t.Run("Should 200 on list tenants", func(t *testing.T) { w := testutils.MakeHTTPRequest(t, sv, testutils.RequestOptions{ Method: http.MethodGet, Endpoint: "/tenants", Tenant: tenants[0].ID, - AdditionalContext: testutils.GetClientMap("test", - []string{"sysadmin", "othergroup"}), + Headers: headers, }) assert.Equal(t, http.StatusOK, w.Code) @@ -74,20 +92,31 @@ func TestGetTenants(t *testing.T) { Method: http.MethodGet, Endpoint: "/tenants", Tenant: "non-existing-tenant-id", - AdditionalContext: testutils.GetClientMap("test", - []string{"sysadmin", "othergroup"}), + Headers: headers, }) assert.Equal(t, http.StatusForbidden, w.Code) }) t.Run("Should 403 on list tenants without permission", func(t *testing.T) { + notAllowedClientData := &auth.ClientData{ + Identifier: "user-123", + Email: "bob@example.com", + GivenName: "Bob", + FamilyName: "Builder", + Groups: []string{"test", "some-other-test-group"}, + } + // Get private key for signing test requests + privateKey, ok := keyStorage.GetPrivateKey(0) + assert.True(t, ok, "test key should exist") + + headersNotAllowed := testutils.NewSignedClientDataHeadersFromStruct(t, notAllowedClientData, privateKey, 0) + w := testutils.MakeHTTPRequest(t, sv, testutils.RequestOptions{ Method: http.MethodGet, Endpoint: "/tenants", Tenant: tenants[0].ID, - AdditionalContext: testutils.GetClientMap("test", - []string{"othergroup"}), + Headers: headersNotAllowed, }) assert.Equal(t, http.StatusForbidden, w.Code) @@ -98,42 +127,68 @@ func TestGetTenants(t *testing.T) { } func TestGetTenantInfo(t *testing.T) { - db, sv := startAPITenant(t) + db, sv, keyStorage := startAPITenant(t) r := sql.NewRepository(db) var tenant model.Tenant + var headers http.Header _, err := r.First(t.Context(), &tenant, *repo.NewQuery()) assert.NoError(t, err) tenantCtx := cmkContext.CreateTenantContext(t.Context(), tenant.ID) - authClient := testutils.NewAuthClient(tenantCtx, t, r, testutils.WithTenantAdminRole()) + tenant.IssuerURL = "https://testissuer.example.com" + _, err = r.Patch(tenantCtx, &tenant, *repo.NewQuery()) + assert.NoError(t, err) group := testutils.NewGroup(func(group *model.Group) { group.IAMIdentifier = "sysadmin" + group.Role = constants.TenantAdminRole }) err = r.Create(tenantCtx, group) assert.NoError(t, err) - t.Run("Should 403 on get tenant info that does not exist", func(t *testing.T) { - w := testutils.MakeHTTPRequest(t, sv, testutils.RequestOptions{ - Method: http.MethodGet, - Endpoint: "/tenantInfo", - Tenant: "nonexistent-tenant-id", - AdditionalContext: authClient.GetClientMap( - testutils.WithAdditionalGroup(uuid.NewString())), - }) - - assert.Equal(t, http.StatusForbidden, w.Code) - }) + clientData := &auth.ClientData{ + Identifier: "user-123", + Email: "bob@example.com", + GivenName: "Bob", + FamilyName: "Builder", + Groups: []string{group.IAMIdentifier, "some-other-test-group"}, + } + // Get private key for signing test requests + privateKey, ok := keyStorage.GetPrivateKey(0) + assert.True(t, ok, "test key should exist") + headers = testutils.NewSignedClientDataHeadersFromStruct(t, clientData, privateKey, 0) + assert.NotEmpty(t, headers) + + clientDataNoGroups := &auth.ClientData{ + Identifier: "user-123", + Email: "bob@example.com", + GivenName: "Bob", + FamilyName: "Builder", + Groups: []string{}, + } + headersNoGroups := testutils.NewSignedClientDataHeadersFromStruct(t, clientDataNoGroups, privateKey, 0) + assert.NotEmpty(t, headersNoGroups) + + clientDataInvalidGroup := &auth.ClientData{ + Identifier: "user-123", + Email: "bob@example.com", + GivenName: "Bob", + FamilyName: "Builder", + Groups: []string{"not-existing-group"}, + } + headersInvalidGroup := testutils.NewSignedClientDataHeadersFromStruct(t, clientDataInvalidGroup, privateKey, 0) + assert.NotEmpty(t, headersInvalidGroup) - t.Run("Should 403 on get tenant info without a user group", func(t *testing.T) { + t.Run("Should 403 on get tenant info that does not exist", func(t *testing.T) { w := testutils.MakeHTTPRequest(t, sv, testutils.RequestOptions{ Method: http.MethodGet, Endpoint: "/tenantInfo", Tenant: "nonexistent-tenant-id", + Headers: headers, }) assert.Equal(t, http.StatusForbidden, w.Code) @@ -144,12 +199,12 @@ func TestGetTenantInfo(t *testing.T) { Method: http.MethodGet, Endpoint: "/tenantInfo", Tenant: tenant.ID, - AdditionalContext: authClient.GetClientMap( - testutils.WithAdditionalGroup(uuid.NewString())), + Headers: headers, }) assert.Equal(t, http.StatusOK, w.Code) resp := testutils.GetJSONBody[cmkapi.Tenant](t, w) + assert.NotNil(t, resp.Id) assert.Equal(t, tenant.ID, *resp.Id) assert.NotNil(t, resp.Role) expectedRole := strings.TrimPrefix(string(tenant.Role), "ROLE_") @@ -157,22 +212,33 @@ func TestGetTenantInfo(t *testing.T) { assert.Equal(t, tenant.Name, resp.Name) }) - t.Run("Should 403 on get tenant by valid ID and no client data", func(t *testing.T) { + t.Run("Should 403 on get tenant info without a user group", func(t *testing.T) { w := testutils.MakeHTTPRequest(t, sv, testutils.RequestOptions{ Method: http.MethodGet, Endpoint: "/tenantInfo", Tenant: tenant.ID, + Headers: headersNoGroups, }) assert.Equal(t, http.StatusForbidden, w.Code) }) + t.Run("Should 500 on get tenant by valid ID and no client data", func(t *testing.T) { + w := testutils.MakeHTTPRequest(t, sv, testutils.RequestOptions{ + Method: http.MethodGet, + Endpoint: "/tenantInfo", + Tenant: tenant.ID, + }) + + assert.Equal(t, http.StatusInternalServerError, w.Code) + }) + t.Run("Should 403 on get tenant by valid ID and no valid group", func(t *testing.T) { w := testutils.MakeHTTPRequest(t, sv, testutils.RequestOptions{ - Method: http.MethodGet, - Endpoint: "/tenantInfo", - Tenant: tenant.ID, - AdditionalContext: authClient.GetClientMap(testutils.WithOverriddenGroup(1)), + Method: http.MethodGet, + Endpoint: "/tenantInfo", + Tenant: tenant.ID, + Headers: headersInvalidGroup, }) assert.Equal(t, http.StatusForbidden, w.Code) diff --git a/internal/controllers/cmk/userinfo_controller_test.go b/internal/controllers/cmk/userinfo_controller_test.go index f7bad66d..d225f898 100644 --- a/internal/controllers/cmk/userinfo_controller_test.go +++ b/internal/controllers/cmk/userinfo_controller_test.go @@ -11,48 +11,56 @@ import ( "github.com/openkcm/cmk/internal/api/cmkapi" "github.com/openkcm/cmk/internal/config" - "github.com/openkcm/cmk/internal/constants" "github.com/openkcm/cmk/internal/model" "github.com/openkcm/cmk/internal/repo/sql" "github.com/openkcm/cmk/internal/testutils" cmkcontext "github.com/openkcm/cmk/utils/context" ) -func startAPIUserInfo(t *testing.T) (*multitenancy.DB, cmkapi.ServeMux, string) { +func startAPIUserInfo(t *testing.T) (*multitenancy.DB, cmkapi.ServeMux, string, *testutils.TestSigningKeyStorage) { t.Helper() db, tenants, dbCfg := testutils.NewTestDB(t, testutils.TestDBConfig{ CreateDatabase: true, }) + keyStorage := testutils.NewTestSigningKeyStorage(t) + return db, testutils.NewAPIServer(t, db, testutils.TestAPIServerConfig{ - Config: config.Config{Database: dbCfg}, - }), tenants[0] + Config: config.Config{Database: dbCfg}, + EnableClientDataMW: true, + SigningKeyStorage: keyStorage, + }), tenants[0], keyStorage } func TestGetUserInfo(t *testing.T) { - db, sv, tenant := startAPIUserInfo(t) + db, sv, tenant, keyStorage := startAPIUserInfo(t) r := sql.NewRepository(db) ctx := cmkcontext.CreateTenantContext(t.Context(), tenant) + // Get private key for signing test requests + privateKey, ok := keyStorage.GetPrivateKey(0) + assert.True(t, ok, "test key should exist") + t.Run("Should 200 on get user info with good client data", func(t *testing.T) { group := testutils.NewGroup(func(_ *model.Group) {}) testutils.CreateTestEntities(ctx, t, r, group) + clientData := &auth.ClientData{ + Identifier: "user-123", + Email: "bob@example.com", + GivenName: "Bob", + FamilyName: "Builder", + Groups: []string{group.IAMIdentifier, "some-other-group"}, + } + headers := testutils.NewSignedClientDataHeadersFromStruct(t, clientData, privateKey, 0) + w := testutils.MakeHTTPRequest(t, sv, testutils.RequestOptions{ Method: http.MethodGet, Endpoint: "/userInfo", Tenant: tenant, - AdditionalContext: map[any]any{ - constants.ClientData: &auth.ClientData{ - Identifier: "user-123", - Email: "bob@example.com", - GivenName: "Bob", - FamilyName: "Builder", - Groups: []string{group.IAMIdentifier, "some-other-group"}, - }, - }, + Headers: headers, }) assert.Equal(t, http.StatusOK, w.Code) @@ -66,19 +74,20 @@ func TestGetUserInfo(t *testing.T) { }) t.Run("Should 200 on get user info without group", func(t *testing.T) { + clientData := &auth.ClientData{ + Identifier: "user-123", + Email: "bob@example.com", + GivenName: "Bob", + FamilyName: "Builder", + Groups: []string{"some-other-group"}, + } + headers := testutils.NewSignedClientDataHeadersFromStruct(t, clientData, privateKey, 0) + w := testutils.MakeHTTPRequest(t, sv, testutils.RequestOptions{ Method: http.MethodGet, Endpoint: "/userInfo", Tenant: tenant, - AdditionalContext: map[any]any{ - constants.ClientData: &auth.ClientData{ - Identifier: "user-123", - Email: "bob@example.com", - GivenName: "Bob", - FamilyName: "Builder", - Groups: []string{"some-other-group"}, - }, - }, + Headers: headers, }) assert.Equal(t, http.StatusOK, w.Code) diff --git a/internal/testutils/api.go b/internal/testutils/api.go index 3b77460c..310b9c38 100644 --- a/internal/testutils/api.go +++ b/internal/testutils/api.go @@ -14,6 +14,7 @@ import ( "github.com/getkin/kin-openapi/openapi3filter" "github.com/openkcm/common-sdk/pkg/commoncfg" "github.com/openkcm/common-sdk/pkg/commongrpc" + "github.com/openkcm/common-sdk/pkg/storage/keyvalue" "github.com/openkcm/plugin-sdk/pkg/catalog" "github.com/stretchr/testify/assert" @@ -40,9 +41,11 @@ const TestCertURL = "https://aia.pki.co.test.com/aia/TEST%20Cloud%20Root%20CA.cr const TestHostPrefix = "https://kms.test/cmk/v1/" type TestAPIServerConfig struct { - Plugins []catalog.BuiltInPlugin // Plugins only set if needed - GRPCCon *commongrpc.DynamicClientConn // GRPCClient only set if needed - Config config.Config + Plugins []catalog.BuiltInPlugin // Plugins only set if needed + GRPCCon *commongrpc.DynamicClientConn // GRPCClient only set if needed + Config config.Config + EnableClientDataMW bool // Enable ClientDataMiddleware (default: false for backward compatibility) + SigningKeyStorage keyvalue.ReadOnlyStringToBytesStorage // Optional: provide custom signing key storage } // NewAPIServer creates a new API server with the given database connection @@ -105,12 +108,14 @@ func NewAPIServer( controller := cmk.NewAPIController(tb.Context(), authzRepo, &cfg, factory, migrator, svcRegistry, authzAPILoader) - return startAPIServer(tb, controller) + return startAPIServer(tb, controller, testCfg) } +//nolint:funlen func startAPIServer( tb testing.TB, controller *cmk.APIController, + testCfg TestAPIServerConfig, ) cmkapi.ServeMux { tb.Helper() @@ -141,11 +146,38 @@ func startAPIServer( IncludeResponseStatus: true, }, }), + } + + // Middlewares are applied from last to first. + // Keep Authz before ClientData in the slice so ClientData runs first at request time. + mws = append(mws, middleware.AuthzMiddleware(controller), middleware.LoggingMiddleware(), middleware.PanicRecoveryMiddleware(), middleware.InjectMultiTenancy(), middleware.InjectRequestID(), + ) + + // Add ClientDataMiddleware if enabled. + // It must be appended after Authz in the slice so it runs before Authz. + if testCfg.EnableClientDataMW { + signingKeyStorage := testCfg.SigningKeyStorage + if signingKeyStorage == nil { + // Create default test signing key storage if not provided + signingKeyStorage = NewTestSigningKeyStorage(tb) + } + + // Default auth context fields for testing + authContextFields := []string{"client_id", "issuer", "multitenancy_ref"} + + // Use test role getter + roleGetter := NewTestRoleGetter() + + mws = append(mws, middleware.ClientDataMiddleware( + signingKeyStorage, + authContextFields, + roleGetter, + )) } cmkapi.HandlerWithOptions(strictController, @@ -180,8 +212,8 @@ type RequestOptions struct { Endpoint string Tenant string // TenantID Body io.Reader // Only need to be set for POST/PATCH. Used with the WithString and WithJSON - Headers map[string]string - AdditionalContext map[any]any + Headers http.Header + AdditionalContext map[any]any // Deprecated: Use Headers with signed client data instead } // WithString is a helper function that converts a string to an io.Reader. @@ -222,14 +254,20 @@ func GetJSONBody[t any](tb testing.TB, w *httptest.ResponseRecorder) t { } // NewHTTPRequest builds an HTTP Request it sets default content-types for certain Methods +// +//nolint:cyclop func NewHTTPRequest(tb testing.TB, opt RequestOptions) *http.Request { tb.Helper() ctx := tb.Context() - //nolint: fatcontext - for k, v := range opt.AdditionalContext { - ctx = context.WithValue(ctx, k, v) + // Legacy support: inject AdditionalContext if provided and ClientDataMiddleware is not enabled + // When ClientDataMiddleware is enabled, AdditionalContext is ignored in favor of Headers + if len(opt.AdditionalContext) > 0 && opt.Headers == nil { + //nolint: fatcontext + for k, v := range opt.AdditionalContext { + ctx = context.WithValue(ctx, k, v) + } } r, err := http.NewRequestWithContext( @@ -253,8 +291,13 @@ func NewHTTPRequest(tb testing.TB, opt RequestOptions) *http.Request { assert.Fail(tb, "HTTP Method not supported!") } - for k, v := range opt.Headers { - r.Header.Add(k, v) + // Apply provided headers + if opt.Headers != nil { + for key, values := range opt.Headers { + for _, value := range values { + r.Header.Add(key, value) + } + } } return r diff --git a/internal/testutils/authz.go b/internal/testutils/authz.go index d077d36a..4a5cc76b 100644 --- a/internal/testutils/authz.go +++ b/internal/testutils/authz.go @@ -2,6 +2,9 @@ package testutils import ( "context" + "crypto/rsa" + "net/http" + "strconv" "testing" "github.com/google/uuid" @@ -172,3 +175,118 @@ func getClientData(identifier string, groupNames []string) *auth.ClientData { Groups: groupNames, } } + +// NewSignedClientDataHeaders generates HTTP headers with signed client data for testing +// This creates the x-client-data and x-client-data-signature headers that ClientDataMiddleware expects +// Uses RS256 algorithm (RSA + SHA-256) for signing +func NewSignedClientDataHeaders( + tb testing.TB, + clientData map[string]any, + privateKey *rsa.PrivateKey, + keyID int, +) http.Header { + tb.Helper() + + // Convert map to auth.ClientData struct + cd := &auth.ClientData{ + Identifier: getString(clientData, "identifier"), + Type: getString(clientData, "type"), + Email: getString(clientData, "email"), + Region: getString(clientData, "region"), + Groups: getStringSlice(clientData, "groups"), + KeyID: strconv.Itoa(keyID), + SignatureAlgorithm: auth.SignatureAlgorithmRS256, + AuthContext: getStringMap(clientData, "authContext"), + } + + // Generate signed headers using the auth package + clientDataHeader, signatureHeader, err := cd.Encode(privateKey) + if err != nil { + tb.Fatalf("Failed to encode and sign client data: %v", err) + } + + // Create HTTP headers + headers := http.Header{} + headers.Set(auth.HeaderClientData, clientDataHeader) + headers.Set(auth.HeaderClientDataSignature, signatureHeader) + + return headers +} + +// NewSignedClientDataHeadersFromStruct generates HTTP headers from an auth.ClientData struct +// This is a convenience function for tests that already have ClientData objects +func NewSignedClientDataHeadersFromStruct( + tb testing.TB, + clientData *auth.ClientData, + privateKey *rsa.PrivateKey, + keyID int, +) http.Header { + tb.Helper() + + // Set required fields for signing + clientData.KeyID = strconv.Itoa(keyID) + clientData.SignatureAlgorithm = auth.SignatureAlgorithmRS256 + + // Generate signed headers using the auth package + clientDataHeader, signatureHeader, err := clientData.Encode(privateKey) + if err != nil { + tb.Fatalf("Failed to encode and sign client data: %v", err) + } + + // Create HTTP headers + headers := http.Header{} + headers.Set(auth.HeaderClientData, clientDataHeader) + headers.Set(auth.HeaderClientDataSignature, signatureHeader) + + return headers +} + +// Helper to get string from map, returns empty string if not found or wrong type +func getString(m map[string]any, key string) string { + if val, ok := m[key]; ok { + if strVal, ok := val.(string); ok { + return strVal + } + } + return "" +} + +// Helper to get string slice from map, returns empty slice if not found or wrong type +func getStringSlice(m map[string]any, key string) []string { + if val, ok := m[key]; ok { + if sliceVal, ok := val.([]string); ok { + return sliceVal + } + // Handle []interface{} case + if ifaceSlice, ok := val.([]any); ok { + result := make([]string, 0, len(ifaceSlice)) + for _, item := range ifaceSlice { + if strVal, ok := item.(string); ok { + result = append(result, strVal) + } + } + return result + } + } + return []string{} +} + +// Helper to get string map from map, returns empty map if not found or wrong type +func getStringMap(m map[string]any, key string) map[string]string { + if val, ok := m[key]; ok { + if mapVal, ok := val.(map[string]string); ok { + return mapVal + } + // Handle map[string]interface{} case + if ifaceMap, ok := val.(map[string]any); ok { + result := make(map[string]string) + for k, v := range ifaceMap { + if strVal, ok := v.(string); ok { + result[k] = strVal + } + } + return result + } + } + return map[string]string{} +} diff --git a/internal/testutils/authz_test.go b/internal/testutils/authz_test.go new file mode 100644 index 00000000..9de4ff39 --- /dev/null +++ b/internal/testutils/authz_test.go @@ -0,0 +1,174 @@ +//nolint:testpackage +package testutils + +import ( + "strconv" + "testing" + + "github.com/openkcm/common-sdk/pkg/auth" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/openkcm/cmk/internal/constants" + "github.com/openkcm/cmk/internal/model" +) + +func TestAuthClientData_GetClientMapAppliesOptions(t *testing.T) { + group := NewGroup(func(g *model.Group) { + g.IAMIdentifier = "group-a" + }) + authClient := AuthClientData{ + Group: group, + GroupID: group.ID.String(), + Identifier: "user-a", + } + + clientMap := authClient.GetClientMap( + WithAdditionalGroup("group-b"), + WithOverriddenIdentifier("user-b"), + ) + + clientData, ok := clientMap[constants.ClientData].(*auth.ClientData) + require.True(t, ok) + require.NotNil(t, clientData) + assert.Equal(t, "user-b", clientData.Identifier) + assert.Equal(t, []string{"group-a", "group-b"}, clientData.Groups) +} + +func TestWithOverriddenGroupGeneratesRequestedCount(t *testing.T) { + authClient := AuthClientData{ + Group: NewGroup(func(g *model.Group) { + g.IAMIdentifier = "seed-group" + }), + Identifier: "seed-user", + } + + clientMap := authClient.GetClientMap(WithOverriddenGroup(3)) + clientData, ok := clientMap[constants.ClientData].(*auth.ClientData) + require.True(t, ok) + require.NotNil(t, clientData) + require.Len(t, clientData.Groups, 3) + + for _, g := range clientData.Groups { + assert.NotEmpty(t, g) + } +} + +func TestAuthClientOptionsAndFactory(t *testing.T) { + t.Run("roles", func(t *testing.T) { + auditor := newAuthClient(WithAuditorRole()) + assert.Equal(t, constants.TenantAuditorRole, auditor.Group.Role) + + keyAdmin := newAuthClient(WithKeyAdminRole()) + assert.Equal(t, constants.KeyAdminRole, keyAdmin.Group.Role) + + tenantAdmin := newAuthClient(WithTenantAdminRole()) + assert.Equal(t, constants.TenantAdminRole, tenantAdmin.Group.Role) + }) + + t.Run("identifier", func(t *testing.T) { + custom := newAuthClient(WithIdentifier("custom-id")) + assert.Equal(t, "custom-id", custom.Identifier) + }) +} + +func TestWithAuthClientDataKC(t *testing.T) { + authClient := newAuthClient(WithTenantAdminRole()) + kc := &model.KeyConfiguration{} + + WithAuthClientDataKC(authClient)(kc) + + assert.Equal(t, authClient.Group.ID, kc.AdminGroupID) + assert.Equal(t, authClient.Group.ID, kc.AdminGroup.ID) + assert.Equal(t, authClient.Group.IAMIdentifier, kc.AdminGroup.IAMIdentifier) +} + +func TestClientMapHelpers(t *testing.T) { + clientMap := GetClientMap("id-a", []string{"g1", "g2"}) + clientData, ok := clientMap[constants.ClientData].(*auth.ClientData) + require.True(t, ok) + require.NotNil(t, clientData) + assert.Equal(t, "id-a", clientData.Identifier) + assert.Equal(t, []string{"g1", "g2"}, clientData.Groups) + + groupless := GetGrouplessClientMap() + grouplessData, ok := groupless[constants.ClientData].(*auth.ClientData) + require.True(t, ok) + require.NotNil(t, grouplessData) + assert.Empty(t, grouplessData.Groups) + assert.NotEmpty(t, grouplessData.Identifier) + + invalid := GetInvalidClientMap() + invalidData, ok := invalid[constants.ClientData].(*auth.ClientData) + require.True(t, ok) + require.NotNil(t, invalidData) + assert.NotEmpty(t, invalidData.Identifier) + assert.Len(t, invalidData.Groups, 2) +} + +func TestNewSignedClientDataHeaders(t *testing.T) { + privateKey, _, err := GenerateTestKeyPair() + require.NoError(t, err) + + input := map[string]any{ + "identifier": "user-1", + "type": "user", + "email": "user@example.com", + "region": "eu", + "groups": []any{"g1", "g2"}, + "authContext": map[string]any{ + "issuer": "issuer-1", + "bad": 42, + }, + } + + headers := NewSignedClientDataHeaders(t, input, privateKey, 7) + clientDataHeader := headers.Get(auth.HeaderClientData) + signature := headers.Get(auth.HeaderClientDataSignature) + require.NotEmpty(t, clientDataHeader) + require.NotEmpty(t, signature) + + decoded, err := auth.DecodeFrom(clientDataHeader) + require.NoError(t, err) + require.NotNil(t, decoded) + assert.Equal(t, "user-1", decoded.Identifier) + assert.Equal(t, []string{"g1", "g2"}, decoded.Groups) + assert.Equal(t, "7", decoded.KeyID) + assert.Equal(t, auth.SignatureAlgorithmRS256, decoded.SignatureAlgorithm) + assert.Equal(t, map[string]string{"issuer": "issuer-1"}, decoded.AuthContext) + + err = decoded.Verify(&privateKey.PublicKey, signature) + require.NoError(t, err) +} + +func TestNewSignedClientDataHeadersFromStructMutatesAndSigns(t *testing.T) { + privateKey, _, err := GenerateTestKeyPair() + require.NoError(t, err) + + clientData := &auth.ClientData{ + Identifier: "user-2", + Groups: []string{"g1"}, + } + + headers := NewSignedClientDataHeadersFromStruct(t, clientData, privateKey, 2) + require.NotEmpty(t, headers.Get(auth.HeaderClientData)) + require.NotEmpty(t, headers.Get(auth.HeaderClientDataSignature)) + assert.Equal(t, strconv.Itoa(2), clientData.KeyID) + assert.Equal(t, auth.SignatureAlgorithmRS256, clientData.SignatureAlgorithm) +} + +func TestMapParsingHelpers(t *testing.T) { + assert.Equal(t, "ok", getString(map[string]any{"k": "ok"}, "k")) + assert.Empty(t, getString(map[string]any{"k": 1}, "k")) + assert.Empty(t, getString(map[string]any{}, "missing")) + + assert.Equal(t, []string{"a"}, getStringSlice(map[string]any{"k": []string{"a"}}, "k")) + assert.Equal(t, []string{"a", "b"}, getStringSlice(map[string]any{"k": []any{"a", "b", 1}}, "k")) + assert.Equal(t, []string{}, getStringSlice(map[string]any{"k": 10}, "k")) + assert.Equal(t, []string{}, getStringSlice(map[string]any{}, "missing")) + + assert.Equal(t, map[string]string{"a": "b"}, getStringMap(map[string]any{"k": map[string]string{"a": "b"}}, "k")) + assert.Equal(t, map[string]string{"a": "b"}, getStringMap(map[string]any{"k": map[string]any{"a": "b", "x": 1}}, "k")) + assert.Equal(t, map[string]string{}, getStringMap(map[string]any{"k": 10}, "k")) + assert.Equal(t, map[string]string{}, getStringMap(map[string]any{}, "missing")) +} diff --git a/internal/testutils/keys.go b/internal/testutils/keys.go new file mode 100644 index 00000000..1faad793 --- /dev/null +++ b/internal/testutils/keys.go @@ -0,0 +1,142 @@ +package testutils + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "os" + "path/filepath" + "testing" + + "github.com/openkcm/common-sdk/pkg/commonfs/loader" + "github.com/openkcm/common-sdk/pkg/storage/keyvalue" + "github.com/stretchr/testify/require" + + "github.com/openkcm/cmk/internal/constants" +) + +// GenerateTestKeyPair generates an RSA key pair for testing +// Returns a 2048-bit RSA private key suitable for RS256 signing +func GenerateTestKeyPair() (*rsa.PrivateKey, *rsa.PublicKey, error) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, nil, err + } + return privateKey, &privateKey.PublicKey, nil +} + +// TestSigningKeyStorage holds test signing keys in memory +type TestSigningKeyStorage struct { + storage keyvalue.ReadOnlyStringToBytesStorage + loader *loader.Loader + tempDir string + privateKeys map[int]*rsa.PrivateKey + cleanupFunc func() +} + +// Get retrieves a public key by ID +// Returns (value, found) to match keyvalue.ReadStorage interface +func (t *TestSigningKeyStorage) Get(keyID string) ([]byte, bool) { + return t.storage.Get(keyID) +} + +// IsEmpty returns whether the storage is empty +func (t *TestSigningKeyStorage) IsEmpty() bool { + return t.storage.IsEmpty() +} + +// List returns all key IDs in the storage +func (t *TestSigningKeyStorage) List() []string { + return t.storage.List() +} + +// GetPrivateKey retrieves a private key by ID for test signing +func (t *TestSigningKeyStorage) GetPrivateKey(keyID int) (*rsa.PrivateKey, bool) { + key, ok := t.privateKeys[keyID] + return key, ok +} + +// Cleanup stops the loader and removes temporary files +func (t *TestSigningKeyStorage) Cleanup() { + if t.cleanupFunc != nil { + t.cleanupFunc() + } +} + +// NewTestSigningKeyStorage creates a signing key storage with pre-generated test keys. +// Generates 1 key pair (keyID 0). Tests can opt into rotation scenarios elsewhere if needed. +// Returns storage that implements keyvalue.ReadOnlyStringToBytesStorage interface +// +//nolint:funcorder +func NewTestSigningKeyStorage(tb testing.TB) *TestSigningKeyStorage { + tb.Helper() + + tmpDir := tb.TempDir() + privateKeys := make(map[int]*rsa.PrivateKey) + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(tb, err, "failed to generate private key") + + privateKeys[0] = privateKey + + // Write public key to PEM file + pubASN1, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey) + require.NoError(tb, err, "failed to marshal public key") + + pubPEM := pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubASN1}) + keyFile := filepath.Join(tmpDir, "0.pem") + + err = os.WriteFile(keyFile, pubPEM, 0o600) + require.NoError(tb, err, "failed to write public key file") + + // Create memory storage and loader for public keys + memoryStorage := keyvalue.NewMemoryStorage[string, []byte]() + signingKeysLoader, err := loader.Create( + loader.OnPath(tmpDir), + loader.WithExtension("pem"), + loader.WithKeyIDType(loader.FileNameWithoutExtension), + loader.WithStorage(memoryStorage), + ) + require.NoError(tb, err, "failed to create signing keys loader") + + err = signingKeysLoader.Start() + require.NoError(tb, err, "failed to load signing keys") + + storage := &TestSigningKeyStorage{ + storage: memoryStorage, + loader: signingKeysLoader, + tempDir: tmpDir, + privateKeys: privateKeys, + cleanupFunc: func() { + _ = signingKeysLoader.Close() + }, + } + + tb.Cleanup(storage.Cleanup) + + return storage +} + +// TestRoleGetter is a mock RoleGetter for testing that always returns a default role +type TestRoleGetter struct { + DefaultRole constants.Role +} + +// GetRoleFromIAM returns the configured default role (or TenantAdminRole if not set) +func (t *TestRoleGetter) GetRoleFromIAM(ctx context.Context, iamIdentifiers []string) (constants.Role, error) { + if t.DefaultRole != "" { + return t.DefaultRole, nil + } + return constants.TenantAdminRole, nil +} + +// NewTestRoleGetter creates a TestRoleGetter with the default role set to TenantAdminRole. +// +//nolint:funcorder +func NewTestRoleGetter() *TestRoleGetter { + return &TestRoleGetter{ + DefaultRole: constants.TenantAdminRole, + } +}