diff --git a/go.mod b/go.mod index 85a8b04be8..19158cd9a1 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( github.com/erikgeiser/promptkit v0.9.0 github.com/evanphx/json-patch/v5 v5.9.0 github.com/fergusstrange/embedded-postgres v1.30.0 + github.com/gammazero/deque v0.2.1 github.com/go-git/go-billy/v5 v5.6.0 github.com/go-git/go-git/v5 v5.12.0 github.com/go-playground/validator/v10 v10.23.0 diff --git a/go.sum b/go.sum index a37d47a0ec..5e1808ddad 100644 --- a/go.sum +++ b/go.sum @@ -382,6 +382,8 @@ github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/ github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/gabriel-vasile/mimetype v1.4.6 h1:3+PzJTKLkvgjeTbts6msPJt4DixhT4YtFNf1gtGe3zc= github.com/gabriel-vasile/mimetype v1.4.6/go.mod h1:JX1qVKqZd40hUPpAfiNTe0Sne7hdfKSbOqqmkq8GCXc= +github.com/gammazero/deque v0.2.1 h1:qSdsbG6pgp6nL7A0+K/B7s12mcCY/5l5SIUpMOl+dC0= +github.com/gammazero/deque v0.2.1/go.mod h1:LFroj8x4cMYCukHJDbxFCkT+r9AndaJnFMuZDV34tuU= github.com/gkampitakis/ciinfo v0.3.0 h1:gWZlOC2+RYYttL0hBqcoQhM7h1qNkVqvRCV1fOvpAv8= github.com/gkampitakis/ciinfo v0.3.0/go.mod h1:1NIwaOcFChN4fa/B0hEBdAb6npDlFL8Bwx4dfRLRqAo= github.com/gkampitakis/go-diff v1.3.2 h1:Qyn0J9XJSDTgnsgHRdz9Zp24RaJeKMUHg2+PDZZdC4M= diff --git a/internal/engine/actions/context.go b/internal/engine/actions/context.go new file mode 100644 index 0000000000..6cda76678f --- /dev/null +++ b/internal/engine/actions/context.go @@ -0,0 +1,89 @@ +// SPDX-FileCopyrightText: Copyright 2023 The Minder Authors +// SPDX-License-Identifier: Apache-2.0 + +package actions + +import ( + "context" + "errors" + "sync" + + engif "github.com/mindersec/minder/internal/engine/interfaces" +) + +// SharedActionsContextKey is the key used to store the shared actions context +// in the context.Context. +type SharedActionsContextKey struct{} + +// SharedFlusherKey is the key used to store the shared flusher +type SharedFlusherKey string + +type sharedFlusher struct { + flusher engif.AggregatingAction + items []any +} + +// SharedActionsContext is the shared actions context. +type SharedActionsContext struct { + shared map[SharedFlusherKey]*sharedFlusher + mux sync.Mutex +} + +// WithSharedActionsContext returns a new context.Context with the shared actions +// context set. +func WithSharedActionsContext(ctx context.Context) (context.Context, *SharedActionsContext) { + sac := &SharedActionsContext{} + return context.WithValue(ctx, SharedActionsContextKey{}, sac), sac +} + +// GetSharedActionsContext returns the shared actions context from the context.Context. +func GetSharedActionsContext(ctx context.Context) *SharedActionsContext { + ctxVal := ctx.Value(SharedActionsContextKey{}) + if ctxVal == nil { + return nil + } + + v, ok := ctxVal.(*SharedActionsContext) + if !ok { + return nil + } + + return v +} + +// ShareAndRegister adds a shared value to the shared actions context. It may +// also register a flusher if it does not exist. +func (sac *SharedActionsContext) ShareAndRegister(key SharedFlusherKey, flusher engif.AggregatingAction, item any) { + sac.mux.Lock() + defer sac.mux.Unlock() + + f, ok := sac.shared[key] + if !ok { + f = &sharedFlusher{ + flusher: flusher, + items: []any{item}, + } + sac.shared[key] = f + return + } + + f.items = append(f.items, item) +} + +// Flush returns all the shared values and clears the shared actions context. +func (sac *SharedActionsContext) Flush() error { + sac.mux.Lock() + defer sac.mux.Unlock() + var errs []error + + for key, f := range sac.shared { + err := f.flusher.Flush(f.items...) + if err != nil { + errs = append(errs, err) + } + + delete(sac.shared, key) + } + + return errors.Join(errs...) +} diff --git a/internal/engine/executor.go b/internal/engine/executor.go index 52f58e710b..ca2f3d7c64 100644 --- a/internal/engine/executor.go +++ b/internal/engine/executor.go @@ -139,21 +139,23 @@ func (e *executor) EvalEntityEvent(ctx context.Context, inf *entities.EntityInfo return fmt.Errorf("error while retrieving profiles and rule instances: %w", err) } + sacctx, sac := actions.WithSharedActionsContext(ctx) + // For each profile, get the profileEvalStatus first. Then, if the profileEvalStatus is nil // evaluate each rule and store the outcome in the database. If profileEvalStatus is non-nil, // just store it for all rules without evaluation. for _, profile := range profileAggregates { - profileEvalStatus := e.profileEvalStatus(ctx, inf, profile) + profileEvalStatus := e.profileEvalStatus(sacctx, inf, profile) for _, rule := range profile.Rules { - if err := e.evaluateRule(ctx, inf, provider, &profile, &rule, ruleEngineCache, profileEvalStatus); err != nil { + if err := e.evaluateRule(sacctx, inf, provider, &profile, &rule, ruleEngineCache, profileEvalStatus); err != nil { return fmt.Errorf("error evaluating entity event: %w", err) } } } - return nil + return sac.Flush() } func (e *executor) evaluateRule( diff --git a/internal/engine/interfaces/interface.go b/internal/engine/interfaces/interface.go index 77c4c12f53..d2365bd585 100644 --- a/internal/engine/interfaces/interface.go +++ b/internal/engine/interfaces/interface.go @@ -31,6 +31,13 @@ type Action interface { params ActionsParams, metadata *json.RawMessage) (json.RawMessage, error) } +// AggregatingAction is the interface for an action that aggregates multiple +// pieces to form a final action. Normally this will come from the result of a +// `Do` call on an action. +type AggregatingAction interface { + Flush(item ...any) error +} + // ActionCmd is the type that defines what effect an action should have type ActionCmd string