Skip to content

Commit

Permalink
feat(wip): cancel tree runners
Browse files Browse the repository at this point in the history
  • Loading branch information
fabiankachlock committed Jul 29, 2024
1 parent de52685 commit f5d3618
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 37 deletions.
39 changes: 18 additions & 21 deletions pkg/runner/tree_runner.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package runner

import (
"errors"
"fmt"
"sync"
"sync/atomic"

Expand Down Expand Up @@ -31,7 +31,7 @@ type TaskTreeRunner struct {
// cancel is a channel that is used to cancel the execution of the task tree.
cancel chan bool
// cancelComplete is a channel that is used to signal that the cancel operation has completed.
cancelComplete chan error
cancelComplete chan bool
// hasError is a flag that indicates whether an error occurred during the execution of the task tree.
hasError atomic.Bool

Expand All @@ -54,7 +54,7 @@ func NewTaskTreeRunner(root *tasks.TaskTreeNode, p ConcurrencyProvider) *TaskTre
updates: make(chan *TreeStatusNode, 1000),
wasCanceled: atomic.Bool{},
cancel: make(chan bool),
cancelComplete: make(chan error),
cancelComplete: make(chan bool),

mutex: sync.RWMutex{},
}
Expand All @@ -68,10 +68,10 @@ func (r *TaskTreeRunner) Status() *TreeStatusNode {
return r.statusTree
}

func (r *TaskTreeRunner) Cancel() error {
func (r *TaskTreeRunner) Cancel() {
r.cancel <- true
close(r.cancel)
return <-r.cancelComplete
<-r.cancelComplete
}

func (r *TaskTreeRunner) updateTaskStatus(node *tasks.TaskTreeNode, status TaskStatus) {
Expand All @@ -93,20 +93,14 @@ func (r *TaskTreeRunner) setError(node *tasks.TaskTreeNode, err error) {

func (r *TaskTreeRunner) Start() error {
done := make(chan bool, 1)
errs := []error{}
errs := map[string]error{}
errMu := sync.Mutex{}
wg := sync.WaitGroup{}

// cleanup
defer func() {
if r.wasCanceled.Load() {
errMu.Lock()
if len(errs) > 0 {
r.cancelComplete <- errors.Join(errs...)
} else {
r.cancelComplete <- nil
}
errMu.Unlock()
r.cancelComplete <- true
}
close(r.updates)
close(r.cancelComplete)
Expand All @@ -129,15 +123,13 @@ func (r *TaskTreeRunner) Start() error {
r.mutex.Unlock()
go func(task *tasks.TaskTreeNode, cancel <-chan bool) {
// acquire a ticket to run the task
// fmt.Println("wait", task.NodeID())
ticket := r.tickets.Acquire()
// fmt.Println("start", task.NodeID())
r.updateTaskStatus(task, StatusRunning)
if err := task.Main.Run(cancel); err != nil {
// fmt.Println("error", task.NodeID(), err)
errMu.Lock()
errs = append(errs, err)
if !r.hasError.Load() {
errs[task.NodeID()] = err
fmt.Println(err)
if !r.hasError.Load() && !r.wasCanceled.Load() {
// first node erroring -> close the channel
close(r.scheduledNodes)
}
Expand All @@ -156,7 +148,6 @@ func (r *TaskTreeRunner) Start() error {
r.mutex.Lock()
delete(r.forwardCancel, task.NodeID())
r.mutex.Unlock()
// fmt.Println("release ticket", task.NodeID())
// release the ticket to be used by another channel
r.tickets.Release(ticket)
wg.Done()
Expand All @@ -172,8 +163,8 @@ func (r *TaskTreeRunner) Start() error {
case <-r.cancel:
// run was canceled - forward cancel to all tasks
r.wasCanceled.Store(true)
close(r.scheduledNodes)
for _, cancel := range r.forwardCancel {
// fmt.Println("cancel", id)
cancel <- true
}
return
Expand All @@ -195,8 +186,12 @@ func (r *TaskTreeRunner) Start() error {
done <- true
close(done)

if r.wasCanceled.Load() {
return tasks.ErrCancelled
}

if len(errs) > 0 {
return errors.Join(errs...)
return tasks.NewMultiTaskError(errs)
}
return nil
}
Expand All @@ -206,6 +201,8 @@ func (r *TaskTreeRunner) scheduleNext(node *tasks.TaskTreeNode) {
return // fail fast
}

r.mutex.Lock()
defer r.mutex.Unlock()
statusNode := findStatus(r.statusTree, node)
if allDone(r.statusTree) {
close(r.scheduledNodes)
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion pkg/tasks/command_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func (ct commandTask) Run(cancel <-chan bool) error {
} else {
err = exec.Command("pkill", "-P", strconv.Itoa(ct.cmd.Process.Pid)).Run()
}
fmt.Println("killed", ct.cmd.Process.Pid, err)
// err = syscall.Kill(ct.cmd.Process.Pid, syscall.SIGKILL)
if err != nil {
// fall back to builtin kill
if err := ct.cmd.Process.Kill(); err != nil {
Expand Down
29 changes: 29 additions & 0 deletions pkg/tasks/error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package tasks

import (
"errors"
"fmt"
"strings"
)

var (
ErrCancelled = errors.New("task cancelled")
)

type MultiTaskError struct {
Errors map[string]error
}

func (mte MultiTaskError) Error() string {
nodeIds := make([]string, 0, len(mte.Errors))
for id := range mte.Errors {
nodeIds = append(nodeIds, id)
}
return fmt.Sprintf("tasks %s failed", strings.Join(nodeIds, ", "))
}

func NewMultiTaskError(errors map[string]error) MultiTaskError {
return MultiTaskError{
Errors: errors,
}
}
5 changes: 1 addition & 4 deletions pkg/ui/quite_runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,7 @@ func (m *quiteView) setupInterruptHandler() {
go func() {
for range c {
for _, runner := range m.runners {
err := runner.Cancel()
m.mu.Lock()
m.errs = append(m.errs, err)
m.mu.Unlock()
runner.Cancel()
}
break
}
Expand Down
11 changes: 6 additions & 5 deletions pkg/ui/static_tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"os"
"os/signal"
"sync"
"time"

"github.com/zwoo-hq/zwooc/pkg/tasks"
)
Expand All @@ -27,10 +26,10 @@ func newStaticTreeRunner(forest tasks.Collection, provider SimpleStatusProvider,

fmt.Printf("%s - %s\n", zwoocBranding, forest.GetName())
model.setupInterruptHandler()
execStart := time.Now()
hasError := false
// execStart := time.Now()
// hasError := false

start := time.Now()
// start := time.Now()
outputs := map[string]*tasks.CommandCapturer{}

// setup task pipes
Expand Down Expand Up @@ -58,7 +57,7 @@ func newStaticTreeRunner(forest tasks.Collection, provider SimpleStatusProvider,

// wait until everything is completed
model.wg.Wait()
execEnd := time.Now()
// execEnd := time.Now()

// TODO: add option in provider to find all errors
// if model.err != nil {
Expand Down Expand Up @@ -110,6 +109,7 @@ func (m *staticTreeView) ReceiveUpdates(c <-chan StatusUpdate, prefix string) {
fmt.Printf("%s %s %s\n", prefix, node.NodeID, canceledStyle.Render("was canceled"))
}
}
fmt.Println("updates done")
m.wg.Done()
}

Expand All @@ -120,6 +120,7 @@ func (m *staticTreeView) WaitForDone() {
m.err = err
m.mu.Unlock()
}
m.wg.Done()
}

func (m *staticTreeView) printFinalStatus() {
Expand Down
4 changes: 4 additions & 0 deletions pkg/ui/tree_progress.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ui

import (
"errors"
"fmt"
"os"
"os/signal"
Expand Down Expand Up @@ -49,6 +50,9 @@ func NewTreeProgressView(forest tasks.Collection, status SimpleStatusProvider, o
}

// TODO: done -display cancel or error or success
if errors.Is(model.err, tasks.ErrCancelled) {
fmt.Println("cancelled")
}

fmt.Println("done!!")
return nil
Expand Down
14 changes: 8 additions & 6 deletions pkg/ui/ui.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,14 @@ func NewRunner(forest tasks.Collection, provider SimpleStatusProvider, options V
}

if options.DisableTUI {
// TODO: use provided runner
newStaticTreeRunner(forest, options)
newStaticTreeRunner(forest, provider, options)
return
}

// try interactive view
if err := NewTreeProgressView(forest, provider, options); err != nil {
// fall back to static view
// TODO: use provided runner
newStaticTreeRunner(forest, options)
newStaticTreeRunner(forest, provider, options)
}
}

Expand Down Expand Up @@ -61,18 +59,22 @@ func (g SimpleStatusProvider) Start() {
close(g.start)
}

func (g SimpleStatusProvider) Cancel() {
func (g *SimpleStatusProvider) Cancel() {
if !g.wasCanceled {
g.wasCanceled = true
g.cancel <- struct{}{}
close(g.cancel)
g.wasCanceled = true
}
}

func (g SimpleStatusProvider) UpdateStatus(update StatusUpdate) {
g.status <- update
}

func (g SimpleStatusProvider) CloseUpdates() {
close(g.status)
}

func (g SimpleStatusProvider) Done(err error) {
g.done <- err
close(g.done)
Expand Down
10 changes: 10 additions & 0 deletions pkg/zwooc/ui_adapter.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package zwooc

import (
"sync"

"github.com/zwoo-hq/zwooc/pkg/config"
"github.com/zwoo-hq/zwooc/pkg/runner"
"github.com/zwoo-hq/zwooc/pkg/tasks"
Expand Down Expand Up @@ -49,15 +51,23 @@ func createForestRunner(forest tasks.Collection, maxConcurrency int) ui.SimpleSt
})

// forward updates
updatesWg := sync.WaitGroup{}
for _, r := range runners {
currentRunner := r
updatesWg.Add(1)
go func() {
for update := range currentRunner.Updates() {
statusProvider.UpdateStatus(runnerToStatusProvider(update))
}
updatesWg.Done()
}()
}

go func() {
updatesWg.Wait()
statusProvider.CloseUpdates()
}()

return statusProvider
}

Expand Down

0 comments on commit f5d3618

Please sign in to comment.