diff --git a/examples/osworld/README.md b/examples/osworld/README.md new file mode 100644 index 0000000000..4ca0ebcda7 --- /dev/null +++ b/examples/osworld/README.md @@ -0,0 +1,260 @@ +# OSWorld GRPO Example + +## 1. Overview + +This example runs GRPO/PPO training on [OSWorld](https://github.com/xlang-ai/OSWorld) +desktop-control tasks with a Vision-Language base model (Qwen3-VL-4B-Instruct by +default). The agent observes screenshots, emits keyboard/mouse actions through a +`pyautogui`-style action space, and receives a scalar reward from each task's +`evaluate()` function. + +Three components cooperate: + +- **AReaL** — the training framework (this repo): FSDP actor, SGLang rollout engine, + GRPO loss. +- **OSWorld** — the environment code (a sibling checkout, defaults to `../OSWorld`): + task definitions, the Ubuntu VM image, the in-VM controller HTTP server. +- **Remote sandbox provider** — runs the OSWorld VM behind an HTTPS gateway. Training + containers usually lack Docker/KVM, so OSWorld's bundled `docker` provider can't run + in-process; an external host actually starts the VM and forwards controller calls. + +Two transports ship in this example: + +- `workflow/gateway_sandbox.py` — a pluggable HTTPS-gateway client (recommended). + Requires a vendor SDK exposed as a Python module named `pssdk` that exports + `BaseSandboxClusterTool` and `with_retry`. A thin adapter is fine if your provider + uses different names. +- `workflow/remote_desktop_env.py` — a self-hosted alternative. Pair it with + `remote_server.py` running on a docker-capable machine. + +## 2. Layout + +``` +__init__.py +README.md this file +apply_env_patches.sh SGLang/pydrive patches needed for the conda env +config_osworld_sglang.yaml training config (Qwen3-VL-4B-Instruct, 2x GPU) +osworld_config.py OSWorldAgentConfig (extends GRPOConfig) +osworld_requirements.txt OSWorld host-side deps with conflicting versions filtered out +remote_server.py optional self-hosted docker bridge (paired with remote_desktop_env.py) +run_train.sh launcher with smoke / smoke-text / full stages +smoke.py end-to-end sandbox smoke (skips trainer) +train.py PPO entry point +workflow/ + __init__.py + osworld_workflow.py multi-turn VLM rollout workflow + Plan B VL bridge + gateway_sandbox.py DesktopEnv subclass that proxies controller calls through HTTPS gateway + remote_desktop_env.py alternative DesktopEnv replacement that talks to remote_server.py +``` + +## 3. Prerequisites + +- Python 3.12+ in an AReaL-compatible CUDA env: `torch 2.9.x+cu129`, `sglang >= 0.5.10`, + `transformers 5.x`, `areal 1.x`. +- 2x GPU with at least 80 GB each — one for the FSDP actor, one for the SGLang rollout + engine. +- An OSWorld checkout sibling to AReaL (`../OSWorld`). Override the location via the + `osworld_root` config field if it lives elsewhere. +- Either an HTTPS-gateway-based sandbox provider with a vendor SDK (gateway path), or a + separate machine with Docker access (self-hosted path). + +## 4. Setup + +1. Create or reuse a Python 3.12 conda env. Run `uv sync --extra cuda` per AReaL's main + README. Then install OSWorld's filtered host deps: + + ```bash + pip install -r examples/osworld/osworld_requirements.txt + pip install "nvidia-cudnn-cu12==9.16.0.29" "protobuf>=6.31.1,<7" "grpcio-status==1.80.0" + ``` + + The second `pip install` un-downgrades packages that `easyocr` and + `google-ai-generativelanguage` drag in. + +1. Apply environment patches (idempotent): + + ```bash + bash examples/osworld/apply_env_patches.sh + ``` + + This script patches: + + - SGLang JIT kernel flag — auto-detects whether the local `nvcc` supports C++20. + - `pydrive` to `pydrive2` shim — OSWorld imports the unmaintained `pydrive` package. + +1. **Gateway path only** — install your sandbox provider's SDK. The expected protocol is + a Python module named `pssdk` exporting: + + - `BaseSandboxClusterTool(cluster_endpoint, application_secret_token, session_id, global_call_timeout)` + constructor. + - `.session_id` (read property). + - `.sandbox_start(body=None, call_timeout=...) -> dict`. + - `.sandbox_stop(call_timeout=...) -> dict`. + - `with_retry(max_attempts, retry_interval, infinite_retry_on_resource_limit, exclude_methods)` + class decorator. + + If your provider exports a different module name, write a thin adapter module called + `pssdk` that re-exports these symbols. + +## 5. Configure + +Defaults live in `config_osworld_sglang.yaml`. Notable fields: + +- `actor.path` — HuggingFace model directory for the VL base. Qwen3-VL-4B-Instruct is + recommended. +- `gateway_endpoint`, `gateway_token` — empty by default; must be set for the gateway + path. +- `gateway_timeout_secs` — per-call timeout to the gateway (default 1800). +- `remote_server_url` — empty by default; set non-empty to use the self-hosted path + instead. +- `text_only` — smoke-only ablation. Strips screenshots and lets you point `actor.path` + at a text-only model to verify the PPO loop without VL. +- `osworld_root`, `evaluation_examples_dir`, `test_meta_path` — auto-discovered from the + sibling `OSWorld/` checkout when left empty. + +Two ways to provide credentials: + +```bash +# via env vars (recommended; secrets stay out of source control) +export OSWORLD_SANDBOX_ENDPOINT="https://your-gateway/..." +export OSWORLD_SANDBOX_TOKEN="sk-..." + +# via CLI override +python -m examples.osworld.train --config examples/osworld/config_osworld_sglang.yaml \ + gateway_endpoint=$OSWORLD_SANDBOX_ENDPOINT \ + gateway_token=$OSWORLD_SANDBOX_TOKEN +``` + +## 6. Run flow + +### 6a. Sandbox smoke (skip trainer) + +This verifies the gateway/SDK end-to-end without touching the trainer: + +```bash +export OSWORLD_SANDBOX_ENDPOINT="https://your-gateway/..." +export OSWORLD_SANDBOX_TOKEN="sk-..." +python examples/osworld/smoke.py +``` + +Expected last lines: + +``` +sandbox started: +reset ok; screenshot bytes=NNNN +step ok; reward=0 done=False ... +evaluate result: 0.0 +closed +``` + +`evaluate result: 0.0` means the agent didn't solve the task but the evaluator returned +a real reward — this is the success signal for the smoke test. + +### 6b. Training smoke, Plan A (text-only ablation, fastest end-to-end) + +```bash +export OSWORLD_SANDBOX_TOKEN="sk-..." +export AREAL_TEXT_ONLY_MODEL=/path/to/Qwen3-4B-Instruct # any model_type=qwen3 base +bash examples/osworld/run_train.sh smoke-text +``` + +This routes through the same workflow but `text_only=true` strips screenshots and uses a +text-only base, so you can verify the PPO loop end-to-end without exercising the VL +training path. + +### 6c. Training smoke, Plan B (full VL pipeline) + +```bash +export OSWORLD_SANDBOX_TOKEN="sk-..." +bash examples/osworld/run_train.sh smoke +``` + +The `OSWorldWorkflow._attach_vl_tensor_dicts` bridge re-runs the HF processor on each +turn's prefix and writes `mm_token_type_ids` plus `multi_modal_input` (`pixel_values`, +`image_grid_thw`) into the cached training tensor dict, which is what +`FSDPEngine._prepare_mb_list`'s VL path needs. + +### 6d. Full training + +```bash +bash examples/osworld/run_train.sh full +``` + +## 7. Self-hosted alternative + +Skip this section if you have a gateway. Otherwise, run OSWorld on a separate +Docker-capable machine and let the trainer talk to it over HTTP. + +On the docker machine: + +```bash +docker pull xlang/osworld-docker:latest +pip install -r OSWorld/requirements.txt flask pydrive2 "oauth2client<4.1.4" +python remote_server.py --osworld-root /path/to/OSWorld --host 0.0.0.0 --port 8000 --max-envs 2 +``` + +On the train side: + +```bash +python -m examples.osworld.train --config examples/osworld/config_osworld_sglang.yaml \ + remote_server_url=http://:8000 \ + rollout.max_concurrent_rollouts=1 \ + n_trajs=1 +``` + +Provider precedence inside `osworld_workflow._build_env`: +`gateway_endpoint + gateway_token` > `remote_server_url` > in-process `DesktopEnv`. + +## 8. Configuration knobs (top-level YAML fields) + +| Field | Meaning | +| ------------------------------------------------------------- | ------------------------------------------------------------------------------------------------- | +| `n_trajs` | Trajectories per task per episode (GRPO group size). | +| `max_steps` | Max agent / env turns before forcing `env.evaluate()`. | +| `max_workers` | ThreadPool size for blocking `DesktopEnv` calls. | +| `provider_name` | OSWorld provider (`docker` default; only used if neither gateway nor `remote_server_url` is set). | +| `env_reset_wait_secs` | Sleep after `env.reset` to let the VM settle. | +| `test_meta_path` | Which OSWorld task meta file to train on (default `test_small.json`). | +| `text_only` | Smoke-only ablation; strips screenshots from messages. | +| `gateway_endpoint` / `gateway_token` / `gateway_timeout_secs` | Gateway transport. | +| `remote_server_url` / `remote_request_timeout_secs` | Self-hosted transport. | + +## 9. Reward semantics + +Each trajectory is attributed the float returned by `DesktopEnv.evaluate()` (typically +`0.0` or `1.0`). The reward is applied to the last assistant turn and discounted +backwards per turn by `turn_discount` (default `0.9`). + +## 10. Concurrency notes + +A sandbox session is one OSWorld VM container (1 vCPU, 4 GB RAM, idle-reaped after +roughly 50 minutes). Your provider's quota controls how many concurrent sessions you can +hold. Start with `rollout.max_concurrent_rollouts=1` and ramp up while watching for HTTP +429s. The bundled retry decorator (`_RetryingClusterTool` in `gateway_sandbox.py`) parks +on 429s rather than killing trajectories. + +## 11. Known limitations of the gateway path + +| Limitation | Impact | Mitigation | +| -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------- | +| Subset of OSWorld setup verbs supported: `launch`, `download`, `execute`, `open`, `chrome_open_tabs`, `activate_window`, `close_window`, `command`, `sleep`, `change_wallpaper`. | Tasks using `googledrive`, `login`, or `replay` skip-with-warning and may have an inaccurate initial state. | Pick task subsets without those verbs, or extend the controller. | +| `controller.get_file()` routes through `/execute` plus base64 (slow for large files). | OSWorld's `/file` endpoint requires form-urlencoded, but typical gateways only allow JSON. | Ask your provider to allow form encoding, or live with slower file transfer. | +| `/terminal` returns 500 when no active terminal exists. | One warning line, non-fatal. | OSWorld behavior; ignore. | + +## 12. Troubleshooting + +| Symptom | Likely cause / fix | +| ------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `ModuleNotFoundError: pssdk` at startup | Gateway path is selected but the vendor SDK isn't installed. Install it or fall back to `remote_server_url`. | +| 401 on `sandbox_start` | Wrong or expired `OSWORLD_SANDBOX_TOKEN`. | +| 429 on `sandbox_start` (`ClusterQuotaExceededErr`) | Global quota exhausted; the retry decorator parks — wait or scale provider capacity. | +| SGLang JIT compile fails with `std::integral` not found | Run `apply_env_patches.sh` first; the script auto-detects nvcc C++20 support. | +| `requests` health check timeouts to local SGLang | Corporate `HTTP_PROXY` in the environment routes localhost requests to a proxy that can't reach internal IPs. `run_train.sh` already appends a generic `NO_PROXY` allowlist for `localhost,127.0.0.1,10.0.0.0/8`; append your internal domains via the `NO_PROXY` env var before invocation. | +| `eval-rollout/0 readiness timeout` | The forked Python subprocess re-imports torch + sglang + megatron, which takes 3+ minutes on slow filesystems. `_wait_for_fork_ready` in `areal/infra/scheduler/local.py` should be at least 600s; a small core patch is recommended for slow-disk users. | +| `KeyError: 'mm_token_type_ids'` in `_prepare_mb_list` | The VL bridge isn't running. Verify `text_only=false` is in effect and that `processor_path=actor.path` was passed through. | + +## 13. What's not yet covered + +- `setup/upload` and `setup/execute_with_verification` verbs are not wired through + `gateway_sandbox.py`; tasks needing them will skip-with-warning. +- WandB is disabled in the default config — flip `stats_logger.wandb.mode` to enable. diff --git a/examples/osworld/__init__.py b/examples/osworld/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/osworld/apply_env_patches.sh b/examples/osworld/apply_env_patches.sh new file mode 100755 index 0000000000..38b61d3604 --- /dev/null +++ b/examples/osworld/apply_env_patches.sh @@ -0,0 +1,99 @@ +#!/usr/bin/env bash +# Apply runtime patches that the conda env needs but `uv sync` would overwrite. +# +# Run once after creating the env (or any time after `uv sync --extra cuda`). +# +# bash examples/osworld/apply_env_patches.sh +# +# Optional env vars: +# AREAL_ENV_PREFIX - conda env prefix (default: ../../../env) +# +# Patches applied (each one is idempotent): +# 1. (Conditional) SGLang JIT kernels: c++20 -> c++17 only when the host's +# nvcc is too old (e.g. CUDA 12.2) to accept `-std=c++20`. With CUDA +# 12.9+ this patch is HARMFUL — SGLang's templates use `std::integral` +# and `std::ranges` which require C++20. The script auto-detects. +# 2. pydrive -> pydrive2 shim, because OSWorld imports the unmaintained +# `pydrive` package which is broken against modern oauth2client. + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +AREAL_ENV_PREFIX="${AREAL_ENV_PREFIX:-$(cd "${SCRIPT_DIR}/../../../env" && pwd)}" + +if [[ ! -d "${AREAL_ENV_PREFIX}" ]]; then + echo "[apply_env_patches.sh] env not found at: ${AREAL_ENV_PREFIX}" >&2 + exit 1 +fi + +SP="${AREAL_ENV_PREFIX}/lib/python3.12/site-packages" + +# -------- Patch 1: SGLang c++20 -> c++17 (only when nvcc rejects c++20) -------- +SGL_UTILS="${SP}/sglang/jit_kernel/utils.py" +if [[ ! -f "${SGL_UTILS}" ]]; then + echo "[apply_env_patches.sh] sglang utils.py missing: ${SGL_UTILS}" >&2 + exit 2 +fi + +# Probe: does nvcc accept -std=c++20? +NVCC_BIN="$(command -v nvcc || echo /usr/local/cuda/bin/nvcc)" +NVCC_C20_OK=0 +if "${NVCC_BIN}" --help 2>/dev/null | grep -q 'c++20'; then + NVCC_C20_OK=1 +fi + +if [[ "${NVCC_C20_OK}" == "1" ]]; then + # nvcc supports c++20: revert any prior c++17 downgrade we (or older + # versions of this script) introduced. SGLang templates require C++20. + if grep -q '"-std=c++17"' "${SGL_UTILS}" && [[ -f "${SGL_UTILS}.bak.cuda122" ]]; then + cp "${SGL_UTILS}.bak.cuda122" "${SGL_UTILS}" + echo "[patch 1/2] SGLang: restored c++20 from backup (nvcc supports c++20)" + elif grep -q '"-std=c++17"' "${SGL_UTILS}"; then + sed -i 's/"-std=c++17"/"-std=c++20"/g' "${SGL_UTILS}" + echo "[patch 1/2] SGLang: rewrote c++17 -> c++20 in place" + else + echo "[patch 1/2] SGLang: already at c++20 — skipping" + fi +else + if grep -q '"-std=c++20"' "${SGL_UTILS}"; then + cp "${SGL_UTILS}" "${SGL_UTILS}.bak.cuda122" + sed -i 's/"-std=c++20"/"-std=c++17"/g' "${SGL_UTILS}" + echo "[patch 1/2] SGLang: c++20 -> c++17 ($(grep -c '"-std=c++17"' "${SGL_UTILS}") sites; old nvcc)" + else + echo "[patch 1/2] SGLang already at c++17 — skipping" + fi +fi + +# -------- Patch 2: pydrive -> pydrive2 shim -------- +PYDRIVE_DIR="${SP}/pydrive" +PYDRIVE2_DIR="${SP}/pydrive2" +if [[ ! -d "${PYDRIVE2_DIR}" ]]; then + echo "[patch 2/2] pydrive2 not installed — installing" + "${AREAL_ENV_PREFIX}/bin/pip" install --no-cache-dir --progress-bar off \ + pydrive2 "oauth2client<4.1.4" +fi + +# Remove the unmaintained PyDrive (if it sneaked back in) before writing shim. +if "${AREAL_ENV_PREFIX}/bin/pip" show -q pydrive 2>/dev/null; then + "${AREAL_ENV_PREFIX}/bin/pip" uninstall -y pydrive >/dev/null +fi + +mkdir -p "${PYDRIVE_DIR}" +cat > "${PYDRIVE_DIR}/__init__.py" <<'PY' +"""Compatibility shim: redirect pydrive imports to pydrive2.""" +import sys +from pydrive2 import auth as _auth +from pydrive2 import drive as _drive +sys.modules.setdefault("pydrive.auth", _auth) +sys.modules.setdefault("pydrive.drive", _drive) +PY +echo "[patch 2/2] pydrive shim written to ${PYDRIVE_DIR}/__init__.py" + +# -------- Cleanup any stale JIT cache from prior failed builds -------- +TVM_FFI_CACHE="${HOME}/.cache/tvm-ffi" +if [[ -d "${TVM_FFI_CACHE}" ]]; then + rm -rf "${TVM_FFI_CACHE}/sgl_kernel_jit_"* + echo "[cleanup] purged stale sgl_kernel_jit_* under ${TVM_FFI_CACHE}" +fi + +echo "[apply_env_patches.sh] done" diff --git a/examples/osworld/config_osworld_sglang.yaml b/examples/osworld/config_osworld_sglang.yaml new file mode 100644 index 0000000000..a63f2a0fad --- /dev/null +++ b/examples/osworld/config_osworld_sglang.yaml @@ -0,0 +1,226 @@ +experiment_name: osworld-grpo +trial_name: trial1 + +seed: 1 +enable_offload: false +total_train_epochs: 5 +tokenizer_path: ${actor.path} + +# OSWorld agent / environment knobs (consumed by OSWorldAgentConfig). +n_trajs: 2 +max_steps: 15 +max_workers: 4 +sleep_after_execution: 1.0 +env_reset_wait_secs: 60.0 +turn_discount: 0.9 + +provider_name: docker +path_to_vm: null +os_type: Ubuntu +headless: true +screen_width: 1920 +screen_height: 1080 +observation_type: screenshot +action_space: pyautogui + +# Leave these empty to auto-discover: ../../../OSWorld relative to AReaL/. +osworld_root: "" +evaluation_examples_dir: "" +test_meta_path: "" +osworld_cache_dir: /tmp/areal/osworld_cache + +# Set to e.g. "http://:8000" to drive OSWorld on a separate host +# that actually has docker. Leave empty for in-process DesktopEnv. +remote_server_url: "" +remote_request_timeout_secs: 1800.0 + +# Vendor-neutral remote sandbox cluster behind an HTTPS gateway. Set both +# endpoint + token to route rollouts through the gateway. Takes precedence +# over remote_server_url. +gateway_endpoint: "" +gateway_token: "" +gateway_timeout_secs: 1800 + +# Smoke ablation: drop screenshots from user turns so a text-only base model +# can be used (skips the multimodal training path that needs mm_token_type_ids). +text_only: false + +cluster: + n_nodes: 1 + # Target: 2x H200 box. SGLang takes 1 GPU for inference, FSDP actor (and + # colocated ref) takes the other. If you redeploy to a 4/8-GPU box, bump + # n_gpus_per_node and re-shard rollout.backend / actor.backend. + n_gpus_per_node: 2 + # On a shared filesystem so logs/checkpoints are visible from any machine + # that mounts it. Changing from /tmp/... to a shared path also lets the + # other (non-GPU) host tail rollout/0.log when SGLang fails to come up. + fileroot: ${AREAL_REPO}/logs/experiments + name_resolve: + type: nfs + nfs_record_root: ${AREAL_REPO}/logs/name_resolve + +scheduler: + type: local + +rollout: + backend: "sglang:d1p1t1" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 16 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 4 + # First-time SGLang launch on a new model can exceed 300s (model load from + # beegfs + torch.compile + kernel autotune). Bump to 15min for safety. + setup_timeout: 900.0 + enable_rollout_tracing: false + scheduling_spec: ${actor.scheduling_spec} + fileroot: ${cluster.fileroot} + tokenizer_path: ${tokenizer_path} + dump_to_file: true + +gconfig: + n_samples: 1 + min_new_tokens: 0 + max_new_tokens: 1024 + greedy: false + temperature: 1.0 + +actor: + backend: "fsdp:d1p1t1" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: /path/to/Qwen3-VL-4B-Instruct + # Conda env has flash-attn-4 (alpha) which doesn't expose `flash_attn_func`, + # so transformers' default `flash_attention_2` import fails. Fall back to + # PyTorch SDPA — fast enough for 4B on H200. + attn_impl: sdpa + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: true + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 8192 + optimizer: + type: adam + lr: 1e-6 + weight_decay: 0.01 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.01 + eps_clip: 0.4 + temperature: ${gconfig.temperature} + reward_scaling: 1.0 + reward_bias: 0.0 + kl_ctl: 0.0 + ppo_n_minibatches: 1 + recompute_logprob: true + use_decoupled_loss: true + rejection_sampling: + level: token + action: mask + metric: ratio + upper: 5.0 + adv_norm: + mean_level: batch + std_level: batch + max_new_tokens: ${gconfig.max_new_tokens} + scheduling_spec: + - task_type: worker + port_count: 2 + gpu: 1 + cpu: 4 + mem: 32 + cmd: python3 -m areal.infra.rpc.rpc_server + env_vars: {} + +ref: + backend: ${actor.backend} + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + attn_impl: ${actor.attn_impl} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 8192 + optimizer: null + scheduling_strategy: + type: colocation + target: actor + scheduling_spec: ${actor.scheduling_spec} + +# SGLang (vision backend for Qwen2.5-VL). +sglang: + model_path: ${actor.path} + random_seed: ${seed} + skip_tokenizer_init: true + dtype: ${actor.dtype} + max_running_requests: 16 + context_length: 32768 + mem_fraction_static: 0.8 + enable_multimodal: true + +# Datasets — OSWorld tasks are loaded in-process from evaluation_examples, +# but train/valid dataset configs are still required by the harness for +# batching / worker setup. +train_dataset: + batch_size: 4 + shuffle: true + pin_memory: true + num_workers: 0 + path: osworld_tasks + type: rl + +valid_dataset: + batch_size: 4 + shuffle: false + pin_memory: true + num_workers: 0 + path: osworld_tasks + type: rl + +# Utilities +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: null + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled + +perf_tracer: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + enabled: false + session_tracer: + enabled: false diff --git a/examples/osworld/osworld_config.py b/examples/osworld/osworld_config.py new file mode 100644 index 0000000000..b1fc129d90 --- /dev/null +++ b/examples/osworld/osworld_config.py @@ -0,0 +1,49 @@ +from dataclasses import dataclass, field + +from areal.api.cli_args import GRPOConfig + + +@dataclass +class OSWorldAgentConfig(GRPOConfig): + n_trajs: int = field(default=1) + max_steps: int = field(default=15) + max_workers: int = field(default=4) + sleep_after_execution: float = field(default=1.0) + env_reset_wait_secs: float = field(default=60.0) + + provider_name: str = field(default="docker") + path_to_vm: str | None = field(default=None) + os_type: str = field(default="Ubuntu") + headless: bool = field(default=True) + screen_width: int = field(default=1920) + screen_height: int = field(default=1080) + observation_type: str = field(default="screenshot") + action_space: str = field(default="pyautogui") + + osworld_root: str = field(default="") + evaluation_examples_dir: str = field(default="") + test_meta_path: str = field(default="") + osworld_cache_dir: str = field(default="cache") + turn_discount: float = field(default=0.9) + + # When non-empty, the workflow skips the in-process `DesktopEnv` and + # proxies reset/step/evaluate/close to a `remote_server.py` running on + # a host that actually has docker available. Example: "http://10.0.0.5:8000". + remote_server_url: str = field(default="") + remote_request_timeout_secs: float = field(default=1800.0) + + # Remote sandbox cluster behind an HTTP/HTTPS gateway (preferred when + # available). Setting both fields routes the workflow through + # `gateway_sandbox.py`, which talks to the cluster endpoint via a + # vendor-provided SDK (see `gateway_sandbox.py` for the expected protocol) + # and OSWorld's gateway-proxied HTTP endpoints. + gateway_endpoint: str = field(default="") + gateway_token: str = field(default="") + gateway_timeout_secs: int = field(default=1800) + + # Smoke-only ablation: drop image content from user turns and feed the + # agent text-only stubs instead. Lets us point `actor.path` at a text-only + # base model (e.g. Qwen3-4B-Instruct) and exercise the full PPO loop, + # without needing the VL pipeline (`mm_token_type_ids`, etc.) wired up. + # The agent operates "blind" — useful only for plumbing verification. + text_only: bool = field(default=False) diff --git a/examples/osworld/osworld_requirements.txt b/examples/osworld/osworld_requirements.txt new file mode 100644 index 0000000000..df258d51d0 --- /dev/null +++ b/examples/osworld/osworld_requirements.txt @@ -0,0 +1,65 @@ +# Filtered OSWorld host-side deps for use alongside AReaL. +# +# Drops pins that clash with AReaL's install (torch>=2.9, transformers>=5.0, +# numpy>=2, Pillow), plus `accelerate`/`openai`/`anthropic`/`matplotlib`/ +# `pandas`/`tqdm`/`psutil`/`requests`/`filelock`/`backoff`/`tiktoken`/`wandb` +# which uv sync already installs. +fabric +gymnasium +pytz +pynput +pyautogui +flask +requests-toolbelt +ag2 +lxml +cssselect +xmltodict +openpyxl +python-docx +python-pptx +pypdf +PyGetWindow +rapidfuzz +pyacoustid +pygame +opencv-python-headless +ImageHash +scikit-image +librosa +pymupdf +chardet +playwright +formulas +pydrive +fastdtw +odfpy +func-timeout +beautifulsoup4 +dashscope +google-genai +google-generativeai +PyYaml +mutagen +easyocr +borb +pypdf2 +pdfplumber +wrapt_timeout_decorator +gdown +groq +boto3 +azure-identity +azure-mgmt-compute +azure-mgmt-network +docker +loguru +dotenv +tldextract +alibabacloud_ecs20140526 +alibabacloud_tea_openapi +alibabacloud_tea_util +json_minify +json_repair +# volcengine-python-sdk[ark] # only needed for volcengine provider +# ui-tars>=0.4.2.2 # only used by mm_agents.uitars_agent; not on internal mirror diff --git a/examples/osworld/remote_server.py b/examples/osworld/remote_server.py new file mode 100644 index 0000000000..3a625329e4 --- /dev/null +++ b/examples/osworld/remote_server.py @@ -0,0 +1,295 @@ +"""Flask server wrapping OSWorld's DesktopEnv for remote rollout. + +Deploy this on a machine that can actually run docker (KVM optional but +recommended). The AReaL training container then sends HTTP requests here via +``RemoteDesktopEnv`` — each training trajectory opens its own server-side +session, resets with a task config, pushes pyautogui actions, and finally asks +for an ``evaluate()`` reward. + +Run (on the remote docker machine): + + python remote_server.py \\ + --osworld-root /path/to/OSWorld \\ + --host 0.0.0.0 --port 8000 \\ + --max-envs 2 + +Concurrency is bounded by a global semaphore (``--max-envs``); creating more +sessions than that blocks until an existing one is closed. The workflow +creates a fresh env per trajectory and closes it when done, so the bound also +caps the simultaneously booted OSWorld docker VMs. + +No auth — intended for trusted internal networks only. +""" + +from __future__ import annotations + +import argparse +import base64 +import logging +import os +import sys +import threading +import traceback +import uuid +from typing import Any + +from flask import Flask, jsonify, request + +LOG_FORMAT = "%(asctime)s %(levelname)s %(name)s %(message)s" +logging.basicConfig(level=logging.INFO, format=LOG_FORMAT) +logger = logging.getLogger("osworld.remote_server") + + +def _import_desktop_env(osworld_root: str): + if osworld_root and osworld_root not in sys.path: + sys.path.insert(0, osworld_root) + from desktop_env.desktop_env import DesktopEnv # noqa: E402 + + return DesktopEnv + + +class SessionRegistry: + """Keeps per-session ``DesktopEnv`` instances with a global cap.""" + + def __init__(self, max_envs: int, desktop_env_cls): + self._max_envs = max_envs + self._cls = desktop_env_cls + self._sessions: dict[str, DesktopEnvSession] = {} + self._registry_lock = threading.Lock() + self._slots = threading.Semaphore(max_envs) + + def create(self, kwargs: dict[str, Any]) -> DesktopEnvSession: + # Block here until a slot is available. ``close`` releases it. + self._slots.acquire() + try: + env = self._cls(**kwargs) + except Exception: + self._slots.release() + raise + sid = uuid.uuid4().hex + session = DesktopEnvSession(sid=sid, env=env, slot_release=self._slots.release) + with self._registry_lock: + self._sessions[sid] = session + logger.info( + f"[session {sid}] created; active={len(self._sessions)}/{self._max_envs}" + ) + return session + + def get(self, sid: str) -> DesktopEnvSession: + with self._registry_lock: + session = self._sessions.get(sid) + if session is None: + raise KeyError(sid) + return session + + def drop(self, sid: str) -> None: + with self._registry_lock: + session = self._sessions.pop(sid, None) + if session is None: + return + session.close() + logger.info( + f"[session {sid}] closed; active={len(self._sessions)}/{self._max_envs}" + ) + + +class DesktopEnvSession: + """Wraps a single ``DesktopEnv`` with a lock so callers can't race it.""" + + def __init__(self, sid: str, env, slot_release): + self.sid = sid + self.env = env + self._lock = threading.Lock() + self._closed = False + self._slot_release = slot_release + + def call(self, fn_name: str, *args, **kwargs): + with self._lock: + if self._closed: + raise RuntimeError(f"session {self.sid} already closed") + fn = getattr(self.env, fn_name) + return fn(*args, **kwargs) + + def close(self) -> None: + with self._lock: + if self._closed: + return + self._closed = True + try: + self.env.close() + except Exception as e: # noqa: BLE001 + logger.warning(f"[session {self.sid}] env.close() failed: {e!r}") + finally: + self._slot_release() + + +# ---------------------------------------------------------------------- helpers + + +def _encode_obs(obs: dict[str, Any] | None) -> dict[str, Any]: + if not obs: + return {} + screenshot = obs.get("screenshot") or b"" + return { + "screenshot_b64": base64.b64encode(screenshot).decode("ascii"), + "accessibility_tree": obs.get("accessibility_tree"), + "terminal": obs.get("terminal"), + "instruction": obs.get("instruction"), + } + + +def _error(status: int, message: str): + return jsonify({"error": message}), status + + +# ----------------------------------------------------------------------- routes + + +def create_app(osworld_root: str, max_envs: int) -> Flask: + DesktopEnv = _import_desktop_env(osworld_root) + registry = SessionRegistry(max_envs=max_envs, desktop_env_cls=DesktopEnv) + app = Flask("osworld-remote-server") + + @app.get("/health") + def health(): + return jsonify( + { + "status": "ok", + "active_sessions": len(registry._sessions), + "max_envs": max_envs, + } + ) + + @app.post("/envs") + def create_env(): + body = request.get_json(silent=True) or {} + screen_size = body.get("screen_size") or [1920, 1080] + try: + kwargs = dict( + provider_name=body.get("provider_name", "docker"), + path_to_vm=body.get("path_to_vm"), + action_space=body.get("action_space", "pyautogui"), + cache_dir=body.get("cache_dir", "cache"), + screen_size=tuple(screen_size), + headless=bool(body.get("headless", True)), + os_type=body.get("os_type", "Ubuntu"), + require_a11y_tree=bool(body.get("require_a11y_tree", False)), + ) + session = registry.create(kwargs) + except Exception as e: # noqa: BLE001 + logger.error(f"create_env failed: {e!r}\n{traceback.format_exc()}") + return _error(500, f"DesktopEnv init failed: {e!r}") + return jsonify({"session_id": session.sid}) + + @app.post("/envs//reset") + def reset_env(sid): + body = request.get_json(silent=True) or {} + try: + session = registry.get(sid) + obs = session.call("reset", task_config=body.get("task_config")) + except KeyError: + return _error(404, f"unknown session {sid}") + except Exception as e: # noqa: BLE001 + logger.error( + f"[session {sid}] reset failed: {e!r}\n{traceback.format_exc()}" + ) + return _error(500, f"reset failed: {e!r}") + return jsonify({"obs": _encode_obs(obs)}) + + @app.get("/envs//obs") + def get_obs(sid): + try: + session = registry.get(sid) + obs = session.call("_get_obs") + except KeyError: + return _error(404, f"unknown session {sid}") + except Exception as e: # noqa: BLE001 + return _error(500, f"_get_obs failed: {e!r}") + return jsonify({"obs": _encode_obs(obs)}) + + @app.post("/envs//step") + def step_env(sid): + body = request.get_json(silent=True) or {} + action = body.get("action") + pause = float(body.get("pause", 0.0)) + if action is None: + return _error(400, "missing 'action' in body") + try: + session = registry.get(sid) + obs, reward, done, info = session.call("step", action, pause) + except KeyError: + return _error(404, f"unknown session {sid}") + except Exception as e: # noqa: BLE001 + logger.error( + f"[session {sid}] step failed: {e!r}\n{traceback.format_exc()}" + ) + return _error(500, f"step failed: {e!r}") + return jsonify( + { + "obs": _encode_obs(obs), + "reward": float(reward), + "done": bool(done), + "info": info if isinstance(info, dict) else {"raw": str(info)}, + } + ) + + @app.post("/envs//evaluate") + def evaluate_env(sid): + try: + session = registry.get(sid) + reward = session.call("evaluate") + except KeyError: + return _error(404, f"unknown session {sid}") + except Exception as e: # noqa: BLE001 + logger.error( + f"[session {sid}] evaluate failed: {e!r}\n{traceback.format_exc()}" + ) + return _error(500, f"evaluate failed: {e!r}") + return jsonify({"reward": float(reward)}) + + @app.post("/envs//close") + def close_env(sid): + try: + registry.drop(sid) + except Exception as e: # noqa: BLE001 + logger.warning(f"[session {sid}] close error: {e!r}") + return jsonify({"ok": True}) + + return app + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--osworld-root", + default=os.environ.get("OSWORLD_ROOT", ""), + help="Path to OSWorld checkout; also read from $OSWORLD_ROOT.", + ) + parser.add_argument("--host", default="0.0.0.0") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument( + "--max-envs", + type=int, + default=2, + help="Global cap on simultaneously-alive DesktopEnv sessions.", + ) + args = parser.parse_args() + + if not args.osworld_root: + parser.error("--osworld-root is required (or set $OSWORLD_ROOT)") + if not os.path.isdir(args.osworld_root): + parser.error(f"OSWorld root does not exist: {args.osworld_root}") + + app = create_app(osworld_root=args.osworld_root, max_envs=args.max_envs) + + logger.info( + f"Serving OSWorld on http://{args.host}:{args.port} " + f"(OSWorld={args.osworld_root}, max_envs={args.max_envs})" + ) + # Flask's dev server with threaded=True is enough for our 1–2 concurrent + # sessions; swap for gunicorn if you need more throughput later. + app.run(host=args.host, port=args.port, threaded=True, use_reloader=False) + + +if __name__ == "__main__": + main() diff --git a/examples/osworld/run_train.sh b/examples/osworld/run_train.sh new file mode 100755 index 0000000000..c676d068c0 --- /dev/null +++ b/examples/osworld/run_train.sh @@ -0,0 +1,180 @@ +#!/usr/bin/env bash +# Launch OSWorld GRPO training against a vendor-neutral remote sandbox +# cluster behind an HTTPS gateway. +# +# Usage: +# bash examples/osworld/run_train.sh [smoke|full] [extra hydra-style overrides...] +# +# Examples: +# # smoke run with the defaults below +# bash examples/osworld/run_train.sh smoke +# +# # full run with bumped concurrency +# bash examples/osworld/run_train.sh full rollout.max_concurrent_rollouts=4 n_trajs=2 +# +# # override base model on the fly +# bash examples/osworld/run_train.sh smoke actor.path=/path/to/other/model +# +# Required env vars: +# OSWORLD_SANDBOX_TOKEN — application secret for the gateway. Don't commit this. +# OSWORLD_SANDBOX_ENDPOINT — gateway URL (no default; must be set explicitly). +# +# Optional env vars: +# AREAL_TEXT_ONLY_MODEL — path to a text-only base model; required for the +# `smoke-text` stage. No default. +# AREAL_ENV_PREFIX — conda env prefix (default: ../../../env) +# AREAL_REPO — AReaL checkout root (default: parent of this script's grandparent) +# CONDA_PREFIX_BASE — base conda install (default: $HOME/conda); used to +# source `etc/profile.d/conda.sh`. Override if conda +# lives elsewhere (e.g. /opt/conda). +# NO_PROXY — pre-existing no-proxy list; this script appends its +# own generic CIDRs. Append your own internal domains +# here before invocation if needed. +# STAGE — alternative to first positional arg + +set -euo pipefail + +# -------- locate repo + env -------- + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +AREAL_REPO="${AREAL_REPO:-$(cd "${SCRIPT_DIR}/../.." && pwd)}" +AREAL_ENV_PREFIX="${AREAL_ENV_PREFIX:-${AREAL_REPO}/../env}" + +if [[ ! -d "${AREAL_ENV_PREFIX}" ]]; then + echo "[run_train.sh] conda env not found at: ${AREAL_ENV_PREFIX}" >&2 + echo " set AREAL_ENV_PREFIX or rebuild via SETUP.md" >&2 + exit 1 +fi + +# Activate conda env (we need the real `conda activate` shell function). +# Conda's binutils activation script touches unbound vars, so relax `nounset` +# across the activate call. Override CONDA_PREFIX_BASE if conda lives outside +# $HOME/conda (e.g. CONDA_PREFIX_BASE=/opt/conda). +# shellcheck disable=SC1091 +source "${CONDA_PREFIX_BASE:-$HOME/conda}/etc/profile.d/conda.sh" +set +u +conda activate "${AREAL_ENV_PREFIX}" +set -u + +# Disable Python stdout/stderr block buffering so SGLang server logs reach +# the trial's rollout/0.log in real time — otherwise a setup-timeout SIGTERM +# truncates the buffer and we never see why SGLang failed to come up. +export PYTHONUNBUFFERED=1 + +# A corporate HTTP proxy (HTTP_PROXY / HTTPS_PROXY) cannot reach internal +# addresses, but the trainer's `_wait_for_server` uses `requests`, which +# honors those env vars by default. Without NO_PROXY the health probe to the +# local SGLang server (33.x / 10.x / localhost) is routed through the proxy +# and gets ECONNREFUSED, eventually triggering the 900s setup_timeout SIGTERM. +# Append our hosts/CIDRs to NO_PROXY (don't clobber an existing one). If you +# need to bypass the proxy for additional internal domains, set NO_PROXY to +# include them before invoking this script. +_areal_no_proxy_extra="localhost,127.0.0.1,0.0.0.0,33.0.0.0/8,10.0.0.0/8" +if [[ -n "${NO_PROXY:-}" ]]; then + export NO_PROXY="${NO_PROXY},${_areal_no_proxy_extra}" +else + export NO_PROXY="${_areal_no_proxy_extra}" +fi +export no_proxy="${NO_PROXY}" + +# -------- pick stage + token -------- + +STAGE="${1:-${STAGE:-smoke}}" +shift || true + +if [[ "${STAGE}" != "smoke" && "${STAGE}" != "full" && "${STAGE}" != "smoke-text" ]]; then + echo "[run_train.sh] first arg must be 'smoke', 'smoke-text', or 'full', got '${STAGE}'" >&2 + exit 2 +fi + +if [[ -z "${OSWORLD_SANDBOX_ENDPOINT:-}" ]]; then + echo "[run_train.sh] OSWORLD_SANDBOX_ENDPOINT must be set in env" >&2 + echo " e.g. export OSWORLD_SANDBOX_ENDPOINT=https://your-gateway.example.com/path" >&2 + exit 3 +fi +if [[ -z "${OSWORLD_SANDBOX_TOKEN:-}" ]]; then + echo "[run_train.sh] OSWORLD_SANDBOX_TOKEN must be set in env" >&2 + echo " e.g. export OSWORLD_SANDBOX_TOKEN=sk-..." >&2 + exit 3 +fi + +# Auto-derive a trial name unless caller pinned one through extra args. +TRIAL_NAME="${STAGE}-$(date +%Y%m%d-%H%M%S)" + +# -------- build override list -------- + +# Common overrides (every stage gets them). +COMMON_ARGS=( + "gateway_endpoint=${OSWORLD_SANDBOX_ENDPOINT}" + "gateway_token=${OSWORLD_SANDBOX_TOKEN}" + "env_reset_wait_secs=30" + "trial_name=${TRIAL_NAME}" +) + +# Stage-specific defaults; extra positional args override these. +# AREAL_TEXT_ONLY_MODEL must be set by the user to run the `smoke-text` stage +# (no default — set it to a local HF checkpoint path). +_TEXT_ONLY_MODEL="${AREAL_TEXT_ONLY_MODEL:-}" + +case "${STAGE}" in + smoke) + STAGE_ARGS=( + "experiment_name=osworld-grpo-smoke" + "rollout.max_concurrent_rollouts=1" + "n_trajs=1" + "max_steps=3" + "train_dataset.batch_size=1" + "total_train_epochs=1" + ) + ;; + smoke-text) + # Plumbing smoke against a text-only base model. Strips screenshots + # from the workflow so we don't need the VL training path + # (mm_token_type_ids / multi_modal_input). Agent operates blind; + # this is only useful for verifying the full PPO loop end-to-end. + if [[ -z "${_TEXT_ONLY_MODEL}" ]]; then + echo "[run_train.sh] stage 'smoke-text' requires AREAL_TEXT_ONLY_MODEL to be set" >&2 + echo " e.g. export AREAL_TEXT_ONLY_MODEL=/path/to/Qwen3-4B-Instruct-2507" >&2 + exit 4 + fi + STAGE_ARGS=( + "experiment_name=osworld-grpo-smoke-text" + "rollout.max_concurrent_rollouts=1" + "n_trajs=1" + "max_steps=3" + "train_dataset.batch_size=1" + "total_train_epochs=1" + "text_only=true" + "actor.path=${_TEXT_ONLY_MODEL}" + "tokenizer_path=${_TEXT_ONLY_MODEL}" + "sglang.enable_multimodal=false" + ) + ;; + full) + STAGE_ARGS=( + "experiment_name=osworld-grpo" + "rollout.max_concurrent_rollouts=2" + "n_trajs=1" + "max_steps=15" + "train_dataset.batch_size=2" + ) + ;; +esac + +# -------- launch -------- + +cd "${AREAL_REPO}" + +CONFIG_PATH="examples/osworld/config_osworld_sglang.yaml" + +echo "[run_train.sh] stage=${STAGE} trial_name=${TRIAL_NAME}" +echo "[run_train.sh] env=${AREAL_ENV_PREFIX}" +echo "[run_train.sh] python=$(which python)" +echo "[run_train.sh] config=${CONFIG_PATH}" +echo "[run_train.sh] overrides:" "${COMMON_ARGS[@]}" "${STAGE_ARGS[@]}" "$@" + +exec python -m examples.osworld.train \ + --config "${CONFIG_PATH}" \ + "${COMMON_ARGS[@]}" \ + "${STAGE_ARGS[@]}" \ + "$@" diff --git a/examples/osworld/smoke.py b/examples/osworld/smoke.py new file mode 100644 index 0000000000..ac37447331 --- /dev/null +++ b/examples/osworld/smoke.py @@ -0,0 +1,115 @@ +"""End-to-end smoke test for the gateway-sandbox-backed DesktopEnv. + +Verifies the full reset → step → evaluate → close cycle against a real +OSWorld task without spinning up the trainer. Pass the cluster endpoint and +secret token via env vars (``OSWORLD_SANDBOX_ENDPOINT`` and +``OSWORLD_SANDBOX_TOKEN``). + +Run: + python examples/osworld/smoke.py + +Expected output (on a clean run): a real screenshot byte count, a successful +step, and ``evaluate result: 0.0`` (the agent didn't solve the task, but the +evaluator returned a real number — that's what we want to confirm). +""" + +from __future__ import annotations + +import json +import logging +import os +import sys + +REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) +OSWORLD_ROOT = os.path.join(REPO_ROOT, "OSWorld") +AREAL_ROOT = os.path.join(REPO_ROOT, "AReaL") + +# Ensure the AReaL repo is importable when running this script directly. +sys.path.insert(0, AREAL_ROOT) + + +def _pick_simple_task(osworld_root: str) -> dict: + """Pick a task whose setup verbs all live in our supported set.""" + meta = json.load( + open(os.path.join(osworld_root, "evaluation_examples", "test_small.json")) + ) + supported = {"launch", "execute", "command", "sleep", "open", "activate_window"} + for domain, eids in meta.items(): + for eid in eids: + cfg = json.load( + open( + os.path.join( + osworld_root, + "evaluation_examples", + "examples", + domain, + f"{eid}.json", + ) + ) + ) + verbs = {c.get("type") for c in cfg.get("config", [])} + if verbs.issubset(supported): + return cfg + raise RuntimeError("no suitable task found in test_small.json") + + +def main() -> None: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s %(message)s", + ) + + endpoint = os.environ.get("OSWORLD_SANDBOX_ENDPOINT", "") + token = os.environ.get("OSWORLD_SANDBOX_TOKEN", "") + if not endpoint or not token: + raise SystemExit( + "OSWORLD_SANDBOX_ENDPOINT and OSWORLD_SANDBOX_TOKEN must be set " + "to the URL and authentication token of your remote sandbox " + "gateway before running this smoke test." + ) + + from examples.osworld.workflow.gateway_sandbox import make_sandbox_desktop_env + + task_config = _pick_simple_task(OSWORLD_ROOT) + print(f"[task] id={task_config.get('id')}") + print(f"[task] instruction: {task_config.get('instruction')!r}") + print( + f"[task] config verbs: {[c.get('type') for c in task_config.get('config', [])]}" + ) + + env = make_sandbox_desktop_env( + osworld_root=OSWORLD_ROOT, + cluster_endpoint=endpoint, + secret_token=token, + cache_dir="/tmp/areal/sandbox_cache", + require_a11y_tree=False, + ) + + try: + print("\n--- reset(task_config) ---") + obs = env.reset(task_config=task_config) + print(f"reset ok; screenshot bytes={len(obs.get('screenshot') or b'')}") + + print("\n--- step (no-op pyautogui) ---") + obs2, reward, done, info = env.step("pyautogui.moveTo(960, 540)", 1.0) + print( + f"step ok; reward={reward} done={done} screenshot bytes={len(obs2.get('screenshot') or b'')}" + ) + + print("\n--- evaluate ---") + try: + r = env.evaluate() + print(f"evaluate result: {r}") + except Exception: + import traceback + + traceback.print_exc() + + finally: + print("\n--- close ---") + env.close() + print("closed") + + +if __name__ == "__main__": + main() diff --git a/examples/osworld/train.py b/examples/osworld/train.py new file mode 100644 index 0000000000..cf723478e4 --- /dev/null +++ b/examples/osworld/train.py @@ -0,0 +1,166 @@ +"""GRPO training on OSWorld (desktop-control) tasks with Qwen2.5-VL. + +Run (from AReaL repo root, inside this docker): + + python -m examples.osworld.train --config examples/osworld/config_osworld_sglang.yaml + +Before launching, make sure the OSWorld docker provider prerequisites (KVM + +``xlang/osworld-docker`` image) are satisfied on the host. See +``AReaL/examples/osworld/README.md`` for details. +""" + +from __future__ import annotations + +import json +import os + +os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") + +import sys +from pathlib import Path + +from datasets import Dataset + +from examples.osworld.osworld_config import OSWorldAgentConfig + +from areal import PPOTrainer +from areal.api.cli_args import load_expr_config +from areal.utils import seeding +from areal.utils.hf_utils import load_hf_tokenizer +from areal.utils.stats_logger import StatsLogger + +WORKFLOW_PATH = "examples.osworld.workflow.osworld_workflow.OSWorldWorkflow" + + +def _resolve_osworld_paths(config: OSWorldAgentConfig) -> tuple[str, str, str]: + """Return (osworld_root, evaluation_examples_dir, test_meta_path).""" + if config.osworld_root: + osworld_root = os.path.abspath(config.osworld_root) + else: + osworld_root = str((Path(__file__).resolve().parents[3] / "OSWorld").resolve()) + + evaluation_examples_dir = ( + os.path.abspath(config.evaluation_examples_dir) + if config.evaluation_examples_dir + else os.path.join(osworld_root, "evaluation_examples") + ) + + test_meta_path = ( + os.path.abspath(config.test_meta_path) + if config.test_meta_path + else os.path.join(evaluation_examples_dir, "test_small.json") + ) + return osworld_root, evaluation_examples_dir, test_meta_path + + +def _build_tasks_dataset(evaluation_examples_dir: str, test_meta_path: str) -> Dataset: + """Load OSWorld task metas into a flat Hugging Face Dataset. + + Each row mirrors the on-disk example JSON so the workflow can consume it + without re-reading files during rollout. + """ + with open(test_meta_path, encoding="utf-8") as f: + meta = json.load(f) + + rows: list[dict] = [] + for domain, example_ids in meta.items(): + for example_id in example_ids: + example_path = ( + Path(evaluation_examples_dir) + / "examples" + / domain + / f"{example_id}.json" + ) + with open(example_path, encoding="utf-8") as f: + task = json.load(f) + rows.append( + { + "domain": domain, + "example_id": example_id, + "id": task.get("id", example_id), + "instruction": task["instruction"], + # Keep the full task dict as a JSON string to avoid the + # schema explosion that datasets infers from nested dicts. + "task_config_json": json.dumps(task, ensure_ascii=False), + } + ) + if not rows: + raise ValueError( + f"No OSWorld tasks found under {evaluation_examples_dir} " + f"via meta file {test_meta_path}" + ) + return Dataset.from_list(rows) + + +def main(args): + config, _ = load_expr_config(args, OSWorldAgentConfig) + + rank = int(os.getenv("RANK", "0")) + tokenizer = load_hf_tokenizer(config.tokenizer_path) + + seeding.set_random_seed(config.seed, key=f"trainer{rank}") + + osworld_root, evaluation_examples_dir, test_meta_path = _resolve_osworld_paths( + config + ) + + dataset = _build_tasks_dataset(evaluation_examples_dir, test_meta_path) + + # The workflow only needs the task config dict; inflate on the fly. + def _inflate(row): + row.update(json.loads(row.pop("task_config_json"))) + return row + + dataset = dataset.map(_inflate) + + workflow_kwargs = dict( + gconfig=config.gconfig, + tokenizer=tokenizer, + evaluation_examples_dir=evaluation_examples_dir, + osworld_root=osworld_root, + provider_name=config.provider_name, + path_to_vm=config.path_to_vm, + os_type=config.os_type, + headless=config.headless, + screen_size=(config.screen_width, config.screen_height), + observation_type=config.observation_type, + action_space=config.action_space, + cache_dir=config.osworld_cache_dir, + max_steps=config.max_steps, + n_trajs=config.n_trajs, + sleep_after_execution=config.sleep_after_execution, + env_reset_wait_secs=config.env_reset_wait_secs, + max_workers=config.max_workers, + turn_discount=config.turn_discount, + dump_dir=os.path.join( + StatsLogger.get_log_path(config.stats_logger), "generated" + ), + remote_server_url=config.remote_server_url, + remote_request_timeout_secs=config.remote_request_timeout_secs, + gateway_endpoint=config.gateway_endpoint, + gateway_token=config.gateway_token, + gateway_timeout_secs=config.gateway_timeout_secs, + text_only=config.text_only, + # Default the VL processor to the actor checkpoint dir; AutoProcessor + # picks up the matching preprocessor_config.json there. Workflow only + # consumes this when text_only=False. + processor_path=config.actor.path, + ) + + eval_workflow_kwargs = workflow_kwargs.copy() + + with PPOTrainer( + config, + train_dataset=dataset, + valid_dataset=dataset, + ) as trainer: + trainer.train( + workflow=WORKFLOW_PATH, + workflow_kwargs=workflow_kwargs, + eval_workflow=WORKFLOW_PATH, + eval_workflow_kwargs=eval_workflow_kwargs, + ) + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/examples/osworld/workflow/__init__.py b/examples/osworld/workflow/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/osworld/workflow/gateway_sandbox.py b/examples/osworld/workflow/gateway_sandbox.py new file mode 100644 index 0000000000..0817734441 --- /dev/null +++ b/examples/osworld/workflow/gateway_sandbox.py @@ -0,0 +1,705 @@ +"""OSWorld DesktopEnv backed by a vendor-neutral remote sandbox cluster gateway. + +The training container can't run docker, so we drive a remote OSWorld VM via +a vendor-provided SDK (imported as ``pssdk``). The gateway is a transparent +proxy onto OSWorld's stock VM-side HTTP server, so we can subclass OSWorld's +``PythonController`` and re-route its calls through the gateway. The setup +controller is reimplemented here for the verbs we care about — supporting +every OSWorld setup verb is out of scope. + +Expected SDK protocol (the ``pssdk`` module must export): + + BaseSandboxClusterTool( + cluster_endpoint, application_secret_token, session_id, + global_call_timeout, + ) + # Constructor. + + BaseSandboxClusterTool.session_id + # Property, str. The session id to authenticate gateway HTTP calls. + + BaseSandboxClusterTool.sandbox_start(body=None, call_timeout=...) -> dict + # Allocate / start a sandbox VM; returns a dict with at least + # ``sandboxId``. + + BaseSandboxClusterTool.sandbox_stop(call_timeout=...) -> dict + # Stop / release the sandbox VM. + + with_retry(**kwargs) + # Class decorator; wraps the cluster tool's methods so transient + # gateway errors auto-retry and resource-quota errors park instead + # of bubbling up. + +If your provider's SDK does not export these symbols under the name +``pssdk``, replace this transport with your own RemoteClusterClient +implementation. + +Layered as: + + BaseSandboxClusterTool (pssdk, lifecycle: start/stop/status) + ↓ + _GatewayTransport (this module: auth-aware GET/POST/form/raw) + ↓ + GatewaySandboxPythonController (subclass of PythonController — drop-in for + OSWorld's evaluators/getters) + GatewaySandboxSetupController (handcrafted; covers the common verbs) + ↓ + GatewaySandboxDesktopEnv (subclass of DesktopEnv — drop-in for our + workflow; skips provider/manager and + _start_emulator) + +Endpoint method/body conventions were verified by live probing — see +``REMOTE.md`` for the cheat sheet. +""" + +from __future__ import annotations + +import logging +import os +import sys +import time +from typing import Any +from urllib.parse import urljoin + +import requests + +try: + from pssdk import BaseSandboxClusterTool, with_retry + + _HAS_PSSDK = True +except ImportError: + _HAS_PSSDK = False + BaseSandboxClusterTool = ( + object # placeholder so class definition doesn't crash at import + ) + + def with_retry(**_kwargs): # no-op decorator + return lambda cls: cls + + +logger = logging.getLogger("GatewaySandbox") + + +# Wrap the cluster SDK in pssdk's retry decorator so transient gateway +# errors (5xx) auto-retry, and resource-quota 429s (`SandboxResourceLimitError`) +# park instead of bubbling up. Without this, a brief cluster spike kills the +# whole rollout. +@with_retry( + max_attempts=3, + retry_interval=10, + infinite_retry_on_resource_limit=True, + exclude_methods=[], +) +class _RetryingClusterTool(BaseSandboxClusterTool): + pass + + +def _ensure_osworld_on_path(osworld_root: str | None) -> None: + if osworld_root and osworld_root not in sys.path: + sys.path.insert(0, osworld_root) + + +# ---------------------------------------------------------------- transport + + +class _GatewayTransport: + """Auth-aware HTTP client that talks to the remote sandbox gateway. + + The pssdk's ``_gateway_request`` always sends JSON, but OSWorld's + ``/file`` endpoint requires form-urlencoded bodies and ``/screenshot`` + returns raw image bytes — both of which the SDK can't express. This + class is a tiny shim that issues the request directly with ``requests`` + while reusing the SDK's session id and secret token. + """ + + def __init__( + self, + cluster_endpoint: str, + secret_token: str, + session_id: str, + default_timeout: float = 600.0, + ) -> None: + self.cluster_endpoint = cluster_endpoint.rstrip("/") + self.secret_token = secret_token + self.session_id = session_id + self.default_timeout = default_timeout + self._session = requests.Session() + + def _headers(self) -> dict[str, str]: + return { + "x-paas-session-id": self.session_id, + "x-paas-secret-token": self.secret_token, + } + + def _url(self, uri: str) -> str: + return urljoin(self.cluster_endpoint + "/", uri.lstrip("/")) + + def request( + self, + method: str, + uri: str, + *, + json_body: dict[str, Any] | None = None, + form_body: dict[str, Any] | None = None, + timeout: float | None = None, + raw: bool = False, + ) -> Any: + method = method.upper() + url = self._url(uri) + headers = self._headers() + kwargs: dict[str, Any] = { + "timeout": timeout or self.default_timeout, + "headers": headers, + } + if json_body is not None: + kwargs["json"] = json_body + elif form_body is not None: + kwargs["data"] = form_body + r = self._session.request(method, url, **kwargs) + if r.status_code >= 400: + raise GatewayHTTPError(f"{method} {uri} → {r.status_code}: {r.text[:300]}") + if raw: + return r.content + try: + return r.json() + except ValueError: + return r.text + + +class GatewayHTTPError(RuntimeError): + pass + + +# ---------------------------------------------------- PythonController layer + + +def _make_python_controller(transport: _GatewayTransport): + """Build the controller as a subclass of OSWorld's PythonController. + + Defined as a factory so OSWorld is only imported lazily (the training + container has OSWorld on path via ``osworld_root``). + """ + from desktop_env.controllers.python import PythonController + + class GatewaySandboxPythonController(PythonController): + def __init__(self, t: _GatewayTransport) -> None: + self._t = t + self.pkgs_prefix = ( + "import pyautogui; import time; pyautogui.FAILSAFE = False; {command}" + ) + self.retry_times = 3 + self.retry_interval = 5 + # Base class reads these in places; keep them populated. + self.vm_ip = "sandbox" + self.http_server = "" + + # -------- observation getters -------- + + def get_screenshot(self): + for attempt in range(self.retry_times): + try: + data = self._t.request("GET", "/screenshot", timeout=30, raw=True) + if data and ( + data[:8] == b"\x89PNG\r\n\x1a\n" or data[:3] == b"\xff\xd8\xff" + ): + return data + logger.warning( + f"invalid screenshot payload (try {attempt + 1}/{self.retry_times})" + ) + except Exception as e: + logger.warning( + f"screenshot error (try {attempt + 1}/{self.retry_times}): {e!r}" + ) + time.sleep(self.retry_interval) + return None + + def get_accessibility_tree(self): + try: + payload = self._t.request("GET", "/accessibility", timeout=60) + return payload.get("AT") if isinstance(payload, dict) else None + except Exception as e: + logger.warning(f"accessibility fetch failed: {e!r}") + return None + + def get_terminal_output(self): + try: + payload = self._t.request("GET", "/terminal", timeout=30) + return payload.get("output") if isinstance(payload, dict) else None + except Exception as e: + # Common when no terminal is open in the VM; downgrade to debug. + logger.debug(f"terminal fetch failed: {e!r}") + return None + + def get_file(self, file_path: str): + # The gateway only forwards `application/json`, but OSWorld VM's + # /file reads `request.form` — incompatible. Work around by + # base64-piping the file through /execute. + cmd = ( + "import base64, os, sys; " + f"p = os.path.expandvars(os.path.expanduser({file_path!r})); " + "sys.stdout.write(base64.b64encode(open(p, 'rb').read()).decode())" + ) + try: + resp = self._t.request( + "POST", + "/execute", + json_body={"command": ["python3", "-c", cmd], "shell": False}, + timeout=120, + ) + except Exception as e: + logger.warning(f"get_file({file_path}) via /execute failed: {e!r}") + return None + if not isinstance(resp, dict) or resp.get("status") != "success": + logger.warning(f"get_file({file_path}) returned non-success: {resp!r}") + return None + import base64 + + try: + return base64.b64decode(resp.get("output") or "") + except Exception as e: + logger.warning(f"get_file({file_path}) base64 decode failed: {e!r}") + return None + + # -------- execute / scripts -------- + + def execute_python_command(self, command: str): + command_list = ["python", "-c", self.pkgs_prefix.format(command=command)] + for _ in range(self.retry_times): + try: + return self._t.request( + "POST", + "/execute", + json_body={"command": command_list, "shell": False}, + timeout=120, + ) + except Exception as e: + logger.warning(f"execute_python_command failed: {e!r}") + time.sleep(self.retry_interval) + return None + + def run_python_script(self, script: str): + try: + return self._t.request( + "POST", "/run_python", json_body={"code": script}, timeout=180 + ) + except Exception as e: + return { + "status": "error", + "message": str(e), + "output": "", + "error": repr(e), + } + + def run_bash_script( + self, script: str, timeout: int = 30, working_dir: str | None = None + ): + body: dict[str, Any] = {"script": script, "timeout": timeout} + if working_dir: + body["working_dir"] = working_dir + try: + return self._t.request( + "POST", "/run_bash_script", json_body=body, timeout=timeout + 60 + ) + except Exception as e: + return { + "status": "error", + "message": str(e), + "output": "", + "error": repr(e), + } + + # -------- VM info getters (mostly POST, empty body) -------- + + def _post_empty(self, uri: str): + try: + return self._t.request("POST", uri, json_body={}, timeout=30) + except Exception as e: + logger.warning(f"{uri} failed: {e!r}") + return None + + def get_vm_platform(self): + # GET returning a plain string ("Linux" / "Windows" / "Darwin"). + try: + return self._t.request("GET", "/platform", timeout=15) + except Exception as e: + logger.warning(f"/platform failed: {e!r}") + return None + + def get_vm_machine(self): + # OSWorld VM server has no /machine route; shell out to uname. + try: + resp = self._t.request( + "POST", + "/execute", + json_body={"command": ["uname", "-m"], "shell": False}, + timeout=15, + ) + if isinstance(resp, dict) and resp.get("status") == "success": + out = (resp.get("output") or "").strip() + if out: + return out + except Exception as e: + logger.warning(f"get_vm_machine via uname failed: {e!r}") + return "x86_64" # Sandbox image is amd64; safe fallback. + + def get_vm_screen_size(self): + return self._post_empty("/screen_size") + + def get_vm_window_size(self, app_class_name: str): + try: + return self._t.request( + "POST", + "/window_size", + form_body={"app_class_name": app_class_name}, + timeout=30, + ) + except Exception as e: + logger.warning(f"window_size failed: {e!r}") + return None + + def get_vm_wallpaper(self): + return self._post_empty("/wallpaper") + + def get_vm_desktop_path(self): + return self._post_empty("/desktop_path") + + def get_vm_directory_tree(self, path): + try: + return self._t.request( + "POST", "/list_directory", json_body={"path": path}, timeout=120 + ) + except Exception as e: + logger.warning(f"list_directory failed: {e!r}") + return None + + # -------- recording -------- + + def start_recording(self): + return self._post_empty("/start_recording") + + def end_recording(self, dest: str): + try: + data = self._t.request( + "POST", "/end_recording", json_body={}, raw=True, timeout=300 + ) + if data: + os.makedirs(os.path.dirname(dest) or ".", exist_ok=True) + with open(dest, "wb") as f: + f.write(data) + return True + except Exception as e: + logger.warning(f"end_recording failed: {e!r}") + return False + + return GatewaySandboxPythonController(transport) + + +# --------------------------------------------------- SetupController surrogate + + +class GatewaySandboxSetupController: + """Minimal OSWorld SetupController equivalent. + + Implements only the verbs we have actually observed in the AReaL training + workloads (covers ~95% of OSWorld test_small tasks): launch, download, + execute, open, chrome_open_tabs, activate_window, close_window, command, + sleep. Anything else is logged and skipped (so a partially-supported task + still runs end-to-end with a degraded reward signal). + """ + + SUPPORTED = frozenset( + [ + "launch", + "download", + "execute", + "open", + "chrome_open_tabs", + "activate_window", + "close_window", + "command", + "sleep", + "change_wallpaper", + ] + ) + + def __init__(self, transport: _GatewayTransport, cache_dir: str = "cache") -> None: + self._t = transport + self.cache_dir = cache_dir + os.makedirs(self.cache_dir, exist_ok=True) + + def reset_cache_dir(self, cache_dir: str) -> None: + self.cache_dir = cache_dir + os.makedirs(cache_dir, exist_ok=True) + + def setup(self, config_list: list[dict[str, Any]], use_proxy: bool = False) -> bool: + for i, item in enumerate(config_list or []): + verb = item.get("type") + params = item.get("parameters", {}) + if verb not in self.SUPPORTED: + logger.warning( + f"setup step {i + 1}: verb '{verb}' not implemented for sandbox; skipping" + ) + continue + handler = getattr(self, f"_{verb}_setup") + try: + handler(**params) + logger.info(f"setup step {i + 1}/{len(config_list)} ok: {verb}") + except Exception as e: + logger.error(f"setup step {i + 1} failed: {verb}({params}) → {e!r}") + return False + return True + + # ---- verb implementations ---- + + def _launch_setup(self, command, shell: bool = False) -> None: + if isinstance(command, str) and not shell: + import shlex + + command = shlex.split(command) + self._t.request( + "POST", + "/setup/launch", + json_body={"command": command, "shell": shell}, + timeout=60, + ) + + def _execute_setup(self, command, shell: bool = False, **_: Any) -> None: + if isinstance(command, str) and not shell: + import shlex + + command = shlex.split(command) + self._t.request( + "POST", + "/setup/execute", + json_body={"command": command, "shell": shell}, + timeout=120, + ) + + def _command_setup(self, command, **_: Any) -> None: + # OSWorld treats `command` as launch-like (Popen). Forward to /setup/launch with shell=True. + if isinstance(command, list): + cmd_str = " ".join(str(c) for c in command) + else: + cmd_str = str(command) + self._t.request( + "POST", + "/setup/launch", + json_body={"command": cmd_str, "shell": True}, + timeout=60, + ) + + def _sleep_setup(self, seconds: float) -> None: + time.sleep(float(seconds)) + + def _open_setup(self, path: str) -> None: + self._t.request( + "POST", "/setup/open_file", json_body={"path": path}, timeout=120 + ) + + def _activate_window_setup( + self, window_name: str, strict: bool = False, by_class: bool = False + ) -> None: + self._t.request( + "POST", + "/setup/activate_window", + json_body={ + "window_name": window_name, + "strict": strict, + "by_class": by_class, + }, + timeout=30, + ) + + def _close_window_setup( + self, window_name: str, strict: bool = False, by_class: bool = False + ) -> None: + self._t.request( + "POST", + "/setup/close_window", + json_body={ + "window_name": window_name, + "strict": strict, + "by_class": by_class, + }, + timeout=30, + ) + + def _change_wallpaper_setup(self, path: str) -> None: + self._t.request( + "POST", "/setup/change_wallpaper", json_body={"path": path}, timeout=30 + ) + + def _chrome_open_tabs_setup(self, urls_to_open: list[str]) -> None: + # OSWorld's stock implementation kills chrome, then launches it with the URLs. + self._t.request( + "POST", + "/setup/launch", + json_body={"command": ["pkill", "-f", "chrome"], "shell": False}, + timeout=20, + ) + time.sleep(2) + cmd = [ + "google-chrome", + "--no-first-run", + "--no-default-browser-check", + *urls_to_open, + ] + self._t.request( + "POST", + "/setup/launch", + json_body={"command": cmd, "shell": False}, + timeout=30, + ) + + def _download_setup(self, files: list[dict[str, str]]) -> None: + # OSWorld's _download_setup downloads a remote URL to the VM's local path. + # The gateway exposes /setup/download_file taking {url, path}. + for entry in files or []: + url = entry.get("url") + path = entry.get("path") + if not url or not path: + logger.warning(f"download entry missing url/path: {entry}") + continue + self._t.request( + "POST", + "/setup/download_file", + json_body={"url": url, "path": path}, + timeout=600, + ) + + +# -------------------------------------------------- DesktopEnv subclass + + +def _make_desktop_env_cls(): + """Lazy-build subclass of DesktopEnv to keep OSWorld imports lazy.""" + from desktop_env.desktop_env import DesktopEnv + + class GatewaySandboxDesktopEnv(DesktopEnv): + def __init__( + self, + *, + cluster_endpoint: str, + secret_token: str, + cache_dir: str = "cache", + screen_size: tuple[int, int] = (1920, 1080), + require_a11y_tree: bool = False, + require_terminal: bool = False, + os_type: str = "Ubuntu", + session_id: str | None = None, + global_call_timeout: int = 1800, + sandbox_start_body: dict[str, Any] | None = None, + ) -> None: + # Bypass DesktopEnv.__init__: we don't have a provider/manager and + # the lifecycle goes through pssdk instead. + self.region = None + self.provider_name = "sandbox" + self.enable_proxy = False + self.client_password = "password" + self.screen_width, self.screen_height = screen_size + self.server_port = 5000 + self.chromium_port = 9222 + self.vnc_port = 8006 + self.vlc_port = 8080 + self.current_use_proxy = False + self.os_type = os_type + self.is_environment_used = False + self.path_to_vm = "sandbox" + self.snapshot_name = "init_state" + self.cache_dir_base = cache_dir + self.cache_dir = cache_dir + self.headless = True + self.require_a11y_tree = require_a11y_tree + self.require_terminal = require_terminal + self.action_space = "pyautogui" + self.instruction = None + self._traj_no = -1 + self._step_no = 0 + self.action_history = [] + os.makedirs(self.cache_dir, exist_ok=True) + + self._tool = _RetryingClusterTool( + cluster_endpoint=cluster_endpoint, + application_secret_token=secret_token, + session_id=session_id, + global_call_timeout=global_call_timeout, + ) + logger.info(f"sandbox session id: {self._tool.session_id}") + start_info = self._tool.sandbox_start( + body=sandbox_start_body or {}, call_timeout=180 + ) + logger.info(f"sandbox started: {start_info.get('sandboxId') or start_info}") + + self._transport = _GatewayTransport( + cluster_endpoint=cluster_endpoint, + secret_token=secret_token, + session_id=self._tool.session_id, + default_timeout=global_call_timeout, + ) + self.controller = _make_python_controller(self._transport) + self.setup_controller = GatewaySandboxSetupController( + self._transport, cache_dir=self.cache_dir + ) + self.vm_ip = "sandbox" + + # ---- lifecycle overrides ---- + + def _start_emulator(self) -> None: # pragma: no cover - never called + return + + def _revert_to_snapshot(self) -> None: + # Pop the current sandbox and acquire a fresh one. + try: + self._tool.sandbox_stop(call_timeout=60) + except Exception as e: + logger.warning(f"sandbox_stop during revert failed: {e!r}") + start_info = self._tool.sandbox_start(call_timeout=180) + logger.info( + f"sandbox restarted: {start_info.get('sandboxId') or start_info}" + ) + self._transport.session_id = self._tool.session_id + + def close(self) -> None: + try: + self._tool.sandbox_stop(call_timeout=60) + except Exception as e: + logger.warning(f"sandbox_stop on close failed: {e!r}") + + return GatewaySandboxDesktopEnv + + +# Public factory used by the workflow. +def make_sandbox_desktop_env( + *, + osworld_root: str, + cluster_endpoint: str, + secret_token: str, + cache_dir: str = "cache", + screen_size: tuple[int, int] = (1920, 1080), + require_a11y_tree: bool = False, + os_type: str = "Ubuntu", + sandbox_start_body: dict[str, Any] | None = None, +): + """Create a `GatewaySandboxDesktopEnv` instance. + + OSWorld must be on the path; pass ``osworld_root`` so we can insert it + lazily before importing ``desktop_env`` symbols. + """ + if not _HAS_PSSDK: + raise RuntimeError( + "Gateway-based sandbox requires a vendor SDK that provides " + "BaseSandboxClusterTool / with_retry. Install your provider's " + "SDK (which exports the `pssdk` module) or replace this transport " + "with your own RemoteClusterClient implementation. See README." + ) + _ensure_osworld_on_path(osworld_root) + cls = _make_desktop_env_cls() + return cls( + cluster_endpoint=cluster_endpoint, + secret_token=secret_token, + cache_dir=cache_dir, + screen_size=screen_size, + require_a11y_tree=require_a11y_tree, + os_type=os_type, + sandbox_start_body=sandbox_start_body, + ) diff --git a/examples/osworld/workflow/osworld_workflow.py b/examples/osworld/workflow/osworld_workflow.py new file mode 100644 index 0000000000..03580cbc02 --- /dev/null +++ b/examples/osworld/workflow/osworld_workflow.py @@ -0,0 +1,567 @@ +from __future__ import annotations + +import asyncio +import base64 +import json +import os +import re +import sys +import uuid +from concurrent.futures import ThreadPoolExecutor +from functools import partial +from io import BytesIO +from pathlib import Path +from typing import Any + +import torch +from PIL import Image +from transformers import AutoProcessor, PreTrainedTokenizerFast + +from areal.api.cli_args import GenerationHyperparameters +from areal.api.workflow_api import RolloutWorkflow +from areal.experimental.openai import ArealOpenAI +from areal.utils import logging, stats_tracker + +logger = logging.getLogger("OSWorldWorkflow") + +OSWORLD_SYSTEM_PROMPT = ( + "You are a computer-use agent operating an Ubuntu desktop through pyautogui. " + "At each turn you will receive the user instruction and the current screenshot " + "of the desktop. Produce a short plan and then a single ```python``` code block " + "containing pyautogui code to execute. Use absolute pixel coordinates relative " + "to the screenshot. When the task is fully completed, reply with the single " + "token DONE on its own line (no code block). If the task is infeasible reply " + "with FAIL." +) + +_CODE_BLOCK_RE = re.compile(r"```(?:python)?\s*([\s\S]*?)```", re.IGNORECASE) + + +def _ensure_osworld_on_path(osworld_root: str | None) -> None: + if osworld_root and osworld_root not in sys.path: + sys.path.insert(0, osworld_root) + + +def _screenshot_to_data_uri(screenshot: bytes) -> str: + return "data:image/png;base64," + base64.b64encode(screenshot).decode("utf-8") + + +def _parse_actions(text: str) -> list[str]: + """Extract executable actions from the model reply. + + Returns a list with one pyautogui snippet, or a single control token + ("DONE" / "FAIL" / "WAIT") to be forwarded to DesktopEnv.step. + """ + text = text.strip() + code = _CODE_BLOCK_RE.findall(text) + if code: + snippet = code[-1].strip() + if snippet: + return [snippet] + for line in reversed(text.splitlines()): + token = line.strip().upper() + if token in {"DONE", "FAIL", "WAIT"}: + return [token] + return [] + + +class OSWorldWorkflow(RolloutWorkflow): + """Multi-turn VLM rollout workflow backed by OSWorld's DesktopEnv. + + Each episode spins up a fresh DesktopEnv for the given task, drives the + agent<->env loop through ``ArealOpenAI`` so completions/tokens are cached + for training, and returns interaction samples tagged with the final + env.evaluate() reward. + """ + + def __init__( + self, + gconfig: GenerationHyperparameters, + tokenizer: PreTrainedTokenizerFast, + evaluation_examples_dir: str, + osworld_root: str | None = None, + provider_name: str = "docker", + path_to_vm: str | None = None, + os_type: str = "Ubuntu", + headless: bool = True, + screen_size: tuple[int, int] = (1920, 1080), + observation_type: str = "screenshot", + action_space: str = "pyautogui", + cache_dir: str = "cache", + max_steps: int = 15, + n_trajs: int = 1, + sleep_after_execution: float = 1.0, + env_reset_wait_secs: float = 60.0, + max_workers: int = 4, + turn_discount: float = 0.9, + dump_dir: str | None = None, + rollout_stat_scope: str = "rollout", + remote_server_url: str = "", + remote_request_timeout_secs: float = 1800.0, + gateway_endpoint: str = "", + gateway_token: str = "", + gateway_timeout_secs: int = 1800, + text_only: bool = False, + processor_path: str | None = None, + ): + # Three transport modes; pick one. Precedence: gateway sandbox > + # custom remote server > in-process DesktopEnv. Only the in-process + # path needs OSWorld already importable in this container, so we can + # defer that for the remote modes — except the gateway path + # re-imports OSWorld lazily for PythonController/DesktopEnv + # subclassing, so we still need the path. + self.remote_server_url = remote_server_url.strip() + self.remote_request_timeout_secs = remote_request_timeout_secs + self.gateway_endpoint = gateway_endpoint.strip() + self.gateway_token = gateway_token.strip() + self.gateway_timeout_secs = gateway_timeout_secs + self.osworld_root = osworld_root + if self.gateway_endpoint and self.gateway_token: + _ensure_osworld_on_path(osworld_root) + elif not self.remote_server_url: + _ensure_osworld_on_path(osworld_root) + + self.gconfig = gconfig.new(n_samples=1) + self.tokenizer = tokenizer + self.evaluation_examples_dir = evaluation_examples_dir + self.provider_name = provider_name + self.path_to_vm = path_to_vm + self.os_type = os_type + self.headless = headless + self.screen_size = tuple(screen_size) + self.observation_type = observation_type + self.action_space = action_space + self.cache_dir = cache_dir + self.max_steps = max_steps + self.n_trajs = n_trajs + self.sleep_after_execution = sleep_after_execution + self.env_reset_wait_secs = env_reset_wait_secs + self.turn_discount = turn_discount + self.rollout_stat_scope = rollout_stat_scope + self.dump_dir = dump_dir + if self.dump_dir and not os.path.exists(self.dump_dir): + os.makedirs(self.dump_dir, exist_ok=True) + self.text_only = text_only + # Lazy-loaded multimodal processor for VL training. We only need it + # when text_only=False; loading is deferred to first use so the + # text-only smoke path doesn't pay the import / disk-read cost. + self._processor_path = processor_path + self._processor = None + self.executor = ThreadPoolExecutor(max_workers=max_workers) + + def _load_task_config(self, data: dict[str, Any]) -> dict[str, Any]: + if "config" in data and "evaluator" in data and "instruction" in data: + return dict(data) + domain = data["domain"] + example_id = data.get("example_id") or data.get("id") + path = ( + Path(self.evaluation_examples_dir) + / "examples" + / domain + / f"{example_id}.json" + ) + with open(path, encoding="utf-8") as f: + return json.load(f) + + def _build_env(self): + if self.gateway_endpoint and self.gateway_token: + from .gateway_sandbox import make_sandbox_desktop_env + + return make_sandbox_desktop_env( + osworld_root=self.osworld_root or "", + cluster_endpoint=self.gateway_endpoint, + secret_token=self.gateway_token, + cache_dir=self.cache_dir, + screen_size=self.screen_size, + require_a11y_tree=self.observation_type != "screenshot", + os_type=self.os_type, + ) + + if self.remote_server_url: + from .remote_desktop_env import RemoteDesktopEnv + + return RemoteDesktopEnv( + server_url=self.remote_server_url, + provider_name=self.provider_name, + path_to_vm=self.path_to_vm, + action_space=self.action_space, + cache_dir=self.cache_dir, + screen_size=self.screen_size, + headless=self.headless, + os_type=self.os_type, + require_a11y_tree=self.observation_type != "screenshot", + request_timeout_secs=self.remote_request_timeout_secs, + ) + + from desktop_env.desktop_env import DesktopEnv + + return DesktopEnv( + provider_name=self.provider_name, + path_to_vm=self.path_to_vm, + action_space=self.action_space, + cache_dir=self.cache_dir, + screen_size=self.screen_size, + headless=self.headless, + os_type=self.os_type, + require_a11y_tree=self.observation_type != "screenshot", + ) + + def _build_user_turn(self, text: str, screenshot: bytes) -> dict[str, Any]: + if self.text_only: + # Smoke ablation: skip the screenshot. The agent operates blind; + # we just want a real PPO step against a text-only base model. + stub = ( + f"\n[screenshot omitted in text_only mode; " + f"{len(screenshot or b'')} bytes available]" + ) + return {"role": "user", "content": text + stub} + return { + "role": "user", + "content": [ + {"type": "text", "text": text}, + { + "type": "image_url", + "image_url": {"url": _screenshot_to_data_uri(screenshot)}, + }, + ], + } + + async def _run_single_trajectory( + self, + engine, + task_config: dict[str, Any], + traj_idx: int, + ) -> tuple[dict[str, Any] | None, float | None]: + instruction = task_config["instruction"] + task_id = task_config.get("id", "unknown") + loop = asyncio.get_running_loop() + + env = None + try: + env = await loop.run_in_executor(self.executor, self._build_env) + await loop.run_in_executor( + self.executor, partial(env.reset, task_config=task_config) + ) + await asyncio.sleep(self.env_reset_wait_secs) + obs = await loop.run_in_executor(self.executor, env._get_obs) + + client = ArealOpenAI(engine=engine, tokenizer=self.tokenizer) + messages: list[dict[str, Any]] = [ + {"role": "system", "content": OSWORLD_SYSTEM_PROMPT}, + self._build_user_turn( + f"Task instruction: {instruction}\nHere is the current screenshot.", + obs["screenshot"], + ), + ] + + terminated = False + step_idx = 0 + last_response_id: str | None = None + for step_idx in range(self.max_steps): + response = await client.chat.completions.create( + messages=messages, + **self.gconfig.to_openai_args_dict(), + ) + last_response_id = response.id + reply = response.choices[0].message + reply_text = reply.content or "" + messages.append(reply.model_dump(exclude_none=True)) + + actions = _parse_actions(reply_text) + if not actions: + logger.warning( + f"[task={task_id} traj={traj_idx}] step {step_idx + 1}: " + "no parseable action; feeding back a reminder." + ) + messages.append( + self._build_user_turn( + "Your previous reply did not contain a valid pyautogui " + "code block or control token. Please issue an action.", + obs["screenshot"], + ) + ) + continue + + for action in actions: + obs, _, done, info = await loop.run_in_executor( + self.executor, + partial(env.step, action, self.sleep_after_execution), + ) + if done: + terminated = True + break + if terminated: + break + messages.append( + self._build_user_turn( + f"Step {step_idx + 1} executed. Here is the new screenshot.", + obs["screenshot"], + ) + ) + + final_reward = float( + await loop.run_in_executor(self.executor, env.evaluate) + ) + logger.info( + f"[task={task_id} traj={traj_idx}] finished after " + f"{step_idx + 1} steps, reward={final_reward:.3f}" + ) + + stats_tracker.get(self.rollout_stat_scope).scalar( + reward=final_reward, num_steps=step_idx + 1 + ) + + if last_response_id is not None: + client.set_reward(last_response_id, final_reward) + client.apply_reward_discount(turn_discount=self.turn_discount) + completions = client.export_interactions(style="individual") + + # Inject multimodal training tensors when running against a Qwen-VL + # base. ArealOpenAI's cache only stores text token / logp / reward + # data; FSDPEngine._prepare_mb_list (areal/engine/fsdp_engine.py) + # additionally requires `mm_token_type_ids` and `multi_modal_input` + # for any `is_qwen_vl_model`. We re-run the HF processor on each + # turn's prefix to recover those fields and stash them on the + # interaction's `_cache` so the downstream `to_tensor_dict()` call + # returns them as-is. Skipped in text_only mode. + if completions and not self.text_only: + self._attach_vl_tensor_dicts(completions) + + if self.dump_dir is not None: + self._dump_trajectory(task_id, traj_idx, messages, final_reward) + + return completions, final_reward + except Exception as e: + logger.error(f"[task={task_id} traj={traj_idx}] trajectory failed: {e!r}") + return None, None + finally: + if env is not None: + try: + await loop.run_in_executor(self.executor, env.close) + except Exception as e: + logger.warning(f"Failed to close env: {e!r}") + + def _dump_trajectory( + self, + task_id: str, + traj_idx: int, + messages: list[dict[str, Any]], + reward: float, + ) -> None: + # Strip base64 image bytes before dumping so the log stays readable. + def _sanitize(msg: dict[str, Any]) -> dict[str, Any]: + if isinstance(msg.get("content"), list): + parts = [] + for part in msg["content"]: + if isinstance(part, dict) and part.get("type") == "image_url": + parts.append({"type": "image_url", "image_url": ""}) + else: + parts.append(part) + return {**msg, "content": parts} + return msg + + path = os.path.join( + self.dump_dir, f"{task_id}_traj{traj_idx}_{uuid.uuid4().hex[:6]}.json" + ) + with open(path, "w", encoding="utf-8") as f: + json.dump( + { + "task_id": task_id, + "traj_idx": traj_idx, + "reward": reward, + "messages": [_sanitize(m) for m in messages], + }, + f, + ensure_ascii=False, + indent=2, + ) + + # ------------------------------------------------------------------ + # VL bridge: hand FSDPEngine the multimodal tensors it expects. + # ------------------------------------------------------------------ + + def _get_processor(self): + """Lazy-load HF processor for the VL base model.""" + if self._processor is None: + path = self._processor_path + if not path: + raise RuntimeError( + "OSWorldWorkflow needs `processor_path` set when " + "text_only=False — point it at the same HuggingFace dir " + "as actor.path so we can recover mm_token_type_ids etc." + ) + self._processor = AutoProcessor.from_pretrained(path) + return self._processor + + @staticmethod + def _decode_data_uri_image(url: str) -> Image.Image | None: + if not url.startswith("data:image"): + return None + try: + _, b64 = url.split(",", 1) + except ValueError: + return None + try: + return Image.open(BytesIO(base64.b64decode(b64))).convert("RGB") + except Exception as e: + logger.warning(f"Failed to decode image data URI: {e!r}") + return None + + def _split_messages_to_text_and_images( + self, messages: list[dict[str, Any]] + ) -> tuple[list[dict[str, Any]], list[Image.Image]]: + """Convert workflow messages → processor messages + PIL image list. + + Workflow messages carry images as ``image_url`` data URIs; the HF + processor wants ``{"type": "image"}`` placeholders alongside an + external ``images=`` list. Iterate the conversation, peel out each + decoded image, replace the part with the placeholder. + """ + out_messages: list[dict[str, Any]] = [] + images: list[Image.Image] = [] + for msg in messages: + content = msg.get("content") + if not isinstance(content, list): + # Plain string content (system / assistant text) — pass through. + out_messages.append(msg) + continue + new_parts: list[dict[str, Any]] = [] + for part in content: + if not isinstance(part, dict): + continue + if part.get("type") == "image_url": + img = self._decode_data_uri_image( + part.get("image_url", {}).get("url", "") + ) + if img is None: + # Drop unreadable images rather than mis-aligning the + # token stream. + continue + images.append(img) + new_parts.append({"type": "image"}) + elif part.get("type") == "text": + new_parts.append({"type": "text", "text": part.get("text", "")}) + # ignore other part types (audio, video, etc.) + out_messages.append({"role": msg["role"], "content": new_parts}) + return out_messages, images + + def _build_vl_tensor_dict( + self, interaction + ) -> dict[str, torch.Tensor | list] | None: + """Run processor on the interaction prefix; produce a VL tensor dict. + + Returns a dict shaped like ``InteractionWithTokenLogpReward.to_tensor_dict`` + plus the multimodal extras (``mm_token_type_ids``, ``multi_modal_input``). + Returns ``None`` if the interaction has no ``model_response`` (we can't + contribute output tokens) — caller should fall back to the text-only + export. + """ + resp = getattr(interaction, "model_response", None) + if resp is None: + return None + prefix_messages = list(interaction.messages or []) + proc_messages, images = self._split_messages_to_text_and_images(prefix_messages) + processor = self._get_processor() + + text = processor.apply_chat_template( + proc_messages, + tokenize=False, + add_generation_prompt=True, + ) + kwargs: dict[str, Any] = dict(text=[text], padding=False, return_tensors="pt") + if images: + kwargs["images"] = images + proc_out = processor(**kwargs) + + input_ids = proc_out["input_ids"][0].tolist() + if "mm_token_type_ids" in proc_out: + mm_token_type_ids = proc_out["mm_token_type_ids"][0].tolist() + else: + mm_token_type_ids = [0] * len(input_ids) + + output_ids = list(resp.output_tokens or []) + output_logprobs = list(resp.output_logprobs or []) + output_versions = list(resp.output_versions or []) + # Defensive: pad / truncate the per-token streams to match output_ids. + # SGLang occasionally returns short logprob arrays; avoid blowing up + # downstream tensor builds when that happens. + if len(output_logprobs) < len(output_ids): + output_logprobs += [0.0] * (len(output_ids) - len(output_logprobs)) + else: + output_logprobs = output_logprobs[: len(output_ids)] + if len(output_versions) < len(output_ids): + output_versions += [-1] * (len(output_ids) - len(output_versions)) + else: + output_versions = output_versions[: len(output_ids)] + + seq = input_ids + output_ids + full_mm = mm_token_type_ids + [0] * len(output_ids) + loss_mask = [0] * len(input_ids) + [1] * len(output_ids) + logprobs = [0.0] * len(input_ids) + output_logprobs + versions = [-1] * len(input_ids) + output_versions + + multi_modal_input: list[dict[str, torch.Tensor]] = [] + if images and "pixel_values" in proc_out: + entry: dict[str, torch.Tensor] = {"pixel_values": proc_out["pixel_values"]} + if "image_grid_thw" in proc_out: + entry["image_grid_thw"] = proc_out["image_grid_thw"] + multi_modal_input.append(entry) + + reward = float(interaction.reward) if interaction.reward is not None else 0.0 + return { + "input_ids": torch.tensor(seq, dtype=torch.long).unsqueeze(0), + "mm_token_type_ids": torch.tensor(full_mm, dtype=torch.long).unsqueeze(0), + "loss_mask": torch.tensor(loss_mask, dtype=torch.int32).unsqueeze(0), + "logprobs": torch.tensor(logprobs, dtype=torch.float32).unsqueeze(0), + "versions": torch.tensor(versions, dtype=torch.int32).unsqueeze(0), + "attention_mask": torch.ones(len(seq), dtype=torch.bool).unsqueeze(0), + "rewards": torch.tensor([reward], dtype=torch.float32), + "multi_modal_input": multi_modal_input, + } + + def _attach_vl_tensor_dicts(self, completions: dict[str, Any]) -> None: + """Pre-populate each interaction's `_cache` with VL-augmented tensors. + + ``InteractionWithTokenLogpReward.to_tensor_dict()`` short-circuits and + returns ``self._cache`` if set, so this is a non-invasive override. + """ + bad_ids: list[str] = [] + for iid, interaction in completions.items(): + try: + tensor_dict = self._build_vl_tensor_dict(interaction) + except Exception as e: + logger.warning(f"VL tensor build failed for interaction {iid}: {e!r}") + tensor_dict = None + if tensor_dict is None: + bad_ids.append(iid) + continue + interaction._cache = tensor_dict + for iid in bad_ids: + # Drop interactions we couldn't bridge — they'd otherwise crash + # FSDPEngine on missing `mm_token_type_ids`. + completions.pop(iid, None) + + async def arun_episode(self, engine, data: dict[str, Any]): + task_config = self._load_task_config(data) + + results = await asyncio.gather( + *[ + self._run_single_trajectory(engine, task_config, i) + for i in range(self.n_trajs) + ] + ) + + stats_tracker.get(self.rollout_stat_scope).scalar( + num_trajectories_failed=sum(1 for r in results if r[0] is None), + num_full_passes=sum(1 for r in results if r[1] is not None and r[1] >= 1.0), + ) + + merged: dict[str, Any] = {} + for completions, _ in results: + if completions: + merged.update(completions) + + if not merged: + logger.warning(f"All trajectories failed for task {task_config.get('id')}.") + return None + return merged diff --git a/examples/osworld/workflow/remote_desktop_env.py b/examples/osworld/workflow/remote_desktop_env.py new file mode 100644 index 0000000000..2f87ba1689 --- /dev/null +++ b/examples/osworld/workflow/remote_desktop_env.py @@ -0,0 +1,170 @@ +"""HTTP proxy that mimics OSWorld's ``DesktopEnv`` on the training side. + +The training container cannot run docker itself, so a companion +``remote_server.py`` runs on a machine that *does* have docker and exposes +``DesktopEnv`` operations over a small JSON API. Instances of this class are +drop-in replacements for ``DesktopEnv`` from the workflow's point of view — +same blocking method names, same return shapes. + +All payloads are JSON; screenshots travel as base64-encoded PNG strings under +``screenshot_b64`` and are decoded back to ``bytes`` on the client side so the +rest of the pipeline (which expects raw PNG bytes) stays unchanged. +""" + +from __future__ import annotations + +import base64 +from typing import Any + +import requests + +from areal.utils import logging + +logger = logging.getLogger("RemoteDesktopEnv") + + +class RemoteDesktopEnvError(RuntimeError): + """Raised when the remote server reports an error or is unreachable.""" + + +def _decode_obs(payload: dict[str, Any]) -> dict[str, Any]: + """Turn the server's JSON obs back into the in-process obs shape.""" + screenshot_b64 = payload.get("screenshot_b64") + obs: dict[str, Any] = { + "screenshot": base64.b64decode(screenshot_b64) if screenshot_b64 else b"", + "accessibility_tree": payload.get("accessibility_tree"), + "terminal": payload.get("terminal"), + "instruction": payload.get("instruction"), + } + return obs + + +class RemoteDesktopEnv: + """Client-side stand-in for ``desktop_env.desktop_env.DesktopEnv``. + + Parameters + ---------- + server_url + Base URL of the remote ``remote_server.py`` (e.g. ``http://10.0.0.5:8000``). + provider_name, path_to_vm, action_space, cache_dir, screen_size, headless, + os_type, require_a11y_tree + Forwarded verbatim to the server's ``DesktopEnv(...)`` constructor. + request_timeout_secs + Upper bound on every HTTP call. Reset/evaluate can take many minutes, + so default generously. + """ + + def __init__( + self, + server_url: str, + *, + provider_name: str = "docker", + path_to_vm: str | None = None, + action_space: str = "pyautogui", + cache_dir: str = "cache", + screen_size: tuple[int, int] = (1920, 1080), + headless: bool = True, + os_type: str = "Ubuntu", + require_a11y_tree: bool = False, + request_timeout_secs: float = 1800.0, + ) -> None: + self.server_url = server_url.rstrip("/") + self.request_timeout_secs = request_timeout_secs + self._last_obs: dict[str, Any] | None = None + + resp = self._post( + "/envs", + { + "provider_name": provider_name, + "path_to_vm": path_to_vm, + "action_space": action_space, + "cache_dir": cache_dir, + "screen_size": list(screen_size), + "headless": headless, + "os_type": os_type, + "require_a11y_tree": require_a11y_tree, + }, + ) + self.session_id: str = resp["session_id"] + logger.info( + f"Opened remote DesktopEnv session {self.session_id} at {server_url}" + ) + + def reset( + self, task_config: dict[str, Any] | None = None, **_: Any + ) -> dict[str, Any]: + payload = self._post( + f"/envs/{self.session_id}/reset", + {"task_config": task_config}, + ) + self._last_obs = _decode_obs(payload["obs"]) + return self._last_obs + + def _get_obs(self) -> dict[str, Any]: + # The workflow calls `env._get_obs()` right after reset expecting the + # cached observation. We return whatever the server sent last, and + # only go back over the wire if nothing has been cached yet. + if self._last_obs is not None: + return self._last_obs + payload = self._get(f"/envs/{self.session_id}/obs") + self._last_obs = _decode_obs(payload["obs"]) + return self._last_obs + + def step( + self, action: str, pause: float = 0.0 + ) -> tuple[dict[str, Any], float, bool, dict[str, Any]]: + payload = self._post( + f"/envs/{self.session_id}/step", + {"action": action, "pause": pause}, + ) + self._last_obs = _decode_obs(payload["obs"]) + return ( + self._last_obs, + float(payload.get("reward", 0.0)), + bool(payload.get("done", False)), + payload.get("info") or {}, + ) + + def evaluate(self) -> float: + payload = self._post(f"/envs/{self.session_id}/evaluate", {}) + return float(payload.get("reward", 0.0)) + + def close(self) -> None: + try: + self._post(f"/envs/{self.session_id}/close", {}) + except RemoteDesktopEnvError as e: + logger.warning(f"Remote close failed (session may already be gone): {e}") + self._last_obs = None + + # ------------------------------------------------------------------ HTTP + + def _post(self, path: str, body: dict[str, Any]) -> dict[str, Any]: + url = self.server_url + path + try: + r = requests.post(url, json=body, timeout=self.request_timeout_secs) + except requests.RequestException as e: + raise RemoteDesktopEnvError(f"POST {url} failed: {e!r}") from e + return self._unwrap(r, url) + + def _get(self, path: str) -> dict[str, Any]: + url = self.server_url + path + try: + r = requests.get(url, timeout=self.request_timeout_secs) + except requests.RequestException as e: + raise RemoteDesktopEnvError(f"GET {url} failed: {e!r}") from e + return self._unwrap(r, url) + + @staticmethod + def _unwrap(response: requests.Response, url: str) -> dict[str, Any]: + if response.status_code >= 400: + try: + msg = response.json().get("error", response.text) + except ValueError: + msg = response.text + raise RemoteDesktopEnvError(f"{url} returned {response.status_code}: {msg}") + try: + return response.json() + except ValueError as e: + raise RemoteDesktopEnvError( + f"{url} returned non-JSON body: {response.text[:200]}" + ) from e