diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index 91f73beaf6..81daedc322 100755 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -116,6 +116,11 @@ type Daemon struct { // triggers a zombie restart, debouncing transient gaps during handoffs. // Only accessed from heartbeat loop goroutine - no sync needed. mayorZombieCount int + + // rigPool runs per-rig heartbeat operations (witness checks, refinery checks, + // polecat health, idle reaping, branch pruning) with bounded concurrency and + // per-rig context timeouts so one slow rig cannot block all others. + rigPool *RigWorkerPool } // sessionDeath records a detected session death for mass death analysis. @@ -315,6 +320,7 @@ func New(config *Config) (*Daemon, error) { restartTracker: restartTracker, otelProvider: otelProvider, metrics: dm, + rigPool: newRigWorkerPool(0, 0, logger), // defaults: 10 workers, 30s timeout }, nil } @@ -1412,9 +1418,10 @@ func (d *Daemon) checkDeaconHeartbeat() { // Respects the rigs filter in daemon.json patrol config. func (d *Daemon) ensureWitnessesRunning() { rigs := d.getPatrolRigs("witness") - for _, rigName := range rigs { + d.rigPool.runPerRig(d.ctx, rigs, func(ctx context.Context, rigName string) error { d.ensureWitnessRunning(rigName) - } + return nil + }) } // hasPendingEvents checks if there are pending .event files in the given channel directory. @@ -1488,9 +1495,10 @@ func (d *Daemon) ensureWitnessRunning(rigName string) { // Respects the rigs filter in daemon.json patrol config. func (d *Daemon) ensureRefineriesRunning() { rigs := d.getPatrolRigs("refinery") - for _, rigName := range rigs { + d.rigPool.runPerRig(d.ctx, rigs, func(ctx context.Context, rigName string) error { d.ensureRefineryRunning(rigName) - } + return nil + }) } // ensureRefineryRunning ensures the refinery for a specific rig is running. @@ -1626,7 +1634,7 @@ func (d *Daemon) killDeaconSessions() { // killWitnessSessions kills leftover witness tmux sessions for all rigs. // Called when the witness patrol is disabled. (hq-2mstj) func (d *Daemon) killWitnessSessions() { - for _, rigName := range d.getKnownRigs() { + d.rigPool.runPerRig(d.ctx, d.getKnownRigs(), func(ctx context.Context, rigName string) error { name := session.WitnessSessionName(session.PrefixFor(rigName)) exists, _ := d.tmux.HasSession(name) if exists { @@ -1635,13 +1643,14 @@ func (d *Daemon) killWitnessSessions() { d.logger.Printf("Error killing %s session: %v", name, err) } } - } + return nil + }) } // killRefinerySessions kills leftover refinery tmux sessions for all rigs. // Called when the refinery patrol is disabled. (hq-2mstj) func (d *Daemon) killRefinerySessions() { - for _, rigName := range d.getKnownRigs() { + d.rigPool.runPerRig(d.ctx, d.getKnownRigs(), func(ctx context.Context, rigName string) error { name := session.RefinerySessionName(session.PrefixFor(rigName)) exists, _ := d.tmux.HasSession(name) if exists { @@ -1650,7 +1659,8 @@ func (d *Daemon) killRefinerySessions() { d.logger.Printf("Error killing %s session: %v", name, err) } } - } + return nil + }) } // killDefaultPrefixGhosts kills tmux sessions that use the default "gt" prefix @@ -2234,10 +2244,10 @@ func KillOrphanedDaemons(townRoot string) (int, error) { // When a crash is detected, the polecat is automatically restarted. // This provides faster recovery than waiting for GUPP timeout or Witness detection. func (d *Daemon) checkPolecatSessionHealth() { - rigs := d.getKnownRigs() - for _, rigName := range rigs { + d.rigPool.runPerRig(d.ctx, d.getKnownRigs(), func(ctx context.Context, rigName string) error { d.checkRigPolecatHealth(rigName) - } + return nil + }) } // checkRigPolecatHealth checks polecat session health for a specific rig. @@ -2511,12 +2521,12 @@ Restart deferred to stuck-agent-dog plugin for context-aware recovery.`, // This reaper checks heartbeat state and kills sessions idle longer than the threshold. func (d *Daemon) reapIdlePolecats() { opCfg := d.loadOperationalConfig().GetDaemonConfig() - timeout := opCfg.PolecatIdleSessionTimeoutD() + idleTimeout := opCfg.PolecatIdleSessionTimeoutD() - rigs := d.getKnownRigs() - for _, rigName := range rigs { - d.reapRigIdlePolecats(rigName, timeout) - } + d.rigPool.runPerRig(d.ctx, d.getKnownRigs(), func(ctx context.Context, rigName string) error { + d.reapRigIdlePolecats(rigName, idleTimeout) + return nil + }) } // reapRigIdlePolecats checks all polecats in a rig and kills idle sessions. @@ -2688,11 +2698,12 @@ func (d *Daemon) pruneStaleBranches() { } } - // Prune in each rig's git directory - for _, rigName := range d.getKnownRigs() { + // Prune in each rig's git directory (parallel — each rig is independent). + d.rigPool.runPerRig(d.ctx, d.getKnownRigs(), func(ctx context.Context, rigName string) error { rigPath := filepath.Join(d.config.TownRoot, rigName) pruneInDir(rigPath, rigName) - } + return nil + }) // Also prune in the town root itself (mayor clone) pruneInDir(d.config.TownRoot, "town-root") diff --git a/internal/daemon/worker.go b/internal/daemon/worker.go new file mode 100644 index 0000000000..58dcdf1cfc --- /dev/null +++ b/internal/daemon/worker.go @@ -0,0 +1,99 @@ +package daemon + +import ( + "context" + "log" + "sync" + "time" +) + +const ( + defaultRigConcurrency = 10 + defaultRigTimeout = 30 * time.Second +) + +// RigWorkerPool runs per-rig heartbeat operations with bounded concurrency +// and per-rig context timeouts. This prevents a slow or hung rig from +// blocking heartbeat operations on all other rigs. +// +// With N rigs and a serial loop, the heartbeat takes O(N × max_op_time). +// With the pool, it takes O(max_op_time) — one slow rig no longer gates all others. +type RigWorkerPool struct { + concurrency int + timeout time.Duration + logger *log.Logger +} + +// newRigWorkerPool creates a RigWorkerPool. +// Zero or negative values for concurrency and timeout fall back to package defaults. +func newRigWorkerPool(concurrency int, timeout time.Duration, logger *log.Logger) *RigWorkerPool { + if concurrency <= 0 { + concurrency = defaultRigConcurrency + } + if timeout <= 0 { + timeout = defaultRigTimeout + } + return &RigWorkerPool{ + concurrency: concurrency, + timeout: timeout, + logger: logger, + } +} + +// runPerRig executes fn once for each rig, with bounded concurrency and per-rig timeouts. +// +// Each invocation of fn receives a child context derived from parent with the pool's +// per-rig timeout applied. If fn respects its context (checks ctx.Done()), it will +// be canceled when the timeout fires. +// +// runPerRig blocks until all goroutines complete. Errors are counted and a single +// summary line is logged rather than per-rig noise. +func (p *RigWorkerPool) runPerRig( + parent context.Context, + rigs []string, + fn func(ctx context.Context, rigName string) error, +) { + if len(rigs) == 0 { + return + } + + sem := make(chan struct{}, p.concurrency) + var wg sync.WaitGroup + var mu sync.Mutex + var errCount int + + for _, r := range rigs { + wg.Add(1) + go func(rigName string) { + defer wg.Done() + + // Acquire a worker slot; block until one is available. + sem <- struct{}{} + defer func() { <-sem }() + + // Each rig gets its own timeout-bounded context so a slow rig + // can be signaled to stop without affecting other rigs. + ctx, cancel := context.WithTimeout(parent, p.timeout) + defer cancel() + + if err := fn(ctx, rigName); err != nil { + mu.Lock() + errCount++ + mu.Unlock() + if p.logger != nil { + p.logger.Printf("rig_worker: %s: %v", rigName, err) + } + } + }(r) + } + + wg.Wait() + + mu.Lock() + count := errCount + mu.Unlock() + + if count > 0 && p.logger != nil { + p.logger.Printf("rig_worker: %d/%d rig(s) had errors", count, len(rigs)) + } +} diff --git a/internal/daemon/worker_test.go b/internal/daemon/worker_test.go new file mode 100644 index 0000000000..fef71b3b3e --- /dev/null +++ b/internal/daemon/worker_test.go @@ -0,0 +1,171 @@ +package daemon + +import ( + "context" + "sync/atomic" + "testing" + "time" +) + +// TestRigWorkerPoolConcurrencyLimit verifies that the pool never runs more than +// the configured number of rigs simultaneously. +func TestRigWorkerPoolConcurrencyLimit(t *testing.T) { + const ( + numRigs = 20 + maxWorkers = 5 + ) + + pool := newRigWorkerPool(maxWorkers, 10*time.Second, nil) + + var active atomic.Int64 + var peak atomic.Int64 + + rigs := make([]string, numRigs) + for i := range rigs { + rigs[i] = "rig" + } + + pool.runPerRig(context.Background(), rigs, func(ctx context.Context, rigName string) error { + cur := active.Add(1) + // Record peak concurrency. + for { + p := peak.Load() + if cur <= p || peak.CompareAndSwap(p, cur) { + break + } + } + time.Sleep(5 * time.Millisecond) // hold the slot briefly + active.Add(-1) + return nil + }) + + got := peak.Load() + if got > maxWorkers { + t.Errorf("peak concurrency %d exceeded limit %d", got, maxWorkers) + } + if got == 0 { + t.Error("no rigs were processed") + } +} + +// TestRigWorkerPoolContextTimeout verifies that per-rig context timeouts fire and +// allow the remaining rigs to proceed unblocked. +func TestRigWorkerPoolContextTimeout(t *testing.T) { + const ( + numRigs = 5 + rigTimeout = 50 * time.Millisecond + slowDelay = 500 * time.Millisecond // much longer than timeout + ) + + pool := newRigWorkerPool(numRigs, rigTimeout, nil) + + var cancelled atomic.Int64 + var completed atomic.Int64 + + rigs := make([]string, numRigs) + for i := range rigs { + rigs[i] = "rig" + } + + pool.runPerRig(context.Background(), rigs, func(ctx context.Context, _ string) error { + select { + case <-time.After(slowDelay): + completed.Add(1) + return nil + case <-ctx.Done(): + cancelled.Add(1) + return ctx.Err() + } + }) + + if cancelled.Load() == 0 { + t.Error("expected at least one rig to be cancelled by timeout, got 0") + } + // All rigs should have responded (either completed or cancelled), not hung. + total := cancelled.Load() + completed.Load() + if total != numRigs { + t.Errorf("expected %d rigs total, got %d (cancelled=%d completed=%d)", + numRigs, total, cancelled.Load(), completed.Load()) + } +} + +// TestRigWorkerPoolSlowRigDoesNotBlockOthers verifies that one slow rig does not +// prevent the remaining rigs from completing within a reasonable wall-clock window. +func TestRigWorkerPoolSlowRigDoesNotBlockOthers(t *testing.T) { + const ( + slowRig = "slow-rig" + rigTimeout = 200 * time.Millisecond + fastDelay = 10 * time.Millisecond + ) + + pool := newRigWorkerPool(10, rigTimeout, nil) + + var fastDone atomic.Int64 + + rigs := []string{slowRig, "fast-1", "fast-2", "fast-3", "fast-4"} + + start := time.Now() + pool.runPerRig(context.Background(), rigs, func(ctx context.Context, rigName string) error { + if rigName == slowRig { + // Slow rig blocks until its context times out. + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(10 * time.Second): // never fires in practice + return nil + } + } + time.Sleep(fastDelay) + fastDone.Add(1) + return nil + }) + elapsed := time.Since(start) + + // All fast rigs must have completed. + if got := fastDone.Load(); got != 4 { + t.Errorf("expected 4 fast rigs to complete, got %d", got) + } + + // Total elapsed should be dominated by rigTimeout (≈200ms), not by a serial + // execution of the slow rig (which would exceed rigTimeout without the pool). + // Allow 3× to account for test environment jitter. + limit := 3 * rigTimeout + if elapsed > limit { + t.Errorf("runPerRig took %v, expected < %v (slow rig should not block overall)", elapsed, limit) + } +} + +// BenchmarkRigWorkerPool100RigsOneSlow measures the wall-clock time of a simulated +// heartbeat tick with 100 rigs, where one rig is slow (100ms). +// +// Run with: go test ./internal/daemon/ -bench=BenchmarkRigWorkerPool100RigsOneSlow -benchtime=5s +func BenchmarkRigWorkerPool100RigsOneSlow(b *testing.B) { + const ( + numRigs = 100 + slowIndex = 7 + slowDelay = 100 * time.Millisecond + fastDelay = 1 * time.Millisecond + rigTimeout = 5 * time.Second + ) + + pool := newRigWorkerPool(defaultRigConcurrency, rigTimeout, nil) + + rigs := make([]string, numRigs) + for i := range rigs { + rigs[i] = "rig" + } + + b.ResetTimer() + for range b.N { + i := 0 + pool.runPerRig(context.Background(), rigs, func(ctx context.Context, _ string) error { + delay := fastDelay + if i == slowIndex { + delay = slowDelay + } + i++ + time.Sleep(delay) + return nil + }) + } +}