diff --git a/tests/deployment/test_modal_starting_limiter.py b/tests/deployment/test_modal_starting_limiter.py new file mode 100644 index 00000000..b89b8d5d --- /dev/null +++ b/tests/deployment/test_modal_starting_limiter.py @@ -0,0 +1,341 @@ +"""Tests for the Modal cold-start fleet limiter (2026-05-22 Patch). + +Covers: + * `_get_starting_semaphore` reads per-worker permits from + MODAL_MAX_STARTING_PER_WORKER, is lazy, idempotent (singleton), and clamps to >=1. + * `ModalDeployment.start` retry loop respects max_retries=2 (not 5). + * `ModalDeployment.start` wall-clock budget aborts further attempts once + `MODAL_INIT_WALL_BUDGET` is exceeded. + * `asyncio.wait_for` inside the retry loop cancels a hung `_start`. + * The STARTING semaphore actually serializes overlapping `_start` calls and + is released on both success and failure paths. + +We do NOT import or hit real modal.com. We bypass `ModalDeployment.__init__` +(which would invoke `_ImageBuilder.auto`) via `object.__new__` and manually +assign the few attributes the methods under test read. +""" + +from __future__ import annotations + +import asyncio +import logging +import time +from typing import Callable + +import pytest + +from uni_agent.deployment.modal import deployment as mod +from uni_agent.deployment.modal.deployment import ModalDeployment + +# -------------------- helpers -------------------- + + +def _reset_limiter_state(monkeypatch, *, per_worker=16, wall_budget=900.0): + """Pin the limiter env vars to known values and force semaphore re-init. + + The limiter reads MODAL_MAX_STARTING_PER_WORKER and MODAL_INIT_WALL_BUDGET + lazily (at first use / per call), so we set them via monkeypatch.setenv and + they revert after the test. `_STARTING_SEMA` is the singleton cache -- clear + it so the next call rebuilds with the patched value. + """ + monkeypatch.setenv("MODAL_MAX_STARTING_PER_WORKER", str(per_worker)) + monkeypatch.setenv("MODAL_INIT_WALL_BUDGET", str(wall_budget)) + monkeypatch.setattr(mod, "_STARTING_SEMA", None, raising=True) + + +def _make_deployment( + start_fn: Callable[[_FakeDeployment], asyncio.Future | None] | None = None, + stop_fn: Callable[[_FakeDeployment], asyncio.Future | None] | None = None, +) -> _FakeDeployment: + """Build a ModalDeployment instance that skips the heavy ImageBuilder + and exposes hookable `_start` / `stop`. + """ + self = object.__new__(_FakeDeployment) + self.logger = logging.getLogger("test-modal-limiter") + self.run_id = "test-run" + self._sandbox = None + self._runtime = None + self._start_calls = 0 + self._stop_calls = 0 + self._concurrent_in_start = 0 + self._max_concurrent_observed = 0 + self._start_fn = start_fn or (lambda d: _ok()) + self._stop_fn = stop_fn or (lambda d: _ok()) + return self + + +async def _ok(): + return None + + +class _FakeDeployment(ModalDeployment): + """ModalDeployment with `_start` and `stop` rewired to user callbacks. + + Crucially, `_start` keeps the production semaphore-acquire wrapping + so we can test serialization. The body just delegates to the test + callback after acquiring the permit. + """ + + async def _start(self): # type: ignore[override] + async with mod._get_starting_semaphore(): + self._start_calls += 1 + self._concurrent_in_start += 1 + self._max_concurrent_observed = max(self._max_concurrent_observed, self._concurrent_in_start) + try: + await self._start_fn(self) + finally: + self._concurrent_in_start -= 1 + + async def stop(self): # type: ignore[override] + self._stop_calls += 1 + await self._stop_fn(self) + + +# -------------------- _get_starting_semaphore -------------------- + + +def test_starting_semaphore_uses_per_worker_env(monkeypatch): + _reset_limiter_state(monkeypatch, per_worker=16) + + async def _check(): + sem = mod._get_starting_semaphore() + # Internal asyncio.Semaphore exposes its initial value via `_value` + # on CPython 3.10+. This is the contract we rely on. + assert sem._value == 16, f"expected 16, got {sem._value}" + + asyncio.run(_check()) + + +def test_starting_semaphore_clamps_to_one(monkeypatch): + # A per-worker value of 0 (or negative) must clamp to >=1 -> no deadlock. + _reset_limiter_state(monkeypatch, per_worker=0) + + async def _check(): + sem = mod._get_starting_semaphore() + assert sem._value == 1, f"expected clamp to 1, got {sem._value}" + + asyncio.run(_check()) + + +def test_starting_semaphore_is_singleton(monkeypatch): + _reset_limiter_state(monkeypatch, per_worker=5) + + async def _check(): + a = mod._get_starting_semaphore() + b = mod._get_starting_semaphore() + assert a is b, "semaphore should be lazily cached" + + asyncio.run(_check()) + + +def test_starting_semaphore_respects_large_value(monkeypatch): + _reset_limiter_state(monkeypatch, per_worker=64) + + async def _check(): + sem = mod._get_starting_semaphore() + assert sem._value == 64 + + asyncio.run(_check()) + + +# -------------------- start() retry + wall-budget -------------------- + + +def test_start_succeeds_on_first_attempt_does_not_retry(monkeypatch): + _reset_limiter_state(monkeypatch) + dep = _make_deployment() + + async def _go(): + await dep.start() + + asyncio.run(_go()) + assert dep._start_calls == 1 + assert dep._stop_calls == 0 # stop() only called on failure + + +def test_start_retries_once_then_succeeds(monkeypatch): + _reset_limiter_state(monkeypatch) + + attempts = {"n": 0} + + async def flaky(dep): + attempts["n"] += 1 + if attempts["n"] == 1: + raise RuntimeError("simulated transient cold-start failure") + + dep = _make_deployment(start_fn=flaky) + + async def _go(): + await dep.start() + + asyncio.run(_go()) + assert dep._start_calls == 2 + assert dep._stop_calls == 1 # cleanup ran once between attempts + + +def test_start_max_retries_is_two_not_five(monkeypatch): + """The pre-patch loop tried 5 times. The patched loop must give up at 2.""" + _reset_limiter_state(monkeypatch) + + async def always_fail(dep): + raise RuntimeError("never works") + + dep = _make_deployment(start_fn=always_fail) + + async def _go(): + await dep.start() + + with pytest.raises(RuntimeError, match=r"after 2 retries"): + asyncio.run(_go()) + assert dep._start_calls == 2, "must attempt exactly 2 times, not 5" + assert dep._stop_calls == 2 + + +def test_start_wall_budget_aborts_before_attempt_when_exhausted(monkeypatch): + """If wall budget is already negative, must NOT call _start again.""" + # 0.5s budget: first attempt sleeps 0.6s and fails -> second attempt + # should be vetoed by the deadline check (not even invoked). + _reset_limiter_state(monkeypatch, wall_budget=0.5) + + async def slow_fail(dep): + await asyncio.sleep(0.6) + raise RuntimeError("too slow") + + dep = _make_deployment(start_fn=slow_fail) + + async def _go(): + await dep.start() + + t0 = time.monotonic() + with pytest.raises(RuntimeError): + asyncio.run(_go()) + elapsed = time.monotonic() - t0 + assert dep._start_calls == 1, ( + f"second attempt must be skipped after wall budget exhausted, got start_calls={dep._start_calls}" + ) + # Generous upper bound: first attempt 0.6s + sleep gap + cleanup << 3s + assert elapsed < 3.0, f"wall budget should short-circuit, elapsed={elapsed:.2f}s" + + +def test_start_wait_for_cancels_hung_start(monkeypatch): + """A _start that hangs forever must be cancelled by the per-attempt wait_for. + + We give a 0.4s wall budget; the hung _start must be killed within that + bound (plus epsilon) instead of hanging the test forever. + """ + _reset_limiter_state(monkeypatch, wall_budget=0.4) + + async def hang_forever(dep): + await asyncio.sleep(3600) + + dep = _make_deployment(start_fn=hang_forever) + + async def _go(): + await dep.start() + + t0 = time.monotonic() + with pytest.raises(RuntimeError): + asyncio.run(_go()) + elapsed = time.monotonic() - t0 + # wait_for floor is max(60.0, remaining); remaining=0.4s -> floor 60s, + # so wall_budget=0.4 will exhaust before 60s timeout fires. The retry + # loop checks `remaining <= 0` next iteration and exits. + # Verify we don't actually wait the full 60s wait_for floor: that + # depends on Python's asyncio.wait_for behavior; with budget < 60 we + # rely on the OUTER loop's deadline check after attempt 1 to bail. + # NOTE: this test mainly proves no infinite hang. + assert elapsed < 90.0, f"start() must not hang forever, elapsed={elapsed:.1f}s" + + +# -------------------- semaphore serialization -------------------- + + +def test_starting_semaphore_serializes_concurrent_starts(monkeypatch): + """With per-worker permits=2, 6 concurrent _start calls must have + at most 2 inside the critical section at any time. + """ + _reset_limiter_state(monkeypatch, per_worker=2) # 2 permits + + # Each _start holds the permit for 0.05s, then succeeds. + async def slow_ok(dep): + await asyncio.sleep(0.05) + + deps = [_make_deployment(start_fn=slow_ok) for _ in range(6)] + + async def _go(): + await asyncio.gather(*[d.start() for d in deps]) + + asyncio.run(_go()) + + # Combine observations across all deps. Each dep's + # _max_concurrent_observed is its OWN local counter (incremented + # before yielding inside the critical section), but the SEMAPHORE + # is shared. To prove the cap, sum: at any moment, the sum of + # _concurrent_in_start across all deps must be <= 2. + # The local _max_concurrent_observed will always be 1 because + # each dep can be inside its own _start at most once. + # Better proof: count how many concurrent deps were active by + # tracking via a shared counter -- next test does that. + assert all(d._start_calls == 1 for d in deps) + + +def test_starting_semaphore_caps_global_in_flight(monkeypatch): + """Stronger version: instrument a SHARED counter to prove the + semaphore really caps the number of `_start` bodies running + simultaneously across multiple ModalDeployment instances. + """ + _reset_limiter_state(monkeypatch, per_worker=2) # 2 permits + + shared = {"in_flight": 0, "peak": 0} + lock = asyncio.Lock() + + async def track(dep): + async with lock: + shared["in_flight"] += 1 + shared["peak"] = max(shared["peak"], shared["in_flight"]) + await asyncio.sleep(0.03) + async with lock: + shared["in_flight"] -= 1 + + deps = [_make_deployment(start_fn=track) for _ in range(10)] + + async def _go(): + await asyncio.gather(*[d.start() for d in deps]) + + asyncio.run(_go()) + + assert shared["in_flight"] == 0 + assert shared["peak"] <= 2, ( + f"semaphore must cap concurrent _start bodies at 2 (per-worker permits), observed peak={shared['peak']}" + ) + assert shared["peak"] >= 1 + + +def test_starting_semaphore_released_on_failure(monkeypatch): + """Permit must be released even when `_start` raises -- otherwise + a chain of failures would slowly leak all permits and deadlock. + """ + _reset_limiter_state(monkeypatch, per_worker=1) # 1 permit + + fail_first_two = {"n": 0} + + async def flaky(dep): + fail_first_two["n"] += 1 + if fail_first_two["n"] <= 2: + raise RuntimeError("transient") + + # Single deployment: 3 attempts total inside start() retry, but + # max_retries=2 so this would only see 2 attempts. To prove + # release across MULTIPLE deployments we run two back-to-back. + dep_a = _make_deployment(start_fn=flaky) + dep_b = _make_deployment(start_fn=lambda d: _ok()) # must succeed -- permit must be free + + async def _go(): + # dep_a uses 2 attempts then raises -- both must release. + with pytest.raises(RuntimeError): + await dep_a.start() + # dep_b must NOT block: if the single permit leaked, it would hang. + await asyncio.wait_for(dep_b.start(), timeout=2.0) + + asyncio.run(_go()) + assert dep_b._start_calls == 1 diff --git a/uni_agent/deployment/modal/deployment.py b/uni_agent/deployment/modal/deployment.py index 8622f16a..7168f132 100644 --- a/uni_agent/deployment/modal/deployment.py +++ b/uni_agent/deployment/modal/deployment.py @@ -23,6 +23,34 @@ __all__ = ["ModalDeployment"] +# Cap how many Modal sandboxes are simultaneously in the "created but runtime +# not yet alive" state; too many at once cause "Runtime did not start" timeouts. +# The semaphore is process-local, so MODAL_MAX_STARTING_PER_WORKER is a +# per-worker cap (size it as fleet-wide target / num rollout workers). +# MODAL_INIT_WALL_BUDGET caps a single trajectory's total init wall-clock. +_DEFAULT_MAX_STARTING_PER_WORKER = 8 +_DEFAULT_INIT_WALL_BUDGET = 900.0 +_STARTING_SEMA: asyncio.Semaphore | None = None + + +def _get_starting_semaphore() -> asyncio.Semaphore: + """Lazy-init the per-worker STARTING semaphore. + + Lazy so it is built inside the running loop and the env var is read at + first use, not at import (so vars set after import still apply). + """ + global _STARTING_SEMA + if _STARTING_SEMA is None: + per_worker = max(1, int(os.getenv("MODAL_MAX_STARTING_PER_WORKER", str(_DEFAULT_MAX_STARTING_PER_WORKER)))) + _STARTING_SEMA = asyncio.Semaphore(per_worker) + return _STARTING_SEMA + + +def _get_init_wall_budget() -> float: + """Per-trajectory init wall-clock budget (seconds), resolved at call time.""" + return float(os.getenv("MODAL_INIT_WALL_BUDGET", str(_DEFAULT_INIT_WALL_BUDGET))) + + def _get_modal_user() -> str: # not sure how to get the user from the modal api return modal.config._profile # type: ignore @@ -199,76 +227,110 @@ async def _start(self): if self._app is None: self._app = await modal.App.lookup.aio("swe-rex", create_if_missing=True) - self.logger.info(f"Starting modal sandbox with image {self._image_name}") - self._hooks.on_custom_step("Starting modal sandbox") - t0 = time.time() - token = self._get_token() - self._sandbox = await modal.Sandbox.create.aio( - "/usr/bin/env", - "bash", - "-c", - self._start_swerex_cmd(token), - image=self._image, - timeout=int(self._deployment_timeout), - encrypted_ports=[self._port], - app=self._app, - **self._modal_kwargs, - ) - tunnels = await self._sandbox.tunnels.aio() - tunnel = tunnels[self._port] - elapsed_sandbox_creation = time.time() - t0 - self.logger.info(f"Sandbox ({self._sandbox.object_id}) created in {elapsed_sandbox_creation:.2f}s") - self.logger.info(f"Check sandbox logs at {await self.get_modal_log_url()}") - self.logger.info(f"Sandbox created with id {self._sandbox.object_id}") - await asyncio.sleep(1) - self.logger.info(f"Starting runtime at {tunnel.url}") - self._hooks.on_custom_step("Starting runtime") - runtime_config = RemoteRuntimeConfig( - host=tunnel.url, - timeout=self._runtime_timeout, - auth_token=token, - proxy=self._proxy, - ) - self._runtime = RemoteRuntime.from_config(runtime_config, run_id=self.run_id) - remaining_startup_timeout = max(0, self._startup_timeout - elapsed_sandbox_creation) - t1 = time.time() - await self._wait_until_alive(timeout=remaining_startup_timeout) - await self.runtime.create_session(CreateBashSessionRequest(startup_timeout=60)) - self.logger.info(f"Runtime started in {time.time() - t1:.2f}s") - - async def start(self, max_retries: int = 5): - """Starts the runtime with retry.""" + # Hold the STARTING permit only through runtime startup; the tool-call + # body afterwards is LLM-bound and must not occupy a permit. + async with _get_starting_semaphore(): + self.logger.info(f"Starting modal sandbox with image {self._image_name}") + self._hooks.on_custom_step("Starting modal sandbox") + t0 = time.time() + token = self._get_token() + self._sandbox = await modal.Sandbox.create.aio( + "/usr/bin/env", + "bash", + "-c", + self._start_swerex_cmd(token), + image=self._image, + timeout=int(self._deployment_timeout), + encrypted_ports=[self._port], + app=self._app, + **self._modal_kwargs, + ) + tunnels = await self._sandbox.tunnels.aio() + tunnel = tunnels[self._port] + elapsed_sandbox_creation = time.time() - t0 + self.logger.info(f"Sandbox ({self._sandbox.object_id}) created in {elapsed_sandbox_creation:.2f}s") + self.logger.info(f"Check sandbox logs at {await self.get_modal_log_url()}") + self.logger.info(f"Sandbox created with id {self._sandbox.object_id}") + await asyncio.sleep(1) + self.logger.info(f"Starting runtime at {tunnel.url}") + self._hooks.on_custom_step("Starting runtime") + runtime_config = RemoteRuntimeConfig( + host=tunnel.url, + timeout=self._runtime_timeout, + auth_token=token, + proxy=self._proxy, + ) + self._runtime = RemoteRuntime.from_config(runtime_config, run_id=self.run_id) + remaining_startup_timeout = max(0, self._startup_timeout - elapsed_sandbox_creation) + t1 = time.time() + await self._wait_until_alive(timeout=remaining_startup_timeout) + await self.runtime.create_session(CreateBashSessionRequest(startup_timeout=60)) + self.logger.info(f"Runtime started in {time.time() - t1:.2f}s") + + async def start(self, max_retries: int = 2): + """Start the runtime with retry, bounded by MODAL_INIT_WALL_BUDGET. + + Few retries + a wall-clock cap stop a stuck trajectory from holding a + STARTING permit for many minutes; on exhaustion it raises and the outer + agent_loop turns it into a reward=0 masked sample. + """ last_error: Exception | None = None + wall_budget = _get_init_wall_budget() + deadline = time.monotonic() + wall_budget for retry in range(max_retries): + remaining = deadline - time.monotonic() + if remaining <= 0: + self.logger.critical(f"Wall-clock budget {wall_budget}s exhausted before attempt {retry + 1}") + break try: - await self._start() + await asyncio.wait_for(self._start(), timeout=max(60.0, remaining)) return except Exception as exc: last_error = exc self.logger.critical(f"Failed to create modal sandbox: {exc}") - # Best-effort cleanup; never let stop() failures shadow the real start - # error or short-circuit the retry loop. - try: - await self.stop() - except Exception as stop_exc: - self.logger.error(f"Cleanup after failed sandbox start raised: {stop_exc}") - if retry < max_retries - 1: - sleep_time = min(30, 2**retry) + await self.stop() + if retry < max_retries - 1 and time.monotonic() < deadline: + sleep_time = min(10, 2**retry) self.logger.info(f"Retrying modal deployment startup in {sleep_time} seconds...") await asyncio.sleep(sleep_time) - raise RuntimeError(f"Failed to create modal sandbox after {max_retries} retries") from last_error + raise RuntimeError( + f"Failed to create modal sandbox after {max_retries} retries (wall budget {wall_budget}s)" + ) from last_error async def stop(self): - """Stops the runtime.""" + """Stop the runtime, best-effort. + + Each step is wrapped so a transient failure (e.g. runtime.close raising + when the socket is already gone) never skips sandbox terminate -- a + leaked sandbox counts against the account's concurrent-sandbox cap. + """ if self._runtime is not None: - await self._runtime.close() + try: + await self._runtime.close() + except Exception as exc: + self.logger.warning(f"runtime.close() swallowed (continuing teardown): {type(exc).__name__}: {exc}") self._runtime = None + + # CRITICAL — must always run to avoid leaking the modal sandbox. if self._sandbox is not None: - exit_code = await self._sandbox.poll.aio() - if exit_code is None: - await self._sandbox.terminate.aio() - self._sandbox = None + try: + exit_code = await self._sandbox.poll.aio() + if exit_code is None: + await self._sandbox.terminate.aio() + except Exception as exc: + self.logger.warning( + f"sandbox poll/terminate first attempt failed: " + f"{type(exc).__name__}: {exc}; retrying terminate once." + ) + try: + await self._sandbox.terminate.aio() + except Exception as exc2: + self.logger.error( + f"sandbox.terminate.aio() retry also failed: {type(exc2).__name__}: {exc2}. Sandbox may leak." + ) + self._sandbox = None + self._app = None @property