Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 27 additions & 13 deletions pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ type Pool[C any] struct {

mu sync.Mutex
pendingAdds int
sessions map[string]Worker[C] // sessionID → pinned worker
registry SessionRegistry[C] // sessionID → pinned worker
inflight map[string]chan struct{} // sessionID → broadcast channel
lastAccessed map[string]time.Time // sessionID → last Acquire time (for TTL)
activeConns map[string]int32 // sessionID → active connections (for TTL)
Expand Down Expand Up @@ -168,7 +168,7 @@ func New[C any](factory WorkerFactory[C], opts ...Option) (*Pool[C], error) {
p := &Pool[C]{
factory: factory,
cfg: cfg,
sessions: make(map[string]Worker[C]),
registry: newLocalRegistry[C](),
inflight: make(map[string]chan struct{}),
lastAccessed: make(map[string]time.Time),
activeConns: make(map[string]int32), // initialize activeConns map
Expand Down Expand Up @@ -243,13 +243,20 @@ func (p *Pool[C]) wireWorker(w Worker[C]) {
// Blocks until a worker is available or ctx is cancelled.
func (p *Pool[C]) Acquire(ctx context.Context, sessionID string) (*Session[C], error) {
for {
var w Worker[C]
var err error
p.mu.Lock()

// ── FAST PATH ──────────────────────────────────────────────────────
// Session already pinned: return the existing worker immediately.
// Also update the last-accessed time so the TTL sweeper doesn't
// evict an actively-used session.
if w, ok := p.sessions[sessionID]; ok {
w, err = p.registry.Get(ctx, sessionID)
if err != nil {
p.mu.Unlock()
return nil, fmt.Errorf("herd: Acquire(%q): directory lookup failed: %w", sessionID, err)
}
if w != nil {
p.touchSession(sessionID)
p.activeConns[sessionID]++
p.mu.Unlock()
Expand Down Expand Up @@ -283,7 +290,6 @@ func (p *Pool[C]) Acquire(ctx context.Context, sessionID string) (*Session[C], e
p.maybeScaleUp()

// Block until a free worker arrives or we time out.
var w Worker[C]
select {
case w = <-p.available:
case <-ctx.Done():
Expand All @@ -302,7 +308,7 @@ func (p *Pool[C]) Acquire(ctx context.Context, sessionID string) (*Session[C], e
// This prevents giving a dead handle to the caller (and to any
// goroutines waiting on the inflight channel).
hCtx, hCancel := context.WithTimeout(ctx, 3*time.Second)
err := w.Healthy(hCtx)
err = w.Healthy(hCtx)
hCancel()

if err != nil {
Expand All @@ -319,7 +325,11 @@ func (p *Pool[C]) Acquire(ctx context.Context, sessionID string) (*Session[C], e

// Pin the worker to this session, record access time, and broadcast.
p.mu.Lock()
p.sessions[sessionID] = w
if err = p.registry.Put(ctx, sessionID, w); err != nil {
p.mu.Unlock()
close(ch)
return nil, fmt.Errorf("herd: Acquire(%q): failed to pin session: %w", sessionID, err)
}
p.lastAccessed[sessionID] = time.Now()

// Increment active connections immediately
Expand All @@ -343,7 +353,7 @@ func (p *Pool[C]) Acquire(ctx context.Context, sessionID string) (*Session[C], e
// available channel. Internal; external callers use Session.Release().
func (p *Pool[C]) release(sessionID string, w Worker[C]) {
p.mu.Lock()
delete(p.sessions, sessionID)
_ = p.registry.Delete(context.Background(), sessionID)
delete(p.lastAccessed, sessionID)
// Validate the worker wasn't evicted by a crash or health check
isValid := false
Expand Down Expand Up @@ -398,8 +408,9 @@ func (p *Pool[C]) releaseConn(sessionID string) {
// sessionID, then calls the user-supplied crash handler.
func (p *Pool[C]) onCrash(sessionID string) {
p.mu.Lock()
w, hadSession := p.sessions[sessionID]
delete(p.sessions, sessionID)
w, _ := p.registry.Get(context.Background(), sessionID)
hadSession := w != nil
_ = p.registry.Delete(context.Background(), sessionID)

// If another Acquire is in-flight for this sessionID, close its channel
// so the waiting goroutine unblocks and returns an error rather than
Expand Down Expand Up @@ -500,18 +511,21 @@ func (p *Pool[C]) healthCheckLoop() {
select {
case <-ticker.C:
p.mu.Lock()
workers := make([]Worker[C], len(p.workers))
copy(workers, p.workers)
workersSnapshot := make([]Worker[C], len(p.workers))
copy(workersSnapshot, p.workers)
p.mu.Unlock()

for _, w := range workers {
for _, w := range workersSnapshot {
hCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
err := w.Healthy(hCtx)
cancel()
if err != nil {
log.Printf("[pool] health-check: worker %s unhealthy (%v) — closing", w.ID(), err)
_ = w.Close()
p.removeWorker(w)
// If this worker was pinned to any session, it should be removed
// but our List doesn't map worker -> session easily.
// The next Acquire will fail health check and clean it up.
p.maybeScaleUp()
}
}
Expand All @@ -538,7 +552,7 @@ func (p *Pool[C]) Stats() PoolStats {
return PoolStats{
TotalWorkers: len(p.workers),
AvailableWorkers: len(p.available),
ActiveSessions: len(p.sessions),
ActiveSessions: p.registry.Len(), // Add Len() to registry or use alternative
InflightAcquires: len(p.inflight),
Node: nodeStats,
}
Expand Down
57 changes: 35 additions & 22 deletions pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,19 +100,25 @@ func newTestPool(t *testing.T, workers ...*stubWorker) *Pool[*stubClient] {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel) // ensure no goroutine leaks after the test exits

cfg := defaultConfig()
cfg.max = len(workers) // Ensure cfg.max is at least the number of workers provided
if cfg.max == 0 {
cfg.max = 1 // Default to 1 if no workers provided
}

p := &Pool[*stubClient]{
factory: factory,
cfg: defaultConfig(),
sessions: make(map[string]Worker[*stubClient]),
cfg: cfg,
registry: newLocalRegistry[*stubClient](),
inflight: make(map[string]chan struct{}),
lastAccessed: make(map[string]time.Time),
activeConns: make(map[string]int32),
workers: make([]Worker[*stubClient], 0, len(workers)),
available: make(chan Worker[*stubClient], len(workers)),
activeConns: make(map[string]int32), // initialize activeConns map
workers: make([]Worker[*stubClient], 0, cfg.max),
available: make(chan Worker[*stubClient], cfg.max),
done: make(chan struct{}),
ctx: ctx,
cancel: cancel,
}
p.ctx = ctx
p.cancel = cancel

// Manually wire workers (same logic as New → wireWorker, minus crash hookup)
for _, w := range workers {
Expand Down Expand Up @@ -175,6 +181,10 @@ func TestSameSessionSingleflight(t *testing.T) {
if stats.AvailableWorkers != 1 {
t.Errorf("expected 1 available worker (w2 untouched), got %d", stats.AvailableWorkers)
}
// Verify one session is pinned
if n := pool.registry.Len(); n != 1 {
t.Errorf("expected 1 session pinned, got %d", n)
}
}

// ---------------------------------------------------------------------------
Expand Down Expand Up @@ -228,8 +238,9 @@ func TestDifferentSessionsIsolated(t *testing.T) {
if stats.AvailableWorkers != 0 {
t.Errorf("expected 0 available workers (all pinned), got %d", stats.AvailableWorkers)
}
if stats.ActiveSessions != 3 {
t.Errorf("expected 3 active sessions, got %d", stats.ActiveSessions)
// Verify 3 sessions are pinned
if n := pool.registry.Len(); n != 3 {
t.Errorf("expected 3 sessions pinned, got %d", n)
}
}

Expand Down Expand Up @@ -258,10 +269,8 @@ func TestCrashDuringAcquire(t *testing.T) {
}

// The session must not exist in the map
pool.mu.Lock()
_, exists := pool.sessions["session-y"]
pool.mu.Unlock()
if exists {
w_dead, _ := pool.registry.Get(context.Background(), "session-y")
if w_dead != nil {
t.Error("session-y should not exist in session map after failed Acquire")
}
}
Expand All @@ -286,6 +295,14 @@ func TestReleaseReturnsWorkerToPool(t *testing.T) {
if got := pool.Stats().AvailableWorkers; got != 0 {
t.Fatalf("expected 0 available after Acquire, got %d", got)
}
// Verify session is pinned
if n := pool.registry.Len(); n != 1 {
t.Fatalf("expected 1 session pinned, got %d", n)
}
sessions, _ := pool.registry.List(context.Background())
if worker, ok := sessions["session-z"]; !ok || worker != w {
t.Fatalf("session-z should be pinned to worker w1")
}

sess.Release()

Expand All @@ -295,10 +312,8 @@ func TestReleaseReturnsWorkerToPool(t *testing.T) {
}

// And the session should be gone from the map
pool.mu.Lock()
_, exists := pool.sessions["session-z"]
pool.mu.Unlock()
if exists {
w_gone, _ := pool.registry.Get(context.Background(), "session-z")
if w_gone != nil {
t.Error("session-z should not exist in session map after Release")
}
}
Expand Down Expand Up @@ -334,11 +349,9 @@ func TestTTLSweepExpiresSessions(t *testing.T) {
pool.sweepExpired()

// Session should be gone from the affinity map.
pool.mu.Lock()
_, stillExists := pool.sessions["session-ttl"]
pool.mu.Unlock()
if stillExists {
t.Error("expected session-ttl to be evicted by TTL sweeper")
w_evicted, _ := pool.registry.Get(context.Background(), "session-ttl")
if w_evicted != nil {
t.Error("session-ttl should have been evicted by TTL sweeper")
}

// Worker should be back in the available channel.
Expand Down
79 changes: 79 additions & 0 deletions registry.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package herd

import (
"context"
"sync"
)

// SessionRegistry tracks which workers are pinned to which session IDs.
// In a distributed setup (Enterprise), this registry is shared across multiple nodes.
type SessionRegistry[C any] interface {
// Get returns the worker pinned to sessionID.
// Returns (nil, nil) if no session exists for this ID.
Get(ctx context.Context, sessionID string) (Worker[C], error)

// Put pins a worker to a sessionID.
Put(ctx context.Context, sessionID string, w Worker[C]) error

// Delete removes the pinning for sessionID.
Delete(ctx context.Context, sessionID string) error

// List returns a snapshot of all currently active sessions.
// Primarily used for background health checks and cleanup.
List(ctx context.Context) (map[string]Worker[C], error)

// Len returns the number of active sessions.
Len() int
}

// localRegistry is the default in-memory implementation for OSS.
// It is not thread-safe by itself; it expects the Pool to hold p.mu
// when calling these methods for now, but includes its own RWMutex
// to future-proof it.
type localRegistry[C any] struct {
mu sync.RWMutex
sessions map[string]Worker[C]
}

func newLocalRegistry[C any]() *localRegistry[C] {
return &localRegistry[C]{
sessions: make(map[string]Worker[C]),
}
}

func (r *localRegistry[C]) Get(_ context.Context, sessionID string) (Worker[C], error) {
r.mu.RLock()
defer r.mu.RUnlock()
return r.sessions[sessionID], nil
}

func (r *localRegistry[C]) Put(_ context.Context, sessionID string, w Worker[C]) error {
r.mu.Lock()
defer r.mu.Unlock()
r.sessions[sessionID] = w
return nil
}

func (r *localRegistry[C]) Delete(_ context.Context, sessionID string) error {
r.mu.Lock()
defer r.mu.Unlock()
delete(r.sessions, sessionID)
return nil
}

func (r *localRegistry[C]) List(_ context.Context) (map[string]Worker[C], error) {
r.mu.RLock()
defer r.mu.RUnlock()
// Return a copy to avoid external mutation of the internal map
copyMap := make(map[string]Worker[C], len(r.sessions))
for k, v := range r.sessions {
copyMap[k] = v
}
return copyMap, nil
}

func (r *localRegistry[C]) Len() int {
r.mu.RLock()
defer r.mu.RUnlock()
return len(r.sessions)
}
6 changes: 4 additions & 2 deletions ttl.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
package herd

import (
"context"
"log"
"time"
)
Expand Down Expand Up @@ -61,6 +62,7 @@ func (p *Pool[C]) sweepExpired() {
sessionID string
worker Worker[C]
}
sessions, _ := p.registry.List(context.Background())
for sid, lastSeen := range p.lastAccessed {
if now.Sub(lastSeen) < p.cfg.ttl {
continue
Expand All @@ -69,7 +71,7 @@ func (p *Pool[C]) sweepExpired() {
if p.activeConns[sid] > 0 {
continue
}
w, ok := p.sessions[sid]
w, ok := sessions[sid]
if !ok {
// Session already released — clean up orphaned timestamp
delete(p.lastAccessed, sid)
Expand All @@ -79,7 +81,7 @@ func (p *Pool[C]) sweepExpired() {
sessionID string
worker Worker[C]
}{sid, w})
delete(p.sessions, sid)
_ = p.registry.Delete(context.Background(), sid)
delete(p.lastAccessed, sid)
}
p.mu.Unlock()
Expand Down
Loading