Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 11 additions & 20 deletions rest-api/api/internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -430,32 +430,22 @@ func (c *Config) GetOrInitJWTOriginConfig() *cauth.JWTOriginConfig {
log.Panic().Err(err).Msg("Invalid issuers configuration")
}

// First pass: collect all static org names (lowercased) from all issuers
reservedOrgNames := make(map[string]bool)
for _, issuerCfg := range issuersConfig {
for _, mapping := range issuerCfg.ClaimMappings {
if mapping.OrgName != "" {
reservedOrgNames[strings.ToLower(mapping.OrgName)] = true
}
}
}
// First pass: reserve the static org names declared in the config file.
c.JwtOriginConfig.ReplaceReservedOrgs(c.configStaticOrgNames())

// Second pass: create jwksConfigs and assign reservedOrgNames only to those with dynamic mappings
// Second pass: create jwksConfigs. AddJwksConfig wires the shared
// reserved-org set into each config (only consulted for dynamic mappings).
for _, issuerCfg := range issuersConfig {
origin, _ := issuerCfg.GetOrigin() // Already validated
jwksTimeout, _ := issuerCfg.GetJWKSTimeout()

// Normalize org names in claim mappings and check for dynamic mappings
// Normalize org names in claim mappings
normalizedMappings := make([]cauth.ClaimMapping, len(issuerCfg.ClaimMappings))
hasDynamicMapping := false
for i, mapping := range issuerCfg.ClaimMappings {
normalizedMappings[i] = mapping
if mapping.OrgName != "" {
normalizedMappings[i].OrgName = strings.ToLower(mapping.OrgName)
}
if mapping.OrgAttribute != "" {
hasDynamicMapping = true
}
}

jwksCfg := cauth.NewJwksConfig(
Expand All @@ -470,11 +460,6 @@ func (c *Config) GetOrInitJWTOriginConfig() *cauth.JWTOriginConfig {
jwksCfg.JWKSTimeout = jwksTimeout
jwksCfg.ClaimMappings = normalizedMappings

// Only assign reservedOrgNames to configs with dynamic claim mappings
if hasDynamicMapping {
jwksCfg.ReservedOrgNames = reservedOrgNames
}

c.JwtOriginConfig.AddJwksConfig(jwksCfg)
}

Expand Down Expand Up @@ -588,6 +573,12 @@ func (c *Config) ValidateIssuersConfig(issuers []IssuerConfig) error {
if len(issuer.ClaimMappings) > 0 && origin != cauth.TokenOriginCustom {
return fmt.Errorf("issuer %s: claimMappings are only allowed for custom origin issuers (origin: %s)", issuer.Name, origin)
}
if origin == cauth.TokenOriginCustom && len(issuer.ClaimMappings) == 0 {
return fmt.Errorf("issuer %s: claimMappings are required for custom origin issuers", issuer.Name)
}
if origin == cauth.TokenOriginCustom && issuer.ServiceAccount {
return fmt.Errorf("issuer %s: serviceAccount is not supported for custom origin issuers; use claimMappings[].isServiceAccount", issuer.Name)
}

// Validate JWKS timeout if specified
if issuer.JWKSTimeout != "" {
Expand Down
87 changes: 87 additions & 0 deletions rest-api/api/internal/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"testing"
"time"

cauth "github.com/NVIDIA/infra-controller/rest-api/auth/pkg/config"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/spf13/viper"
Expand Down Expand Up @@ -61,6 +62,92 @@ func TestNewConfig(t *testing.T) {
}
}

func TestValidateIssuersConfig_CustomClaimMappingRules(t *testing.T) {
cfg := &Config{v: newViper()}

validCustom := IssuerConfig{
Name: "external-idp",
Origin: cauth.TokenOriginCustom,
JWKS: "https://idp.example.com/jwks",
Issuer: "https://idp.example.com",
ClaimMappings: []cauth.ClaimMapping{{
OrgName: "tenant-a",
Roles: []string{"TENANT_ADMIN"},
}},
}

tests := []struct {
name string
issuer IssuerConfig
wantErr string
}{
{
name: "custom issuer with mapping is valid",
issuer: validCustom,
},
{
name: "custom issuer without mappings is invalid",
issuer: IssuerConfig{
Name: "external-idp",
Origin: cauth.TokenOriginCustom,
JWKS: "https://idp.example.com/jwks",
Issuer: "https://idp.example.com",
},
wantErr: "claimMappings are required",
},
{
name: "non-custom issuer with mappings is invalid",
issuer: IssuerConfig{
Name: "kas-idp",
Origin: cauth.TokenOriginKasLegacy,
JWKS: "https://idp.example.com/jwks",
Issuer: "https://idp.example.com",
ClaimMappings: []cauth.ClaimMapping{{
OrgName: "tenant-a",
Roles: []string{"TENANT_ADMIN"},
}},
},
wantErr: "claimMappings are only allowed for custom origin issuers",
},
{
name: "custom issuer rejects issuer-level service account",
issuer: IssuerConfig{
Name: "external-idp",
Origin: cauth.TokenOriginCustom,
JWKS: "https://idp.example.com/jwks",
Issuer: "https://idp.example.com",
ServiceAccount: true,
ClaimMappings: []cauth.ClaimMapping{{
OrgName: "tenant-a",
Roles: []string{"TENANT_ADMIN"},
}},
},
wantErr: "serviceAccount is not supported for custom origin issuers",
},
{
name: "non-custom issuer without mappings remains valid",
issuer: IssuerConfig{
Name: "kas-idp",
Origin: cauth.TokenOriginKasLegacy,
JWKS: "https://idp.example.com/jwks",
Issuer: "https://idp.example.com",
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := cfg.ValidateIssuersConfig([]IssuerConfig{tt.issuer})
if tt.wantErr == "" {
require.NoError(t, err)
return
}
require.Error(t, err)
assert.Contains(t, err.Error(), tt.wantErr)
})
}
}

func TestConfig_WatchConfigFile(t *testing.T) {
const initialSitePhoneHomeURL = "http://initial.example/phone_home"

Expand Down
219 changes: 219 additions & 0 deletions rest-api/api/internal/config/issuer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package config

import (
"context"
"fmt"
"strings"
"time"

cauth "github.com/NVIDIA/infra-controller/rest-api/auth/pkg/config"
cdb "github.com/NVIDIA/infra-controller/rest-api/db/pkg/db"
cdbm "github.com/NVIDIA/infra-controller/rest-api/db/pkg/db/model"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
)

// dbClaimMappingsToAuth copies persisted claim mappings into the auth shape.
func dbClaimMappingsToAuth(in []cdbm.IssuerClaimMapping) []cauth.ClaimMapping {
out := make([]cauth.ClaimMapping, len(in))
for i, cm := range in {
out[i] = cauth.ClaimMapping{
OrgAttribute: cm.OrgAttribute,
OrgDisplayAttribute: cm.OrgDisplayAttribute,
OrgName: cm.OrgName,
OrgDisplayName: cm.OrgDisplayName,
RolesAttribute: cm.RolesAttribute,
Roles: cm.Roles,
IsServiceAccount: cm.IsServiceAccount,
}
}
return out
}

// jwksConfigForIssuer builds a JwksConfig from a DB Issuer.
func (c *Config) jwksConfigForIssuer(iss *cdbm.Issuer) *cauth.JwksConfig {
jwksCfg := cauth.NewJwksConfig(iss.Name, iss.JWKSURL, iss.IssuerURL, iss.Origin, iss.ServiceAccount, iss.Audiences, iss.Scopes)
if iss.JWKSTimeout != "" {
d, err := time.ParseDuration(iss.JWKSTimeout)
if err == nil {
jwksCfg.JWKSTimeout = d
}
}

mappings := dbClaimMappingsToAuth(iss.ClaimMappings)
for i := range mappings {
mappings[i].OrgName = strings.ToLower(mappings[i].OrgName)
}
jwksCfg.ClaimMappings = mappings
// The shared reserved-org set is wired in by AddJwksConfig when this config
// is installed (via ApplyIssuer/SeedIssuersFromDB).
return jwksCfg
}

// issuerToConfig maps a DB Issuer back to an IssuerConfig for validation.
func (c *Config) issuerToConfig(iss *cdbm.Issuer) IssuerConfig {
return IssuerConfig{
Name: iss.Name,
Origin: iss.Origin,
JWKS: iss.JWKSURL,
Issuer: iss.IssuerURL,
ServiceAccount: iss.ServiceAccount,
Audiences: iss.Audiences,
Scopes: iss.Scopes,
JWKSTimeout: iss.JWKSTimeout,
ClaimMappings: dbClaimMappingsToAuth(iss.ClaimMappings),
AllowDuplicateStaticOrgNames: iss.AllowDuplicateStaticOrgNames,
}
}

// ValidateRegisteredIssuer validates candidate against the static config issuers
// and all registered DB issuers, excluding excludeID.
func (c *Config) ValidateRegisteredIssuer(ctx context.Context, dbSession *cdb.Session, tx *cdb.Tx, candidate *cdbm.Issuer, excludeID *uuid.UUID) error {
combined := c.GetIssuersConfig()

existing, err := cdbm.NewIssuerDAO(dbSession).GetAll(ctx, tx)
if err != nil {
return err
}
for i := range existing {
if excludeID != nil && existing[i].ID == *excludeID {
continue
}
combined = append(combined, c.issuerToConfig(&existing[i]))
}
combined = append(combined, c.issuerToConfig(candidate))

return c.ValidateIssuersConfig(combined)
}

// ApplyIssuer hot-applies a DB Issuer into the live JWT origin map.
func (c *Config) ApplyIssuer(iss *cdbm.Issuer) error {
joCfg := c.GetOrInitJWTOriginConfig()
if joCfg == nil {
return fmt.Errorf("JWT origin config not initialized")
}
jwksCfg := c.jwksConfigForIssuer(iss)
joCfg.AddJwksConfig(jwksCfg)
err := jwksCfg.UpdateJWKS()
if err != nil {
return err
}
return nil
}

// RemoveIssuer removes an issuer from the live JWT origin map.
func (c *Config) RemoveIssuer(issuerURL string) {
if c.JwtOriginConfig != nil {
c.JwtOriginConfig.RemoveConfig(issuerURL)
}
}

// configStaticOrgNames returns the lowercased static org names declared by the
// statically-configured issuers. Config is immutable after boot, so this is
// cheap to recompute on demand.
func (c *Config) configStaticOrgNames() map[string]bool {
out := make(map[string]bool)
for _, issuerCfg := range c.GetIssuersConfig() {
for _, mapping := range issuerCfg.ClaimMappings {
if mapping.OrgName != "" {
out[strings.ToLower(mapping.OrgName)] = true
}
}
}
return out
}

// RebuildReservedOrgs recomputes the reserved org set from config and DB static orgs.
func (c *Config) RebuildReservedOrgs(ctx context.Context, dbSession *cdb.Session) error {
joCfg := c.GetOrInitJWTOriginConfig()
if joCfg == nil {
return fmt.Errorf("JWT origin config not initialized")
}

issuers, err := cdbm.NewIssuerDAO(dbSession).GetAll(ctx, nil)
if err != nil {
return err
}

union := c.configStaticOrgNames()
for i := range issuers {
for _, cm := range issuers[i].ClaimMappings {
if cm.OrgName != "" {
union[strings.ToLower(cm.OrgName)] = true
}
}
}

joCfg.ReplaceReservedOrgs(union)
return nil
}

// SeedIssuersFromDB applies all registered DB issuers into the live JWT origin map at startup.
// A statically-configured issuer URL is skipped; a JWKS fetch failure is non-fatal.
func (c *Config) SeedIssuersFromDB(ctx context.Context, dbSession *cdb.Session) error {
joCfg := c.GetOrInitJWTOriginConfig()
if joCfg == nil {
return fmt.Errorf("JWT origin config not initialized")
}

issuers, err := cdbm.NewIssuerDAO(dbSession).GetAll(ctx, nil)
if err != nil {
return err
}

issDAO := cdbm.NewIssuerDAO(dbSession)
for i := range issuers {
ctxErr := ctx.Err()
if ctxErr != nil {
return ctxErr
}

iss := &issuers[i]
if c.IsStaticIssuer(iss.IssuerURL) {
log.Warn().Str("issuer", iss.IssuerURL).Msg("Skipping DB issuer that is statically configured")
continue
}
jwksCfg := c.jwksConfigForIssuer(iss)
joCfg.AddJwksConfig(jwksCfg)
status := cdbm.IssuerStatusReady
uerr := jwksCfg.UpdateJWKSWithContext(ctx)
if uerr != nil {
ctxErr = ctx.Err()
if ctxErr != nil {
return ctxErr
}
status = cdbm.IssuerStatusPending
log.Warn().Err(uerr).Str("issuer", iss.IssuerURL).
Msg("Failed to fetch JWKS for DB-registered issuer at boot; will lazy-refresh on first use")
}
if iss.Status != status {
_, uerr = cdb.WithTxResult(ctx, dbSession, func(tx *cdb.Tx) (*cdbm.Issuer, error) {
return issDAO.Update(ctx, tx, cdbm.IssuerUpdateInput{IssuerID: iss.ID, Status: &status})
})
if uerr != nil {
return uerr
}
}
}

rerr := c.RebuildReservedOrgs(ctx, dbSession)
if rerr != nil {
return rerr
}

log.Info().Int("count", len(issuers)).Msg("Seeded DB-registered issuers into JWT origin config")
return nil
}

// IsStaticIssuer reports whether the given issuer URL is configured as a static issuer.
func (c *Config) IsStaticIssuer(issuerURL string) bool {
for _, issuerCfg := range c.GetIssuersConfig() {
if issuerCfg.Issuer == issuerURL {
return true
}
}
return false
}
Loading