diff --git a/exp/checks/check.go b/exp/checks/check.go new file mode 100644 index 00000000..fa03226c --- /dev/null +++ b/exp/checks/check.go @@ -0,0 +1,37 @@ +package checks + +import ( + "fmt" + "time" +) + +// Check stores the information necessary to run a check. +type Check struct { + CheckID string `json:"check_id"` // Required + StartTime time.Time `json:"start_time"` // Required + Image string `json:"image"` // Required + Target string `json:"target"` // Required + Timeout int `json:"timeout"` // Required + AssetType string `json:"assettype"` // Optional + Options string `json:"options"` // Optional + RequiredVars []string `json:"required_vars"` // Optional + Metadata map[string]string `json:"metadata"` // Optional + RunTime int64 +} + +func (j *Check) logTrace(msg, action string) string { + if j.RunTime == 0 { + j.RunTime = time.Now().Unix() + } + return fmt.Sprintf( + "event=checkTrace checkID=%s target=%s assetType=%s checkImage=%s queuedTime=%d runningTime=%d action=%s msg=\"%s\"", + j.CheckID, + j.Target, + j.AssetType, + j.Image, + j.RunTime-j.StartTime.Unix(), + time.Now().Unix()-j.RunTime, + action, + msg, + ) +} diff --git a/exp/checks/consumer.go b/exp/checks/consumer.go new file mode 100644 index 00000000..db6ba215 --- /dev/null +++ b/exp/checks/consumer.go @@ -0,0 +1,221 @@ +/* +Copyright 2023 Adevinta +*/ + +package checks + +import ( + "context" + "encoding/json" + "errors" + "sync" + "sync/atomic" + "time" + + "github.com/adevinta/vulcan-agent/v2/log" +) + +var ( + ErrInvalidMessage = errors.New("invalid message") + ErrUnprocessableMessage = errors.New("unprocessable check message") + // ErrMaxTimeNoRead is returned by a queue reader when there were no + // messages available in the queue for more than the configured amount of + // time. + ErrMaxTimeNoRead = errors.New("no messages available in the queue for more than the max time") +) + +type Message struct { + Body string + TimesRead int + // Processed will be written when a Consumer considers a message processed. + // The value written will be true if the message is considered to be + // successfully processed and can be deleted from the queue, otherwise the + // value written will be false and the message should not be deleted from + // queue so the consumer can retry processing it in the future. + Processed chan<- bool +} + +type Queue interface { + ReadMessage(ctx context.Context) (*Message, error) +} + +type CheckRunner interface { + Run(check Check, timesRead int) error +} + +type ConsumerCfg struct { + MaxConcurrentChecks int + MaxReadTime *time.Duration +} + +type Consumer struct { + *sync.RWMutex + cfg ConsumerCfg + queue Queue + wg *sync.WaitGroup + lastMessageReceived *time.Time + log log.Logger + Runner CheckRunner + nProcessingMessages uint32 +} + +// NewConsumer creates a new Reader with the given processor, queueARN and config. +func NewConsumer(log log.Logger, cfg ConsumerCfg, queue Queue, runner CheckRunner) *Consumer { + return &Consumer{ + cfg: cfg, + RWMutex: &sync.RWMutex{}, + Runner: runner, + log: log, + wg: &sync.WaitGroup{}, + queue: queue, + lastMessageReceived: nil, + nProcessingMessages: 0, + } +} + +// Start starts reading messages from the queue. It will stop reading from the +// queue when the passed in context is canceled. The caller can use the returned +// channel to track when the reader stopped reading from the queue and all the +// checks being run are finished. +func (r *Consumer) Start(ctx context.Context) <-chan error { + finished := make(chan error, 1) + if r.cfg.MaxConcurrentChecks < 1 { + finished <- errors.New("MaxConcurrentChecks must be greater than 0") + return finished + } + done := make(chan error, 1) + r.wg.Add(1) + go r.consume(ctx, done) + go func() { + err := <-done + r.wg.Wait() + finished <- err + close(finished) + }() + return finished +} + +func (r *Consumer) consume(ctx context.Context, done chan<- error) { + defer r.wg.Done() + if r.Runner == nil { + done <- errors.New("message processor is missing") + close(done) + return + } + tokens := make(chan struct{}, r.cfg.MaxConcurrentChecks) + for i := 0; i < r.cfg.MaxConcurrentChecks; i++ { + tokens <- struct{}{} + } + var ( + err error + msg Message + ) +loop: + for { + select { + case <-ctx.Done(): + err = ctx.Err() + break loop + case <-tokens: + msg, err = r.readMessage(ctx) + if err == ErrMaxTimeNoRead { + r.log.Infof("reader stopped: max time without reading messages elapsed") + break loop + } + if err != nil { + break loop + } + r.wg.Add(1) + go r.process(msg, tokens) + } + } + done <- err + close(done) +} + +func (r *Consumer) process(msg Message, done chan<- struct{}) { + atomic.AddUint32(&r.nProcessingMessages, 1) + defer func() { + // Decrement the number of messages being processed, see: + // https://golang.org/src/sync/atomic/doc.go?s=3841:3896#L87 + atomic.AddUint32(&r.nProcessingMessages, ^uint32(0)) + // Signal the message has been processed. + done <- struct{}{} + // Signal de goroutine has exited. + r.wg.Done() + }() + check := Check{ + RunTime: time.Now().Unix(), + } + err := json.Unmarshal([]byte(msg.Body), &check) + if err != nil { + r.log.Errorf("unable to parse message %q: %v", msg.Body, err) + msg.Processed <- true + return + } + err = r.Runner.Run(check, msg.TimesRead) + // A nil Processed channel means the queue reader that returned the message + // is not interested in knowing when a message was processed. + if msg.Processed == nil { + return + } + del := true + if err != nil { + del = true + r.log.Errorf("error processing message: %+v", err) + if errors.Is(err, ErrUnprocessableMessage) { + del = false + } + } + msg.Processed <- del +} + +func (r *Consumer) readMessage(ctx context.Context) (Message, error) { + maxTimeNoRead := r.cfg.MaxReadTime + start := time.Now() + var msg Message + for { + readMsg, err := r.queue.ReadMessage(ctx) + if err != nil { + if errors.Is(err, ErrInvalidMessage) { + r.log.Errorf("error reading message: %w") + continue + } + return msg, err + } + if readMsg != nil { + msg = *readMsg + break + } + // Check if we need to stop the reader because we exceed the maximun time + // trying to read a message. + now := time.Now() + n := atomic.LoadUint32(&r.nProcessingMessages) + if maxTimeNoRead != nil && now.Sub(start) > *maxTimeNoRead && n == 0 { + return Message{}, ErrMaxTimeNoRead + } + } + now := time.Now() + r.setLastMessageReceived(&now) + return msg, nil +} + +func (r *Consumer) setLastMessageReceived(t *time.Time) { + r.Lock() + r.lastMessageReceived = t + r.Unlock() +} + +// LastMessageReceived returns the time where the last message was received by +// the Reader. If no message was received so far it returns nil. +func (r *Consumer) LastMessageReceived() *time.Time { + r.RLock() + defer r.RUnlock() + return r.lastMessageReceived +} + +// SetMessageProcessor sets the queue's message processor. It must be +// set before calling [*Reader.StartReading]. +func (r *Consumer) SetMessageProcessor(p CheckRunner) { + r.Runner = p +} diff --git a/exp/checks/consumer_test.go b/exp/checks/consumer_test.go new file mode 100644 index 00000000..54cfef58 --- /dev/null +++ b/exp/checks/consumer_test.go @@ -0,0 +1,168 @@ +/* +Copyright 2021 Adevinta +*/ + +package checks + +import ( + "context" + "errors" + "fmt" + "sync" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + + "github.com/adevinta/vulcan-agent/v2/log" +) + +type mockQueue struct { + readMessage func(ctx context.Context) (*Message, error) +} + +func (m mockQueue) ReadMessage(ctx context.Context) (*Message, error) { + return m.readMessage(ctx) +} + +type TrackDeletedMessage struct { +} + +type mockMessage struct { + Body string + Err error + TimesRead int +} + +type inMemQueue struct { + sync.Mutex + wg sync.WaitGroup + current int + Messages []mockMessage + Processed []Message + NotProcessed []Message +} + +// Wait waits for all the onogoing goroutines that are wating for a message to +// be consumed to finish. +func (i *inMemQueue) Wait() { + i.wg.Wait() +} + +func (i *inMemQueue) ReadMessage(_ context.Context) (*Message, error) { + i.Lock() + defer i.Unlock() + if i.current >= len(i.Messages) { + return nil, nil + } + msg := i.Messages[i.current] + i.current++ + + err := msg.Err + if err != nil { + return nil, err + } + processed := make(chan bool, 1) + result := &Message{ + Body: msg.Body, + TimesRead: msg.TimesRead, + Processed: processed, + } + i.wg.Add(1) + go func(msg Message) { + defer i.wg.Done() + del := <-processed + i.Lock() + defer i.Unlock() + if del { + i.Processed = append(i.Processed, msg) + return + } + i.NotProcessed = append(i.NotProcessed, msg) + }(*result) + return result, nil +} + +type mockRunner struct { + run func(check Check, timesRead int) error +} + +func (m mockRunner) Run(check Check, timesRead int) error { + return m.run(check, timesRead) +} + +func TestConsumer(t *testing.T) { + type fields struct { + queue Queue + cfg ConsumerCfg + log log.Logger + Runner CheckRunner + } + tests := []struct { + name string + fields fields + ctx context.Context + want error + // wantState checks the internal states of a consumer for a given test. + // if the state is the expected it returns an empty string, if it isn't, + // it returns the difference between the wanted and the state the + // consumer got. + wantState func(c *Consumer) string + }{ + { + name: "RunsChecksReadFromQueueAndStops", + fields: fields{ + log: &log.NullLog{}, + cfg: ConsumerCfg{ + MaxConcurrentChecks: 10, + MaxReadTime: toPtr(1 * time.Second), + }, + queue: &inMemQueue{ + Messages: []mockMessage{ + { + Body: string(mustMarshal(checkFixture)), + }, + }, + }, + Runner: mockRunner{ + run: func(check Check, timesRead int) error { + return nil + }, + }, + }, + ctx: context.Background(), + want: ErrMaxTimeNoRead, + wantState: func(c *Consumer) string { + q := c.queue.(*inMemQueue) + q.Wait() + gotProcessed := q.Processed + gotNotProcessed := q.NotProcessed + wantProcessed := []Message{{ + Body: string(mustMarshal(checkFixture)), + }} + wantNotProcessed := []Message{} + processedDiff := cmp.Diff(wantProcessed, gotProcessed) + notProcessedDiff := cmp.Diff(wantNotProcessed, gotNotProcessed) + return fmt.Sprintf("%s%s", processedDiff, notProcessedDiff) + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := NewConsumer(tt.fields.log, tt.fields.cfg, tt.fields.queue, tt.fields.Runner) + done := c.Start(tt.ctx) + got := <-done + if !errors.Is(tt.want, got) { + t.Fatalf("error want!=got, %+v!=%+v", tt.want, got) + } + stateDiff := tt.wantState(c) + if stateDiff != "" { + t.Fatalf("want state!=got state, diff %s", stateDiff) + } + }) + } +} + +func toPtr[T any](v T) *T { + return &v +} diff --git a/exp/checks/runner.go b/exp/checks/runner.go new file mode 100644 index 00000000..550c12ce --- /dev/null +++ b/exp/checks/runner.go @@ -0,0 +1,339 @@ +/* +Copyright 2021 Adevinta +*/ + +package checks + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/adevinta/vulcan-agent/v2/backend" + "github.com/adevinta/vulcan-agent/v2/log" + "github.com/adevinta/vulcan-agent/v2/stateupdater" +) + +var ( + // ErrCheckWithSameID is returned when the runner is about to run a check + // with and ID equal to the ID of an already running check. + ErrCheckWithSameID = errors.New("check with a same ID is already running") + + // DefaultMaxMessageProcessedTimes defines the maximun number of times the + // processor tries to processe a checks message before it declares the check + // as failed. + DefaultMaxMessageProcessedTimes = 200 +) + +type checkAborter struct { + cancels sync.Map +} + +func (c *checkAborter) Add(checkID string, cancel context.CancelFunc) error { + _, exists := c.cancels.LoadOrStore(checkID, cancel) + if exists { + return ErrCheckWithSameID + } + return nil +} + +func (c *checkAborter) Remove(checkID string) { + c.cancels.Delete(checkID) +} + +func (c *checkAborter) Exist(checkID string) bool { + _, ok := c.cancels.Load(checkID) + return ok +} + +func (c *checkAborter) Abort(checkID string) { + v, ok := c.cancels.Load(checkID) + if !ok { + return + } + cancel := v.(context.CancelFunc) + cancel() +} + +func (c *checkAborter) AbortAll() { + c.cancels.Range(func(_, v interface{}) bool { + cancel := v.(context.CancelFunc) + cancel() + return true + }) +} + +// Running returns the number the checks that are in a given point of time being +// tracked by the Aborter component. In other words the number of checks +// running. +func (c *checkAborter) Running() int { + count := 0 + c.cancels.Range(func(_, v interface{}) bool { + count++ + return true + }) + return count +} + +type CheckStateUpdater interface { + UpdateState(stateupdater.CheckState) error + UploadCheckData(checkID, kind string, startedAt time.Time, content []byte) (string, error) + CheckStatusTerminal(ID string) bool + FlushCheckStatus(ID string) error + UpdateCheckStatusTerminal(stateupdater.CheckState) +} + +// AbortedChecks defines the shape of the component needed by a Runner in order +// to know if a check is aborted before it is exected. +type AbortedChecks interface { + IsAborted(ID string) (bool, error) +} + +// Runner runs the checks associated to a concreate message by receiving calls +// to it ProcessMessage function. +type Runner struct { + Backend backend.Backend + Logger log.Logger + CheckUpdater CheckStateUpdater + cAborter *checkAborter + abortedChecks AbortedChecks + defaultTimeout time.Duration + maxMessageProcessedTimes int +} + +// RunnerConfig contains config parameters for a Runner. +type RunnerConfig struct { + DefaultTimeout int + MaxProcessMessageTimes int +} + +// New creates a Runner initialized with the given log, backend and +// maximun number of tokens. The maximum number of tokens is the maximun number +// jobs that the Runner can execute at the same time. +func New(logger log.Logger, backend backend.Backend, checkUpdater CheckStateUpdater, + aborted AbortedChecks, cfg RunnerConfig) *Runner { + if cfg.MaxProcessMessageTimes < 1 { + cfg.MaxProcessMessageTimes = DefaultMaxMessageProcessedTimes + } + return &Runner{ + Backend: backend, + CheckUpdater: checkUpdater, + cAborter: &checkAborter{ + cancels: sync.Map{}, + }, + abortedChecks: aborted, + Logger: logger, + maxMessageProcessedTimes: cfg.MaxProcessMessageTimes, + defaultTimeout: time.Duration(cfg.DefaultTimeout * int(time.Second)), + } +} + +// Abort aborts a check if it is running. +func (cr *Runner) Abort(ID string) { + cr.cAborter.Abort(ID) +} + +// AbortAll aborts all the checks that are running. +func (cr *Runner) AbortAll() { + cr.cAborter.AbortAll() +} + +// Run executes the job. +func (cr *Runner) Run(check Check, timesRead int) error { + var err error + readMsg := fmt.Sprintf("check read from queue #[%d] times", timesRead) + cr.Logger.Debugf(check.logTrace(readMsg, "read")) + // Check if the message has been processed more than the maximum defined + // times. + if timesRead > cr.maxMessageProcessedTimes { + status := stateupdater.StatusFailed + err = cr.CheckUpdater.UpdateState( + stateupdater.CheckState{ + ID: check.CheckID, + Status: &status, + }) + if err != nil { + err = fmt.Errorf("error updating the status of the check: %s: %w", check.CheckID, err) + return err + } + cr.Logger.Errorf("error max processed times exceeded for check: %s", check.CheckID) + // We flush the terminal status and send the final status to the writer. + err = cr.CheckUpdater.FlushCheckStatus(check.CheckID) + if err != nil { + err = fmt.Errorf("error deleting the terminal status of the check: %s: %w", check.CheckID, err) + return err + } + return nil + } + // Check if the check has been aborted. + aborted, err := cr.abortedChecks.IsAborted(check.CheckID) + if err != nil { + err = fmt.Errorf("error querying aborted checks %w", err) + return err + } + + if aborted { + status := stateupdater.StatusAborted + err = cr.CheckUpdater.UpdateState( + stateupdater.CheckState{ + ID: check.CheckID, + Status: &status, + }) + if err != nil { + err = fmt.Errorf("error updating the status of the check %s: %w", check.CheckID, err) + return err + } + cr.Logger.Infof("check %s already aborted", check.CheckID) + return nil + } + + var timeout time.Duration + if check.Timeout != 0 { + timeout = time.Duration(check.Timeout * int(time.Second)) + } else { + timeout = cr.defaultTimeout + } + + // Create the context under which the backend will execute the check. The + // context will be cancelled either because the function cancel will be + // called by the aborter or because the timeout for the check has elapsed. + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + err = cr.cAborter.Add(check.CheckID, cancel) + // The above function can only return an error if the check already exists. + // So we just avoid executing it twice. + if err != nil { + return nil + } + ctName, ctVersion, err := getChecktypeInfo(check.Image) + if err != nil { + cr.cAborter.Remove(check.CheckID) + return fmt.Errorf("unable to read checktype info: %w", err) + } + runParams := backend.RunParams{ + CheckID: check.CheckID, + Target: check.Target, + Image: check.Image, + AssetType: check.AssetType, + Options: check.Options, + RequiredVars: check.RequiredVars, + CheckTypeName: ctName, + ChecktypeVersion: ctVersion, + } + cr.Logger.Infof(check.logTrace("running check", "running")) + finished, err := cr.Backend.Run(ctx, runParams) + if err != nil { + cr.cAborter.Remove(check.CheckID) + return fmt.Errorf("unable to run check: %w", err) + } + var logsLink string + // The finished channel is written by the backend when a check has finished. + // The value written to the channel contains the logs of the check(stdin and + // stdout) plus a field Error indicanting if there were any unexpected error + // during the execution. If that error is not nil, the backend was unable to + // retrieve the output of the check so the Output field will be nil. + res := <-finished + // When the check is finished it can not be aborted anymore + // so we remove it from aborter. + cr.cAborter.Remove(check.CheckID) + // Try always to upload the logs of the check if present. + if res.Output != nil { + logsLink, err = cr.CheckUpdater.UploadCheckData(check.CheckID, "logs", check.StartTime, res.Output) + if err != nil { + err = fmt.Errorf("unable to store the logs of the check %s: %w", check.CheckID, err) + // We return to retry the log upload later. + return err + } + + // Set the link for the logs of the check. + err = cr.CheckUpdater.UpdateState(stateupdater.CheckState{ + ID: check.CheckID, + Raw: &logsLink, + }) + if err != nil { + err = fmt.Errorf("unable to update the link to the logs of the check %s: %w", check.CheckID, err) + return err + } + cr.Logger.Debugf(check.logTrace(logsLink, "raw_logs")) + } + + // We query if the check has sent any status update with a terminal status. + isTerminal := cr.CheckUpdater.CheckStatusTerminal(check.CheckID) + + // Check if the backend returned any not expected error while running the check. + execErr := res.Error + if execErr != nil && + !errors.Is(execErr, context.DeadlineExceeded) && + !errors.Is(execErr, context.Canceled) && + !errors.Is(execErr, backend.ErrNonZeroExitCode) { + return execErr + } + + // The times this component has to set the state of a check are + // when the check is canceled, timeout or finished with an exit status code + // different from 0. That's because, in those cases, it's possible for the + // check to not have had time to set the state by itself. + var status string + if errors.Is(execErr, context.DeadlineExceeded) { + status = stateupdater.StatusTimeout + } + if errors.Is(execErr, context.Canceled) { + status = stateupdater.StatusAborted + } + if errors.Is(execErr, backend.ErrNonZeroExitCode) { + status = stateupdater.StatusFailed + } + // Ensure the check sent a status update with a terminal status. + if status == "" && !isTerminal { + status = stateupdater.StatusFailed + } + // If the check was not canceled or aborted we just finish its execution. + if status == "" { + // We signal the CheckUpdater that we don't need it to store that + // information any more. + err = cr.CheckUpdater.FlushCheckStatus(check.CheckID) + if err != nil { + err = fmt.Errorf("error deleting the terminal status of the check%s: %w", check.CheckID, err) + return err + } + return err + } + err = cr.CheckUpdater.UpdateState(stateupdater.CheckState{ + ID: check.CheckID, + Status: &status, + }) + if err != nil { + err = fmt.Errorf("error updating the status of the check %s: %w", check.CheckID, err) + return err + } + // We signal the CheckUpdater that we don't need it to store that + // information any more. + err = cr.CheckUpdater.FlushCheckStatus(check.CheckID) + if err != nil { + err = fmt.Errorf("error deleting the terminal status of the check %s: %w", check.CheckID, err) + return err + } + return err +} + +// Running returns the current number of checks running. +func (cr *Runner) Running() int { + return cr.cAborter.Running() +} + +// getChecktypeInfo extracts checktype data from a Docker image URI. +func getChecktypeInfo(imageURI string) (checktypeName string, checktypeVersion string, err error) { + domain, path, tag, err := backend.ParseImage(imageURI) + if err != nil { + err = fmt.Errorf("unable to parse image %s - %w", imageURI, err) + return + } + checktypeName = fmt.Sprintf("%s/%s", domain, path) + if domain == "docker.io" { + checktypeName = path + } + checktypeVersion = tag + return +} diff --git a/exp/checks/runner_test.go b/exp/checks/runner_test.go new file mode 100644 index 00000000..e2656241 --- /dev/null +++ b/exp/checks/runner_test.go @@ -0,0 +1,954 @@ +/* +Copyright 2021 Adevinta +*/ + +package checks + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "sync" + "testing" + "time" + + "github.com/adevinta/vulcan-agent/v2/backend" + "github.com/adevinta/vulcan-agent/v2/log" + "github.com/adevinta/vulcan-agent/v2/stateupdater" + "github.com/google/go-cmp/cmp" + "github.com/google/uuid" +) + +var ( + errUnexpectedTest = errors.New("unexpected") + errUploadingLogsTests = errors.New("error uploading logs") + checkFixture = Check{ + CheckID: uuid.NewString(), + StartTime: time.Now(), + Image: "job1:latest", + Target: "example.com", + Timeout: 60, + AssetType: "hostname", + Options: "{}", + RequiredVars: []string{"var1"}, + } +) + +type CheckRaw struct { + Raw []byte + CheckID string + StartTime time.Time +} + +type inMemChecksUpdater struct { + updates []stateupdater.CheckState + raws []CheckRaw + terminalStatus map[string]stateupdater.CheckState + uploadLogError error +} + +func (im *inMemChecksUpdater) UpdateState(cs stateupdater.CheckState) error { + status := "" + if cs.Status != nil { + status = *cs.Status + } else { + storedCheckStatus, ok := im.terminalStatus[cs.ID] + if ok { + status = *storedCheckStatus.Status + } + } + if _, ok := stateupdater.TerminalStatuses[status]; ok { + im.UpdateCheckStatusTerminal(cs) + return nil + } + if im.updates == nil { + im.updates = make([]stateupdater.CheckState, 0) + } + im.updates = append(im.updates, cs) + return nil +} + +func (im *inMemChecksUpdater) UploadCheckData(checkID string, kind string, stime time.Time, raw []byte) (string, error) { + if im.uploadLogError != nil { + return "", im.uploadLogError + } + + if im.raws != nil { + im.raws = make([]CheckRaw, 0) + } + im.raws = append(im.raws, CheckRaw{ + Raw: raw, + CheckID: checkID, + StartTime: stime, + }) + return fmt.Sprintf("%s/%s", checkID, kind), nil +} + +func (im *inMemChecksUpdater) CheckStatusTerminal(ID string) bool { + _, ok := im.terminalStatus[ID] + return ok +} + +func (im *inMemChecksUpdater) UpdateCheckStatusTerminal(s stateupdater.CheckState) { + if im.terminalStatus == nil { + im.terminalStatus = make(map[string]stateupdater.CheckState) + } + cs, ok := im.terminalStatus[s.ID] + if !ok { + im.terminalStatus[s.ID] = s + return + } + + // We update the existing CheckState. + cs.Merge(s) + + im.terminalStatus[cs.ID] = cs +} + +func (im *inMemChecksUpdater) FlushCheckStatus(ID string) error { + if im.updates == nil { + im.updates = make([]stateupdater.CheckState, 0) + } + terminalStatus, ok := im.terminalStatus[ID] + if ok { + im.updates = append(im.updates, terminalStatus) + } + return nil +} + +type mockChecksUpdater struct { + stateUpdater func(cs stateupdater.CheckState) error + checkRawUpload func(checkID, kind string, startedAt time.Time, content []byte) (string, error) + checkTerminalChecker func(ID string) bool + checkTerminalDeleter func(ID string) error + checkTerminalUpdater func(cs stateupdater.CheckState) +} + +func (m *mockChecksUpdater) UpdateState(cs stateupdater.CheckState) error { + return m.stateUpdater(cs) +} + +func (m *mockChecksUpdater) UploadCheckData(checkID, kind string, stime time.Time, raw []byte) (string, error) { + return m.checkRawUpload(checkID, kind, stime, raw) +} + +func (m *mockChecksUpdater) CheckStatusTerminal(ID string) bool { + return m.checkTerminalChecker(ID) +} + +func (m *mockChecksUpdater) FlushCheckStatus(ID string) error { + return m.checkTerminalDeleter(ID) +} + +func (m *mockChecksUpdater) UpdateCheckStatusTerminal(cs stateupdater.CheckState) { + m.checkTerminalUpdater(cs) +} + +type mockBackend struct { + CheckRunner func(ctx context.Context, params backend.RunParams) (<-chan backend.RunResult, error) +} + +func (mb *mockBackend) Run(ctx context.Context, params backend.RunParams) (<-chan backend.RunResult, error) { + return mb.CheckRunner(ctx, params) +} + +type jRunnerStateChecker func(r *Runner) string + +type inMemAbortedChecks struct { + aborted map[string]struct{} + err error +} + +func (ia *inMemAbortedChecks) IsAborted(ID string) (bool, error) { + if ia.err != nil { + return false, ia.err + } + _, ok := ia.aborted[ID] + return ok, nil +} + +func TestRunner_ProcessMessage(t *testing.T) { + type fields struct { + Backend backend.Backend + Tokens chan interface{} + Logger log.Logger + CheckUpdater CheckStateUpdater + cAborter *checkAborter + aborted AbortedChecks + defaultTimeout time.Duration + maxMessageProcessedTimes int + } + type args struct { + check Check + timesRead int + } + tests := []struct { + name string + fields fields + args args + want error + wantState jRunnerStateChecker + }{ + { + name: "RunsACheck", + fields: fields{ + Backend: &mockBackend{ + CheckRunner: func(ctx context.Context, params backend.RunParams) (<-chan backend.RunResult, error) { + res := make(chan backend.RunResult) + go func() { + output, err := json.Marshal(params) + if err != nil { + panic(err) + } + results := backend.RunResult{ + Output: output, + } + res <- results + }() + return res, nil + }, + }, + cAborter: &checkAborter{ + cancels: sync.Map{}, + }, + aborted: &inMemAbortedChecks{make(map[string]struct{}), nil}, + defaultTimeout: time.Duration(10 * time.Second), + Tokens: make(chan interface{}, 10), + Logger: &log.NullLog{}, + CheckUpdater: &inMemChecksUpdater{ + terminalStatus: map[string]stateupdater.CheckState{ + checkFixture.CheckID: { + ID: checkFixture.CheckID, + Status: str2ptr(stateupdater.StatusFinished), + }, + }, + }, + }, + args: args{ + check: checkFixture, + }, + want: nil, + wantState: func(r *Runner) string { + updater := r.CheckUpdater.(*inMemChecksUpdater) + gotRaws := updater.raws + wantRunParams := backend.RunParams{ + CheckID: checkFixture.CheckID, + CheckTypeName: "job1", + ChecktypeVersion: "latest", + Image: checkFixture.Image, + Options: checkFixture.Options, + Target: checkFixture.Target, + AssetType: checkFixture.AssetType, + RequiredVars: checkFixture.RequiredVars, + } + wantRaws := []CheckRaw{{ + Raw: mustMarshalRunParams(wantRunParams), + CheckID: checkFixture.CheckID, + StartTime: checkFixture.StartTime, + }} + gotUpdates := updater.updates + rawLink := fmt.Sprintf("%s/logs", checkFixture.CheckID) + wantUpdates := []stateupdater.CheckState{ + { + ID: checkFixture.CheckID, + Raw: &rawLink, + Status: str2ptr(stateupdater.StatusFinished), + }, + } + rawsDiff := cmp.Diff(wantRaws, gotRaws) + updateDiff := cmp.Diff(wantUpdates, gotUpdates) + return fmt.Sprintf("%s%s", rawsDiff, updateDiff) + }, + }, + { + name: "TerminatesACheckIfNeeded", + fields: fields{ + Backend: &mockBackend{ + CheckRunner: func(ctx context.Context, params backend.RunParams) (<-chan backend.RunResult, error) { + res := make(chan backend.RunResult) + go func() { + output, err := json.Marshal(params) + if err != nil { + panic(err) + } + results := backend.RunResult{ + Output: output, + } + res <- results + }() + return res, nil + }, + }, + cAborter: &checkAborter{ + cancels: sync.Map{}, + }, + aborted: &inMemAbortedChecks{make(map[string]struct{}), nil}, + defaultTimeout: time.Duration(10 * time.Second), + Tokens: make(chan interface{}, 10), + Logger: &log.NullLog{}, + CheckUpdater: &inMemChecksUpdater{}, + }, + args: args{ + check: checkFixture, + }, + want: nil, + wantState: func(r *Runner) string { + updater := r.CheckUpdater.(*inMemChecksUpdater) + gotRaws := updater.raws + wantRunParams := backend.RunParams{ + CheckID: checkFixture.CheckID, + CheckTypeName: "job1", + ChecktypeVersion: "latest", + Image: checkFixture.Image, + Options: checkFixture.Options, + Target: checkFixture.Target, + AssetType: checkFixture.AssetType, + RequiredVars: checkFixture.RequiredVars, + } + wantRaws := []CheckRaw{{ + Raw: mustMarshalRunParams(wantRunParams), + CheckID: checkFixture.CheckID, + StartTime: checkFixture.StartTime, + }} + gotUpdates := updater.updates + rawLink := fmt.Sprintf("%s/logs", checkFixture.CheckID) + wantUpdates := []stateupdater.CheckState{ + { + ID: checkFixture.CheckID, + Raw: &rawLink, + }, + { + ID: checkFixture.CheckID, + Status: str2ptr(stateupdater.StatusFailed), + }, + } + rawsDiff := cmp.Diff(wantRaws, gotRaws) + updateDiff := cmp.Diff(wantUpdates, gotUpdates) + return fmt.Sprintf("%s%s", rawsDiff, updateDiff) + }, + }, + { + name: "ReturnsAnError", + fields: fields{ + Backend: &mockBackend{ + CheckRunner: func(ctx context.Context, params backend.RunParams) (<-chan backend.RunResult, error) { + res := make(chan backend.RunResult) + go func() { + results := backend.RunResult{ + Error: errUnexpectedTest, + } + res <- results + }() + return res, nil + }, + }, + cAborter: &checkAborter{ + cancels: sync.Map{}, + }, + aborted: &inMemAbortedChecks{make(map[string]struct{}), nil}, + defaultTimeout: time.Duration(10 * time.Second), + Tokens: make(chan interface{}, 10), + Logger: &log.NullLog{}, + CheckUpdater: &inMemChecksUpdater{}, + }, + args: args{ + check: checkFixture, + }, + want: errUnexpectedTest, + wantState: func(r *Runner) string { + updater := r.CheckUpdater.(*inMemChecksUpdater) + gotRaws := updater.raws + var wantRaws []CheckRaw + gotUpdates := updater.updates + var wantUpdates []stateupdater.CheckState + rawsDiff := cmp.Diff(wantRaws, gotRaws) + updateDiff := cmp.Diff(wantUpdates, gotUpdates) + return fmt.Sprintf("%s%s", rawsDiff, updateDiff) + }, + }, + { + name: "UpdatesStateWhenCheckTimedout", + fields: fields{ + Backend: &mockBackend{ + CheckRunner: func(ctx context.Context, params backend.RunParams) (<-chan backend.RunResult, error) { + res := make(chan backend.RunResult) + go func() { + output, err := json.Marshal(params) + if err != nil { + panic(err) + } + results := backend.RunResult{ + Output: output, + Error: context.DeadlineExceeded, + } + res <- results + }() + return res, nil + }, + }, + cAborter: &checkAborter{ + cancels: sync.Map{}, + }, + aborted: &inMemAbortedChecks{make(map[string]struct{}), nil}, + defaultTimeout: time.Duration(10 * time.Second), + Tokens: make(chan interface{}, 10), + Logger: &log.NullLog{}, + CheckUpdater: &inMemChecksUpdater{}, + }, + args: args{ + check: checkFixture, + }, + want: nil, + wantState: func(r *Runner) string { + updater := r.CheckUpdater.(*inMemChecksUpdater) + gotRaws := updater.raws + wantRunParams := backend.RunParams{ + CheckID: checkFixture.CheckID, + CheckTypeName: "job1", + ChecktypeVersion: "latest", + Image: checkFixture.Image, + Options: checkFixture.Options, + Target: checkFixture.Target, + AssetType: checkFixture.AssetType, + RequiredVars: checkFixture.RequiredVars, + } + wantRaws := []CheckRaw{{ + Raw: mustMarshalRunParams(wantRunParams), + CheckID: checkFixture.CheckID, + StartTime: checkFixture.StartTime, + }} + gotUpdates := updater.updates + state := stateupdater.StatusTimeout + rawLink := fmt.Sprintf("%s/logs", checkFixture.CheckID) + wantUpdates := []stateupdater.CheckState{ + { + ID: checkFixture.CheckID, + Raw: &rawLink, + }, + { + ID: checkFixture.CheckID, + Status: &state, + // Raw: &rawLink, + }, + } + rawsDiff := cmp.Diff(wantRaws, gotRaws) + updateDiff := cmp.Diff(wantUpdates, gotUpdates) + return fmt.Sprintf("%s%s", rawsDiff, updateDiff) + }, + }, + { + name: "UpdatesStateWhenCheckCanceled", + fields: fields{ + Backend: &mockBackend{ + CheckRunner: func(ctx context.Context, params backend.RunParams) (<-chan backend.RunResult, error) { + res := make(chan backend.RunResult) + go func() { + output, err := json.Marshal(params) + if err != nil { + panic(err) + } + results := backend.RunResult{ + Output: output, + Error: context.Canceled, + } + res <- results + }() + return res, nil + }, + }, + cAborter: &checkAborter{ + cancels: sync.Map{}, + }, + aborted: &inMemAbortedChecks{make(map[string]struct{}), nil}, + defaultTimeout: time.Duration(10 * time.Second), + Tokens: make(chan interface{}, 10), + Logger: &log.NullLog{}, + CheckUpdater: &inMemChecksUpdater{}, + }, + args: args{ + check: checkFixture, + }, + want: nil, + wantState: func(r *Runner) string { + updater := r.CheckUpdater.(*inMemChecksUpdater) + gotRaws := updater.raws + wantRunParams := backend.RunParams{ + CheckID: checkFixture.CheckID, + CheckTypeName: "job1", + ChecktypeVersion: "latest", + Image: checkFixture.Image, + Options: checkFixture.Options, + Target: checkFixture.Target, + AssetType: checkFixture.AssetType, + RequiredVars: checkFixture.RequiredVars, + } + wantRaws := []CheckRaw{{ + Raw: mustMarshalRunParams(wantRunParams), + CheckID: checkFixture.CheckID, + StartTime: checkFixture.StartTime, + }} + gotUpdates := updater.updates + state := stateupdater.StatusAborted + rawLink := fmt.Sprintf("%s/logs", checkFixture.CheckID) + wantUpdates := []stateupdater.CheckState{ + { + ID: checkFixture.CheckID, + Raw: &rawLink, + }, + { + ID: checkFixture.CheckID, + Status: &state, + // Raw: &rawLink, + }, + } + rawsDiff := cmp.Diff(wantRaws, gotRaws) + updateDiff := cmp.Diff(wantUpdates, gotUpdates) + return fmt.Sprintf("%s%s", rawsDiff, updateDiff) + }, + }, + { + name: "DoesNotRunAbortedChecks", + fields: fields{ + Backend: &mockBackend{ + CheckRunner: func(ctx context.Context, params backend.RunParams) (<-chan backend.RunResult, error) { + res := make(chan backend.RunResult) + go func() { + output, err := json.Marshal(params) + if err != nil { + panic(err) + } + results := backend.RunResult{ + Output: output, + Error: context.Canceled, + } + res <- results + }() + return res, nil + }, + }, + cAborter: &checkAborter{ + cancels: sync.Map{}, + }, + aborted: &inMemAbortedChecks{ + aborted: map[string]struct{}{ + checkFixture.CheckID: {}, + }, + }, + defaultTimeout: time.Duration(10 * time.Second), + Tokens: make(chan interface{}, 10), + Logger: &log.NullLog{}, + CheckUpdater: &inMemChecksUpdater{}, + }, + args: args{ + check: checkFixture, + }, + want: nil, + wantState: func(r *Runner) string { + updater := r.CheckUpdater.(*inMemChecksUpdater) + gotRaws := updater.raws + var wantRaws []CheckRaw + gotUpdates := updater.updates + state := stateupdater.StatusAborted + wantUpdates := []stateupdater.CheckState{ + { + ID: checkFixture.CheckID, + Status: &state, + }, + } + rawsDiff := cmp.Diff(wantRaws, gotRaws) + updateDiff := cmp.Diff(wantUpdates, gotUpdates) + return fmt.Sprintf("%s%s", rawsDiff, updateDiff) + }, + }, + { + name: "ReturnErrorWhenUnableToUpdateStatus", + fields: fields{ + Backend: &mockBackend{ + CheckRunner: func(ctx context.Context, params backend.RunParams) (<-chan backend.RunResult, error) { + res := make(chan backend.RunResult) + go func() { + output, err := json.Marshal(params) + if err != nil { + panic(err) + } + results := backend.RunResult{ + Output: output, + } + res <- results + }() + return res, nil + }, + }, + cAborter: &checkAborter{ + cancels: sync.Map{}, + }, + aborted: &inMemAbortedChecks{ + aborted: map[string]struct{}{ + checkFixture.CheckID: {}, + }, + }, + defaultTimeout: time.Duration(10 * time.Second), + Tokens: make(chan interface{}, 10), + Logger: &log.NullLog{}, + CheckUpdater: &mockChecksUpdater{ + stateUpdater: func(cs stateupdater.CheckState) error { + return errUnexpectedTest + }, + checkRawUpload: func(checkID string, kind string, stime time.Time, raw []byte) (string, error) { + return "link", nil + }, + checkTerminalChecker: func(ID string) bool { + return false + }, + checkTerminalDeleter: func(string) error { + return nil + }, + checkTerminalUpdater: func(cs stateupdater.CheckState) { + }, + }, + }, + args: args{ + check: checkFixture, + }, + want: errUnexpectedTest, + wantState: func(r *Runner) string { + return "" + }, + }, + { + name: "UpdatesLogsWhenUnexpectedError", + fields: fields{ + Backend: &mockBackend{ + CheckRunner: func(ctx context.Context, params backend.RunParams) (<-chan backend.RunResult, error) { + res := make(chan backend.RunResult) + go func() { + results := backend.RunResult{ + Output: []byte("logs"), + Error: errUnexpectedTest, + } + res <- results + }() + return res, nil + }, + }, + cAborter: &checkAborter{ + cancels: sync.Map{}, + }, + aborted: &inMemAbortedChecks{ + aborted: map[string]struct{}{}, + }, + defaultTimeout: time.Duration(10 * time.Second), + Tokens: make(chan interface{}, 10), + Logger: &log.NullLog{}, + CheckUpdater: &inMemChecksUpdater{}, + maxMessageProcessedTimes: 1, + }, + args: args{ + check: checkFixture, + }, + want: errUnexpectedTest, + wantState: func(r *Runner) string { + updater := r.CheckUpdater.(*inMemChecksUpdater) + gotRaws := updater.raws + wantRaws := []CheckRaw{{ + Raw: []byte("logs"), + CheckID: checkFixture.CheckID, + StartTime: checkFixture.StartTime, + }} + gotUpdates := updater.updates + rawLink := fmt.Sprintf("%s/logs", checkFixture.CheckID) + wantUpdates := []stateupdater.CheckState{ + { + ID: checkFixture.CheckID, + Raw: &rawLink, + }, + } + rawsDiff := cmp.Diff(wantRaws, gotRaws) + updateDiff := cmp.Diff(wantUpdates, gotUpdates) + return fmt.Sprintf("%s%s", rawsDiff, updateDiff) + }, + }, + { + name: "DoesNotRunChecksProcessedTooManyTimes", + fields: fields{ + Backend: &mockBackend{ + CheckRunner: func(ctx context.Context, params backend.RunParams) (<-chan backend.RunResult, error) { + res := make(chan backend.RunResult) + go func() { + output, err := json.Marshal(params) + if err != nil { + panic(err) + } + results := backend.RunResult{ + Output: output, + Error: context.Canceled, + } + res <- results + }() + return res, nil + }, + }, + cAborter: &checkAborter{ + cancels: sync.Map{}, + }, + aborted: &inMemAbortedChecks{ + aborted: map[string]struct{}{}, + }, + defaultTimeout: time.Duration(10 * time.Second), + maxMessageProcessedTimes: 1, + Tokens: make(chan interface{}, 10), + Logger: &log.NullLog{}, + CheckUpdater: &inMemChecksUpdater{}, + }, + args: args{ + check: checkFixture, + timesRead: 2, + }, + want: nil, + wantState: func(r *Runner) string { + updater := r.CheckUpdater.(*inMemChecksUpdater) + gotRaws := updater.raws + var wantRaws []CheckRaw + gotUpdates := updater.updates + state := stateupdater.StatusFailed + wantUpdates := []stateupdater.CheckState{ + { + ID: checkFixture.CheckID, + Status: &state, + }, + } + rawsDiff := cmp.Diff(wantRaws, gotRaws) + updateDiff := cmp.Diff(wantUpdates, gotUpdates) + return fmt.Sprintf("%s%s", rawsDiff, updateDiff) + }, + }, + { + name: "UpdatesStateWhenErrorUpdatedStatus", + fields: fields{ + Backend: &mockBackend{ + CheckRunner: func(ctx context.Context, params backend.RunParams) (<-chan backend.RunResult, error) { + res := make(chan backend.RunResult) + go func() { + output, err := json.Marshal(params) + if err != nil { + panic(err) + } + results := backend.RunResult{ + Output: output, + Error: backend.ErrNonZeroExitCode, + } + res <- results + }() + return res, nil + }, + }, + cAborter: &checkAborter{ + cancels: sync.Map{}, + }, + aborted: &inMemAbortedChecks{make(map[string]struct{}), nil}, + defaultTimeout: time.Duration(10 * time.Second), + Tokens: make(chan interface{}, 10), + Logger: &log.NullLog{}, + CheckUpdater: &inMemChecksUpdater{}, + }, + args: args{ + check: checkFixture, + }, + want: nil, + wantState: func(r *Runner) string { + updater := r.CheckUpdater.(*inMemChecksUpdater) + gotRaws := updater.raws + wantRunParams := backend.RunParams{ + CheckID: checkFixture.CheckID, + CheckTypeName: "job1", + ChecktypeVersion: "latest", + Image: checkFixture.Image, + Options: checkFixture.Options, + Target: checkFixture.Target, + AssetType: checkFixture.AssetType, + RequiredVars: checkFixture.RequiredVars, + } + wantRaws := []CheckRaw{{ + Raw: mustMarshalRunParams(wantRunParams), + CheckID: checkFixture.CheckID, + StartTime: checkFixture.StartTime, + }} + gotUpdates := updater.updates + state := stateupdater.StatusFailed + rawLink := fmt.Sprintf("%s/logs", checkFixture.CheckID) + wantUpdates := []stateupdater.CheckState{ + { + ID: checkFixture.CheckID, + Raw: &rawLink, + }, + { + ID: checkFixture.CheckID, + Status: &state, + }, + } + rawsDiff := cmp.Diff(wantRaws, gotRaws) + updateDiff := cmp.Diff(wantUpdates, gotUpdates) + return fmt.Sprintf("%s%s", rawsDiff, updateDiff) + }, + }, + { + name: "ReturnsErrorWhenUnableToUploadLogs", + fields: fields{ + Backend: &mockBackend{ + CheckRunner: func(ctx context.Context, params backend.RunParams) (<-chan backend.RunResult, error) { + res := make(chan backend.RunResult) + go func() { + output, err := json.Marshal(params) + if err != nil { + panic(err) + } + results := backend.RunResult{ + Output: output, + } + res <- results + }() + return res, nil + }, + }, + cAborter: &checkAborter{ + cancels: sync.Map{}, + }, + aborted: &inMemAbortedChecks{make(map[string]struct{}), nil}, + defaultTimeout: time.Duration(10 * time.Second), + Tokens: make(chan interface{}, 10), + Logger: &log.NullLog{}, + CheckUpdater: &inMemChecksUpdater{ + terminalStatus: map[string]stateupdater.CheckState{ + checkFixture.CheckID: { + ID: checkFixture.CheckID, + Status: str2ptr(stateupdater.StatusFinished), + }, + }, + uploadLogError: errUploadingLogsTests, + }, + }, + args: args{ + check: checkFixture, + }, + want: errUploadingLogsTests, + wantState: func(r *Runner) string { + updater := r.CheckUpdater.(*inMemChecksUpdater) + gotRaws := updater.raws + var wantRaws []CheckRaw + gotUpdates := updater.updates + // No update with a FAILED terminal status will be sent because + // it will be reattempted. + var wantUpdates []stateupdater.CheckState + rawsDiff := cmp.Diff(wantRaws, gotRaws) + updateDiff := cmp.Diff(wantUpdates, gotUpdates) + return fmt.Sprintf("%s%s", rawsDiff, updateDiff) + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cr := &Runner{ + Backend: tt.fields.Backend, + Logger: tt.fields.Logger, + CheckUpdater: tt.fields.CheckUpdater, + abortedChecks: tt.fields.aborted, + cAborter: tt.fields.cAborter, + defaultTimeout: tt.fields.defaultTimeout, + maxMessageProcessedTimes: tt.fields.maxMessageProcessedTimes, + } + got := cr.Run(tt.args.check, tt.args.timesRead) + if !errors.Is(got, tt.want) { + t.Fatalf("error want!=got, %+v!=%+v", tt.want, got) + } + stateDiff := tt.wantState(cr) + if stateDiff != "" { + t.Fatalf("want state!=got state, diff %s", stateDiff) + } + }) + } +} + +func mustMarshal(params Check) []byte { + res, err := json.Marshal(params) + if err != nil { + panic(err) + } + return res +} + +func mustMarshalRunParams(params backend.RunParams) []byte { + res, err := json.Marshal(params) + if err != nil { + panic(err) + } + return res +} + +func str2ptr(str string) *string { + return &str +} + +func TestRunner_getChecktypeInfo(t *testing.T) { + tests := []struct { + image string + wants []string + wantsErr bool + }{ + { + image: "check", + wants: []string{"check", "latest"}, + }, + { + image: "check:1", + wants: []string{"check", "1"}, + }, + { + image: "vulcan/check1", + wants: []string{"vulcan/check1", "latest"}, + }, + { // Should be error? + image: "docker.io/check1", + wants: []string{"check1", "latest"}, + }, + { + image: "docker.io/vulcansec/check1", + wants: []string{"vulcansec/check1", "latest"}, + }, + { + image: "artifactory.com/check1", + wants: []string{"artifactory.com/check1", "latest"}, + }, + { + image: "artifactory.com/vulcan/check1", + wants: []string{"artifactory.com/vulcan/check1", "latest"}, + }, + { + image: "artifactory.com/vulcan/check1:1", + wants: []string{"artifactory.com/vulcan/check1", "1"}, + }, + { + image: "artifactory.com:1234/vulcan/check1:1", + wants: []string{"artifactory.com:1234/vulcan/check1", "1"}, + }, + { + image: "http://artifactory.com:1234/vulcan/check1:1", + wants: []string{"", ""}, + wantsErr: true, + }, + { + image: "image with spaces", + wants: []string{"", ""}, + wantsErr: true, + }, + } + for _, c := range tests { + n, v, err := getChecktypeInfo(c.image) + if c.wantsErr != (err != nil) { + t.Errorf("image:%s want error %v and got %v", c.image, c.wantsErr, err) + } + if diff := cmp.Diff(c.wants, []string{n, v}); diff != "" { + t.Errorf("image:%s want state!=got state, diff %s", c.image, diff) + } + } +} diff --git a/exp/queue/queue.go b/exp/queue/queue.go new file mode 100644 index 00000000..bfffd9e6 --- /dev/null +++ b/exp/queue/queue.go @@ -0,0 +1,199 @@ +package queue + +import ( + "context" + "errors" + "fmt" + "strconv" + "sync" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/arn" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/sqs" + "github.com/aws/aws-sdk-go/service/sqs/sqsiface" + + "github.com/adevinta/vulcan-agent/v2/config" + "github.com/adevinta/vulcan-agent/v2/exp/checks" + "github.com/adevinta/vulcan-agent/v2/log" +) + +const ( + MaxQuantumDelta = 3 // in seconds +) + +type SQSReader struct { + receiveParams sqs.ReceiveMessageInput + visibilityTimeout int + processMessageQuantum int + trackingMessagesWG *sync.WaitGroup + log log.Logger + sqs sqsiface.SQSAPI +} + +func NewSQSReader(log log.Logger, cfg config.SQSReader) (*SQSReader, error) { + sess, err := session.NewSession() + if err != nil { + err = fmt.Errorf("error creating AWSSSession, %w", err) + return nil, err + } + arn, err := arn.Parse(cfg.ARN) + if err != nil { + return nil, fmt.Errorf("error parsing SQS queue ARN: %v", err) + } + + awsCfg := aws.NewConfig() + if arn.Region != "" { + awsCfg = awsCfg.WithRegion(arn.Region) + } + if cfg.Endpoint != "" { + awsCfg = awsCfg.WithEndpoint(cfg.Endpoint) + } + + params := &sqs.GetQueueUrlInput{ + QueueName: aws.String(arn.Resource), + } + if arn.AccountID != "" { + params.SetQueueOwnerAWSAccountId(arn.AccountID) + } + + srv := sqs.New(sess, awsCfg) + resp, err := srv.GetQueueUrl(params) + if err != nil { + return nil, fmt.Errorf("error retrieving SQS queue URL: %v", err) + } + delta := cfg.VisibilityTimeout - cfg.ProcessQuantum + if delta < MaxQuantumDelta { + err := errors.New("difference between visibility timeout and quantum is too short") + return nil, err + } + receiveParams := sqs.ReceiveMessageInput{ + QueueUrl: aws.String(*resp.QueueUrl), + MaxNumberOfMessages: aws.Int64(1), + WaitTimeSeconds: aws.Int64(0), + VisibilityTimeout: aws.Int64(int64(cfg.VisibilityTimeout)), + AttributeNames: []*string{aws.String("ApproximateReceiveCount")}, + } + reader := &SQSReader{ + receiveParams: receiveParams, + sqs: srv, + visibilityTimeout: cfg.VisibilityTimeout, + log: log, + trackingMessagesWG: &sync.WaitGroup{}, + } + return reader, nil +} + +func (r *SQSReader) trackMessage(msg sqs.Message, done <-chan bool) { + defer r.trackingMessagesWG.Done() + timer := time.NewTimer(time.Duration(r.processMessageQuantum) * time.Second) +loop: + for { + select { + case <-timer.C: + extendTime := int64(r.visibilityTimeout) + input := &sqs.ChangeMessageVisibilityInput{ + QueueUrl: r.receiveParams.QueueUrl, + ReceiptHandle: msg.ReceiptHandle, + VisibilityTimeout: &extendTime, + } + _, err := r.sqs.ChangeMessageVisibility(input) + if err != nil { + r.log.Errorf("extending message visibility time for message with id: %s, error: %+v", *msg.MessageId, err) + break loop + } + timer.Reset(time.Duration(r.processMessageQuantum) * time.Second) + case delete := <-done: + timer.Stop() + if !delete { + r.log.Errorf("unexpected error processing message with id: %s, message not deleted", *msg.MessageId) + break loop + } + r.log.Debugf("deleting message with id %s", *msg.MessageId) + input := &sqs.DeleteMessageInput{ + QueueUrl: r.receiveParams.QueueUrl, + ReceiptHandle: msg.ReceiptHandle, + } + _, err := r.sqs.DeleteMessage(input) + if err != nil { + r.log.Errorf("deleting message with id: %s, error: %+v", *msg.MessageId, err) + break loop + } + break loop + } + } +} + +func (r *SQSReader) ReadMessage(ctx context.Context) (*checks.Message, error) { + var msg *sqs.Message + resp, err := r.sqs.ReceiveMessageWithContext(ctx, &r.receiveParams) + if err != nil { + if errors.Is(err, context.Canceled) { + return nil, err + } + if aerr, ok := err.(awserr.Error); ok && aerr.Code() == request.CanceledErrorCode { + return nil, context.Canceled + } + return nil, err + } + + if len(resp.Messages) == 0 { + return nil, nil + } + msg = resp.Messages[0] + err = validateSQSMessage(msg) + if err != nil { + // If the message is invalid we delete it from the queue. + if msg.ReceiptHandle == nil { + delErr := errors.New("can't delete invalid message, receipt handle is empty") + return nil, errors.Join(delErr, err) + } + // Invalid message delete from queue without processing. + _, err := r.sqs.DeleteMessage(&sqs.DeleteMessageInput{ + ReceiptHandle: msg.ReceiptHandle, + QueueUrl: r.receiveParams.QueueUrl, + }) + if err != nil { + delErr := fmt.Errorf("deleting invalid message: %w", err) + return nil, errors.Join(delErr, err) + } + return nil, err + } + m := checks.Message{Body: *msg.Body} + var n int + if rc, ok := msg.Attributes["ApproximateReceiveCount"]; ok && rc != nil { + n, err = strconv.Atoi(*rc) + if err != nil { + return nil, fmt.Errorf("error reading ApproximateReceiveCount msg attribute %v", err) + } + } + m.TimesRead = n + processed := make(chan bool, 1) + m.Processed = processed + r.trackingMessagesWG.Add(1) + go r.trackMessage(*msg, processed) + return &m, nil +} + +func validateSQSMessage(msg *sqs.Message) error { + if msg == nil { + invalidErr := errors.New("%w: unexpected empty message") + return errors.Join(checks.ErrInvalidMessage, invalidErr) + } + if msg.Body == nil { + invalidErr := errors.New("unexpected empty body message") + return errors.Join(checks.ErrInvalidMessage, invalidErr) + } + if msg.MessageId == nil { + invalidErr := errors.New("unexpected empty message id") + return errors.Join(checks.ErrInvalidMessage, invalidErr) + } + if msg.ReceiptHandle == nil { + invalidErr := errors.New("unexpected empty receipt handle") + return errors.Join(checks.ErrInvalidMessage, invalidErr) + } + return nil +}