diff --git a/.gitignore b/.gitignore index fe58ee8..08c74be 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,7 @@ examples/ollama/ollama # macOS .DS_Store + +# Coverage +coverage.out +coverage.html diff --git a/README.md b/README.md index 10949e6..0e0e48d 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # Herd +![Architecture Comparison](./assets/arch_comparison.png) + **Herd** is a session-affine process pool for Go. It manages a fleet of OS subprocess "workers" and routes incoming requests to the correct worker based on an arbitrary session ID. ### The Core Invariant diff --git a/assets/arch_comparison.png b/assets/arch_comparison.png new file mode 100644 index 0000000..5b57697 Binary files /dev/null and b/assets/arch_comparison.png differ diff --git a/examples/ollama/main.go b/examples/ollama/main.go index 2bcc316..e35a738 100644 --- a/examples/ollama/main.go +++ b/examples/ollama/main.go @@ -62,7 +62,9 @@ func main() { WithEnv("OLLAMA_HOST=127.0.0.1:{{.Port}}"). WithEnv("OLLAMA_MODELS=" + *modelsDir). WithHealthPath("/"). // ollama: GET / → 200 "Ollama is running" - WithStartTimeout(2 * time.Minute) + WithStartTimeout(2 * time.Minute). + WithMemoryLimit(1024 * 1024 * 512) // 512MB + // ── Pool ─────────────────────────────────────────────────────────────── pool, err := herd.New(factory, diff --git a/factory.go b/factory.go index dfad06f..f6efffa 100644 --- a/factory.go +++ b/factory.go @@ -138,12 +138,13 @@ func (w *processWorker) monitor() { // // pool, err := herd.New(herd.NewProcessFactory("./my-binary", "--port", "{{.Port}}")) type ProcessFactory struct { - binary string - args []string // may contain "{{.Port}}" — replaced at spawn time - extraEnv []string // additional KEY=VALUE env vars; "{{.Port}}" is replaced here too - healthPath string // path to poll for liveness; defaults to "/health" - startTimeout time.Duration // maximum time to wait for the first successful health check - counter atomic.Int64 + binary string + args []string // may contain "{{.Port}}" — replaced at spawn time + extraEnv []string // additional KEY=VALUE env vars; "{{.Port}}" is replaced here too + healthPath string // path to poll for liveness; defaults to "/health" + startTimeout time.Duration // maximum time to wait for the first successful health check + memoryLimitBytes uint64 // maximum memory in bytes for the child process + counter atomic.Int64 } // NewProcessFactory returns a ProcessFactory that spawns the given binary. @@ -198,6 +199,19 @@ func (f *ProcessFactory) WithStartTimeout(d time.Duration) *ProcessFactory { return f } +// WithMemoryLimit sets a soft virtual memory limit on the worker process in bytes. +// +// On Linux, this is enforced using a shell wrapper (`sh -c ulimit -v `). +// If the worker exceeds this limit, the OS will kill it (typically via SIGSEGV/SIGABRT), +// and herd's crash handler (if configured via WithCrashHandler) will clean up the session. +// +// On macOS and platforms where `ulimit` cannot be modified by unprivileged users, +// the worker will still spawn gracefully but the memory limit will act as a no-op. +func (f *ProcessFactory) WithMemoryLimit(limitBytes uint64) *ProcessFactory { + f.memoryLimitBytes = limitBytes + return f +} + // Spawn implements WorkerFactory[*http.Client]. // It allocates a free port, starts the binary, and blocks until the worker // passes a /health check or ctx is cancelled. @@ -223,7 +237,21 @@ func (f *ProcessFactory) Spawn(ctx context.Context) (Worker[*http.Client], error resolvedEnv[i] = strings.ReplaceAll(e, "{{.Port}}", portStr) } - cmd := exec.CommandContext(ctx, f.binary, resolvedArgs...) + var cmd *exec.Cmd + if f.memoryLimitBytes > 0 { + limitKB := f.memoryLimitBytes / 1024 + // Execute via shell wrapper to set ulimit. + // On macOS, ulimit -v might fail (Invalid argument) so we gracefully fallback using '|| true'. + script := fmt.Sprintf("ulimit -v %d 2>/dev/null || true; exec \"$@\"", limitKB) + + shellArgs := []string{"-c", script, "--", f.binary} + shellArgs = append(shellArgs, resolvedArgs...) + + cmd = exec.CommandContext(ctx, "sh", shellArgs...) + } else { + cmd = exec.CommandContext(ctx, f.binary, resolvedArgs...) + } + cmd.Env = append(os.Environ(), append([]string{"PORT=" + portStr}, resolvedEnv...)...) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr diff --git a/factory_test.go b/factory_test.go new file mode 100644 index 0000000..04952fe --- /dev/null +++ b/factory_test.go @@ -0,0 +1,109 @@ +package herd + +import ( + "context" + "net/http" + "os" + "testing" + "time" +) + +// TestHelperProcess isn't a real test. It's used as a dummy worker process +// for ProcessFactory tests. +func TestHelperProcess(t *testing.T) { + if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" { + return + } + + mode := os.Getenv("HELPER_MODE") + port := os.Getenv("PORT") + + if mode == "immediate_exit" { + os.Exit(1) + } + + if mode == "hang" { + // Just sleep forever, never start HTTP server + time.Sleep(1 * time.Hour) + os.Exit(0) + } + + http.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + }) + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + }) + + // Start server + if err := http.ListenAndServe("127.0.0.1:"+port, nil); err != nil { + os.Exit(1) + } + os.Exit(0) +} + +func TestProcessFactory_SpawnSuccess(t *testing.T) { + factory := NewProcessFactory(os.Args[0], "-test.run=TestHelperProcess") + factory.WithEnv("GO_WANT_HELPER_PROCESS=1") + factory.WithEnv("HELPER_MODE=success") + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + worker, err := factory.Spawn(ctx) + if err != nil { + t.Fatalf("Spawn failed: %v", err) + } + defer worker.Close() + + if worker.ID() == "" { + t.Errorf("worker ID is empty") + } + if worker.Address() == "" { + t.Errorf("worker Address is empty") + } + + // Double check health passes + err = worker.Healthy(ctx) + if err != nil { + t.Errorf("worker.Healthy failed: %v", err) + } +} + +func TestProcessFactory_SpawnImmediateExit(t *testing.T) { + factory := NewProcessFactory(os.Args[0], "-test.run=TestHelperProcess") + factory.WithEnv("GO_WANT_HELPER_PROCESS=1") + factory.WithEnv("HELPER_MODE=immediate_exit") + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + _, err := factory.Spawn(ctx) + if err == nil { + t.Fatalf("expected Spawn to fail when process exits immediately, got nil") + } +} + +func TestProcessFactory_SpawnTimeout(t *testing.T) { + factory := NewProcessFactory(os.Args[0], "-test.run=TestHelperProcess") + factory.WithEnv("GO_WANT_HELPER_PROCESS=1") + factory.WithEnv("HELPER_MODE=hang") + factory.WithStartTimeout(500 * time.Millisecond) // Short timeout + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + _, err := factory.Spawn(ctx) + if err == nil { + t.Fatalf("expected Spawn to fail due to timeout, got nil") + } +} + +func TestProcessFactory_WithMemoryLimit(t *testing.T) { + factory := NewProcessFactory("echo", "hello").WithMemoryLimit(1024 * 1024) + if factory.memoryLimitBytes != 1024*1024 { + t.Errorf("expected 1024*1024 limit bytes, got %d", factory.memoryLimitBytes) + } +} diff --git a/integration_test.go b/integration_test.go new file mode 100644 index 0000000..ceb4b6a --- /dev/null +++ b/integration_test.go @@ -0,0 +1,89 @@ +package herd_test + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "os" + "testing" + "time" + + "github.com/hackstrix/herd" + "github.com/hackstrix/herd/proxy" +) + +// In integration test, we use the same TestHelperProcess defined in factory_test.go +// as our dummy target. The TestMain func or factory_test setup handles it. +func TestIntegration_EndToEnd(t *testing.T) { + // Stand up a ProcessFactory that spawns our dummy HTTP server + factory := herd.NewProcessFactory(os.Args[0], "-test.run=TestHelperProcess") + factory.WithEnv("GO_WANT_HELPER_PROCESS=1") + factory.WithEnv("HELPER_MODE=success") + + // Create pool: 1 worker min, 1 max. + pool, err := herd.New(factory, herd.WithAutoScale(1, 1), herd.WithTTL(500*time.Millisecond)) + if err != nil { + t.Fatalf("failed to create pool: %v", err) + } + defer pool.Shutdown(context.Background()) + + // Wait briefly for min workers to be spawned + time.Sleep(200 * time.Millisecond) + + // Build the reverse proxy using a custom session ID extractor + p := proxy.NewReverseProxy(pool, func(r *http.Request) string { + return r.Header.Get("X-Session-ID") + }) + + proxyServer := httptest.NewServer(p) + defer proxyServer.Close() + + // ---------------------------------------------------- + // Request 1: Should successfully hit the dummy worker + // ---------------------------------------------------- + req1, _ := http.NewRequest("GET", proxyServer.URL+"/", nil) + req1.Header.Set("X-Session-ID", "integration-session-1") + + client := proxyServer.Client() + resp1, err := client.Do(req1) + if err != nil { + t.Fatalf("proxy request 1 failed: %v", err) + } + defer resp1.Body.Close() + + if resp1.StatusCode != http.StatusOK { + t.Fatalf("expected proxy to return 200 OK, got %d", resp1.StatusCode) + } + + body1, _ := io.ReadAll(resp1.Body) + if string(body1) != "OK" { + t.Errorf("expected standard 'OK' from dummy worker, got %q", string(body1)) + } + + // ---------------------------------------------------- + // Pool Stats Check + // ---------------------------------------------------- + // After the proxy request finishes, the session is released and worker returned to pool. + stats := pool.Stats() + if stats.ActiveSessions != 0 { + t.Errorf("expected 0 active sessions after request, got %d", stats.ActiveSessions) + } + if stats.AvailableWorkers != 1 { + t.Errorf("expected 1 available worker, got %d", stats.AvailableWorkers) + } + + // ---------------------------------------------------- + // TTL Expiry Check + // ---------------------------------------------------- + // Wait out the TTL logic (500ms) + time.Sleep(1 * time.Second) + + // Since we set TTL=500ms, the session should be evicted, though with min=1, the worker is kept! + // Actually, TTL sweeper drops the worker completely, then the background maintenance + // loop will immediately notice we are < min and spawn a new one. + statsAfterTTL := pool.Stats() + if statsAfterTTL.ActiveSessions != 0 { + t.Errorf("expected 0 active sessions, got %d", statsAfterTTL.ActiveSessions) + } +} diff --git a/justfile b/justfile new file mode 100644 index 0000000..dcabc09 --- /dev/null +++ b/justfile @@ -0,0 +1,27 @@ +# Default recipe +default: test-race + +# Run standard tests +test: + go test ./... + +# Run tests with race detector +test-race: + go test -race ./... + +# Run golangci-lint +lint: + golangci-lint run ./... + +# Tidy module dependencies +tidy: + go mod tidy + +# Generate coverage report +coverage: + go test -coverprofile=coverage.out ./... + go tool cover -html=coverage.out -o coverage.html + @echo "Coverage report written to coverage.html" + +# CI pipeline shortcut (lint, tidy, test-race) +ci: lint tidy test-race diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go new file mode 100644 index 0000000..755e5f1 --- /dev/null +++ b/proxy/proxy_test.go @@ -0,0 +1,142 @@ +package proxy_test + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/hackstrix/herd" + "github.com/hackstrix/herd/proxy" +) + +// --------------------------------------------------------------------------- +// Mock Pool +// --------------------------------------------------------------------------- + +type mockClient struct{} + +type mockWorker struct { + id string + address string +} + +func (m *mockWorker) ID() string { return m.id } +func (m *mockWorker) Address() string { return m.address } +func (m *mockWorker) Client() *mockClient { return &mockClient{} } +func (m *mockWorker) Healthy(ctx context.Context) error { return nil } +func (m *mockWorker) Close() error { return nil } + +type mockFactory struct { + worker *mockWorker +} + +func (m *mockFactory) Spawn(ctx context.Context) (herd.Worker[*mockClient], error) { + return m.worker, nil +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +func TestReverseProxy_SuccessRoutesRequest(t *testing.T) { + // 1. Create a dummy downstream server that the worker "represents" + targetServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Backend-Handled", "true") + w.WriteHeader(http.StatusOK) + w.Write([]byte("hello from worker")) + })) + defer targetServer.Close() + + // 2. Build the pool with our mock setup + factory := &mockFactory{ + worker: &mockWorker{ + id: "worker-1", + address: targetServer.URL, + }, + } + pool, err := herd.New(factory, herd.WithAutoScale(1, 1)) + if err != nil { + t.Fatalf("failed to create pool: %v", err) + } + + // 3. Create the proxy + p := proxy.NewReverseProxy(pool, func(r *http.Request) string { + return r.Header.Get("X-Session-ID") + }) + + proxyServer := httptest.NewServer(p) + defer proxyServer.Close() + + // 4. Send a request hitting the proxy with session ID + req, _ := http.NewRequest("GET", proxyServer.URL+"/", nil) + req.Header.Set("X-Session-ID", "session-123") + + client := proxyServer.Client() + resp, err := client.Do(req) + if err != nil { + t.Fatalf("proxy request failed: %v", err) + } + defer resp.Body.Close() + + // 5. Verify the route occurred + if resp.StatusCode != http.StatusOK { + t.Errorf("expected 200 OK, got %d", resp.StatusCode) + } + if got := resp.Header.Get("X-Backend-Handled"); got != "true" { + t.Errorf("expected header X-Backend-Handled=true, got %q", got) + } + + body, _ := io.ReadAll(resp.Body) + if string(body) != "hello from worker" { + t.Errorf("expected body 'hello from worker', got %q", string(body)) + } +} + +func TestReverseProxy_MissingSessionIDReturns400(t *testing.T) { + pool, _ := herd.New(&mockFactory{}, herd.WithAutoScale(1, 1)) + p := proxy.NewReverseProxy(pool, func(r *http.Request) string { + return "" // empty session ID + }) + + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + p.ServeHTTP(w, req) + + res := w.Result() + if res.StatusCode != http.StatusBadRequest { + t.Errorf("expected 400 Bad Request, got %d", res.StatusCode) + } +} + +func TestReverseProxy_AcquireFailureReturns503(t *testing.T) { + pool, _ := herd.New(&mockFactory{worker: &mockWorker{address: "http://127.0.0.1"}}, herd.WithAutoScale(1, 1)) + + // Exhaust the pool by acquiring the only worker + _, err := pool.Acquire(context.Background(), "other-session") + if err != nil { + t.Fatalf("failed to exhaust pool: %v", err) + } + + p := proxy.NewReverseProxy(pool, func(r *http.Request) string { + return "session-wait" + }) + + req := httptest.NewRequest("GET", "/", nil) + // Extremely short timeout so it fails to acquire fast + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + req = req.WithContext(ctx) + + w := httptest.NewRecorder() + + p.ServeHTTP(w, req) + + res := w.Result() + if res.StatusCode != http.StatusServiceUnavailable { + t.Errorf("expected 503 Service Unavailable, got %d", res.StatusCode) + } +} diff --git a/worker_test.go b/worker_test.go new file mode 100644 index 0000000..33cff76 --- /dev/null +++ b/worker_test.go @@ -0,0 +1,118 @@ +package herd + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "os/exec" + "testing" + "time" +) + +// roundTripFunc implements http.RoundTripper +type roundTripFunc func(req *http.Request) *http.Response + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req), nil +} + +func TestWorkerHealthy(t *testing.T) { + tests := []struct { + name string + statusCode int + wantErr bool + }{ + {"200 OK", http.StatusOK, false}, + {"500 Internal Error", http.StatusInternalServerError, true}, + {"404 Not Found", http.StatusNotFound, true}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + client := &http.Client{ + Transport: roundTripFunc(func(req *http.Request) *http.Response { + return &http.Response{ + StatusCode: tc.statusCode, + Body: io.NopCloser(bytes.NewBufferString("dummy body")), + Header: make(http.Header), + } + }), + } + + w := &processWorker{ + id: "test-worker", + address: "http://127.0.0.1:9999", + healthPath: "/health", + client: client, + } + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + err := w.Healthy(ctx) + if (err != nil) != tc.wantErr { + t.Fatalf("Healthy() error = %v, wantErr %v", err, tc.wantErr) + } + }) + } +} + +func TestWorkerHealthy_RequestError(t *testing.T) { + // A transport that returns an error + + client := &http.Client{ + Transport: &errorRoundTripper{err: fmt.Errorf("network reset")}, + } + + w := &processWorker{ + id: "test-worker", + address: "http://127.0.0.1:9999", + healthPath: "/health", + client: client, + } + + ctx := context.Background() + err := w.Healthy(ctx) + if err == nil { + t.Fatalf("expected error on Healthy due to transport failure, got nil") + } +} + +type errorRoundTripper struct { + err error +} + +func (e *errorRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return nil, e.err +} + +func TestWorkerClose_KillsProcess(t *testing.T) { + // Start a long running process that we will kill + cmd := exec.Command("sleep", "60") + if err := cmd.Start(); err != nil { + t.Fatalf("failed to start sleep command: %v", err) + } + + w := &processWorker{ + id: "test-worker-kill", + cmd: cmd, + } + + err := w.Close() + if err != nil { + t.Fatalf("Close() failed: %v", err) + } + + // Wait for process to exit + err = cmd.Wait() + if err == nil { + t.Fatalf("expected process to have an error from being killed, but it exited cleanly") + } + + // Double check that it's drained + if w.draining.Load() != 1 { + t.Errorf("expected worker to be marked as draining, got %d", w.draining.Load()) + } +}