Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@
"filename": "README.md",
"hashed_secret": "a8253456364f1bfc7da7ae4a1db5b45d106317a5",
"is_verified": false,
"line_number": 454
"line_number": 514
}
],
"SLURM.md": [
Expand Down Expand Up @@ -561,5 +561,5 @@
}
]
},
"generated_at": "2026-03-02T22:46:56Z"
"generated_at": "2026-03-11T21:25:28Z"
}
70 changes: 65 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -376,17 +376,27 @@ python gsm8k_server.py evaluate \
Run the following commands in **separate terminals**, in this order:

**Terminal 1** — Start the API server first (must be running before environments connect):
```sh
```bash
run-api
```

**Terminal 2** — Start an environment:
```sh
python gsm8k_server.py serve --slurm False # or an env of your choice
```bash
python environments/gsm8k_server.py serve --slurm False # or an env of your choice
```

**Terminal 3** — (Optional) Dry-run your configuration:

```bash
atropos-sft-gen path/to/output.jsonl \
--tokenizer Qwen/Qwen2.5-1.5B-Instruct \
--dry-run
```

If this succeeds, your tokenizer and rollout server connectivity are correctly configured.

**Terminal 3** — Generate data:
```sh
```bash
atropos-sft-gen path/to/output.jsonl --tokenizer Qwen/Qwen2.5-1.5B-Instruct # or whichever tokenizer you have in your env config
```
Rejection sampling can be controlled via `--save-top-n-per-group`, `--allow-negative-scores`, and `--minimum-score-diff-max-min`. See `atropos-sft-gen -h` for more detailed usage info.
Expand Down Expand Up @@ -442,10 +452,60 @@ Ensure you're using a clean virtual environment with the correct Python version:

```bash
python -m venv .venv
source .venv/bin/activate # On Windows: .venv\Scripts\activate
source .venv/bin/activate # On Windows (PowerShell): .venv\Scripts\Activate.ps1
pip install -e ".[dev]"
```

### Windows Quickstart

While Atropos is primarily documented with Unix-like shells in mind, it works well on Windows too.
Below is a minimal end-to-end example using **PowerShell**.

1. Create and activate a virtual environment:

```powershell
cd C:\path\to\atropos
python -m venv .venv
.venv\Scripts\Activate.ps1
pip install -e .[dev]
```

2. Start the API server:

```powershell
run-api
```

3. In a new PowerShell window, start an environment (for example GSM8K):

```powershell
cd C:\path\to\atropos
.venv\Scripts\Activate.ps1
python .\environments\gsm8k_server.py serve --slurm False
```

4. In a third PowerShell window, dry-run your offline data generation setup, then generate data:

```powershell
cd C:\path\to\atropos
.venv\Scripts\Activate.ps1

# Optional: configuration check
atropos-sft-gen .\gsm8k_rollouts.jsonl `
--tokenizer Qwen/Qwen2.5-1.5B-Instruct `
--dry-run

# Actual data generation
atropos-sft-gen .\gsm8k_rollouts.jsonl `
--tokenizer Qwen/Qwen2.5-1.5B-Instruct
```

If you see connectivity errors in dry-run, double-check that:

- `run-api` is running and listening on the expected port (default `http://localhost:8000`)
- Your environment script (e.g. `gsm8k_server.py`) is running without errors
- Any firewall or VPN software is not blocking local HTTP requests

**`OPENAI_API_KEY` not set errors**

Set your API key as an environment variable, or configure it in the environment's `config_init`:
Expand Down
113 changes: 97 additions & 16 deletions atroposlib/cli/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,15 +235,76 @@ async def grab_batch(jsonl_writer: jsonlines.Writer):
pbar.update(min(batch_count, num_seqs_to_save - total_count))


async def dpo_dry_run(
api_url: str,
tokenizer: str,
) -> None:
"""
Lightweight connectivity check for offline DPO data generation.

This function verifies that:
- the tokenizer can be loaded, and
- the rollout API is reachable.

It does not register a run, reset any server state, or write output files.
"""
print("[atropos-dpo-gen] Starting dry run...")

# 1) Check tokenizer loading
try:
AutoTokenizer.from_pretrained(tokenizer)
print(f"[atropos-dpo-gen] ✅ Loaded tokenizer '{tokenizer}'.")
except Exception as exc: # pragma: no cover - defensive
raise SystemExit(
"[atropos-dpo-gen] Failed to load tokenizer "
f"'{tokenizer}'. Please check the model name and your "
"Hugging Face credentials.\n"
f"Underlying error: {exc}"
) from exc

# 2) Check basic API connectivity without mutating state
try:
timeout = aiohttp.ClientTimeout(total=5)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(f"{api_url}/batch") as response:
if response.status >= 400:
raise SystemExit(
"[atropos-dpo-gen] Connected to rollout server but "
f"received HTTP {response.status} for GET /batch. "
"Please ensure the Atropos API is started with "
"`run-api` and that the URL/port is correct."
)
print(f"[atropos-dpo-gen] ✅ Reached rollout server at '{api_url}'.")
except aiohttp.ClientConnectorError as exc:
raise SystemExit(
"[atropos-dpo-gen] Could not reach rollout server at "
f"'{api_url}'. Please ensure:\n"
"- `run-api` is running (default: http://localhost:8000)\n"
"- your `--api-url` matches the server host and port\n"
"- no firewall or proxy is blocking the request."
) from exc
except asyncio.TimeoutError as exc:
raise SystemExit(
"[atropos-dpo-gen] Timed out while trying to reach the rollout "
f"server at '{api_url}'. The server may be overloaded or not "
"running. Try again after confirming `run-api` is up."
) from exc

print(
"[atropos-dpo-gen] Dry run completed successfully. "
"You are ready to generate DPO data."
)


def main():
parser = argparse.ArgumentParser(
description="Grab SFT data from an Atropos API instance."
description="Grab DPO data from an Atropos API instance."
)
parser.add_argument(
"filepath",
type=str,
default="sft_data.jsonl",
help="Path to the output JSONL file for SFT data.",
help="Path to the output JSONL file for DPO data.",
)
parser.add_argument(
"--api-url",
Expand Down Expand Up @@ -302,22 +363,42 @@ def main():
action="store_true",
help="Append to the previous file instead of overwriting it.",
)
parser.add_argument(
"--dry-run",
action="store_true",
help=(
"Validate tokenizer loading and rollout server connectivity "
"without registering a run or writing any output. "
"Useful for quickly checking your configuration before "
"collecting DPO data."
),
)
args = parser.parse_args()
asyncio.run(
dpo_data_grabber(
args.filepath,
args.api_url,
args.group_size,
args.max_token_len,
args.tokenizer,
args.save_messages,
args.save_n_pairs_per_group,
args.num_seqs_to_save,
args.allow_negative_scores,
args.minimum_score_diff_max_min,
args.append_to_previous,

# Run either a dry run (connectivity + config check) or full data grab
if args.dry_run:
asyncio.run(
dpo_dry_run(
api_url=args.api_url,
tokenizer=args.tokenizer,
)
)
else:
asyncio.run(
dpo_data_grabber(
args.filepath,
args.api_url,
args.group_size,
args.max_token_len,
args.tokenizer,
args.save_messages,
args.save_n_pairs_per_group,
args.num_seqs_to_save,
args.allow_negative_scores,
args.minimum_score_diff_max_min,
args.append_to_previous,
)
)
)


if __name__ == "__main__":
Expand Down
112 changes: 96 additions & 16 deletions atroposlib/cli/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,68 @@ async def grab_batch(jsonl_writer: jsonlines.Writer):
pbar.update(min(batch_count, num_seqs_to_save - total_count))


async def sft_dry_run(
api_url: str,
tokenizer: str,
) -> None:
"""
Lightweight connectivity check for offline SFT data generation.

This function verifies that:
- the tokenizer can be loaded, and
- the rollout API is reachable.

It does not register a run, reset any server state, or write output files.
"""
print("[atropos-sft-gen] Starting dry run...")

# 1) Check tokenizer loading
try:
AutoTokenizer.from_pretrained(tokenizer)
print(f"[atropos-sft-gen] ✅ Loaded tokenizer '{tokenizer}'.")
except Exception as exc: # pragma: no cover - defensive
raise SystemExit(
"[atropos-sft-gen] Failed to load tokenizer "
f"'{tokenizer}'. Please check the model name and your "
"Hugging Face credentials.\n"
f"Underlying error: {exc}"
) from exc

# 2) Check basic API connectivity without mutating state
try:
timeout = aiohttp.ClientTimeout(total=5)
async with aiohttp.ClientSession(timeout=timeout) as session:
# /batch is always expected to exist on a running rollout server
async with session.get(f"{api_url}/batch") as response:
if response.status >= 400:
raise SystemExit(
"[atropos-sft-gen] Connected to rollout server but "
f"received HTTP {response.status} for GET /batch. "
"Please ensure the Atropos API is started with "
"`run-api` and that the URL/port is correct."
)
print(f"[atropos-sft-gen] ✅ Reached rollout server at '{api_url}'.")
except aiohttp.ClientConnectorError as exc:
raise SystemExit(
"[atropos-sft-gen] Could not reach rollout server at "
f"'{api_url}'. Please ensure:\n"
"- `run-api` is running (default: http://localhost:8000)\n"
"- your `--api-url` matches the server host and port\n"
"- no firewall or proxy is blocking the request."
) from exc
except asyncio.TimeoutError as exc:
raise SystemExit(
"[atropos-sft-gen] Timed out while trying to reach the rollout "
f"server at '{api_url}'. The server may be overloaded or not "
"running. Try again after confirming `run-api` is up."
) from exc

print(
"[atropos-sft-gen] Dry run completed successfully. "
"You are ready to generate SFT data."
)


def main():
"""Parses command-line arguments and runs the SFT data grabber."""
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -313,25 +375,43 @@ def main():
default=64,
help="Number of tasks per step for batch size calculation (batch_size = group_size * tasks_per_step).",
)
parser.add_argument(
"--dry-run",
action="store_true",
help=(
"Validate tokenizer loading and rollout server connectivity "
"without registering a run or writing any output. "
"Useful for quickly checking your configuration before "
"collecting SFT data."
),
)
args = parser.parse_args()

# Run the main async function
asyncio.run(
sft_data_grabber(
args.filepath,
args.api_url,
args.group_size,
args.max_token_len,
args.tokenizer,
args.save_messages,
args.save_top_n_per_group,
args.num_seqs_to_save,
args.allow_negative_scores,
args.minimum_score_diff_max_min,
args.append_to_previous,
args.tasks_per_step,
# Run either a dry run (connectivity + config check) or full data grab
if args.dry_run:
asyncio.run(
sft_dry_run(
api_url=args.api_url,
tokenizer=args.tokenizer,
)
)
else:
asyncio.run(
sft_data_grabber(
args.filepath,
args.api_url,
args.group_size,
args.max_token_len,
args.tokenizer,
args.save_messages,
args.save_top_n_per_group,
args.num_seqs_to_save,
args.allow_negative_scores,
args.minimum_score_diff_max_min,
args.append_to_previous,
args.tasks_per_step,
)
)
)


if __name__ == "__main__":
Expand Down
Loading