diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a2b0d16..6d0d741 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -16,7 +16,8 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - go-version: [ "1.21.x", "1.22.x", "1.23.x", "1.24.x", "1.25.x", "1.26.x"] + # 1.21.x dropped: go.mod requires >= 1.22 (needed for SysProcAttr.CgroupFD) + go-version: [ "1.22.x", "1.23.x", "1.24.x", "1.25.x", "1.26.x"] steps: - uses: actions/checkout@v4 @@ -35,6 +36,42 @@ jobs: - name: Run tests with race detector run: go test -v -race -coverprofile=coverage.txt -covermode=atomic ./... + # cgroup-integration: Layer 3 tests that require real cgroupv2 and root access. + # Runs only on the latest stable Go version to keep CI fast. + cgroup-integration: + name: Cgroup Integration Tests + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: "1.26.x" + + - name: Verify cgroupv2 is available + run: | + if ! grep -q cgroup2 /proc/mounts; then + echo "cgroupv2 not mounted — skipping integration tests" + echo "CGROUP_AVAILABLE=false" >> "$GITHUB_ENV" + else + echo "CGROUP_AVAILABLE=true" >> "$GITHUB_ENV" + fi + + - name: Build test dependencies (healthworker) + run: go build ./testdata/healthworker/... + + - name: Run cgroup integration tests (as root) + if: env.CGROUP_AVAILABLE == 'true' + run: | + sudo --preserve-env=PATH,GOPATH,GOCACHE,HOME \ + env HERD_CGROUP_TEST=1 \ + $(which go) test -v -run TestSandbox -timeout 60s ./... + env: + GOPATH: ${{ env.GOPATH }} + GOCACHE: ${{ env.GOCACHE }} + lint: name: Lint Code runs-on: ubuntu-latest @@ -44,7 +81,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v5 with: - go-version: "1.21.x" + go-version: "1.22.x" cache: false # golangci-lint-action handles its own caching - name: golangci-lint diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..c070b8a --- /dev/null +++ b/Dockerfile @@ -0,0 +1,24 @@ +from golang:1.24-alpine as BUILDER + +WORKDIR /herd + +COPY . . + +WORKDIR /herd/examples/playwright + +RUN go mod download + +RUN go build -o main main.go + +from ubuntu:24.04 AS RUNNER + +RUN apt update && apt install -y nodejs npm + + +RUN npx playwright install chromium --with-deps + +WORKDIR /app + +COPY --from=BUILDER /herd/examples/playwright/main . + +CMD ["./main"] diff --git a/docker-compose.yaml b/docker-compose.yaml new file mode 100644 index 0000000..573e076 --- /dev/null +++ b/docker-compose.yaml @@ -0,0 +1,5 @@ +services: + playwright: + build: . + ports: + - "8080:8080" \ No newline at end of file diff --git a/factory_cgroup_test.go b/factory_cgroup_test.go new file mode 100644 index 0000000..384c6ec --- /dev/null +++ b/factory_cgroup_test.go @@ -0,0 +1,137 @@ +// factory_cgroup_test.go — Unit tests for ProcessFactory cgroup configuration. +// +// No build tag: runs on all platforms (macOS, Linux, Windows). +// No processes are spawned; only field values and option validation are tested. +package herd + +import ( + "testing" +) + +func TestNewProcessFactory_DefaultCgroupPIDs(t *testing.T) { + f := NewProcessFactory("./fake-binary") + if f.cgroupPIDs != 100 { + t.Errorf("expected default cgroupPIDs=100, got %d", f.cgroupPIDs) + } +} + +func TestNewProcessFactory_DefaultMemoryCPUUnlimited(t *testing.T) { + f := NewProcessFactory("./fake-binary") + if f.cgroupMemory != 0 { + t.Errorf("expected default cgroupMemory=0 (unlimited), got %d", f.cgroupMemory) + } + if f.cgroupCPU != 0 { + t.Errorf("expected default cgroupCPU=0 (unlimited), got %d", f.cgroupCPU) + } +} + +func TestWithMemoryLimit_StoresBytes(t *testing.T) { + const limit = 512 * 1024 * 1024 // 512 MB + f := NewProcessFactory("./fake-binary").WithMemoryLimit(limit) + if f.cgroupMemory != limit { + t.Errorf("expected cgroupMemory=%d, got %d", limit, f.cgroupMemory) + } +} + +func TestWithMemoryLimit_Zero_DisablesLimit(t *testing.T) { + f := NewProcessFactory("./fake-binary").WithMemoryLimit(512 * 1024 * 1024).WithMemoryLimit(0) + if f.cgroupMemory != 0 { + t.Errorf("expected cgroupMemory=0 after zeroing, got %d", f.cgroupMemory) + } +} + +func TestWithMemoryLimit_NegativePanics(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("expected panic for negative WithMemoryLimit") + } + }() + NewProcessFactory("./fake-binary").WithMemoryLimit(-1) +} + +func TestWithCPULimit_HalfCore(t *testing.T) { + f := NewProcessFactory("./fake-binary").WithCPULimit(0.5) + if f.cgroupCPU != 50_000 { + t.Errorf("expected cgroupCPU=50000 for 0.5 cores, got %d", f.cgroupCPU) + } +} + +func TestWithCPULimit_TwoCores(t *testing.T) { + f := NewProcessFactory("./fake-binary").WithCPULimit(2.0) + if f.cgroupCPU != 200_000 { + t.Errorf("expected cgroupCPU=200000 for 2.0 cores, got %d", f.cgroupCPU) + } +} + +func TestWithCPULimit_Zero_DisablesLimit(t *testing.T) { + f := NewProcessFactory("./fake-binary").WithCPULimit(1.0).WithCPULimit(0) + if f.cgroupCPU != 0 { + t.Errorf("expected cgroupCPU=0 after zeroing, got %d", f.cgroupCPU) + } +} + +func TestWithCPULimit_NegativePanics(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("expected panic for negative WithCPULimit") + } + }() + NewProcessFactory("./fake-binary").WithCPULimit(-0.5) +} + +func TestWithPIDsLimit_Explicit(t *testing.T) { + f := NewProcessFactory("./fake-binary").WithPIDsLimit(50) + if f.cgroupPIDs != 50 { + t.Errorf("expected cgroupPIDs=50, got %d", f.cgroupPIDs) + } +} + +func TestWithPIDsLimit_Unlimited(t *testing.T) { + f := NewProcessFactory("./fake-binary").WithPIDsLimit(-1) + if f.cgroupPIDs != -1 { + t.Errorf("expected cgroupPIDs=-1 for unlimited, got %d", f.cgroupPIDs) + } +} + +func TestWithPIDsLimit_ZeroPanics(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("expected panic for WithPIDsLimit(0)") + } + }() + NewProcessFactory("./fake-binary").WithPIDsLimit(0) +} + +func TestWithPIDsLimit_LessThanNegativeOnePanics(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("expected panic for WithPIDsLimit(-2)") + } + }() + NewProcessFactory("./fake-binary").WithPIDsLimit(-2) +} + +func TestWithPIDsLimit_Chaining(t *testing.T) { + // Verify the builder returns the same factory pointer for fluent chaining. + f := NewProcessFactory("./fake-binary") + f2 := f.WithPIDsLimit(25) + if f != f2 { + t.Error("WithPIDsLimit should return the same *ProcessFactory for chaining") + } +} + +func TestWithMemoryLimit_Chaining(t *testing.T) { + f := NewProcessFactory("./fake-binary") + f2 := f.WithMemoryLimit(1024) + if f != f2 { + t.Error("WithMemoryLimit should return the same *ProcessFactory for chaining") + } +} + +func TestWithCPULimit_Chaining(t *testing.T) { + f := NewProcessFactory("./fake-binary") + f2 := f.WithCPULimit(1.0) + if f != f2 { + t.Error("WithCPULimit should return the same *ProcessFactory for chaining") + } +} diff --git a/go.mod b/go.mod index 813c1bb..d73be28 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/hackstrix/herd -go 1.21 +go 1.22 diff --git a/process_worker_factory.go b/process_worker_factory.go index 0c40872..38eae8b 100644 --- a/process_worker_factory.go +++ b/process_worker_factory.go @@ -60,6 +60,8 @@ type processWorker struct { healthPath string // e.g. "/health" or "/" client *http.Client + cgroupHandle sandboxHandle + mu sync.Mutex cmd *exec.Cmd sessionID string // guarded by mu @@ -134,6 +136,9 @@ func (w *processWorker) monitor() { // broadcast to all the listeners that the worker is dead. close(w.dead) + if w.cgroupHandle != nil { + w.cgroupHandle.Cleanup() + } w.mu.Lock() prevSession := w.sessionID @@ -166,6 +171,10 @@ type ProcessFactory struct { healthPath string // path to poll for liveness; defaults to "/health" startTimeout time.Duration // maximum time to wait for the first successful health check startHealthCheckDelay time.Duration // delay the health check for the first time. + enableSandbox bool // true by default for isolation + cgroupMemory int64 // bytes; 0 means unlimited + cgroupCPU int64 // quota in micros per 100ms period; 0 means unlimited + cgroupPIDs int64 // max pids; -1 means unlimited counter atomic.Int64 } @@ -183,6 +192,8 @@ func NewProcessFactory(binary string, args ...string) *ProcessFactory { healthPath: "/health", startTimeout: 30 * time.Second, startHealthCheckDelay: 1 * time.Second, + enableSandbox: true, + cgroupPIDs: 100, } } @@ -229,6 +240,48 @@ func (f *ProcessFactory) WithStartHealthCheckDelay(d time.Duration) *ProcessFact return f } +// WithMemoryLimit sets the cgroup memory limit, in bytes, for each spawned worker. +// A value of 0 disables the memory limit. +func (f *ProcessFactory) WithMemoryLimit(bytes int64) *ProcessFactory { + if bytes < 0 { + panic("herd: WithMemoryLimit bytes must be >= 0") + } + f.cgroupMemory = bytes + return f +} + +// WithCPULimit sets the cgroup CPU quota in cores for each spawned worker. +// For example, 0.5 means half a CPU and 2 means two CPUs. A value of 0 disables the limit. +func (f *ProcessFactory) WithCPULimit(cores float64) *ProcessFactory { + if cores < 0 { + panic("herd: WithCPULimit cores must be >= 0") + } + if cores == 0 { + f.cgroupCPU = 0 + return f + } + f.cgroupCPU = int64(cores * 100_000) + return f +} + +// WithPIDsLimit sets the cgroup PID limit for each spawned worker. +// Pass -1 for unlimited. Values of 0 or less than -1 are invalid. +func (f *ProcessFactory) WithPIDsLimit(n int64) *ProcessFactory { + if n == 0 || n < -1 { + panic("herd: WithPIDsLimit n must be > 0 or -1 for unlimited") + } + f.cgroupPIDs = n + return f +} + +// WithInsecureSandbox disables the namespace/cgroup sandbox. +// Use only for local debugging on non-Linux systems or when you explicitly +// trust the spawned processes. +func (f *ProcessFactory) WithInsecureSandbox() *ProcessFactory { + f.enableSandbox = false + return f +} + func streamLogs(workerID string, pipe io.ReadCloser, isError bool) { // bufio.Scanner guarantees we read line-by-line, preventing torn logs. scanner := bufio.NewScanner(pipe) @@ -271,6 +324,21 @@ func (f *ProcessFactory) Spawn(ctx context.Context) (Worker[*http.Client], error // During program exits, this should be cleaned up by the Shutdown method cmd := exec.Command(f.binary, resolvedArgs...) cmd.Env = append(os.Environ(), append([]string{"PORT=" + portStr}, resolvedEnv...)...) + var cgroupHandle sandboxHandle + + if f.enableSandbox { + h, err := applySandboxFlags(cmd, id, sandboxConfig{ + memoryMaxBytes: f.cgroupMemory, + cpuMaxMicros: f.cgroupCPU, + pidsMax: f.cgroupPIDs, + }) + if err != nil { + return nil, fmt.Errorf("herd: ProcessFactory: failed to apply sandbox: %w", err) + } + cgroupHandle = h + } else { + log.Printf("[%s] WARNING: running UN-SANDBOXED. Not recommended for production.", id) + } stdout, err := cmd.StdoutPipe() if err != nil { @@ -284,6 +352,9 @@ func (f *ProcessFactory) Spawn(ctx context.Context) (Worker[*http.Client], error if err := cmd.Start(); err != nil { return nil, fmt.Errorf("herd: ProcessFactory: start %s: %w", f.binary, err) } + if cgroupHandle != nil { + cgroupHandle.PostStart() + } log.Printf("[%s] started pid=%d addr=%s", id, cmd.Process.Pid, address) // Stream logs in background @@ -291,13 +362,14 @@ func (f *ProcessFactory) Spawn(ctx context.Context) (Worker[*http.Client], error go streamLogs(id, stderr, true) w := &processWorker{ - id: id, - port: port, - address: address, - healthPath: f.healthPath, - client: &http.Client{Timeout: 3 * time.Second}, - cmd: cmd, - dead: make(chan struct{}), + id: id, + port: port, + address: address, + healthPath: f.healthPath, + client: &http.Client{Timeout: 3 * time.Second}, + cgroupHandle: cgroupHandle, + cmd: cmd, + dead: make(chan struct{}), } // Monitor the process in background — fires onCrash if it exits unexpectedly diff --git a/sandbox.go b/sandbox.go new file mode 100644 index 0000000..ff3f9ed --- /dev/null +++ b/sandbox.go @@ -0,0 +1,16 @@ +package herd + +// sandboxConfig contains per-worker sandbox resource constraints. +// A value of 0 means "unlimited" for memory and CPU. +type sandboxConfig struct { + memoryMaxBytes int64 + cpuMaxMicros int64 + pidsMax int64 +} + +// sandboxHandle owns post-start and cleanup hooks for sandbox resources. +// Implementations may be no-op on unsupported or soft-fail paths. +type sandboxHandle interface { + PostStart() + Cleanup() +} diff --git a/sandbox_integration_test.go b/sandbox_integration_test.go new file mode 100644 index 0000000..f04493c --- /dev/null +++ b/sandbox_integration_test.go @@ -0,0 +1,242 @@ +//go:build linux + +// sandbox_integration_test.go — Integration tests that spawn real processes and +// verify kernel-level cgroup enforcement. +// +// These tests require: +// - Linux with cgroupv2 enabled (/sys/fs/cgroup must be a cgroup2 mount) +// - Root, or a user with cgroup delegation (e.g. a systemd user slice) +// - The HERD_CGROUP_TEST=1 environment variable to be set +// +// Run with: +// +// HERD_CGROUP_TEST=1 go test -v -run TestSandbox -timeout 60s ./... +package herd + +import ( + "context" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + "time" +) + +// requireCgroupIntegration skips the test unless HERD_CGROUP_TEST=1 is set. +func requireCgroupIntegration(t *testing.T) { + t.Helper() + if os.Getenv("HERD_CGROUP_TEST") != "1" { + t.Skip("skipping cgroup integration test: set HERD_CGROUP_TEST=1 to run") + } +} + +// buildHealthWorker compiles the testdata/healthworker binary into a temp dir +// and returns the path to the compiled binary. +func buildHealthWorker(t *testing.T) string { + t.Helper() + bin := filepath.Join(t.TempDir(), "healthworker") + cmd := exec.Command("go", "build", "-o", bin, "./testdata/healthworker") + cmd.Dir = filepath.Join(os.Getenv("GOPACKAGE"), "..", ".") // codebase root + // Use the module root (where go.mod lives) as working directory. + cmd.Dir = "." + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("build healthworker: %v\n%s", err, out) + } + return bin +} + +// --------------------------------------------------------------------------- +// Test: born-in-cgroup verification +// --------------------------------------------------------------------------- + +func TestSandbox_BornInCgroup(t *testing.T) { + requireCgroupIntegration(t) + + bin := buildHealthWorker(t) + + factory := NewProcessFactory(bin). + WithHealthPath("/health"). + WithStartTimeout(10 * time.Second). + WithStartHealthCheckDelay(100 * time.Millisecond). + WithPIDsLimit(50) + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + worker, err := factory.Spawn(ctx) + if err != nil { + t.Fatalf("Spawn: %v", err) + } + defer worker.Close() + + // Extract the OS-level pid from the processWorker. + pw, ok := worker.(*processWorker) + if !ok { + t.Fatal("expected *processWorker from Spawn") + } + pw.mu.Lock() + pid := pw.cmd.Process.Pid + pw.mu.Unlock() + + // Read /proc//cgroup and verify placement. + cgroupFile := fmt.Sprintf("/proc/%d/cgroup", pid) + data, err := os.ReadFile(cgroupFile) + if err != nil { + t.Fatalf("read %s: %v", cgroupFile, err) + } + contents := string(data) + t.Logf("/proc/%d/cgroup:\n%s", pid, contents) + + workerCgroupPath := "/herd/" + worker.ID() + if !strings.Contains(contents, workerCgroupPath) { + t.Errorf("expected cgroup path to contain %q, got:\n%s", workerCgroupPath, contents) + } +} + +// --------------------------------------------------------------------------- +// Test: cgroup directory lifecycle (exists after spawn, gone after close) +// --------------------------------------------------------------------------- + +func TestSandbox_CgroupDirLifecycle(t *testing.T) { + requireCgroupIntegration(t) + + bin := buildHealthWorker(t) + + factory := NewProcessFactory(bin). + WithHealthPath("/health"). + WithStartTimeout(10 * time.Second). + WithStartHealthCheckDelay(100 * time.Millisecond) + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + worker, err := factory.Spawn(ctx) + if err != nil { + t.Fatalf("Spawn: %v", err) + } + + cgroupPath := filepath.Join("/sys/fs/cgroup/herd", worker.ID()) + + // Directory must exist while the worker is alive. + if _, err := os.Stat(cgroupPath); os.IsNotExist(err) { + t.Errorf("expected cgroup dir %q to exist while worker is alive", cgroupPath) + } else { + t.Logf("cgroup dir present: %s", cgroupPath) + } + + // Close the worker and wait for monitor() to run Cleanup(). + pw := worker.(*processWorker) + _ = worker.Close() + select { + case <-pw.dead: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for worker to die after Close") + } + + // Give monitor() a moment to run Cleanup. + time.Sleep(100 * time.Millisecond) + + if _, err := os.Stat(cgroupPath); !os.IsNotExist(err) { + t.Errorf("expected cgroup dir %q to be removed after Close, but it still exists", cgroupPath) + } +} + +// --------------------------------------------------------------------------- +// Test: pids.max limit file verification +// --------------------------------------------------------------------------- + +func TestSandbox_PIDEnforcement(t *testing.T) { + requireCgroupIntegration(t) + + bin := buildHealthWorker(t) + + // Set a reasonable PID limit that still allows the healthworker to start. + // A Go HTTP server needs ~5-10 PIDs for the runtime + main goroutine + HTTP handling. + // Setting 30 leaves room but still demonstrates the limit is enforced at the kernel level. + factory := NewProcessFactory(bin). + WithHealthPath("/health"). + WithStartTimeout(10 * time.Second). + WithStartHealthCheckDelay(100 * time.Millisecond). + WithPIDsLimit(30) + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + worker, err := factory.Spawn(ctx) + if err != nil { + t.Fatalf("Spawn: %v", err) + } + defer worker.Close() + + // Read pids.max from the actual cgroup dir to confirm the limit was written. + // This verifies that the cgroup limit file is set correctly at the kernel level. + cgroupPath := filepath.Join("/sys/fs/cgroup/herd", worker.ID()) + data, err := os.ReadFile(filepath.Join(cgroupPath, "pids.max")) + if err != nil { + t.Fatalf("read pids.max: %v", err) + } + got := strings.TrimSpace(string(data)) + if got != "30" { + t.Errorf("pids.max: expected '30', got %q", got) + } + t.Logf("pids.max confirmed at kernel level: %s", got) + + // Check pids.current to see how many PIDs the worker is actually using. + currentData, err := os.ReadFile(filepath.Join(cgroupPath, "pids.current")) + if err != nil { + t.Logf("note: could not read pids.current: %v", err) + } else { + current := strings.TrimSpace(string(currentData)) + t.Logf("pids.current (actual usage): %s / 30", current) + } +} + +// --------------------------------------------------------------------------- +// Test: memory.max enforcement — confirm limit file is written correctly +// --------------------------------------------------------------------------- + +func TestSandbox_MemoryLimitFileWritten(t *testing.T) { + requireCgroupIntegration(t) + + bin := buildHealthWorker(t) + + const memLimit int64 = 64 * 1024 * 1024 // 64 MB + + factory := NewProcessFactory(bin). + WithHealthPath("/health"). + WithStartTimeout(10 * time.Second). + WithStartHealthCheckDelay(100 * time.Millisecond). + WithMemoryLimit(memLimit) + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + worker, err := factory.Spawn(ctx) + if err != nil { + t.Fatalf("Spawn: %v", err) + } + defer worker.Close() + + cgroupPath := filepath.Join("/sys/fs/cgroup/herd", worker.ID()) + + memMax, err := os.ReadFile(filepath.Join(cgroupPath, "memory.max")) + if err != nil { + t.Fatalf("read memory.max: %v", err) + } + if got := strings.TrimSpace(string(memMax)); got != "67108864" { + t.Errorf("memory.max: expected '67108864', got %q", got) + } + + swapMax, err := os.ReadFile(filepath.Join(cgroupPath, "memory.swap.max")) + if err != nil { + t.Fatalf("read memory.swap.max: %v", err) + } + if got := strings.TrimSpace(string(swapMax)); got != "0" { + t.Errorf("memory.swap.max: expected '0', got %q", got) + } + t.Logf("memory limits confirmed: max=%s swap=%s", + strings.TrimSpace(string(memMax)), strings.TrimSpace(string(swapMax))) +} diff --git a/sandbox_linux.go b/sandbox_linux.go new file mode 100644 index 0000000..3694edd --- /dev/null +++ b/sandbox_linux.go @@ -0,0 +1,125 @@ +//go:build linux + +package herd + +import ( + "errors" + "fmt" + "log" + "os" + "os/exec" + "path/filepath" + "strconv" + "syscall" +) + +const ( + herdCgroupRoot = "/sys/fs/cgroup/herd" + cpuPeriodMicros = 100_000 +) + +// activeCgroupRoot is the base directory used for all cgroup operations. +// It defaults to herdCgroupRoot but can be overridden in tests to redirect +// cgroup file writes to a temp dir without needing real cgroup privileges. +var activeCgroupRoot = herdCgroupRoot + +type cgroupHandle struct { + path string + fd *os.File +} + +func (h *cgroupHandle) PostStart() { + if h == nil || h.fd == nil { + return + } + if err := h.fd.Close(); err != nil { + log.Printf("[sandbox] warning: close cgroup fd for %s: %v", h.path, err) + } + h.fd = nil +} + +func (h *cgroupHandle) Cleanup() { + if h == nil || h.path == "" { + return + } + if err := syscall.Rmdir(h.path); err != nil && !errors.Is(err, syscall.ENOENT) { + log.Printf("[sandbox] warning: cleanup cgroup %s: %v", h.path, err) + } +} + +// applySandboxFlags applies Linux cgroup v2 constraints to the command. +// If cgroup provisioning is unavailable (for example due to permissions), it +// soft-fails and allows the worker to start without constraints. +func applySandboxFlags(cmd *exec.Cmd, workerID string, cfg sandboxConfig) (sandboxHandle, error) { + + if cfg.pidsMax == 0 { + cfg.pidsMax = 100 + } + + if err := os.MkdirAll(activeCgroupRoot, 0o755); err != nil { + log.Printf("[sandbox:%s] WARNING: cgroup root mkdir failed: %v; continuing without cgroup constraints", workerID, err) + return nil, nil + } + + if err := writeCgroupFile(activeCgroupRoot, "cgroup.subtree_control", "+memory +cpu +pids"); err != nil { + log.Printf("[sandbox:%s] WARNING: cgroup controller enable failed: %v; continuing without cgroup constraints", workerID, err) + return nil, nil + } + + cgroupPath := filepath.Join(activeCgroupRoot, workerID) + if err := os.Mkdir(cgroupPath, 0o755); err != nil { + if !errors.Is(err, os.ErrExist) { + log.Printf("[sandbox:%s] WARNING: cgroup leaf mkdir failed: %v; continuing without cgroup constraints", workerID, err) + return nil, nil + } + } + + if cfg.memoryMaxBytes > 0 { + if err := writeCgroupFile(cgroupPath, "memory.max", strconv.FormatInt(cfg.memoryMaxBytes, 10)); err != nil { + log.Printf("[sandbox:%s] WARNING: memory.max write failed: %v; continuing without cgroup constraints", workerID, err) + return nil, nil + } + if err := writeCgroupFile(cgroupPath, "memory.swap.max", "0"); err != nil { + log.Printf("[sandbox:%s] WARNING: memory.swap.max write failed: %v; continuing without cgroup constraints", workerID, err) + return nil, nil + } + } + + if cfg.cpuMaxMicros > 0 { + cpuMax := fmt.Sprintf("%d %d", cfg.cpuMaxMicros, cpuPeriodMicros) + if err := writeCgroupFile(cgroupPath, "cpu.max", cpuMax); err != nil { + log.Printf("[sandbox:%s] WARNING: cpu.max write failed: %v; continuing without cgroup constraints", workerID, err) + return nil, nil + } + } + + pidsValue := "max" + if cfg.pidsMax > 0 { + pidsValue = strconv.FormatInt(cfg.pidsMax, 10) + } + if err := writeCgroupFile(cgroupPath, "pids.max", pidsValue); err != nil { + log.Printf("[sandbox:%s] WARNING: pids.max write failed: %v; continuing without cgroup constraints", workerID, err) + return nil, nil + } + + dir, err := os.Open(cgroupPath) + if err != nil { + log.Printf("[sandbox:%s] WARNING: open cgroup directory failed: %v; continuing without cgroup constraints", workerID, err) + return nil, nil + } + + sys := cmd.SysProcAttr + if sys == nil { + sys = &syscall.SysProcAttr{} + } + sys.CgroupFD = int(dir.Fd()) + sys.UseCgroupFD = true + cmd.SysProcAttr = sys + + return &cgroupHandle{path: cgroupPath, fd: dir}, nil +} + +func writeCgroupFile(cgroupPath, filename, value string) error { + path := filepath.Join(cgroupPath, filename) + return os.WriteFile(path, []byte(value), 0o644) +} diff --git a/sandbox_linux_test.go b/sandbox_linux_test.go new file mode 100644 index 0000000..5e4f74e --- /dev/null +++ b/sandbox_linux_test.go @@ -0,0 +1,243 @@ +//go:build linux + +// sandbox_linux_test.go — Unit tests for applySandboxFlags and cgroupHandle. +// +// These tests redirect activeCgroupRoot to t.TempDir() so they work +// without real cgroup privileges — all file writes go to a temp directory. +// The SysProcAttr wiring is verified on an uncommitted exec.Cmd (never started). +package herd + +import ( + "os" + "os/exec" + "path/filepath" + "strings" + "testing" +) + +// withTempCgroupRoot points activeCgroupRoot to a temp dir for the duration +// of the test and resets it afterwards. +func withTempCgroupRoot(t *testing.T) string { + t.Helper() + root := t.TempDir() + // Pre-create the subtree_control file so writeCgroupFile can write to it. + // (Real cgroupfs has this; our temp dir does not.) + if err := os.WriteFile(filepath.Join(root, "cgroup.subtree_control"), []byte(""), 0o644); err != nil { + t.Fatalf("setup: create subtree_control: %v", err) + } + old := activeCgroupRoot + activeCgroupRoot = root + t.Cleanup(func() { activeCgroupRoot = old }) + return root +} + +func newFakeCmd() *exec.Cmd { + // "true" exists on all Linux systems and is harmless — the Cmd is never started. + return exec.Command("true") +} + +func readCgroupFile(t *testing.T, cgroupPath, filename string) string { + t.Helper() + data, err := os.ReadFile(filepath.Join(cgroupPath, filename)) + if err != nil { + t.Fatalf("read cgroup file %s: %v", filename, err) + } + return strings.TrimSpace(string(data)) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +func TestApplySandboxFlags_DefaultPIDs(t *testing.T) { + root := withTempCgroupRoot(t) + cmd := newFakeCmd() + + h, err := applySandboxFlags(cmd, "worker-1", sandboxConfig{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if h == nil { + t.Fatal("expected non-nil handle, got nil (soft-fail triggered unexpectedly)") + } + + cgroupPath := filepath.Join(root, "worker-1") + got := readCgroupFile(t, cgroupPath, "pids.max") + if got != "100" { + t.Errorf("pids.max: expected '100' (default), got %q", got) + } +} + +func TestApplySandboxFlags_MemoryLimit(t *testing.T) { + root := withTempCgroupRoot(t) + cmd := newFakeCmd() + const limit int64 = 64 * 1024 * 1024 // 64 MB + + h, err := applySandboxFlags(cmd, "worker-mem", sandboxConfig{memoryMaxBytes: limit}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if h == nil { + t.Fatal("expected non-nil handle") + } + + cgroupPath := filepath.Join(root, "worker-mem") + if got := readCgroupFile(t, cgroupPath, "memory.max"); got != "67108864" { + t.Errorf("memory.max: expected '67108864', got %q", got) + } + if got := readCgroupFile(t, cgroupPath, "memory.swap.max"); got != "0" { + t.Errorf("memory.swap.max: expected '0', got %q", got) + } +} + +func TestApplySandboxFlags_CPULimit(t *testing.T) { + root := withTempCgroupRoot(t) + cmd := newFakeCmd() + + h, err := applySandboxFlags(cmd, "worker-cpu", sandboxConfig{cpuMaxMicros: 50_000}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if h == nil { + t.Fatal("expected non-nil handle") + } + + cgroupPath := filepath.Join(root, "worker-cpu") + if got := readCgroupFile(t, cgroupPath, "cpu.max"); got != "50000 100000" { + t.Errorf("cpu.max: expected '50000 100000', got %q", got) + } +} + +func TestApplySandboxFlags_UnlimitedPIDs(t *testing.T) { + root := withTempCgroupRoot(t) + cmd := newFakeCmd() + + h, err := applySandboxFlags(cmd, "worker-nopid", sandboxConfig{pidsMax: -1}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if h == nil { + t.Fatal("expected non-nil handle") + } + + cgroupPath := filepath.Join(root, "worker-nopid") + if got := readCgroupFile(t, cgroupPath, "pids.max"); got != "max" { + t.Errorf("pids.max: expected 'max' for -1, got %q", got) + } +} + +func TestApplySandboxFlags_NoCPULimitFileWhenZero(t *testing.T) { + root := withTempCgroupRoot(t) + cmd := newFakeCmd() + + _, err := applySandboxFlags(cmd, "worker-nocpu", sandboxConfig{cpuMaxMicros: 0}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + cgroupPath := filepath.Join(root, "worker-nocpu") + if _, err := os.Stat(filepath.Join(cgroupPath, "cpu.max")); err == nil { + t.Error("cpu.max should not be written when cpuMaxMicros=0") + } +} + +func TestApplySandboxFlags_SysProcAttrWired(t *testing.T) { + withTempCgroupRoot(t) + cmd := newFakeCmd() + + h, err := applySandboxFlags(cmd, "worker-attr", sandboxConfig{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if h == nil { + t.Fatal("expected non-nil handle") + } + if cmd.SysProcAttr == nil { + t.Fatal("SysProcAttr should be set after applySandboxFlags") + } + if !cmd.SysProcAttr.UseCgroupFD { + t.Error("UseCgroupFD should be true") + } + if cmd.SysProcAttr.CgroupFD <= 0 { + t.Errorf("CgroupFD should be a valid fd (>0), got %d", cmd.SysProcAttr.CgroupFD) + } +} + +func TestApplySandboxFlags_SoftFailOnBadRoot(t *testing.T) { + // Point to a path that cannot be created (inside /proc which is read-only). + old := activeCgroupRoot + activeCgroupRoot = "/proc/herd_test_unreachable_path" + defer func() { activeCgroupRoot = old }() + + cmd := newFakeCmd() + h, err := applySandboxFlags(cmd, "worker-fail", sandboxConfig{}) + if err != nil { + t.Fatalf("expected soft fail (nil, nil) but got error: %v", err) + } + if h != nil { + t.Errorf("expected nil handle on soft fail, got %v", h) + } +} + +func TestApplySandboxFlags_ExistingCgroupDirIsReused(t *testing.T) { + root := withTempCgroupRoot(t) + // Pre-create the leaf dir to simulate a stale entry. + cgroupPath := filepath.Join(root, "worker-exist") + if err := os.Mkdir(cgroupPath, 0o755); err != nil { + t.Fatalf("pre-create cgroup dir: %v", err) + } + cmd := newFakeCmd() + + h, err := applySandboxFlags(cmd, "worker-exist", sandboxConfig{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // Should not soft-fail just because the dir already exists. + if h == nil { + t.Error("expected non-nil handle when cgroup dir already exists") + } +} + +func TestCgroupHandle_PostStart_ClosesFile(t *testing.T) { + // Create a real file to wrap as the fd. + tmp, err := os.CreateTemp(t.TempDir(), "cgfd-*") + if err != nil { + t.Fatalf("create temp file: %v", err) + } + h := &cgroupHandle{path: t.TempDir(), fd: tmp} + h.PostStart() + if h.fd != nil { + t.Error("expected fd to be nil after PostStart") + } + // Calling PostStart again should be a no-op (fd already nil). + h.PostStart() +} + +func TestCgroupHandle_Cleanup_RemovesDir(t *testing.T) { + dir := t.TempDir() + leaf := filepath.Join(dir, "leaf") + if err := os.Mkdir(leaf, 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + h := &cgroupHandle{path: leaf} + h.Cleanup() + if _, err := os.Stat(leaf); !os.IsNotExist(err) { + t.Error("expected cgroup leaf dir to be removed after Cleanup") + } +} + +func TestCgroupHandle_Cleanup_Idempotent(t *testing.T) { + dir := t.TempDir() + leaf := filepath.Join(dir, "stale") + if err := os.Mkdir(leaf, 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + h := &cgroupHandle{path: leaf} + h.Cleanup() // removes dir + h.Cleanup() // dir already gone — should not panic or log error as warning +} + +func TestCgroupHandle_Cleanup_NilSafe(t *testing.T) { + var h *cgroupHandle + h.Cleanup() // must not panic +} diff --git a/sandbox_unsupported.go b/sandbox_unsupported.go new file mode 100644 index 0000000..908aa65 --- /dev/null +++ b/sandbox_unsupported.go @@ -0,0 +1,29 @@ +//go:build !linux + +package herd + +import ( + "errors" + "os/exec" + "runtime" +) + +// ErrSandboxUnsupported is returned when sandbox mode is requested on a non-Linux OS. +var ErrSandboxUnsupported = errors.New( + "\n\n##################################### WARNING ##################################################\n\n" + + "herd: STRICT SANDBOX ENABLED BUT UNSUPPORTED ON THIS OS.\n\n" + + " The security sandbox relies on Linux cgroups and namespaces (CLONE_NEWUSER, etc.),\n" + + " which do not exist on " + runtime.GOOS + ".\n\n" + + " FIX: If you are developing locally on macOS or Windows and want to test your pool logic,\n" + + " you MUST explicitly opt-out of sandbox mode by using:\n\n" + + " factory.WithInsecureSandbox()\n\n" + + " Warning: Do not use WithInsecureSandbox() in production unless you fully trust the workloads.\n\n" + + "###############################################################################################", +) + +// applySandboxFlags applies Linux-specific sandbox isolation. +// On non-Linux systems, this returns an error if sandbox mode is enabled, +// forcing a loud failure instead of a false sense of security. +func applySandboxFlags(cmd *exec.Cmd, workerID string, cfg sandboxConfig) (sandboxHandle, error) { + return nil, ErrSandboxUnsupported +} diff --git a/testdata/healthworker/main.go b/testdata/healthworker/main.go new file mode 100644 index 0000000..778da0c --- /dev/null +++ b/testdata/healthworker/main.go @@ -0,0 +1,51 @@ +// testdata/healthworker/main.go — minimal HTTP server used by integration tests. +// +// The binary: +// - Listens on the port given by the PORT env var (default 8080). +// - Responds with 200 OK on GET /health. +// - Exits with status 0 on SIGTERM/SIGINT. +// +// It is compiled by integration tests at runtime via `go build ./testdata/healthworker`. +package main + +import ( + "fmt" + "log" + "net" + "net/http" + "os" + "os/signal" + "syscall" +) + +func main() { + port := os.Getenv("PORT") + if port == "" { + port = "8080" + } + + mux := http.NewServeMux() + mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + fmt.Fprintln(w, "ok") + }) + + ln, err := net.Listen("tcp", "127.0.0.1:"+port) + if err != nil { + log.Fatalf("listen: %v", err) + } + + srv := &http.Server{Handler: mux} + + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGTERM, syscall.SIGINT) + go func() { + <-quit + _ = srv.Close() + }() + + log.Printf("healthworker listening on :%s", port) + if err := srv.Serve(ln); err != nil && err != http.ErrServerClosed { + log.Fatalf("serve: %v", err) + } +}