diff --git a/cmd/dev/app/rule_type/rttst.go b/cmd/dev/app/rule_type/rttst.go index 7c263bb956..2236414313 100644 --- a/cmd/dev/app/rule_type/rttst.go +++ b/cmd/dev/app/rule_type/rttst.go @@ -67,6 +67,7 @@ func CmdTest() *cobra.Command { testCmd.Flags().StringP("token", "t", "", "token to authenticate to the provider."+ "Can also be set via the TEST_AUTH_TOKEN environment variable.") testCmd.Flags().StringArrayP("data-source", "d", []string{}, "YAML file containing the data source to test the rule with") + testCmd.Flags().BoolP("debug", "", false, "Start REGO debugger (only works for REGO-based rules types)") if err := testCmd.MarkFlagRequired("rule-type"); err != nil { fmt.Fprintf(os.Stderr, "Error marking flag as required: %s\n", err) @@ -98,6 +99,7 @@ func testCmdRun(cmd *cobra.Command, _ []string) error { token := viper.GetString("test.auth.token") providerclass := cmd.Flag("provider") providerconfig := cmd.Flag("provider-config") + debug := cmd.Flag("debug").Value.String() == "true" dataSourceFileStrings, err := cmd.Flags().GetStringArray("data-source") if err != nil { @@ -197,7 +199,10 @@ func testCmdRun(cmd *cobra.Command, _ []string) error { // TODO: use cobra context here ctx := context.Background() - eng, err := rtengine.NewRuleTypeEngine(ctx, ruletype, prov, nil /*experiments*/, options.WithDataSources(dsRegistry)) + eng, err := rtengine.NewRuleTypeEngine(ctx, ruletype, prov, nil, /*experiments*/ + options.WithDataSources(dsRegistry), + options.WithDebugger(debug), + ) if err != nil { return fmt.Errorf("cannot create rule type engine: %w", err) } diff --git a/internal/engine/eval/rego/debug.go b/internal/engine/eval/rego/debug.go new file mode 100644 index 0000000000..311c057c45 --- /dev/null +++ b/internal/engine/eval/rego/debug.go @@ -0,0 +1,842 @@ +// SPDX-FileCopyrightText: Copyright 2024 The Minder Authors +// SPDX-License-Identifier: Apache-2.0 + +// Package rego provides the rego rule evaluator +package rego + +import ( + "bufio" + "context" + "errors" + "fmt" + "math" + "os" + "regexp" + "slices" + "strconv" + "strings" + + "github.com/open-policy-agent/opa/ast/location" + "github.com/open-policy-agent/opa/debug" + "github.com/open-policy-agent/opa/rego" + + "github.com/mindersec/minder/internal/util/cli" + "github.com/mindersec/minder/pkg/engine/v1/interfaces" +) + +type eventHandler struct { + ch chan *debug.Event +} + +func newEventHandler() *eventHandler { + return &eventHandler{ + ch: make(chan *debug.Event), + } +} + +func (eh *eventHandler) HandleEvent(event debug.Event) { + eh.ch <- &event +} + +func (eh *eventHandler) WaitFor( + ctx context.Context, + eventTypes ...debug.EventType, +) *debug.Event { + for { + select { + case e := <-eh.ch: + if slices.Contains(eventTypes, e.Type) { + return e + } + case <-ctx.Done(): + return nil + } + } +} + +var ( + errEmptySource = errors.New("empty source code") + errInvalidInput = errors.New("invalid input") + errInvalidInstr = errors.New("invalid instruction") + errInvalidBP = errors.New("invalid breakpoint") +) + +// Debug implements an interactive debugger for REGO-based evaluators. +func (e *Evaluator) Debug( + ctx context.Context, + _ *interfaces.Result, + input *Input, + funcs ...func(*rego.Rego), +) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + allOpts := make([]func(*rego.Rego), 0, len(e.regoOpts)+len(funcs)) + allOpts = append(allOpts, e.regoOpts...) + allOpts = append(allOpts, funcs...) + + ds, err := newDebugSession( + withPrompt("(mindbg)"), + withSource(e.cfg.Def), + withInput(input), + withQuery(e.reseval.getQueryString()), + withOpts(allOpts...), + ) + if err != nil { + return fmt.Errorf("error initializing debugger: %w", err) + } + + return ds.Start(ctx) +} + +type debugSession struct { + prompt string + src string + lines int + input *Input + query string + opts []debug.LaunchOption + + // fields initialized after starting the session + session debug.Session + eh *eventHandler +} + +type debugSessionOption func(*debugSession) error + +func withPrompt(prompt string) debugSessionOption { + return func(ds *debugSession) error { + ds.prompt = prompt + return nil + } +} + +func withSource(src string) debugSessionOption { + return func(ds *debugSession) error { + if len(src) == 0 { + return errEmptySource + } + ds.src = src + ds.lines = len(strings.Split(src, "\n")) + return nil + } +} + +func withInput(input any) debugSessionOption { + return func(ds *debugSession) error { + inner, ok := input.(*Input) + if !ok { + return fmt.Errorf("%w: wrong type %T", errInvalidInput, input) + } + ds.input = inner + return nil + } +} + +func withQuery(query string) debugSessionOption { + return func(ds *debugSession) error { + ds.query = query + return nil + } +} + +func withOpts(opts ...func(*rego.Rego)) debugSessionOption { + return func(ds *debugSession) error { + var res []debug.LaunchOption + if ds.opts == nil { + res = make([]debug.LaunchOption, 0, len(opts)) + } else { + res = ds.opts + } + + for _, opt := range opts { + res = append(res, debug.RegoOption(opt)) + } + + ds.opts = res + return nil + } +} + +func newDebugSession( + opts ...debugSessionOption, +) (*debugSession, error) { + ds := &debugSession{} + + for _, opt := range opts { + if err := opt(ds); err != nil { + return nil, err + } + } + + return ds, nil +} + +func (ds *debugSession) startDebugger( + ctx context.Context, +) error { + eh := newEventHandler() + debugger := debug.NewDebugger( + debug.SetEventHandler(eh.HandleEvent), + ) + launchProps := debug.LaunchEvalProperties{ + LaunchProperties: debug.LaunchProperties{ + StopOnEntry: false, + StopOnFail: false, + StopOnResult: true, + EnablePrint: true, + RuleIndexing: false, + }, + Input: ds.input, + Query: ds.query, + } + + session, err := debugger.LaunchEval(ctx, launchProps, ds.opts...) + if err != nil { + return err + } + + ds.session = session + ds.eh = eh + + return nil +} + +//nolint:gocyclo +func (ds *debugSession) Start(ctx context.Context) error { + err := ds.startDebugger(ctx) + if err != nil { + return fmt.Errorf("error launching debugger: %w", err) + } + + thr := debug.ThreadID(1) + fmt.Printf("%s ", ds.prompt) + scanner := bufio.NewScanner(os.Stdin) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + + var b strings.Builder + switch { + case line == "": + // There's nothing to do here, but it is + // useful to let the user spam enter to see if + // it's working. + case line == "r": + err = ds.startDebugger(ctx) + if err != nil { + return fmt.Errorf("error restarting debugger: %w", err) + } + fmt.Fprintf(&b, "Restarted") + case line == "c": + if err := ds.session.Resume(thr); err != nil { + return fmt.Errorf("error resuming execution: %w", err) + } + + evt := ds.eh.WaitFor(ctx, + debug.ExceptionEventType, + debug.StoppedEventType, + debug.StdoutEventType, + ) + switch evt.Type { + case debug.ExceptionEventType: + fmt.Fprintf(&b, "\nException\n") + if err := printLocals(&b, ds.session, evt.Thread); err != nil { + return fmt.Errorf("error printing locals: %w", err) + } + case debug.StoppedEventType: + fmt.Fprintf(&b, "\nStopped\n") + if err := printLocals(&b, ds.session, evt.Thread); err != nil { + return fmt.Errorf("error printing locals: %w", err) + } + case debug.StdoutEventType: + fmt.Fprintf(&b, "\nFinished\n") + if err := printLocals(&b, ds.session, evt.Thread); err != nil { + return fmt.Errorf("error printing locals: %w", err) + } + fmt.Fprintf(&b, "\nResult: ") + err := printVar(&b, + fmt.Sprintf("%s.*", RegoQueryPrefix), + ds.session, + evt.Thread, + ) + if err != nil { + return fmt.Errorf("error printing variable: %w", err) + } + } + case line == "locals": + if err := printLocals(&b, ds.session, thr); err != nil { + return fmt.Errorf("error printing locals: %w", err) + } + case line == "bp": + bps, err := ds.session.Breakpoints() + if err != nil { + return fmt.Errorf("error getting breakpoints: %w", err) + } + printBreakpoints(&b, bps) + case line == "bt": + stack, err := ds.session.StackTrace(thr) + if err != nil { + return fmt.Errorf("error getting stack trace: %w", err) + } + printStackTrace(&b, stack, 10) + case line == "list", line == "l": + stack, err := ds.session.StackTrace(thr) + if err != nil { + return fmt.Errorf("error getting stack trace: %w", err) + } + printSource(&b, ds.src, stack) + case line == "trs": + threads, err := ds.session.Threads() + if err != nil { + return fmt.Errorf("error getting threads: %w", err) + } + printThreads(&b, threads) + + // "clearall" command currently removes all + // breakpoints, both user-defined and internal + // ones. This is not desirable for the very same + // reasons described in the comment related to the + // "next" command. + case line == "cla", + line == "clearall": + if err := ds.session.ClearBreakpoints(); err != nil { + return fmt.Errorf("error clearing breakpoints: %w", err) + } + + // "next" is a bit quirky, since it requires a few + // steps to function, namely: + // + // * adding a so called "internal breakpoint" + // * running until it's reached, and finally + // * removing the breakpoint + // + // Internal breakpoints should be managed separately + // from user-defined breakpoints, as the user should + // neither see them nor be allowed to remove them + // since it could invalidate some assumptions the code + // does around them. + // + // TODO: add two lists of breakpoints to + // `debugSession` struct and add routines to manage + // them. + case line == "n", + line == "next": + stack, err := ds.session.StackTrace(thr) + if err != nil { + return fmt.Errorf("error getting stack trace: %w", err) + } + if loc := getCurrentLocation(stack); loc != nil { + loc.Row += 1 // let's hope it always exists... + loc.Col = 0 + + // add internal breakpoint + bp, err := ds.session.AddBreakpoint(*loc) + if err != nil { + return fmt.Errorf("error setting breakpoint: %w", err) + } + + // resume execution + if err := ds.session.Resume(thr); err != nil { + return fmt.Errorf("error resuming execution: %w", err) + } + + evt := ds.eh.WaitFor(ctx, debug.StoppedEventType) + stack, err := ds.session.StackTrace(evt.Thread) + if err != nil { + return fmt.Errorf("error getting stack trace: %w", err) + } + + // clear internal breakpoint, even if + // we stopped for another reason. + if _, err := ds.session.RemoveBreakpoint(bp.ID()); err != nil { + return fmt.Errorf("error removing breakpoing: %w", err) + } + + printSource(&b, ds.src, stack) + } + case line == "s", + line == "sv": + go func() { + if err := ds.session.StepOver(thr); err != nil { + panic(err) + } + }() + evt := ds.eh.WaitFor(ctx, debug.StoppedEventType) + stack, err := ds.session.StackTrace(evt.Thread) + if err != nil { + return fmt.Errorf("error getting stack trace: %w", err) + } + printSource(&b, ds.src, stack) + case line == "si": + go func() { + if err := ds.session.StepIn(thr); err != nil { + panic(err) + } + }() + evt := ds.eh.WaitFor(ctx, debug.StoppedEventType) + stack, err := ds.session.StackTrace(evt.Thread) + if err != nil { + return fmt.Errorf("error getting stack trace: %w", err) + } + printSource(&b, ds.src, stack) + case line == "so": + go func() { + if err := ds.session.StepOut(thr); err != nil { + panic(err) + } + }() + evt := ds.eh.WaitFor(ctx, debug.StoppedEventType) + stack, err := ds.session.StackTrace(evt.Thread) + if err != nil { + return fmt.Errorf("error getting stack trace: %w", err) + } + printSource(&b, ds.src, stack) + case line == "q": + return fmt.Errorf("user abort") + case line == "h", + line == "help": + printHelp(&b) + case strings.HasPrefix(line, "p"): + varname, err := toVarName(line) + if err != nil { + fmt.Fprintln(&b, err) + continue + } + // printVar function accepts a regexp as + // variable name, allowing the caller to match + // multiple variables. + // + // We don't want to expose this functionality + // to the user, as the general case (fetching + // a specific variable) becomes awkward, + // requiring the user to specify the full + // regex. + // + // To solve this, we always wrap the received + // variable name in ^ and $. + r := fmt.Sprintf("^%s$", varname) + if err := printVar(&b, r, ds.session, thr); err != nil { + return fmt.Errorf("error printing variables: %w", err) + } + case strings.HasPrefix(line, "b"): + loc, err := toLocation(line, ds.lines) + if err != nil { + fmt.Fprintln(&b, err) + } else { + bp, err := ds.session.AddBreakpoint(*loc) + if err != nil { + return fmt.Errorf("error setting breakpoint: %w", err) + } + fmt.Fprintln(&b) + printBreakpoint(&b, bp) + } + + // "clear" command currently allows removing all + // breakpoints, both user-defined and internal + // ones. This is not desirable for the very same + // reasons described in the comment related to the + // "next" command. + case strings.HasPrefix(line, "cl "), + strings.HasPrefix(line, "clear "): + ids := make([]debug.BreakpointID, 0) + bps, err := ds.session.Breakpoints() + if err != nil { + return fmt.Errorf("error gettin breakpoints: %w", err) + } + for _, bp := range bps { + ids = append(ids, bp.ID()) + } + id, err := toBreakpointID(line, ids) + if err != nil { + fmt.Fprintln(&b, err) + } else { + if _, err := ds.session.RemoveBreakpoint(id); err != nil { + return fmt.Errorf("error removing breakpoint: %w", err) + } + } + default: + fmt.Fprintf(&b, "Invalid command: %s\nPress h for help\n", line) + } + + output := b.String() + if output != "" { + fmt.Printf("%s\n%s ", output, ds.prompt) + } else { + fmt.Printf("%s ", ds.prompt) + } + } + + return scanner.Err() +} + +func toLocation(line string, lineCount int) (*location.Location, error) { + num, ok := strings.CutPrefix(line, "b ") + if !ok { + return nil, fmt.Errorf(`%w: "%s"`, errInvalidInstr, line) + } + i, err := strconv.ParseInt(num, 10, 64) + if err != nil { + return nil, fmt.Errorf(`%w: invalid line "%s": %s`, errInvalidBP, num, err) + } + if i < 1 || int(i) > lineCount { + return nil, fmt.Errorf("%w: invalid line %d", errInvalidBP, i) + } + return &location.Location{File: "minder.rego", Row: int(i)}, nil +} + +func toBreakpointID(line string, ids []debug.BreakpointID) (debug.BreakpointID, error) { + num1, ok1 := strings.CutPrefix(line, "cl ") + num2, ok2 := strings.CutPrefix(line, "clear ") + if !ok1 && !ok2 { + return debug.BreakpointID(-1), fmt.Errorf(`%w: "%s"`, errInvalidInstr, line) + } + + var num string + if !ok1 { + num = num2 + } + if !ok2 { + num = num1 + } + + i, err := strconv.ParseInt(num, 10, 64) + if err != nil { + return debug.BreakpointID(-1), fmt.Errorf( + `%w: invalid breakpoint id %s`, + errInvalidBP, num, + ) + } + + if i < 1 { + return debug.BreakpointID(-1), fmt.Errorf( + "%w: negative line id", + errInvalidBP, + ) + } + + if !slices.Contains(ids, debug.BreakpointID(i)) { + return debug.BreakpointID(-1), fmt.Errorf( + "%w: breakpoint does not exist", + errInvalidBP, + ) + } + + return debug.BreakpointID(i), nil +} + +func toVarName(line string) (string, error) { + varname, ok := strings.CutPrefix(line, "p ") + if !ok { + return "", fmt.Errorf(`%w: "%s"`, errInvalidInstr, line) + } + return varname, nil +} + +func printBreakpoints(b *strings.Builder, bps []debug.Breakpoint) { + if len(bps) == 0 { + return + } + fmt.Fprintln(b) + for _, bp := range bps { + printBreakpoint(b, bp) + } +} + +func printBreakpoint(b *strings.Builder, bp debug.Breakpoint) { + fmt.Fprintf(b, "Breakpoint %d set at %s:%d\n", + bp.ID(), + bp.Location().File, + bp.Location().Row, + ) +} + +func printThreads(b *strings.Builder, threads []debug.Thread) { + if len(threads) == 0 { + return + } + fmt.Fprintln(b) + for _, thread := range threads { + fmt.Fprintf(b, "Thread %d\n", thread.ID()) + } +} + +func getCurrentLocation(stack debug.StackTrace) *location.Location { + if len(stack) == 0 { + return nil + } + + frame := stack[0] + return frame.Location() +} + +func printStackTrace(b *strings.Builder, stack debug.StackTrace, limit int) { + if len(stack) == 0 { + return + } + + fmt.Fprintln(b) + for _, frame := range stack[:limit] { + if loc := frame.Location(); loc != nil { + fmt.Fprintf(b, "Frame %d at %s:%d.%d\n", + frame.ID(), + loc.File, + loc.Row, + loc.Col, + ) + } + } + if len(stack) > limit { + fmt.Fprintf(b, "...\n") + } +} + +func printSource(b *strings.Builder, src string, stack debug.StackTrace) { + if len(stack) == 0 { + printSourceSimple(b, src) + return + } + + lines := strings.Split(src, "\n") + padding := int64(math.Floor(math.Log10(float64(len(lines)))) + 1) + + fmt.Fprintln(b) + frame := stack[0] + if loc := frame.Location(); loc != nil { + fmt.Fprintf(b, "Frame %d at %s:%d.%d\n", + frame.ID(), + loc.File, + loc.Row, + loc.Col, + ) + + for idx, line := range strings.Split(src, "\n") { + fmt.Fprintf(b, "%*d: %s", padding, idx+1, line) + if idx+1 == loc.Row { + theline := strings.Split(string(loc.Text), "\n")[0] + fmt.Fprintf(b, "\n%s%s", + strings.Repeat(" ", loc.Col+int(padding)+2-1), + cli.SimpleBoldStyle.Render(strings.Repeat("^", len(theline))), + ) + } + fmt.Fprintln(b) + } + } +} + +func printSourceSimple(b *strings.Builder, source string) { + fmt.Fprintln(b) + lines := strings.Split(source, "\n") + padding := int64(math.Floor(math.Log10(float64(len(lines)))) + 1) + for idx, line := range lines { + fmt.Fprintf(b, "%*d: %s\n", padding, idx+1, line) + } +} + +func printLocals(b *strings.Builder, s debug.Session, thrID debug.ThreadID) error { + trace, err := s.StackTrace(thrID) + if err != nil { + return fmt.Errorf("error getting stacktrace: %w", err) + } + + if len(trace) == 0 { + return nil + } + + // The first trace in the list is the one related to the + // current stack frame. + scopes, err := s.Scopes(trace[0].ID()) + if err != nil { + return fmt.Errorf("error getting scopes: %w", err) + } + + for _, scope := range scopes { + vars, err := s.Variables(scope.VariablesReference()) + if err != nil { + return fmt.Errorf("error getting variables: %w", err) + } + for _, v := range vars { + fmt.Fprintf(b, "%s %s = %s\n", v.Type(), v.Name(), v.Value()) + } + } + + return nil +} + +func printVar( + b *strings.Builder, + varname string, + s debug.Session, + thrID debug.ThreadID, +) error { + r, err := regexp.Compile(varname) + if err != nil { + return fmt.Errorf("error instantiating regex: %w", err) + } + + trace, err := s.StackTrace(thrID) + if err != nil { + return fmt.Errorf("error getting stacktrace: %w", err) + } + + if len(trace) == 0 { + return nil + } + + // The first trace in the list is the one related to the + // current stack frame. + scopes, err := s.Scopes(trace[0].ID()) + if err != nil { + return fmt.Errorf("error getting scopes: %w", err) + } + + for _, scope := range scopes { + if err := printVariablesInScope(b, r, s, scope.VariablesReference()); err != nil { + return err + } + } + + return nil +} + +func printVariablesInScope( + b *strings.Builder, + r *regexp.Regexp, + s debug.Session, + varRef debug.VarRef, +) error { + if varRef == 0 { + return nil + } + + vars, err := s.Variables(varRef) + if err != nil { + return fmt.Errorf("error getting variables: %w", err) + } + for _, v := range vars { + if r.MatchString(v.Name()) { + var b1 strings.Builder + if err := varToString(&b1, v, s, 0); err != nil { + return err + } + fmt.Fprintf(b, "%s %s = %s\n", v.Type(), v.Name(), b1.String()) + + // We break early here despite the fact that + // multiple variables might match the given + // `varname`. This is done to honour lexical + // scope, showing just the only variable that + // is actually being used for evaluation in + // the given frame. + return nil + } + } + + return nil +} + +func varToString( + b *strings.Builder, + v debug.Variable, + s debug.Session, + indentation int, +) error { + padding := strings.Repeat(" ", indentation) + switch v.Type() { + case "array": + return elementsToString(b, v, s, indentation, "[", "]", + func(elem debug.Variable) error { + fmt.Fprintf(b, " %s", padding) + err := varToString(b, elem, s, indentation) + if err != nil { + return err + } + fmt.Fprintf(b, ",\n") + return nil + }, + ) + case "set": + return elementsToString(b, v, s, indentation, "{", "}", + func(elem debug.Variable) error { + fmt.Fprintf(b, " %s", padding) + err := varToString(b, elem, s, indentation) + if err != nil { + return err + } + fmt.Fprintf(b, ",\n") + return nil + }, + ) + case "object": + return elementsToString(b, v, s, indentation, "{", "}", + func(elem debug.Variable) error { + fmt.Fprintf(b, " %s%s: ", padding, elem.Name()) + err := varToString(b, elem, s, indentation) + if err != nil { + return err + } + fmt.Fprintf(b, ",\n") + return nil + }, + ) + default: + fmt.Fprintf(b, "%s%s", padding, v.Value()) + } + + return nil +} + +func elementsToString( + b *strings.Builder, + v debug.Variable, + s debug.Session, + indentation int, + leftDelimiter string, + rightDelimiter string, + formatter func(debug.Variable) error, +) error { + padding := strings.Repeat(" ", indentation) + fmt.Fprintf(b, "%s%s\n", padding, leftDelimiter) + elems, err := s.Variables(v.VariablesReference()) + if err != nil { + return err + } + for _, elem := range elems { + if err := formatter(elem); err != nil { + return err + } + } + fmt.Fprintf(b, "%s%s", padding, rightDelimiter) + + return nil +} + +var helpMsg = ` +Controlling execution: + c ------------- continue + r ------------- restart debugging session + q ------------- quit + +Printing: + bt ------------ print stack trace (top 10) + trs ----------- print threads + list/l -------- list source + locals -------- print local variables + +Breakpoints: + bp ------------ show breakpoints + b ------- set breakpoint at line + clear/cl - clear breakpoint with id + clearall/cla -- clear all breakpoints + +Stepping: + n ------------- next line + s/sv ---------- step over + so ------------ step out + si ------------ step into + +Help: + help/h -------- print help +` + +func printHelp(b *strings.Builder) { + fmt.Fprint(b, helpMsg) +} diff --git a/internal/engine/eval/rego/eval.go b/internal/engine/eval/rego/eval.go index fd2a597360..726d1494c2 100644 --- a/internal/engine/eval/rego/eval.go +++ b/internal/engine/eval/rego/eval.go @@ -44,6 +44,15 @@ type Evaluator struct { regoOpts []func(*rego.Rego) reseval resultEvaluator datasources *v1datasources.DataSourceRegistry + debug bool +} + +var _ eoptions.HasDebuggerSupport = (*Evaluator)(nil) + +// SetDebugFlag implements `HasDebuggerSupport` interface. +func (e *Evaluator) SetDebugFlag(flag bool) error { + e.debug = flag + return nil } // Input is the input for the rego evaluator @@ -132,6 +141,26 @@ func (e *Evaluator) Eval( // If the evaluator has data sources defined, expose their functions regoFuncOptions = append(regoFuncOptions, buildDataSourceOptions(res, e.datasources)...) + input := &Input{ + Profile: pol, + Ingested: obj, + OutputFormat: e.cfg.ViolationFormat, + } + enrichInputWithEntityProps(input, entity) + + if e.debug { + err := e.Debug( + ctx, + res, + input, + regoFuncOptions..., + ) + if err != nil { + return nil, err + } + return nil, nil + } + // Create the rego object r := e.newRegoFromOptions( regoFuncOptions..., @@ -142,13 +171,6 @@ func (e *Evaluator) Eval( return nil, fmt.Errorf("could not prepare Rego: %w", err) } - input := &Input{ - Profile: pol, - Ingested: obj, - OutputFormat: e.cfg.ViolationFormat, - } - - enrichInputWithEntityProps(input, entity) rs, err := pq.Eval(ctx, rego.EvalInput(input)) if err != nil { return nil, fmt.Errorf("error evaluating profile. Might be wrong input: %w", err) diff --git a/internal/engine/eval/rego/result.go b/internal/engine/eval/rego/result.go index d713aa2557..300b903b3e 100644 --- a/internal/engine/eval/rego/result.go +++ b/internal/engine/eval/rego/result.go @@ -53,6 +53,7 @@ func (c ConstraintsViolationsFormat) String() string { } type resultEvaluator interface { + getQueryString() string getQuery() func(*rego.Rego) parseResult(rego.ResultSet, protoreflect.ProtoMessage) (*interfaces.EvaluationResult, error) } @@ -60,8 +61,12 @@ type resultEvaluator interface { type denyByDefaultEvaluator struct { } -func (*denyByDefaultEvaluator) getQuery() func(r *rego.Rego) { - return rego.Query(RegoQueryPrefix) +func (*denyByDefaultEvaluator) getQueryString() string { + return RegoQueryPrefix +} + +func (d *denyByDefaultEvaluator) getQuery() func(r *rego.Rego) { + return rego.Query(d.getQueryString()) } func (*denyByDefaultEvaluator) parseResult(rs rego.ResultSet, entity protoreflect.ProtoMessage, @@ -168,8 +173,12 @@ type constraintsEvaluator struct { format ConstraintsViolationsFormat } -func (*constraintsEvaluator) getQuery() func(r *rego.Rego) { - return rego.Query(fmt.Sprintf("%s.violations[details]", RegoQueryPrefix)) +func (*constraintsEvaluator) getQueryString() string { + return fmt.Sprintf("%s.violations[details]", RegoQueryPrefix) +} + +func (c *constraintsEvaluator) getQuery() func(r *rego.Rego) { + return rego.Query(c.getQueryString()) } func (c *constraintsEvaluator) parseResult(rs rego.ResultSet, _ protoreflect.ProtoMessage) (*interfaces.EvaluationResult, error) { diff --git a/internal/engine/options/options.go b/internal/engine/options/options.go index 0da6223418..b71b1eb3ab 100644 --- a/internal/engine/options/options.go +++ b/internal/engine/options/options.go @@ -35,6 +35,26 @@ func WithFlagsClient(client openfeature.IClient) Option { } } +// HasDebuggerSupport interface should be implemented by evaluation +// engines that support interactive debugger. Currently, only +// REGO-based engines should implement this. +type HasDebuggerSupport interface { + SetDebugFlag(bool) error +} + +// WithDebugger sets the evaluation engine to start an interactive +// debugging session. This MUST NOT be used in backend servers, and is +// only meant to be used in CLI tools. +func WithDebugger(flag bool) Option { + return func(e interfaces.Evaluator) error { + inner, ok := e.(HasDebuggerSupport) + if !ok { + return nil + } + return inner.SetDebugFlag(flag) + } +} + // SupportsDataSources interface advertises the fact that the implementer // can register data sources with the evaluator. type SupportsDataSources interface { diff --git a/internal/util/cli/styles.go b/internal/util/cli/styles.go index 4973c54b85..453db8c72b 100644 --- a/internal/util/cli/styles.go +++ b/internal/util/cli/styles.go @@ -27,7 +27,8 @@ var ( // Common styles var ( - CursorStyle = lipgloss.NewStyle().Foreground(SecondaryColor) + CursorStyle = lipgloss.NewStyle().Foreground(SecondaryColor) + SimpleBoldStyle = lipgloss.NewStyle().Bold(true) ) // Banner styles