diff --git a/charts/session-manager/Chart.yaml b/charts/session-manager/Chart.yaml index 2218150..ba46793 100644 --- a/charts/session-manager/Chart.yaml +++ b/charts/session-manager/Chart.yaml @@ -15,7 +15,7 @@ type: application # This is the chart version. This version number should be incremented each time you make changes # to the chart and its templates, including the app version. # Versions are expected to follow Semantic Versioning (https://semver.org/) -version: 0.9.4 +version: 0.9.5 # This is the version number of the application being deployed. This version number should be # incremented each time you make changes to the application. Versions are not expected to diff --git a/charts/session-manager/values-dev.yaml b/charts/session-manager/values-dev.yaml index d9d30d3..f2b0d5d 100644 --- a/charts/session-manager/values-dev.yaml +++ b/charts/session-manager/values-dev.yaml @@ -313,11 +313,14 @@ config: idleSessionTimeout: 90m callbackURL: http://localhost:8080/sm/callback clientAuth: - clientID: "client-id" type: client_secret clientSecret: source: embedded value: secret + # Deprecated: ClientID is no longer used in the application code, but is still required in the config + # to backfill existing database entries during migration. + # It will be removed in a future release once the migration has been applied in all environments. + clientID: "client-id" sessionCookieTemplate: name: "__Host-Http-SESSION" path: "/" diff --git a/charts/session-manager/values.yaml b/charts/session-manager/values.yaml index af222f8..fb9f58d 100644 --- a/charts/session-manager/values.yaml +++ b/charts/session-manager/values.yaml @@ -322,11 +322,14 @@ config: idleSessionTimeout: 90m callbackURL: http://localhost:8080/sm/callback clientAuth: - clientID: "client-id" type: client_secret clientSecret: source: embedded value: secret + # Deprecated: ClientID is no longer used in the application code, but is still required in the config + # to backfill existing database entries during migration. + # It will be removed in a future release once the migration has been applied in all environments. + clientID: "client-id" sessionCookieTemplate: name: "__Host-Http-SESSION" path: "/" diff --git a/config.yaml b/config.yaml index b289f85..c9878f7 100644 --- a/config.yaml +++ b/config.yaml @@ -222,6 +222,9 @@ sessionManager: idleSessionTimeout: 90m callbackURL: http://localhost:8080/sm/callback clientAuth: + # Deprecated: ClientID is no longer used in the application code, but is still required in the config + # to backfill existing database entries during migration. + # It will be removed in a future release once the migration has been applied in all environments. clientID: "client-id" type: client_secret clientSecret: diff --git a/helm-tests/integration/helm-install_test.go b/helm-tests/integration/helm-install_test.go index 879024a..6d18121 100644 --- a/helm-tests/integration/helm-install_test.go +++ b/helm-tests/integration/helm-install_test.go @@ -76,7 +76,6 @@ func TestHelmInstall(t *testing.T) { "config.valkey.host.value": "valkey.default.svc.cluster.local:6379", "config.valkey.password.value": "", "config.sessionManager.callbackURL": "http://localhost:8080/sm/callback", - "config.sessionManager.clientAuth.clientID": "test-client", "config.sessionManager.clientAuth.clientSecret.value": "test-secret", "config.sessionManager.csrfSecret.value": "test-csrf-secret-at-least-thirty-two-bits", }, diff --git a/integration/grpc_test.go b/integration/grpc_test.go index d80d275..447d961 100644 --- a/integration/grpc_test.go +++ b/integration/grpc_test.go @@ -327,7 +327,7 @@ func startServer(t *testing.T, port int) (*stdgrpc.Server, *trust.Service, func( srv := stdgrpc.NewServer() oidcmappingv1.RegisterServiceServer(srv, grpc.NewOIDCMappingServer(service)) - sessionv1.RegisterServiceServer(srv, grpc.NewSessionServer(ctx, nil, trustRepo, time.Hour, "")) + sessionv1.RegisterServiceServer(srv, grpc.NewSessionServer(ctx, nil, trustRepo, time.Hour)) // start go func() { diff --git a/integration/session_grpc_test.go b/integration/session_grpc_test.go index 7a235e4..9a775f5 100644 --- a/integration/session_grpc_test.go +++ b/integration/session_grpc_test.go @@ -199,7 +199,7 @@ func startSessionServer(t *testing.T, port int) (*stdgrpc.Server, session.Reposi } srv := stdgrpc.NewServer() - sessionv1.RegisterServiceServer(srv, grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "")) + sessionv1.RegisterServiceServer(srv, grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute)) // start go func() { diff --git a/internal/business/business.go b/internal/business/business.go index a387e2e..01ed3ad 100644 --- a/internal/business/business.go +++ b/internal/business/business.go @@ -109,7 +109,6 @@ func internalMain(ctx context.Context, cfg *config.Config) error { sessionRepo, trustRepo, cfg.SessionManager.IdleSessionTimeout, - cfg.SessionManager.ClientAuth.ClientID, grpc.WithQueryParametersIntrospect(cfg.SessionManager.AdditionalQueryParametersIntrospect), grpc.WithTransportCredentials(credsBuilder), ) @@ -225,7 +224,9 @@ func newCredsBuilder(cfg *config.Config) (credentials.Builder, error) { return nil, fmt.Errorf("failed to load mTLS config: %w", err) } - return func(clientID string) credentials.TransportCredentials { return credentials.NewTLS(clientID, tlsConfig) }, nil + return func(clientID string) credentials.TransportCredentials { + return credentials.NewTLS(clientID, tlsConfig) + }, nil case "client_secret", "client_secret_post": secret, err := commoncfg.LoadValueFromSourceRef(cfg.SessionManager.ClientAuth.ClientSecret) if err != nil { @@ -237,7 +238,9 @@ func newCredsBuilder(cfg *config.Config) (credentials.Builder, error) { }, nil case "insecure": slog.Warn("insecure credentials are used. Do not use this in production") - return func(clientID string) credentials.TransportCredentials { return credentials.NewInsecure(clientID) }, nil + return func(clientID string) credentials.TransportCredentials { + return credentials.NewInsecure(clientID) + }, nil default: return nil, errors.New("unknown Client Auth type") } diff --git a/internal/business/business_test.go b/internal/business/business_test.go index da934f6..3b96efb 100644 --- a/internal/business/business_test.go +++ b/internal/business/business_test.go @@ -20,8 +20,7 @@ func TestLoadHTTPClient_MTLS(t *testing.T) { cfg := &config.Config{ SessionManager: config.SessionManager{ ClientAuth: config.ClientAuth{ - Type: "mtls", - ClientID: "test-client", + Type: "mtls", MTLS: &commoncfg.MTLS{ Cert: commoncfg.SourceRef{File: commoncfg.CredentialFile{Path: "/nonexistent/cert.pem"}}, CertKey: commoncfg.SourceRef{File: commoncfg.CredentialFile{Path: "/nonexistent/key.pem"}}, @@ -42,7 +41,6 @@ func TestLoadHTTPClient_ClientSecret(t *testing.T) { SessionManager: config.SessionManager{ ClientAuth: config.ClientAuth{ Type: "client_secret", - ClientID: "test-client", ClientSecret: commoncfg.SourceRef{Source: "embedded", Value: "test-secret"}, }, }, @@ -53,7 +51,7 @@ func TestLoadHTTPClient_ClientSecret(t *testing.T) { require.NotNil(t, builder) // Verify it's using our custom transport - creds := builder(cfg.SessionManager.ClientAuth.ClientID) + creds := builder("test-client") clientSecretCreds, ok := creds.(*credentials.ClientSecretPost) require.True(t, ok) @@ -65,8 +63,7 @@ func TestLoadHTTPClient_Insecure(t *testing.T) { cfg := &config.Config{ SessionManager: config.SessionManager{ ClientAuth: config.ClientAuth{ - Type: "insecure", - ClientID: "test-client", + Type: "insecure", }, }, } @@ -80,8 +77,7 @@ func TestLoadHTTPClient_UnknownType(t *testing.T) { cfg := &config.Config{ SessionManager: config.SessionManager{ ClientAuth: config.ClientAuth{ - Type: "unknown", - ClientID: "test-client", + Type: "unknown", }, }, } diff --git a/internal/business/server/grpc_server_test.go b/internal/business/server/grpc_server_test.go index 1044085..8ea000d 100644 --- a/internal/business/server/grpc_server_test.go +++ b/internal/business/server/grpc_server_test.go @@ -27,7 +27,7 @@ func TestStartGRPCServer_ContextCancellation(t *testing.T) { // Create minimal server instances oidcmappingsrv := grpc.NewOIDCMappingServer(nil) - sessionsrv := grpc.NewSessionServer(ctx, nil, nil, 0, "") + sessionsrv := grpc.NewSessionServer(ctx, nil, nil, 0) // Start the server in a goroutine errChan := make(chan error, 1) diff --git a/internal/config/config.go b/internal/config/config.go index 4d5cd63..6c223b7 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -113,7 +113,6 @@ type CookieTemplate struct { } type ClientAuth struct { - ClientID string `yaml:"clientID"` // Type defines how to authenticate the client. // Supported types are: // - mtls: Mutual TLS authentication @@ -123,6 +122,11 @@ type ClientAuth struct { MTLS *commoncfg.MTLS `yaml:"mTLS"` // ClientSecret contains the client secret source reference when Type is set to "clientSecret". ClientSecret commoncfg.SourceRef `yaml:"clientSecret"` + + // Deprecated: ClientID is no longer used in the application code, but is still required in the config + // to backfill existing database entries during migration. + // It will be removed in a future release once the migration has been applied in all environments. + ClientID string `yaml:"clientID"` } type Migrate struct { diff --git a/internal/dbtest/postgrestest/postgres.go b/internal/dbtest/postgrestest/postgres.go index 5c836f5..3ae4d7f 100644 --- a/internal/dbtest/postgrestest/postgres.go +++ b/internal/dbtest/postgrestest/postgres.go @@ -5,6 +5,8 @@ import ( "database/sql" "fmt" "log/slog" + "os" + "path/filepath" "time" "github.com/jackc/pgx/v5" @@ -89,6 +91,10 @@ func makeDBConn(ctx context.Context, port network.Port) *pgxpool.Pool { } func migrateDB(ctx context.Context, port network.Port) { + // Create a test config file for the migration + configCleanup := createTestConfig() + defer configCleanup() + db, err := sql.Open("pgx", connStr(port)) if err != nil { panic(err) @@ -122,3 +128,44 @@ func prepareDB(ctx context.Context, dbPool *pgxpool.Pool, port network.Port) { panic(err) } } + +// createTestConfig creates a temporary config file for the migration tests. +// It returns a cleanup function that should be called to remove the config file. +func createTestConfig() func() { + // Create a temporary directory for the config + tmpDir, err := os.MkdirTemp("", "session-manager-test-*") + if err != nil { + panic(fmt.Sprintf("Failed to create temp dir: %v", err)) + } + + // Create config.yaml with test client_id + configContent := `sessionManager: + clientAuth: + clientID: "test-client-id" +` + configPath := filepath.Join(tmpDir, "config.yaml") + err = os.WriteFile(configPath, []byte(configContent), 0644) + if err != nil { + os.RemoveAll(tmpDir) + panic(fmt.Sprintf("Failed to write config file: %v", err)) + } + + // Change to the temp directory so the config loader can find it + originalDir, err := os.Getwd() + if err != nil { + os.RemoveAll(tmpDir) + panic(fmt.Sprintf("Failed to get current directory: %v", err)) + } + + err = os.Chdir(tmpDir) + if err != nil { + os.RemoveAll(tmpDir) + panic(fmt.Sprintf("Failed to change directory: %v", err)) + } + + // Return cleanup function + return func() { + _ = os.Chdir(originalDir) + os.RemoveAll(tmpDir) + } +} diff --git a/internal/grpc/session.go b/internal/grpc/session.go index ab226a3..e92faf9 100644 --- a/internal/grpc/session.go +++ b/internal/grpc/session.go @@ -41,7 +41,6 @@ type SessionServer struct { queryParametersIntrospect []string idleSessionTimeout time.Duration allowHttpScheme bool - clientID string // cache introspection results introspectionCache *ttlcache.Cache[string, oidc.Introspection] @@ -52,7 +51,6 @@ func NewSessionServer( sessionRepo session.Repository, trustRepo trust.OIDCMappingRepository, idleSessionTimeout time.Duration, - clientID string, opts ...SessionServerOption, ) *SessionServer { s := &SessionServer{ @@ -60,7 +58,6 @@ func NewSessionServer( trustRepo: trustRepo, idleSessionTimeout: idleSessionTimeout, newCreds: func(clientID string) credentials.TransportCredentials { return credentials.NewInsecure(clientID) }, - clientID: clientID, } for _, opt := range opts { if opt != nil { @@ -220,7 +217,7 @@ func (s *SessionServer) getClientID(mapping *trust.OIDCMapping) string { return mapping.ClientID } - return s.clientID + return "" } func (s *SessionServer) httpClient(mapping *trust.OIDCMapping) *http.Client { diff --git a/internal/grpc/session_test.go b/internal/grpc/session_test.go index 2dda10f..f047cf2 100644 --- a/internal/grpc/session_test.go +++ b/internal/grpc/session_test.go @@ -31,7 +31,7 @@ func TestNewSessionServer(t *testing.T) { trustRepo := trustmock.NewInMemRepository() idleSessionTimeout := 90 * time.Minute - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, idleSessionTimeout, "") + server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, idleSessionTimeout) assert.NotNil(t, server) }) @@ -45,7 +45,6 @@ func TestNewSessionServer(t *testing.T) { sessionRepo, trustRepo, idleSessionTimeout, - "", grpc.WithQueryParametersIntrospect([]string{"param1", "param2"}), ) @@ -61,7 +60,6 @@ func TestNewSessionServer(t *testing.T) { sessionRepo, trustRepo, idleSessionTimeout, - "", nil, ) @@ -123,7 +121,7 @@ func TestGetSession(t *testing.T) { trustmock.WithTrust(sess.TenantID, mapping), ) - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "", + server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, grpc.WithAllowHttpScheme(true), ) @@ -194,7 +192,7 @@ func TestGetSession(t *testing.T) { trustmock.WithTrust(sess.TenantID, mapping), ) - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "", + server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, grpc.WithAllowHttpScheme(true), ) @@ -250,7 +248,7 @@ func TestGetSession(t *testing.T) { trustmock.WithTrust(sess.TenantID, mapping), ) - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "", + server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, grpc.WithAllowHttpScheme(true), ) @@ -274,7 +272,7 @@ func TestGetSession(t *testing.T) { ) trustRepo := trustmock.NewInMemRepository() - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "") + server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute) req := &sessionv1.GetSessionRequest{ SessionId: "session-123", @@ -303,7 +301,7 @@ func TestGetSession(t *testing.T) { trustRepo := trustmock.NewInMemRepository() - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "") + server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute) req := &sessionv1.GetSessionRequest{ SessionId: "session-789", @@ -330,7 +328,7 @@ func TestGetSession(t *testing.T) { trustRepo := trustmock.NewInMemRepository() - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "") + server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute) req := &sessionv1.GetSessionRequest{ SessionId: "session-fail", @@ -361,7 +359,7 @@ func TestGetSession(t *testing.T) { // No mapping added to repo trustRepo := trustmock.NewInMemRepository() - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "") + server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute) req := &sessionv1.GetSessionRequest{ SessionId: "session-no-provider", @@ -397,7 +395,7 @@ func TestGetSession(t *testing.T) { trustmock.WithTrust(sess.TenantID, mapping), ) - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "") + server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute) req := &sessionv1.GetSessionRequest{ SessionId: "session-blocked", @@ -445,7 +443,7 @@ func TestGetSession(t *testing.T) { trustmock.WithTrust(sess.TenantID, mapping), ) - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "") + server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute) req := &sessionv1.GetSessionRequest{ SessionId: "session-fingerprint", @@ -481,7 +479,7 @@ func TestGetSession(t *testing.T) { trustmock.WithTrust("wrong-tenant", mapping), ) - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "") + server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute) req := &sessionv1.GetSessionRequest{ SessionId: "session-tenant", @@ -517,7 +515,7 @@ func TestGetSession(t *testing.T) { trustmock.WithTrust(sess.TenantID, mapping), ) - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "") + server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute) req := &sessionv1.GetSessionRequest{ SessionId: "session-config-fail", @@ -569,7 +567,7 @@ func TestGetSession(t *testing.T) { trustmock.WithTrust(sess.TenantID, mapping), ) - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "", + server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, grpc.WithAllowHttpScheme(true), ) @@ -625,7 +623,7 @@ func TestGetSession(t *testing.T) { trustmock.WithTrust(sess.TenantID, mapping), ) - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "", + server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, grpc.WithAllowHttpScheme(true), ) @@ -674,7 +672,7 @@ func TestGetSession(t *testing.T) { trustmock.WithTrust(sess.TenantID, mapping), ) - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "", + server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, grpc.WithAllowHttpScheme(true), ) @@ -708,7 +706,6 @@ func TestWithQueryParametersIntrospect(t *testing.T) { sessionRepo, trustRepo, 90*time.Minute, - "", opt, ) @@ -731,7 +728,7 @@ func TestGetOIDCProvider(t *testing.T) { trustmock.WithTrust("tenant-123", mapping), ) - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "") + server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute) req := &sessionv1.GetOIDCProviderRequest{ TenantId: "tenant-123", @@ -751,7 +748,7 @@ func TestGetOIDCProvider(t *testing.T) { sessionRepo := sessionmock.NewInMemRepository() trustRepo := trustmock.NewInMemRepository() - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "") + server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute) req := &sessionv1.GetOIDCProviderRequest{ TenantId: "non-existent-tenant", @@ -770,7 +767,7 @@ func TestGetOIDCProvider(t *testing.T) { trustmock.WithGetError(errors.New("database connection error")), ) - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "") + server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute) req := &sessionv1.GetOIDCProviderRequest{ TenantId: "tenant-123", diff --git a/internal/session/housekeeper_test.go b/internal/session/housekeeper_test.go index e466a63..e085a7b 100644 --- a/internal/session/housekeeper_test.go +++ b/internal/session/housekeeper_test.go @@ -106,6 +106,7 @@ func TestRefreshAccessToken(t *testing.T) { mapping := trust.OIDCMapping{ IssuerURL: discoveryServerURL, + ClientID: "test-client-id", Properties: map[string]string{ "test-param": "param-value", }, @@ -127,9 +128,7 @@ func TestRefreshAccessToken(t *testing.T) { require.NoError(t, err) cfg := &config.SessionManager{ - ClientAuth: config.ClientAuth{ - ClientID: "test-client-id", - }, + ClientAuth: config.ClientAuth{}, AdditionalQueryParametersToken: []string{"test-param"}, CSRFSecretParsed: []byte(testCSRFSecret), } @@ -170,9 +169,7 @@ func TestRefreshAccessToken(t *testing.T) { require.NoError(t, err) cfg := &config.SessionManager{ - ClientAuth: config.ClientAuth{ - ClientID: "test-client-id", - }, + ClientAuth: config.ClientAuth{}, CSRFSecretParsed: []byte(testCSRFSecret), } @@ -226,9 +223,7 @@ func TestRefreshAccessToken(t *testing.T) { require.NoError(t, err) cfg := &config.SessionManager{ - ClientAuth: config.ClientAuth{ - ClientID: "test-client-id", - }, + ClientAuth: config.ClientAuth{}, CSRFSecretParsed: []byte(testCSRFSecret), } @@ -283,9 +278,7 @@ func TestRefreshAccessToken(t *testing.T) { require.NoError(t, err) cfg := &config.SessionManager{ - ClientAuth: config.ClientAuth{ - ClientID: "test-client-id", - }, + ClientAuth: config.ClientAuth{}, AdditionalQueryParametersToken: []string{"missing-param"}, CSRFSecretParsed: []byte(testCSRFSecret), } diff --git a/internal/session/manager.go b/internal/session/manager.go index 12c2e12..ce5b32b 100644 --- a/internal/session/manager.go +++ b/internal/session/manager.go @@ -49,7 +49,6 @@ type Manager struct { sessionDuration time.Duration idleSessionTimeout time.Duration callbackURL *url.URL - clientID string queryParametersAuth []string queryParametersToken []string authContextKeys []string @@ -94,7 +93,6 @@ func NewManager( csrfCookieTemplate: cfg.CSRFCookieTemplate, loginCSRFCookieTemplate: cfg.LoginCSRFCookieTemplate, callbackURL: callbackURL, - clientID: cfg.ClientAuth.ClientID, newCreds: func(clientID string) credentials.TransportCredentials { return credentials.NewInsecure(clientID) }, csrfSecret: cfg.CSRFSecretParsed, } @@ -628,7 +626,7 @@ func (m *Manager) getClientID(mapping trust.OIDCMapping) string { return mapping.ClientID } - return m.clientID + return "" } func (m *Manager) exchangeCode(ctx context.Context, openidConf *oidc.Configuration, code, codeVerifier string, mapping trust.OIDCMapping) (tokenResponse, error) { diff --git a/internal/session/manager_cookie_test.go b/internal/session/manager_cookie_test.go index 9927f82..4df737f 100644 --- a/internal/session/manager_cookie_test.go +++ b/internal/session/manager_cookie_test.go @@ -435,10 +435,8 @@ func TestManager_Logout(t *testing.T) { { name: "Success - redirect to postLogoutURL when no end session endpoint", cfg: &config.SessionManager{ - CSRFSecretParsed: []byte(testCSRFSecret), - ClientAuth: config.ClientAuth{ - ClientID: testClientID, - }, + CSRFSecretParsed: []byte(testCSRFSecret), + PostLogoutRedirectURL: postLogoutURL, }, setupOIDCRepo: func(t *testing.T) *trustmock.Repository { t.Helper() @@ -469,9 +467,7 @@ func TestManager_Logout(t *testing.T) { cfg: &config.SessionManager{ CSRFSecretParsed: []byte(testCSRFSecret), PostLogoutRedirectURL: postLogoutURL, - ClientAuth: config.ClientAuth{ - ClientID: testClientID, - }, + ClientAuth: config.ClientAuth{}, }, setupOIDCRepo: func(t *testing.T) *trustmock.Repository { t.Helper() diff --git a/internal/session/manager_test.go b/internal/session/manager_test.go index cae05e7..442a647 100644 --- a/internal/session/manager_test.go +++ b/internal/session/manager_test.go @@ -51,6 +51,7 @@ func TestManager_Auth(t *testing.T) { oidcMapping := trust.OIDCMapping{ IssuerURL: oidcServer.URL, + ClientID: "my-client-id", Blocked: false, JWKSURI: "http://jwks.example.com", Audiences: []string{requestURI}, @@ -81,10 +82,8 @@ func TestManager_Auth(t *testing.T) { SessionDuration: time.Hour, CallbackURL: callbackURL, AdditionalQueryParametersAuthorize: []string{"paramAuth1"}, - ClientAuth: config.ClientAuth{ - ClientID: testClientID, - }, - CSRFSecretParsed: []byte(testCSRFSecret), + ClientAuth: config.ClientAuth{}, + CSRFSecretParsed: []byte(testCSRFSecret), }, tenantID: tenantID, fingerprint: "fingerprint", @@ -591,9 +590,7 @@ func TestManager_LogoutEdgeCases(t *testing.T) { cfg := &config.SessionManager{ CSRFSecretParsed: []byte(testCSRFSecret), - ClientAuth: config.ClientAuth{ - ClientID: testClientID, - }, + ClientAuth: config.ClientAuth{}, } m, err := session.NewManager(ctx, cfg, oidcMock, sessionMock, auditLogger) diff --git a/internal/trust/mapping.go b/internal/trust/mapping.go index b273807..8d400c5 100644 --- a/internal/trust/mapping.go +++ b/internal/trust/mapping.go @@ -7,9 +7,7 @@ type OIDCMapping struct { Audiences []string Properties map[string]string - // ClientID is a client_id property used for authentication. - // It is an optional value for the trust config. If the trust's client id is not specified, - // the application-global client id is used. + // ClientID is a mandatory property used for authentication. ClientID string } diff --git a/internal/trust/repository_test.go b/internal/trust/repository_test.go index 549e42d..075077f 100644 --- a/internal/trust/repository_test.go +++ b/internal/trust/repository_test.go @@ -4,6 +4,8 @@ import ( "context" "database/sql" "fmt" + "os" + "path/filepath" "github.com/jackc/pgx/v5/pgxpool" "github.com/pressly/goose/v3" @@ -61,10 +63,11 @@ func (m *RepoWrapper) Delete(ctx context.Context, tenantID string) error { // Get implements oidc.OIDCMappingRepository. func (m *RepoWrapper) Get(ctx context.Context, tenantID string) (trust.OIDCMapping, error) { if m.MockGet != nil { - _, err := m.MockGet(ctx, tenantID) + mapping, err := m.MockGet(ctx, tenantID) if err != nil { return trust.OIDCMapping{}, err } + return mapping, nil } return m.Repo.Get(ctx, tenantID) } @@ -114,6 +117,10 @@ func createRepo(ctx context.Context) (trust.OIDCMappingRepository, error) { } func migrateDB(ctx context.Context, connStr string) error { + // Create a test config file for the migration + configCleanup := createTestConfig() + defer configCleanup() + db, err := sql.Open("pgx", connStr) if err != nil { return err @@ -133,3 +140,44 @@ func migrateDB(ctx context.Context, connStr string) error { } return nil } + +// createTestConfig creates a temporary config file for the migration tests. +// It returns a cleanup function that should be called to remove the config file. +func createTestConfig() func() { + // Create a temporary directory for the config + tmpDir, err := os.MkdirTemp("", "session-manager-test-*") + if err != nil { + panic(fmt.Sprintf("Failed to create temp dir: %v", err)) + } + + // Create config.yaml with test client_id + configContent := `sessionManager: + clientAuth: + clientID: "test-client-id" +` + configPath := filepath.Join(tmpDir, "config.yaml") + err = os.WriteFile(configPath, []byte(configContent), 0644) + if err != nil { + os.RemoveAll(tmpDir) + panic(fmt.Sprintf("Failed to write config file: %v", err)) + } + + // Change to the temp directory so the config loader can find it + originalDir, err := os.Getwd() + if err != nil { + os.RemoveAll(tmpDir) + panic(fmt.Sprintf("Failed to get current directory: %v", err)) + } + + err = os.Chdir(tmpDir) + if err != nil { + os.RemoveAll(tmpDir) + panic(fmt.Sprintf("Failed to change directory: %v", err)) + } + + // Return cleanup function + return func() { + _ = os.Chdir(originalDir) + os.RemoveAll(tmpDir) + } +} diff --git a/sql/00005_client_id.go b/sql/00005_client_id.go new file mode 100644 index 0000000..b248cac --- /dev/null +++ b/sql/00005_client_id.go @@ -0,0 +1,54 @@ +package migrations + +import ( + "context" + "database/sql" + "errors" + + "github.com/openkcm/common-sdk/pkg/commoncfg" + "github.com/pressly/goose/v3" + + slogctx "github.com/veqryn/slog-context" + + "github.com/openkcm/session-manager/internal/config" +) + +func init() { + goose.AddMigrationContext(Up00005, Down00005) +} + +func Up00005(ctx context.Context, tx *sql.Tx) error { + clientID, err := readClientIDfromConfig() + if err != nil { + return err + } + slogctx.Debug(ctx, "Updating trust table with client_id", "client_id", clientID) + _, err = tx.ExecContext(ctx, "UPDATE trust SET client_id=$1 WHERE client_id IS NULL or client_id='';", clientID) + return err +} + +func Down00005(ctx context.Context, tx *sql.Tx) error { + return nil +} + +func readClientIDfromConfig() (string, error) { + // Load the config which contains the client_id + cfg := &config.Config{} + loader := commoncfg.NewLoader(cfg, commoncfg.WithPaths( + "/etc/session-manager", + "$HOME/.session-manager", + ".", + )) + if err := loader.LoadConfig(); err != nil { + return "", err + } + + // Read the client_id from the config + //nolint:staticcheck + clientID := cfg.SessionManager.ClientAuth.ClientID + if clientID == "" { + return "", errors.New("client_id is not set in the config") + } + + return clientID, nil +} diff --git a/sql/00005_client_id_test.go b/sql/00005_client_id_test.go new file mode 100644 index 0000000..4afe7e2 --- /dev/null +++ b/sql/00005_client_id_test.go @@ -0,0 +1,304 @@ +package migrations + +import ( + "context" + "database/sql" + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/jackc/pgx/v5/pgxpool" + "github.com/pressly/goose/v3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/testcontainers/testcontainers-go/modules/postgres" + + _ "github.com/jackc/pgx/v5/stdlib" +) + +const ( + testDBHost = "localhost" + testDBUser = "postgres" + testDBPassword = "secret" + testDBName = "migration_test" + testDBSSLMode = "disable" +) + +// TestUp00005_WithConfig tests the migration when a valid config file exists +func TestUp00005_WithConfig(t *testing.T) { + ctx := t.Context() + db, pool, cleanup := setupTestDB(ctx, t) + defer cleanup() + + // Create test config with a specific client_id + configCleanup := createTestConfigWithClientID("test-client-123") + defer configCleanup() + + // Insert test data with various client_id states + _, err := pool.Exec(ctx, ` + INSERT INTO trust (tenant_id, blocked, issuer, jwks_uri, audiences, properties, client_id) + VALUES + ('tenant-null', false, 'issuer1', '', '{}', '{}', NULL), + ('tenant-empty', false, 'issuer2', '', '{}', '{}', ''), + ('tenant-existing', false, 'issuer3', '', '{}', '{}', 'existing-client-id') + `) + require.NoError(t, err) + + // Create a transaction for the migration + tx, err := db.BeginTx(ctx, nil) + require.NoError(t, err) + defer func() { _ = tx.Rollback() }() + + // Run the up migration + err = Up00005(ctx, tx) + require.NoError(t, err) + + // Commit the transaction + err = tx.Commit() + require.NoError(t, err) + + // Verify the results + t.Run("NULL client_id should be updated", func(t *testing.T) { + var clientID string + err := pool.QueryRow(ctx, "SELECT client_id FROM trust WHERE tenant_id = 'tenant-null'").Scan(&clientID) + require.NoError(t, err) + assert.Equal(t, "test-client-123", clientID) + }) + + t.Run("empty client_id should be updated", func(t *testing.T) { + var clientID string + err := pool.QueryRow(ctx, "SELECT client_id FROM trust WHERE tenant_id = 'tenant-empty'").Scan(&clientID) + require.NoError(t, err) + assert.Equal(t, "test-client-123", clientID) + }) + + t.Run("existing client_id should NOT be updated", func(t *testing.T) { + var clientID string + err := pool.QueryRow(ctx, "SELECT client_id FROM trust WHERE tenant_id = 'tenant-existing'").Scan(&clientID) + require.NoError(t, err) + assert.Equal(t, "existing-client-id", clientID, "existing client_id should not be overwritten") + }) +} + +// TestUp00005_WithoutConfig tests that migration handles missing config gracefully +func TestUp00005_WithoutConfig(t *testing.T) { + ctx := t.Context() + db, pool, cleanup := setupTestDB(ctx, t) + defer cleanup() + + // Ensure we're in a directory with no config file + tmpDir := t.TempDir() + + t.Chdir(tmpDir) + + // Insert test data + _, err := pool.Exec(ctx, ` + INSERT INTO trust (tenant_id, blocked, issuer, jwks_uri, audiences, properties, client_id) + VALUES ('tenant-no-config', false, 'issuer1', '', '{}', '{}', NULL) + `) + require.NoError(t, err) + + // Create a transaction for the migration + tx, err := db.BeginTx(ctx, nil) + require.NoError(t, err) + defer func() { _ = tx.Rollback() }() + + // Run the up migration - should not fail even without config + err = Up00005(ctx, tx) + + // The migration should handle missing config gracefully + // This test documents the current behavior - adjust assertion based on desired behavior + if err != nil { + assert.Contains(t, err.Error(), "Config File", "should indicate config file issue") + } +} + +// TestUp00005_EmptyClientIDInConfig tests migration when config has empty client_id +func TestUp00005_EmptyClientIDInConfig(t *testing.T) { + ctx := t.Context() + db, pool, cleanup := setupTestDB(ctx, t) + defer cleanup() + + // Create test config with empty client_id + configCleanup := createTestConfigWithClientID("") + defer configCleanup() + + // Insert test data + _, err := pool.Exec(ctx, ` + INSERT INTO trust (tenant_id, blocked, issuer, jwks_uri, audiences, properties, client_id) + VALUES ('tenant-test', false, 'issuer1', '', '{}', '{}', NULL) + `) + require.NoError(t, err) + + // Create a transaction for the migration + tx, err := db.BeginTx(ctx, nil) + require.NoError(t, err) + defer func() { _ = tx.Rollback() }() + + // Run the up migration - should fail because client_id is empty + err = Up00005(ctx, tx) + assert.Error(t, err, "should fail when client_id is empty in config") + assert.Contains(t, err.Error(), "client_id is not set") +} + +// TestDown00005 tests the down migration +func TestDown00005(t *testing.T) { + ctx := t.Context() + db, _, cleanup := setupTestDB(ctx, t) + defer cleanup() + + // Create a transaction for the migration + tx, err := db.BeginTx(ctx, nil) + require.NoError(t, err) + defer func() { _ = tx.Rollback() }() + + // Run the down migration + err = Down00005(ctx, tx) + + // Down migration is currently a no-op, so should succeed + assert.NoError(t, err) +} + +// TestUp00005_MultipleRows tests migration with many rows +func TestUp00005_MultipleRows(t *testing.T) { + ctx := t.Context() + db, pool, cleanup := setupTestDB(ctx, t) + defer cleanup() + + configCleanup := createTestConfigWithClientID("bulk-client-id") + defer configCleanup() + + // Insert multiple test rows + for i := range 10 { + clientID := "" + if i%3 == 0 { + clientID = fmt.Sprintf("existing-%d", i) + } + _, err := pool.Exec(ctx, ` + INSERT INTO trust (tenant_id, blocked, issuer, jwks_uri, audiences, properties, client_id) + VALUES ($1, false, $2, '', '{}', '{}', $3) + `, fmt.Sprintf("tenant-%d", i), fmt.Sprintf("issuer-%d", i), clientID) + require.NoError(t, err) + } + + // Run the migration + tx, err := db.BeginTx(ctx, nil) + require.NoError(t, err) + defer func() { _ = tx.Rollback() }() + + err = Up00005(ctx, tx) + require.NoError(t, err) + + err = tx.Commit() + require.NoError(t, err) + + // Verify all rows have client_id set + rows, err := pool.Query(ctx, "SELECT tenant_id, client_id FROM trust ORDER BY tenant_id") + require.NoError(t, err) + defer rows.Close() + + count := 0 + for rows.Next() { + var tenantID, clientID string + err := rows.Scan(&tenantID, &clientID) + require.NoError(t, err) + + assert.NotEmpty(t, clientID, "client_id should not be empty for tenant %s", tenantID) + + // Check if this tenant should have retained its existing client_id + var expectedID int + if _, err := fmt.Sscanf(tenantID, "tenant-%d", &expectedID); err == nil && expectedID%3 == 0 { + assert.Equal(t, fmt.Sprintf("existing-%d", expectedID), clientID) + } else { + assert.Equal(t, "bulk-client-id", clientID) + } + count++ + } + + assert.Equal(t, 10, count, "should have processed all 10 rows") +} + +// setupTestDB creates a test database and returns db connection, pool, and cleanup function +func setupTestDB(ctx context.Context, t *testing.T) (*sql.DB, *pgxpool.Pool, func()) { + t.Helper() + + // Start PostgreSQL container + pgContainer, err := postgres.Run( + ctx, + "postgres:17-alpine", + postgres.WithDatabase(testDBName), + postgres.WithUsername(testDBUser), + postgres.WithPassword(testDBPassword), + postgres.BasicWaitStrategies(), + ) + require.NoError(t, err) + + port, err := pgContainer.MappedPort(ctx, "5432") + require.NoError(t, err) + + connStr := fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%s sslmode=%s", + testDBHost, testDBUser, testDBPassword, testDBName, port.Port(), testDBSSLMode) + + // Create pgx pool for queries + pool, err := pgxpool.New(ctx, connStr) + require.NoError(t, err) + + // Create sql.DB for migrations + db, err := sql.Open("pgx", connStr) + require.NoError(t, err) + + // Run migrations up to but not including 00005 + goose.SetBaseFS(FS) + err = goose.SetDialect("pgx") + require.NoError(t, err) + + // Manually run migrations 1-4 + err = goose.UpToContext(ctx, db, ".", 4) + require.NoError(t, err) + + cleanup := func() { + pool.Close() + db.Close() + _ = pgContainer.Terminate(ctx) + } + + return db, pool, cleanup +} + +// createTestConfigWithClientID creates a temporary config file with the specified client_id +func createTestConfigWithClientID(clientID string) func() { + tmpDir, err := os.MkdirTemp("", "migration-test-*") + if err != nil { + panic(fmt.Sprintf("Failed to create temp dir: %v", err)) + } + + configContent := fmt.Sprintf(`sessionManager: + clientAuth: + clientID: "%s" +`, clientID) + + configPath := filepath.Join(tmpDir, "config.yaml") + err = os.WriteFile(configPath, []byte(configContent), 0644) + if err != nil { + os.RemoveAll(tmpDir) + panic(fmt.Sprintf("Failed to write config file: %v", err)) + } + + originalDir, err := os.Getwd() + if err != nil { + os.RemoveAll(tmpDir) + panic(fmt.Sprintf("Failed to get current directory: %v", err)) + } + + err = os.Chdir(tmpDir) + if err != nil { + os.RemoveAll(tmpDir) + panic(fmt.Sprintf("Failed to change directory: %v", err)) + } + + return func() { + _ = os.Chdir(originalDir) + os.RemoveAll(tmpDir) + } +}