diff --git a/pool.go b/pool.go index ffee418..ab635d4 100644 --- a/pool.go +++ b/pool.go @@ -137,7 +137,7 @@ type Pool[C any] struct { mu sync.Mutex pendingAdds int - sessions map[string]Worker[C] // sessionID → pinned worker + registry SessionRegistry[C] // sessionID → pinned worker inflight map[string]chan struct{} // sessionID → broadcast channel lastAccessed map[string]time.Time // sessionID → last Acquire time (for TTL) activeConns map[string]int32 // sessionID → active connections (for TTL) @@ -168,7 +168,7 @@ func New[C any](factory WorkerFactory[C], opts ...Option) (*Pool[C], error) { p := &Pool[C]{ factory: factory, cfg: cfg, - sessions: make(map[string]Worker[C]), + registry: newLocalRegistry[C](), inflight: make(map[string]chan struct{}), lastAccessed: make(map[string]time.Time), activeConns: make(map[string]int32), // initialize activeConns map @@ -243,13 +243,20 @@ func (p *Pool[C]) wireWorker(w Worker[C]) { // Blocks until a worker is available or ctx is cancelled. func (p *Pool[C]) Acquire(ctx context.Context, sessionID string) (*Session[C], error) { for { + var w Worker[C] + var err error p.mu.Lock() // ── FAST PATH ────────────────────────────────────────────────────── // Session already pinned: return the existing worker immediately. // Also update the last-accessed time so the TTL sweeper doesn't // evict an actively-used session. - if w, ok := p.sessions[sessionID]; ok { + w, err = p.registry.Get(ctx, sessionID) + if err != nil { + p.mu.Unlock() + return nil, fmt.Errorf("herd: Acquire(%q): directory lookup failed: %w", sessionID, err) + } + if w != nil { p.touchSession(sessionID) p.activeConns[sessionID]++ p.mu.Unlock() @@ -283,7 +290,6 @@ func (p *Pool[C]) Acquire(ctx context.Context, sessionID string) (*Session[C], e p.maybeScaleUp() // Block until a free worker arrives or we time out. - var w Worker[C] select { case w = <-p.available: case <-ctx.Done(): @@ -302,7 +308,7 @@ func (p *Pool[C]) Acquire(ctx context.Context, sessionID string) (*Session[C], e // This prevents giving a dead handle to the caller (and to any // goroutines waiting on the inflight channel). hCtx, hCancel := context.WithTimeout(ctx, 3*time.Second) - err := w.Healthy(hCtx) + err = w.Healthy(hCtx) hCancel() if err != nil { @@ -319,7 +325,11 @@ func (p *Pool[C]) Acquire(ctx context.Context, sessionID string) (*Session[C], e // Pin the worker to this session, record access time, and broadcast. p.mu.Lock() - p.sessions[sessionID] = w + if err = p.registry.Put(ctx, sessionID, w); err != nil { + p.mu.Unlock() + close(ch) + return nil, fmt.Errorf("herd: Acquire(%q): failed to pin session: %w", sessionID, err) + } p.lastAccessed[sessionID] = time.Now() // Increment active connections immediately @@ -343,7 +353,7 @@ func (p *Pool[C]) Acquire(ctx context.Context, sessionID string) (*Session[C], e // available channel. Internal; external callers use Session.Release(). func (p *Pool[C]) release(sessionID string, w Worker[C]) { p.mu.Lock() - delete(p.sessions, sessionID) + _ = p.registry.Delete(context.Background(), sessionID) delete(p.lastAccessed, sessionID) // Validate the worker wasn't evicted by a crash or health check isValid := false @@ -398,8 +408,9 @@ func (p *Pool[C]) releaseConn(sessionID string) { // sessionID, then calls the user-supplied crash handler. func (p *Pool[C]) onCrash(sessionID string) { p.mu.Lock() - w, hadSession := p.sessions[sessionID] - delete(p.sessions, sessionID) + w, _ := p.registry.Get(context.Background(), sessionID) + hadSession := w != nil + _ = p.registry.Delete(context.Background(), sessionID) // If another Acquire is in-flight for this sessionID, close its channel // so the waiting goroutine unblocks and returns an error rather than @@ -500,11 +511,11 @@ func (p *Pool[C]) healthCheckLoop() { select { case <-ticker.C: p.mu.Lock() - workers := make([]Worker[C], len(p.workers)) - copy(workers, p.workers) + workersSnapshot := make([]Worker[C], len(p.workers)) + copy(workersSnapshot, p.workers) p.mu.Unlock() - for _, w := range workers { + for _, w := range workersSnapshot { hCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) err := w.Healthy(hCtx) cancel() @@ -512,6 +523,9 @@ func (p *Pool[C]) healthCheckLoop() { log.Printf("[pool] health-check: worker %s unhealthy (%v) — closing", w.ID(), err) _ = w.Close() p.removeWorker(w) + // If this worker was pinned to any session, it should be removed + // but our List doesn't map worker -> session easily. + // The next Acquire will fail health check and clean it up. p.maybeScaleUp() } } @@ -538,7 +552,7 @@ func (p *Pool[C]) Stats() PoolStats { return PoolStats{ TotalWorkers: len(p.workers), AvailableWorkers: len(p.available), - ActiveSessions: len(p.sessions), + ActiveSessions: p.registry.Len(), // Add Len() to registry or use alternative InflightAcquires: len(p.inflight), Node: nodeStats, } diff --git a/pool_test.go b/pool_test.go index 8b2ef87..8142b48 100644 --- a/pool_test.go +++ b/pool_test.go @@ -100,19 +100,25 @@ func newTestPool(t *testing.T, workers ...*stubWorker) *Pool[*stubClient] { ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) // ensure no goroutine leaks after the test exits + cfg := defaultConfig() + cfg.max = len(workers) // Ensure cfg.max is at least the number of workers provided + if cfg.max == 0 { + cfg.max = 1 // Default to 1 if no workers provided + } + p := &Pool[*stubClient]{ factory: factory, - cfg: defaultConfig(), - sessions: make(map[string]Worker[*stubClient]), + cfg: cfg, + registry: newLocalRegistry[*stubClient](), inflight: make(map[string]chan struct{}), lastAccessed: make(map[string]time.Time), - activeConns: make(map[string]int32), - workers: make([]Worker[*stubClient], 0, len(workers)), - available: make(chan Worker[*stubClient], len(workers)), + activeConns: make(map[string]int32), // initialize activeConns map + workers: make([]Worker[*stubClient], 0, cfg.max), + available: make(chan Worker[*stubClient], cfg.max), done: make(chan struct{}), - ctx: ctx, - cancel: cancel, } + p.ctx = ctx + p.cancel = cancel // Manually wire workers (same logic as New → wireWorker, minus crash hookup) for _, w := range workers { @@ -175,6 +181,10 @@ func TestSameSessionSingleflight(t *testing.T) { if stats.AvailableWorkers != 1 { t.Errorf("expected 1 available worker (w2 untouched), got %d", stats.AvailableWorkers) } + // Verify one session is pinned + if n := pool.registry.Len(); n != 1 { + t.Errorf("expected 1 session pinned, got %d", n) + } } // --------------------------------------------------------------------------- @@ -228,8 +238,9 @@ func TestDifferentSessionsIsolated(t *testing.T) { if stats.AvailableWorkers != 0 { t.Errorf("expected 0 available workers (all pinned), got %d", stats.AvailableWorkers) } - if stats.ActiveSessions != 3 { - t.Errorf("expected 3 active sessions, got %d", stats.ActiveSessions) + // Verify 3 sessions are pinned + if n := pool.registry.Len(); n != 3 { + t.Errorf("expected 3 sessions pinned, got %d", n) } } @@ -258,10 +269,8 @@ func TestCrashDuringAcquire(t *testing.T) { } // The session must not exist in the map - pool.mu.Lock() - _, exists := pool.sessions["session-y"] - pool.mu.Unlock() - if exists { + w_dead, _ := pool.registry.Get(context.Background(), "session-y") + if w_dead != nil { t.Error("session-y should not exist in session map after failed Acquire") } } @@ -286,6 +295,14 @@ func TestReleaseReturnsWorkerToPool(t *testing.T) { if got := pool.Stats().AvailableWorkers; got != 0 { t.Fatalf("expected 0 available after Acquire, got %d", got) } + // Verify session is pinned + if n := pool.registry.Len(); n != 1 { + t.Fatalf("expected 1 session pinned, got %d", n) + } + sessions, _ := pool.registry.List(context.Background()) + if worker, ok := sessions["session-z"]; !ok || worker != w { + t.Fatalf("session-z should be pinned to worker w1") + } sess.Release() @@ -295,10 +312,8 @@ func TestReleaseReturnsWorkerToPool(t *testing.T) { } // And the session should be gone from the map - pool.mu.Lock() - _, exists := pool.sessions["session-z"] - pool.mu.Unlock() - if exists { + w_gone, _ := pool.registry.Get(context.Background(), "session-z") + if w_gone != nil { t.Error("session-z should not exist in session map after Release") } } @@ -334,11 +349,9 @@ func TestTTLSweepExpiresSessions(t *testing.T) { pool.sweepExpired() // Session should be gone from the affinity map. - pool.mu.Lock() - _, stillExists := pool.sessions["session-ttl"] - pool.mu.Unlock() - if stillExists { - t.Error("expected session-ttl to be evicted by TTL sweeper") + w_evicted, _ := pool.registry.Get(context.Background(), "session-ttl") + if w_evicted != nil { + t.Error("session-ttl should have been evicted by TTL sweeper") } // Worker should be back in the available channel. diff --git a/registry.go b/registry.go new file mode 100644 index 0000000..5e1d872 --- /dev/null +++ b/registry.go @@ -0,0 +1,79 @@ +package herd + +import ( + "context" + "sync" +) + +// SessionRegistry tracks which workers are pinned to which session IDs. +// In a distributed setup (Enterprise), this registry is shared across multiple nodes. +type SessionRegistry[C any] interface { + // Get returns the worker pinned to sessionID. + // Returns (nil, nil) if no session exists for this ID. + Get(ctx context.Context, sessionID string) (Worker[C], error) + + // Put pins a worker to a sessionID. + Put(ctx context.Context, sessionID string, w Worker[C]) error + + // Delete removes the pinning for sessionID. + Delete(ctx context.Context, sessionID string) error + + // List returns a snapshot of all currently active sessions. + // Primarily used for background health checks and cleanup. + List(ctx context.Context) (map[string]Worker[C], error) + + // Len returns the number of active sessions. + Len() int +} + +// localRegistry is the default in-memory implementation for OSS. +// It is not thread-safe by itself; it expects the Pool to hold p.mu +// when calling these methods for now, but includes its own RWMutex +// to future-proof it. +type localRegistry[C any] struct { + mu sync.RWMutex + sessions map[string]Worker[C] +} + +func newLocalRegistry[C any]() *localRegistry[C] { + return &localRegistry[C]{ + sessions: make(map[string]Worker[C]), + } +} + +func (r *localRegistry[C]) Get(_ context.Context, sessionID string) (Worker[C], error) { + r.mu.RLock() + defer r.mu.RUnlock() + return r.sessions[sessionID], nil +} + +func (r *localRegistry[C]) Put(_ context.Context, sessionID string, w Worker[C]) error { + r.mu.Lock() + defer r.mu.Unlock() + r.sessions[sessionID] = w + return nil +} + +func (r *localRegistry[C]) Delete(_ context.Context, sessionID string) error { + r.mu.Lock() + defer r.mu.Unlock() + delete(r.sessions, sessionID) + return nil +} + +func (r *localRegistry[C]) List(_ context.Context) (map[string]Worker[C], error) { + r.mu.RLock() + defer r.mu.RUnlock() + // Return a copy to avoid external mutation of the internal map + copyMap := make(map[string]Worker[C], len(r.sessions)) + for k, v := range r.sessions { + copyMap[k] = v + } + return copyMap, nil +} + +func (r *localRegistry[C]) Len() int { + r.mu.RLock() + defer r.mu.RUnlock() + return len(r.sessions) +} diff --git a/ttl.go b/ttl.go index b08530e..55561c7 100644 --- a/ttl.go +++ b/ttl.go @@ -27,6 +27,7 @@ package herd import ( + "context" "log" "time" ) @@ -61,6 +62,7 @@ func (p *Pool[C]) sweepExpired() { sessionID string worker Worker[C] } + sessions, _ := p.registry.List(context.Background()) for sid, lastSeen := range p.lastAccessed { if now.Sub(lastSeen) < p.cfg.ttl { continue @@ -69,7 +71,7 @@ func (p *Pool[C]) sweepExpired() { if p.activeConns[sid] > 0 { continue } - w, ok := p.sessions[sid] + w, ok := sessions[sid] if !ok { // Session already released — clean up orphaned timestamp delete(p.lastAccessed, sid) @@ -79,7 +81,7 @@ func (p *Pool[C]) sweepExpired() { sessionID string worker Worker[C] }{sid, w}) - delete(p.sessions, sid) + _ = p.registry.Delete(context.Background(), sid) delete(p.lastAccessed, sid) } p.mu.Unlock()