diff --git a/.secrets.baseline b/.secrets.baseline index 7059d28bb..53b134534 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -133,7 +133,7 @@ "filename": "README.md", "hashed_secret": "a8253456364f1bfc7da7ae4a1db5b45d106317a5", "is_verified": false, - "line_number": 454 + "line_number": 514 } ], "SLURM.md": [ @@ -561,5 +561,5 @@ } ] }, - "generated_at": "2026-03-02T22:46:56Z" + "generated_at": "2026-03-11T21:25:28Z" } diff --git a/README.md b/README.md index 3b533a9bc..6dd060309 100644 --- a/README.md +++ b/README.md @@ -376,17 +376,27 @@ python gsm8k_server.py evaluate \ Run the following commands in **separate terminals**, in this order: **Terminal 1** — Start the API server first (must be running before environments connect): -```sh +```bash run-api ``` **Terminal 2** — Start an environment: -```sh -python gsm8k_server.py serve --slurm False # or an env of your choice +```bash +python environments/gsm8k_server.py serve --slurm False # or an env of your choice +``` + +**Terminal 3** — (Optional) Dry-run your configuration: + +```bash +atropos-sft-gen path/to/output.jsonl \ + --tokenizer Qwen/Qwen2.5-1.5B-Instruct \ + --dry-run ``` +If this succeeds, your tokenizer and rollout server connectivity are correctly configured. + **Terminal 3** — Generate data: -```sh +```bash atropos-sft-gen path/to/output.jsonl --tokenizer Qwen/Qwen2.5-1.5B-Instruct # or whichever tokenizer you have in your env config ``` Rejection sampling can be controlled via `--save-top-n-per-group`, `--allow-negative-scores`, and `--minimum-score-diff-max-min`. See `atropos-sft-gen -h` for more detailed usage info. @@ -442,10 +452,60 @@ Ensure you're using a clean virtual environment with the correct Python version: ```bash python -m venv .venv -source .venv/bin/activate # On Windows: .venv\Scripts\activate +source .venv/bin/activate # On Windows (PowerShell): .venv\Scripts\Activate.ps1 pip install -e ".[dev]" ``` +### Windows Quickstart + +While Atropos is primarily documented with Unix-like shells in mind, it works well on Windows too. +Below is a minimal end-to-end example using **PowerShell**. + +1. Create and activate a virtual environment: + +```powershell +cd C:\path\to\atropos +python -m venv .venv +.venv\Scripts\Activate.ps1 +pip install -e .[dev] +``` + +2. Start the API server: + +```powershell +run-api +``` + +3. In a new PowerShell window, start an environment (for example GSM8K): + +```powershell +cd C:\path\to\atropos +.venv\Scripts\Activate.ps1 +python .\environments\gsm8k_server.py serve --slurm False +``` + +4. In a third PowerShell window, dry-run your offline data generation setup, then generate data: + +```powershell +cd C:\path\to\atropos +.venv\Scripts\Activate.ps1 + +# Optional: configuration check +atropos-sft-gen .\gsm8k_rollouts.jsonl ` + --tokenizer Qwen/Qwen2.5-1.5B-Instruct ` + --dry-run + +# Actual data generation +atropos-sft-gen .\gsm8k_rollouts.jsonl ` + --tokenizer Qwen/Qwen2.5-1.5B-Instruct +``` + +If you see connectivity errors in dry-run, double-check that: + +- `run-api` is running and listening on the expected port (default `http://localhost:8000`) +- Your environment script (e.g. `gsm8k_server.py`) is running without errors +- Any firewall or VPN software is not blocking local HTTP requests + **`OPENAI_API_KEY` not set errors** Set your API key as an environment variable, or configure it in the environment's `config_init`: diff --git a/atroposlib/cli/dpo.py b/atroposlib/cli/dpo.py index 5f8a32806..d7a7d84bd 100644 --- a/atroposlib/cli/dpo.py +++ b/atroposlib/cli/dpo.py @@ -235,15 +235,76 @@ async def grab_batch(jsonl_writer: jsonlines.Writer): pbar.update(min(batch_count, num_seqs_to_save - total_count)) +async def dpo_dry_run( + api_url: str, + tokenizer: str, +) -> None: + """ + Lightweight connectivity check for offline DPO data generation. + + This function verifies that: + - the tokenizer can be loaded, and + - the rollout API is reachable. + + It does not register a run, reset any server state, or write output files. + """ + print("[atropos-dpo-gen] Starting dry run...") + + # 1) Check tokenizer loading + try: + AutoTokenizer.from_pretrained(tokenizer) + print(f"[atropos-dpo-gen] ✅ Loaded tokenizer '{tokenizer}'.") + except Exception as exc: # pragma: no cover - defensive + raise SystemExit( + "[atropos-dpo-gen] Failed to load tokenizer " + f"'{tokenizer}'. Please check the model name and your " + "Hugging Face credentials.\n" + f"Underlying error: {exc}" + ) from exc + + # 2) Check basic API connectivity without mutating state + try: + timeout = aiohttp.ClientTimeout(total=5) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.get(f"{api_url}/batch") as response: + if response.status >= 400: + raise SystemExit( + "[atropos-dpo-gen] Connected to rollout server but " + f"received HTTP {response.status} for GET /batch. " + "Please ensure the Atropos API is started with " + "`run-api` and that the URL/port is correct." + ) + print(f"[atropos-dpo-gen] ✅ Reached rollout server at '{api_url}'.") + except aiohttp.ClientConnectorError as exc: + raise SystemExit( + "[atropos-dpo-gen] Could not reach rollout server at " + f"'{api_url}'. Please ensure:\n" + "- `run-api` is running (default: http://localhost:8000)\n" + "- your `--api-url` matches the server host and port\n" + "- no firewall or proxy is blocking the request." + ) from exc + except asyncio.TimeoutError as exc: + raise SystemExit( + "[atropos-dpo-gen] Timed out while trying to reach the rollout " + f"server at '{api_url}'. The server may be overloaded or not " + "running. Try again after confirming `run-api` is up." + ) from exc + + print( + "[atropos-dpo-gen] Dry run completed successfully. " + "You are ready to generate DPO data." + ) + + def main(): parser = argparse.ArgumentParser( - description="Grab SFT data from an Atropos API instance." + description="Grab DPO data from an Atropos API instance." ) parser.add_argument( "filepath", type=str, default="sft_data.jsonl", - help="Path to the output JSONL file for SFT data.", + help="Path to the output JSONL file for DPO data.", ) parser.add_argument( "--api-url", @@ -302,22 +363,42 @@ def main(): action="store_true", help="Append to the previous file instead of overwriting it.", ) + parser.add_argument( + "--dry-run", + action="store_true", + help=( + "Validate tokenizer loading and rollout server connectivity " + "without registering a run or writing any output. " + "Useful for quickly checking your configuration before " + "collecting DPO data." + ), + ) args = parser.parse_args() - asyncio.run( - dpo_data_grabber( - args.filepath, - args.api_url, - args.group_size, - args.max_token_len, - args.tokenizer, - args.save_messages, - args.save_n_pairs_per_group, - args.num_seqs_to_save, - args.allow_negative_scores, - args.minimum_score_diff_max_min, - args.append_to_previous, + + # Run either a dry run (connectivity + config check) or full data grab + if args.dry_run: + asyncio.run( + dpo_dry_run( + api_url=args.api_url, + tokenizer=args.tokenizer, + ) + ) + else: + asyncio.run( + dpo_data_grabber( + args.filepath, + args.api_url, + args.group_size, + args.max_token_len, + args.tokenizer, + args.save_messages, + args.save_n_pairs_per_group, + args.num_seqs_to_save, + args.allow_negative_scores, + args.minimum_score_diff_max_min, + args.append_to_previous, + ) ) - ) if __name__ == "__main__": diff --git a/atroposlib/cli/sft.py b/atroposlib/cli/sft.py index 7d5995f35..bd86221d9 100644 --- a/atroposlib/cli/sft.py +++ b/atroposlib/cli/sft.py @@ -239,6 +239,68 @@ async def grab_batch(jsonl_writer: jsonlines.Writer): pbar.update(min(batch_count, num_seqs_to_save - total_count)) +async def sft_dry_run( + api_url: str, + tokenizer: str, +) -> None: + """ + Lightweight connectivity check for offline SFT data generation. + + This function verifies that: + - the tokenizer can be loaded, and + - the rollout API is reachable. + + It does not register a run, reset any server state, or write output files. + """ + print("[atropos-sft-gen] Starting dry run...") + + # 1) Check tokenizer loading + try: + AutoTokenizer.from_pretrained(tokenizer) + print(f"[atropos-sft-gen] ✅ Loaded tokenizer '{tokenizer}'.") + except Exception as exc: # pragma: no cover - defensive + raise SystemExit( + "[atropos-sft-gen] Failed to load tokenizer " + f"'{tokenizer}'. Please check the model name and your " + "Hugging Face credentials.\n" + f"Underlying error: {exc}" + ) from exc + + # 2) Check basic API connectivity without mutating state + try: + timeout = aiohttp.ClientTimeout(total=5) + async with aiohttp.ClientSession(timeout=timeout) as session: + # /batch is always expected to exist on a running rollout server + async with session.get(f"{api_url}/batch") as response: + if response.status >= 400: + raise SystemExit( + "[atropos-sft-gen] Connected to rollout server but " + f"received HTTP {response.status} for GET /batch. " + "Please ensure the Atropos API is started with " + "`run-api` and that the URL/port is correct." + ) + print(f"[atropos-sft-gen] ✅ Reached rollout server at '{api_url}'.") + except aiohttp.ClientConnectorError as exc: + raise SystemExit( + "[atropos-sft-gen] Could not reach rollout server at " + f"'{api_url}'. Please ensure:\n" + "- `run-api` is running (default: http://localhost:8000)\n" + "- your `--api-url` matches the server host and port\n" + "- no firewall or proxy is blocking the request." + ) from exc + except asyncio.TimeoutError as exc: + raise SystemExit( + "[atropos-sft-gen] Timed out while trying to reach the rollout " + f"server at '{api_url}'. The server may be overloaded or not " + "running. Try again after confirming `run-api` is up." + ) from exc + + print( + "[atropos-sft-gen] Dry run completed successfully. " + "You are ready to generate SFT data." + ) + + def main(): """Parses command-line arguments and runs the SFT data grabber.""" parser = argparse.ArgumentParser( @@ -313,25 +375,43 @@ def main(): default=64, help="Number of tasks per step for batch size calculation (batch_size = group_size * tasks_per_step).", ) + parser.add_argument( + "--dry-run", + action="store_true", + help=( + "Validate tokenizer loading and rollout server connectivity " + "without registering a run or writing any output. " + "Useful for quickly checking your configuration before " + "collecting SFT data." + ), + ) args = parser.parse_args() - # Run the main async function - asyncio.run( - sft_data_grabber( - args.filepath, - args.api_url, - args.group_size, - args.max_token_len, - args.tokenizer, - args.save_messages, - args.save_top_n_per_group, - args.num_seqs_to_save, - args.allow_negative_scores, - args.minimum_score_diff_max_min, - args.append_to_previous, - args.tasks_per_step, + # Run either a dry run (connectivity + config check) or full data grab + if args.dry_run: + asyncio.run( + sft_dry_run( + api_url=args.api_url, + tokenizer=args.tokenizer, + ) + ) + else: + asyncio.run( + sft_data_grabber( + args.filepath, + args.api_url, + args.group_size, + args.max_token_len, + args.tokenizer, + args.save_messages, + args.save_top_n_per_group, + args.num_seqs_to_save, + args.allow_negative_scores, + args.minimum_score_diff_max_min, + args.append_to_previous, + args.tasks_per_step, + ) ) - ) if __name__ == "__main__": diff --git a/atroposlib/tests/test_offline_cli_dry_run.py b/atroposlib/tests/test_offline_cli_dry_run.py new file mode 100644 index 000000000..1a7ac2986 --- /dev/null +++ b/atroposlib/tests/test_offline_cli_dry_run.py @@ -0,0 +1,175 @@ +import asyncio +from types import SimpleNamespace + +import pytest + +import atroposlib.cli.dpo as dpo_cli +import atroposlib.cli.sft as sft_cli + + +class _DummyResponse: + def __init__(self, status: int = 200) -> None: + self.status = status + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + +class _DummySession: + def __init__(self, *_, **__) -> None: + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + def get(self, *_args, **_kwargs): + return _DummyResponse() + + +@pytest.mark.asyncio +async def test_sft_dry_run_uses_tokenizer_and_reaches_api(monkeypatch): + called = SimpleNamespace(tok=False, session=False) + + def _fake_from_pretrained(model_name: str): + called.tok = True + assert model_name == "dummy-tokenizer" + return object() + + monkeypatch.setattr( + sft_cli, "AutoTokenizer", SimpleNamespace(from_pretrained=_fake_from_pretrained) + ) + monkeypatch.setattr(sft_cli.aiohttp, "ClientSession", _DummySession) + + await sft_cli.sft_dry_run( + api_url="http://localhost:8000", tokenizer="dummy-tokenizer" + ) + + assert called.tok is True + + +@pytest.mark.asyncio +async def test_dpo_dry_run_uses_tokenizer_and_reaches_api(monkeypatch): + called = SimpleNamespace(tok=False) + + def _fake_from_pretrained(model_name: str): + called.tok = True + assert model_name == "dummy-tokenizer" + return object() + + monkeypatch.setattr( + dpo_cli, "AutoTokenizer", SimpleNamespace(from_pretrained=_fake_from_pretrained) + ) + monkeypatch.setattr(dpo_cli.aiohttp, "ClientSession", _DummySession) + + await dpo_cli.dpo_dry_run( + api_url="http://localhost:8000", tokenizer="dummy-tokenizer" + ) + + assert called.tok is True + + +def test_sft_main_invokes_dry_run_when_flag_is_set(monkeypatch): + """ + Ensure that passing --dry-run to the entrypoint does *not* + call the full data grabber, only the dry run helper. + """ + + called = SimpleNamespace(dry=False, full=False) + + async def _fake_sft_dry_run(api_url: str, tokenizer: str) -> None: + called.dry = True + assert api_url == "http://example.com" + assert tokenizer == "tok" + + async def _fake_sft_data_grabber(*_args, **_kwargs): + called.full = True + + monkeypatch.setattr(sft_cli, "sft_dry_run", _fake_sft_dry_run) + monkeypatch.setattr(sft_cli, "sft_data_grabber", _fake_sft_data_grabber) + + # Simulate that argparse has already parsed args with --dry-run + class _Args: + filepath = "out.jsonl" + api_url = "http://example.com" + group_size = 2 + max_token_len = 2048 + tokenizer = "tok" + save_messages = False + save_top_n_per_group = 3 + num_seqs_to_save = 10 + allow_negative_scores = False + minimum_score_diff_max_min = 0.0 + append_to_previous = False + tasks_per_step = 64 + dry_run = True + + monkeypatch.setattr( + sft_cli, + "argparse", + SimpleNamespace(Namespace=_Args, ArgumentParser=lambda *_, **__: None), + ) + monkeypatch.setattr( + sft_cli, + "asyncio", + SimpleNamespace( + run=lambda coro: asyncio.get_event_loop().run_until_complete(coro) + ), + ) + + sft_cli.main() + + assert called.dry is True + assert called.full is False + + +def test_dpo_main_invokes_dry_run_when_flag_is_set(monkeypatch): + called = SimpleNamespace(dry=False, full=False) + + async def _fake_dpo_dry_run(api_url: str, tokenizer: str) -> None: + called.dry = True + assert api_url == "http://example.com" + assert tokenizer == "tok" + + async def _fake_dpo_data_grabber(*_args, **_kwargs): + called.full = True + + monkeypatch.setattr(dpo_cli, "dpo_dry_run", _fake_dpo_dry_run) + monkeypatch.setattr(dpo_cli, "dpo_data_grabber", _fake_dpo_data_grabber) + + class _Args: + filepath = "out.jsonl" + api_url = "http://example.com" + group_size = 2 + max_token_len = 2048 + tokenizer = "tok" + save_messages = False + save_n_pairs_per_group = 3 + num_seqs_to_save = 10 + allow_negative_scores = False + minimum_score_diff_max_min = 0.5 + append_to_previous = False + dry_run = True + + monkeypatch.setattr( + dpo_cli, + "argparse", + SimpleNamespace(Namespace=_Args, ArgumentParser=lambda *_, **__: None), + ) + monkeypatch.setattr( + dpo_cli, + "asyncio", + SimpleNamespace( + run=lambda coro: asyncio.get_event_loop().run_until_complete(coro) + ), + ) + + dpo_cli.main() + + assert called.dry is True + assert called.full is False