Skip to content

Commit

Permalink
feat: checking user defined roles and policies for referential integr…
Browse files Browse the repository at this point in the history
…ity (argoproj#20825) (argoproj#22132)

Signed-off-by: Mike Cutsail <[email protected]>
  • Loading branch information
devopsjedi authored Mar 3, 2025
1 parent 2bcaa19 commit 561cbef
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 23 deletions.
18 changes: 3 additions & 15 deletions controller/appcontroller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ import (
cacheutil "github.com/argoproj/argo-cd/v3/util/cache"
appstatecache "github.com/argoproj/argo-cd/v3/util/cache/appstate"
"github.com/argoproj/argo-cd/v3/util/settings"
utilTest "github.com/argoproj/argo-cd/v3/util/test"
)

var testEnableEventList []string = argo.DefaultEnableEventList()
Expand Down Expand Up @@ -1316,21 +1317,8 @@ func TestSetOperationStateOnDeletedApp(t *testing.T) {
assert.True(t, patched)
}

type logHook struct {
entries []logrus.Entry
}

func (h *logHook) Levels() []logrus.Level {
return []logrus.Level{logrus.WarnLevel}
}

func (h *logHook) Fire(entry *logrus.Entry) error {
h.entries = append(h.entries, *entry)
return nil
}

func TestSetOperationStateLogRetries(t *testing.T) {
hook := logHook{}
hook := utilTest.LogHook{}
logrus.AddHook(&hook)
t.Cleanup(func() {
logrus.StandardLogger().ReplaceHooks(logrus.LevelHooks{})
Expand All @@ -1348,7 +1336,7 @@ func TestSetOperationStateLogRetries(t *testing.T) {
})
ctrl.setOperationState(newFakeApp(), &v1alpha1.OperationState{Phase: synccommon.OperationSucceeded})
assert.True(t, patched)
assert.Contains(t, hook.entries[0].Message, "fake error")
assert.Contains(t, hook.Entries[0].Message, "fake error")
}

func TestNeedRefreshAppStatus(t *testing.T) {
Expand Down
45 changes: 37 additions & 8 deletions util/rbac/rbac.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ type CasbinEnforcer interface {
EnableEnforce(bool)
AddFunction(name string, function govaluate.ExpressionFunction)
GetGroupingPolicy() ([][]string, error)
GetAllRoles() ([]string, error)
GetImplicitPermissionsForUser(user string, domain ...string) ([][]string, error)
}

const (
Expand Down Expand Up @@ -154,16 +156,16 @@ func (e *Enforcer) invalidateCache(actions ...func()) {
e.enforcerCache.Flush()
}

func (e *Enforcer) getCabinEnforcer(project string, policy string) CasbinEnforcer {
res, err := e.tryGetCabinEnforcer(project, policy)
func (e *Enforcer) getCasbinEnforcer(project string, policy string) CasbinEnforcer {
res, err := e.tryGetCasbinEnforcer(project, policy)
if err != nil {
panic(err)
}
return res
}

// tryGetCabinEnforcer returns the cached enforcer for the given optional project and project policy.
func (e *Enforcer) tryGetCabinEnforcer(project string, policy string) (CasbinEnforcer, error) {
// tryGetCasbinEnforcer returns the cached enforcer for the given optional project and project policy.
func (e *Enforcer) tryGetCasbinEnforcer(project string, policy string) (CasbinEnforcer, error) {
e.lock.Lock()
defer e.lock.Unlock()
var cached *cachedEnforcer
Expand Down Expand Up @@ -252,10 +254,32 @@ func (e *Enforcer) EnableEnforce(s bool) {

// LoadPolicy executes casbin.Enforcer functionality and will invalidate cache if required.
func (e *Enforcer) LoadPolicy() error {
_, err := e.tryGetCabinEnforcer("", "")
_, err := e.tryGetCasbinEnforcer("", "")
return err
}

// CheckUserDefinedRoleReferentialIntegrity iterates over roles and policies to validate the existence of a matching policy subject for every defined role
func CheckUserDefinedRoleReferentialIntegrity(e CasbinEnforcer) error {
allRoles, err := e.GetAllRoles()
if err != nil {
return err
}
notFound := make([]string, 0)
for _, roleName := range allRoles {
permissions, err := e.GetImplicitPermissionsForUser(roleName)
if err != nil {
return err
}
if len(permissions) == 0 {
notFound = append(notFound, roleName)
}
}
if len(notFound) > 0 {
return fmt.Errorf("user defined roles not found in policies: %s", strings.Join(notFound, ","))
}
return nil
}

// Glob match func
func globMatchFunc(args ...any) (any, error) {
if len(args) < 2 {
Expand Down Expand Up @@ -301,7 +325,7 @@ func (e *Enforcer) SetClaimsEnforcerFunc(claimsEnforcer ClaimsEnforcerFunc) {
// Enforce is a wrapper around casbin.Enforce to additionally enforce a default role and a custom
// claims function
func (e *Enforcer) Enforce(rvals ...any) bool {
return enforce(e.getCabinEnforcer("", ""), e.defaultRole, e.claimsEnforcerFunc, rvals...)
return enforce(e.getCasbinEnforcer("", ""), e.defaultRole, e.claimsEnforcerFunc, rvals...)
}

// EnforceErr is a convenience helper to wrap a failed enforcement with a detailed error about the request
Expand Down Expand Up @@ -348,7 +372,7 @@ func (e *Enforcer) EnforceRuntimePolicy(project string, policy string, rvals ...
// user-defined policy. This allows any explicit denies of the built-in, and user-defined policies
// to override the run-time policy. Runs normal enforcement if run-time policy is empty.
func (e *Enforcer) CreateEnforcerWithRuntimePolicy(project string, policy string) CasbinEnforcer {
return e.getCabinEnforcer(project, policy)
return e.getCasbinEnforcer(project, policy)
}

// EnforceWithCustomEnforcer wraps enforce with an custom enforcer
Expand Down Expand Up @@ -509,10 +533,15 @@ func (e *Enforcer) syncUpdate(cm *corev1.ConfigMap, onUpdated func(cm *corev1.Co

// ValidatePolicy verifies a policy string is acceptable to casbin
func ValidatePolicy(policy string) error {
_, err := newEnforcerSafe(globMatchFunc, newBuiltInModel(), newAdapter("", "", policy))
casbinEnforcer, err := newEnforcerSafe(globMatchFunc, newBuiltInModel(), newAdapter("", "", policy))
if err != nil {
return fmt.Errorf("policy syntax error: %s", policy)
}

// Check for referential integrity
if err := CheckUserDefinedRoleReferentialIntegrity(casbinEnforcer); err != nil {
log.Warning(err.Error())
}
return nil
}

Expand Down
40 changes: 40 additions & 0 deletions util/rbac/rbac_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"testing"
"time"

"github.com/argoproj/argo-cd/v3/util/test"

"github.com/golang-jwt/jwt/v5"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -271,6 +273,44 @@ func TestNoPolicy(t *testing.T) {
assert.False(t, enf.Enforce("admin", "applications", "delete", "foo/bar"))
}

// TestValidatePolicyCheckUserDefinedPolicyReferentialIntegrity adds a hook into logrus.StandardLogger and validates
// policies with and without referential integrity issues. Log entries are searched to verify expected outcomes.
func TestValidatePolicyCheckUserDefinedPolicyReferentialIntegrity(t *testing.T) {
// Policy with referential integrity
policy := `
p, role:depA, *, get, foo/obj, allow
p, role:depB, *, get, foo/obj, deny
g, depA, role:depA
g, depB, role:depB
`
hook := test.LogHook{}
log.AddHook(&hook)
t.Cleanup(func() {
log.StandardLogger().ReplaceHooks(log.LevelHooks{})
})
require.NoError(t, ValidatePolicy(policy))
assert.Empty(t, hook.GetRegexMatchesInEntries("user defined roles not found in policies"))

// Policy with a role reference which transitively associates to policies
policy = `
p, role:depA, *, get, foo/obj, allow
p, role:depB, *, get, foo/obj, deny
g, depC, role:depC
g, role:depC, role:depA
`
require.NoError(t, ValidatePolicy(policy))
assert.Empty(t, hook.GetRegexMatchesInEntries("user defined roles not found in policies"))

// Policy with a role reference which has no associated policies
policy = `
p, role:depA, *, get, foo/obj, allow
p, role:depB, *, get, foo/obj, deny
g, depC, role:depC
`
require.NoError(t, ValidatePolicy(policy))
assert.Len(t, hook.GetRegexMatchesInEntries("user defined roles not found in policies: role:depC"), 1)
}

// TestClaimsEnforcerFunc tests
func TestClaimsEnforcerFunc(t *testing.T) {
kubeclientset := fake.NewClientset()
Expand Down
27 changes: 27 additions & 0 deletions util/test/testutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@ import (
"io"
"net/http"
"net/http/httptest"
"regexp"
"testing"
"time"

log "github.com/sirupsen/logrus"

"github.com/go-jose/go-jose/v3"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -269,3 +272,27 @@ func generateJWTToken(issuer string) (string, error) {
}
return tokenString, nil
}

type LogHook struct {
Entries []log.Entry
}

func (h *LogHook) Levels() []log.Level {
return []log.Level{log.WarnLevel}
}

func (h *LogHook) Fire(entry *log.Entry) error {
h.Entries = append(h.Entries, *entry)
return nil
}

func (h *LogHook) GetRegexMatchesInEntries(match string) []string {
re := regexp.MustCompile(match)
matches := make([]string, 0)
for _, entry := range h.Entries {
if re.Match([]byte(entry.Message)) {
matches = append(matches, entry.Message)
}
}
return matches
}

0 comments on commit 561cbef

Please sign in to comment.